KT_Q_Table/agents/offline_agent.py

41 lines
1.5 KiB
Python

import numpy as np
import random
import os
class QLearningAgent:
def __init__(self, agent_params, state_size, action_size):
self.state_size = state_size
self.action_size = action_size
self.lr = agent_params['learning_rate']
self.gamma = agent_params['discount_factor']
self.epsilon = agent_params.get('epsilon', 0.1) # Add epsilon for exploration/evaluation
self.q_table = np.zeros((state_size, action_size))
def get_action(self, state):
if random.uniform(0, 1) < self.epsilon:
return random.randint(0, self.action_size - 1)
else:
return np.argmax(self.q_table[state, :])
def learn(self, batch):
for state, action, reward, next_state, terminated in zip(
batch['observations'], batch['actions'], batch['rewards'], batch['next_observations'], batch['terminals']
):
old_value = self.q_table[state, action]
next_max = np.max(self.q_table[next_state, :])
new_value = old_value + self.lr * (reward + self.gamma * next_max * (1 - terminated) - old_value)
self.q_table[state, action] = new_value
def save_model(self, path):
np.save(path, self.q_table)
print(f"Q-Table saved to {path}")
def load_q_table(self, file_path):
if os.path.exists(file_path):
self.q_table = np.load(file_path)
print(f"Q-Table loaded from {file_path}")
else:
print(f"Error: No Q-Table found at {file_path}")