import numpy as np import os from .policy import EpisodePolicy class QLearningAgent: def __init__(self, agent_params, state_size, action_size): self.state_size = state_size self.action_size = action_size self.lr = agent_params["learning_rate"] self.gamma = agent_params["discount_factor"] # Initialize policy self.episode_policy = EpisodePolicy(epsilon=agent_params.get("epsilon", 0.1)) self.q_table = np.zeros((state_size, action_size)) def get_action(self, state, action_mask=None): q_values = self.q_table[state, :] if action_mask is None: action_mask = self.episode_policy.get_action_mask() action = self.episode_policy.select_action(q_values, action_mask) if action is None: # All actions have been taken in this episode return None return action def learn(self, batch): for state, action, reward, next_state, terminated in zip( batch["observations"], batch["actions"], batch["rewards"], batch["next_observations"], batch["terminals"], ): old_value = self.q_table[state, action] next_max = np.max(self.q_table[next_state, :]) new_value = old_value + self.lr * ( reward + self.gamma * next_max * (1 - terminated) - old_value ) self.q_table[state, action] = new_value def save_model(self, path): np.save(path, self.q_table) print(f"Q-Table saved to {path}") def load_q_table(self, file_path): if os.path.exists(file_path): self.q_table = np.load(file_path) print(f"Q-Table loaded from {file_path}") else: print(f"Error: No Q-Table found at {file_path}") def reset_episode(self): """Reset agent for new episode""" self.policy.reset_episode()