59 lines
1.9 KiB
Python
59 lines
1.9 KiB
Python
import numpy as np
|
|
import os
|
|
from .policy import EpisodePolicy
|
|
|
|
|
|
class QLearningAgent:
|
|
def __init__(self, agent_params, state_size, action_size):
|
|
self.state_size = state_size
|
|
self.action_size = action_size
|
|
self.lr = agent_params["learning_rate"]
|
|
self.gamma = agent_params["discount_factor"]
|
|
|
|
# Initialize policy
|
|
self.episode_policy = EpisodePolicy(epsilon=agent_params.get("epsilon", 0.1))
|
|
self.q_table = np.zeros((state_size, action_size))
|
|
|
|
def get_action(self, state, action_mask=None):
|
|
q_values = self.q_table[state, :]
|
|
if action_mask is None:
|
|
action_mask = self.episode_policy.get_action_mask()
|
|
action = self.episode_policy.select_action(q_values, action_mask)
|
|
|
|
if action is None:
|
|
# All actions have been taken in this episode
|
|
return None
|
|
|
|
return action
|
|
|
|
def learn(self, batch):
|
|
for state, action, reward, next_state, terminated in zip(
|
|
batch["observations"],
|
|
batch["actions"],
|
|
batch["rewards"],
|
|
batch["next_observations"],
|
|
batch["terminals"],
|
|
):
|
|
old_value = self.q_table[state, action]
|
|
next_max = np.max(self.q_table[next_state, :])
|
|
|
|
new_value = old_value + self.lr * (
|
|
reward + self.gamma * next_max * (1 - terminated) - old_value
|
|
)
|
|
self.q_table[state, action] = new_value
|
|
|
|
def save_model(self, path):
|
|
np.save(path, self.q_table)
|
|
print(f"Q-Table saved to {path}")
|
|
|
|
def load_q_table(self, file_path):
|
|
if os.path.exists(file_path):
|
|
self.q_table = np.load(file_path)
|
|
print(f"Q-Table loaded from {file_path}")
|
|
else:
|
|
print(f"Error: No Q-Table found at {file_path}")
|
|
|
|
def reset_episode(self):
|
|
"""Reset agent for new episode"""
|
|
self.policy.reset_episode()
|