KT_Q_Table/tests/test_qlearning_agent.py

110 lines
3.3 KiB
Python

import pytest
import numpy as np
from agents.offline_agent import QLearningAgent
@pytest.fixture
def agent_params():
return {
'learning_rate': 0.1,
'discount_factor': 0.99,
'epsilon': 0.0 # Deterministic for testing
}
@pytest.fixture
def agent(agent_params):
return QLearningAgent(agent_params, state_size=4, action_size=3)
def test_agent_initialization(agent):
"""Test agent initialization"""
assert agent.state_size == 4
assert agent.action_size == 3
assert agent.lr == 0.1
assert agent.gamma == 0.99
assert agent.q_table.shape == (4, 3)
assert np.all(agent.q_table == 0) # Q-table should be initialized to zeros
def test_get_action_with_mask(agent):
"""Test action selection with action masking"""
# Set up known Q-values
agent.q_table[0] = np.array([1.0, 2.0, 3.0])
# Test without mask
action = agent.get_action(0)
assert action == 2 # Should choose highest Q-value
# Test with mask
action_mask = np.array([1, 1, 0]) # Mask out the highest value
action = agent.get_action(0, action_mask)
assert action == 1 # Should choose second highest value
def test_episode_tracking(agent):
"""Test action tracking within an episode"""
agent.q_table[0] = np.array([1.0, 2.0, 3.0])
# Take all possible actions
actions = []
for _ in range(agent.action_size):
action = agent.get_action(0)
assert action is not None
actions.append(action)
# Verify all actions were unique
assert len(set(actions)) == agent.action_size
# Next action should be None as all actions are taken
assert agent.get_action(0) is None
def test_episode_reset(agent):
"""Test episode reset functionality"""
agent.q_table[0] = np.array([1.0, 2.0, 3.0])
# Take some actions
agent.get_action(0)
agent.get_action(0)
# Reset episode
agent.reset_episode()
# Should be able to take the best action again
action = agent.get_action(0)
assert action == 2 # Highest Q-value action
def test_learning(agent):
"""Test Q-learning update"""
# Create a simple batch
batch = {
'observations': np.array([0]),
'actions': np.array([1]),
'rewards': np.array([1.0]),
'next_observations': np.array([1]),
'terminals': np.array([False])
}
# Set up known Q-values
agent.q_table[1] = np.array([0.5, 0.8, 0.3]) # Next state Q-values
old_value = agent.q_table[0, 1]
# Perform learning update
agent.learn(batch)
# Check if Q-value was updated correctly
# Q(s,a) = Q(s,a) + lr * (R + gamma * max(Q(s')) - Q(s,a))
expected_value = old_value + agent.lr * (1.0 + agent.gamma * 0.8 - old_value)
assert np.isclose(agent.q_table[0, 1], expected_value)
def test_save_and_load(agent, tmp_path):
"""Test model saving and loading"""
# Set some Q-values
agent.q_table[0] = np.array([1.0, 2.0, 3.0])
# Save model
save_path = tmp_path / "q_table.npy"
agent.save_model(save_path)
# Create new agent and load model
new_agent = QLearningAgent(agent_params(), state_size=4, action_size=3)
new_agent.load_q_table(save_path)
# Check if Q-values match
assert np.all(agent.q_table == new_agent.q_table)