import numpy as np import random import os class QLearningAgent: def __init__(self, agent_params, state_size, action_size): self.state_size = state_size self.action_size = action_size self.lr = agent_params['learning_rate'] self.gamma = agent_params['discount_factor'] self.epsilon = agent_params.get('epsilon', 0.1) # Add epsilon for exploration/evaluation self.q_table = np.zeros((state_size, action_size)) def get_action(self, state): if random.uniform(0, 1) < self.epsilon: return random.randint(0, self.action_size - 1) else: return np.argmax(self.q_table[state, :]) def learn(self, batch): for state, action, reward, next_state, terminated in zip( batch['observations'], batch['actions'], batch['rewards'], batch['next_observations'], batch['terminals'] ): old_value = self.q_table[state, action] next_max = np.max(self.q_table[next_state, :]) new_value = old_value + self.lr * (reward + self.gamma * next_max * (1 - terminated) - old_value) self.q_table[state, action] = new_value def save_model(self, path): np.save(path, self.q_table) print(f"Q-Table saved to {path}") def load_q_table(self, file_path): if os.path.exists(file_path): self.q_table = np.load(file_path) print(f"Q-Table loaded from {file_path}") else: print(f"Error: No Q-Table found at {file_path}")