41 lines
1.5 KiB
Python
41 lines
1.5 KiB
Python
import numpy as np
|
|
import random
|
|
import os
|
|
|
|
class QLearningAgent:
|
|
def __init__(self, agent_params, state_size, action_size):
|
|
self.state_size = state_size
|
|
self.action_size = action_size
|
|
self.lr = agent_params['learning_rate']
|
|
self.gamma = agent_params['discount_factor']
|
|
self.epsilon = agent_params.get('epsilon', 0.1) # Add epsilon for exploration/evaluation
|
|
|
|
self.q_table = np.zeros((state_size, action_size))
|
|
|
|
def get_action(self, state):
|
|
if random.uniform(0, 1) < self.epsilon:
|
|
return random.randint(0, self.action_size - 1)
|
|
else:
|
|
return np.argmax(self.q_table[state, :])
|
|
|
|
def learn(self, batch):
|
|
for state, action, reward, next_state, terminated in zip(
|
|
batch['observations'], batch['actions'], batch['rewards'], batch['next_observations'], batch['terminals']
|
|
):
|
|
old_value = self.q_table[state, action]
|
|
next_max = np.max(self.q_table[next_state, :])
|
|
|
|
new_value = old_value + self.lr * (reward + self.gamma * next_max * (1 - terminated) - old_value)
|
|
self.q_table[state, action] = new_value
|
|
|
|
def save_model(self, path):
|
|
np.save(path, self.q_table)
|
|
print(f"Q-Table saved to {path}")
|
|
|
|
def load_q_table(self, file_path):
|
|
if os.path.exists(file_path):
|
|
self.q_table = np.load(file_path)
|
|
print(f"Q-Table loaded from {file_path}")
|
|
else:
|
|
print(f"Error: No Q-Table found at {file_path}")
|