KT_Q_Table/agents/offline_agent.py

59 lines
1.9 KiB
Python

import numpy as np
import os
from .policy import EpisodePolicy
class QLearningAgent:
def __init__(self, agent_params, state_size, action_size):
self.state_size = state_size
self.action_size = action_size
self.lr = agent_params["learning_rate"]
self.gamma = agent_params["discount_factor"]
# Initialize policy
self.episode_policy = EpisodePolicy(epsilon=agent_params.get("epsilon", 0.1))
self.q_table = np.zeros((state_size, action_size))
def get_action(self, state, action_mask=None):
q_values = self.q_table[state, :]
if action_mask is None:
action_mask = self.episode_policy.get_action_mask()
action = self.episode_policy.select_action(q_values, action_mask)
if action is None:
# All actions have been taken in this episode
return None
return action
def learn(self, batch):
for state, action, reward, next_state, terminated in zip(
batch["observations"],
batch["actions"],
batch["rewards"],
batch["next_observations"],
batch["terminals"],
):
old_value = self.q_table[state, action]
next_max = np.max(self.q_table[next_state, :])
new_value = old_value + self.lr * (
reward + self.gamma * next_max * (1 - terminated) - old_value
)
self.q_table[state, action] = new_value
def save_model(self, path):
np.save(path, self.q_table)
print(f"Q-Table saved to {path}")
def load_q_table(self, file_path):
if os.path.exists(file_path):
self.q_table = np.load(file_path)
print(f"Q-Table loaded from {file_path}")
else:
print(f"Error: No Q-Table found at {file_path}")
def reset_episode(self):
"""Reset agent for new episode"""
self.policy.reset_episode()