110 lines
3.3 KiB
Python
110 lines
3.3 KiB
Python
import pytest
|
|
import numpy as np
|
|
from agents.offline_agent import QLearningAgent
|
|
|
|
@pytest.fixture
|
|
def agent_params():
|
|
return {
|
|
'learning_rate': 0.1,
|
|
'discount_factor': 0.99,
|
|
'epsilon': 0.0 # Deterministic for testing
|
|
}
|
|
|
|
@pytest.fixture
|
|
def agent(agent_params):
|
|
return QLearningAgent(agent_params, state_size=4, action_size=3)
|
|
|
|
def test_agent_initialization(agent):
|
|
"""Test agent initialization"""
|
|
assert agent.state_size == 4
|
|
assert agent.action_size == 3
|
|
assert agent.lr == 0.1
|
|
assert agent.gamma == 0.99
|
|
assert agent.q_table.shape == (4, 3)
|
|
assert np.all(agent.q_table == 0) # Q-table should be initialized to zeros
|
|
|
|
def test_get_action_with_mask(agent):
|
|
"""Test action selection with action masking"""
|
|
# Set up known Q-values
|
|
agent.q_table[0] = np.array([1.0, 2.0, 3.0])
|
|
|
|
# Test without mask
|
|
action = agent.get_action(0)
|
|
assert action == 2 # Should choose highest Q-value
|
|
|
|
# Test with mask
|
|
action_mask = np.array([1, 1, 0]) # Mask out the highest value
|
|
action = agent.get_action(0, action_mask)
|
|
assert action == 1 # Should choose second highest value
|
|
|
|
def test_episode_tracking(agent):
|
|
"""Test action tracking within an episode"""
|
|
agent.q_table[0] = np.array([1.0, 2.0, 3.0])
|
|
|
|
# Take all possible actions
|
|
actions = []
|
|
for _ in range(agent.action_size):
|
|
action = agent.get_action(0)
|
|
assert action is not None
|
|
actions.append(action)
|
|
|
|
# Verify all actions were unique
|
|
assert len(set(actions)) == agent.action_size
|
|
|
|
# Next action should be None as all actions are taken
|
|
assert agent.get_action(0) is None
|
|
|
|
def test_episode_reset(agent):
|
|
"""Test episode reset functionality"""
|
|
agent.q_table[0] = np.array([1.0, 2.0, 3.0])
|
|
|
|
# Take some actions
|
|
agent.get_action(0)
|
|
agent.get_action(0)
|
|
|
|
# Reset episode
|
|
agent.reset_episode()
|
|
|
|
# Should be able to take the best action again
|
|
action = agent.get_action(0)
|
|
assert action == 2 # Highest Q-value action
|
|
|
|
def test_learning(agent):
|
|
"""Test Q-learning update"""
|
|
# Create a simple batch
|
|
batch = {
|
|
'observations': np.array([0]),
|
|
'actions': np.array([1]),
|
|
'rewards': np.array([1.0]),
|
|
'next_observations': np.array([1]),
|
|
'terminals': np.array([False])
|
|
}
|
|
|
|
# Set up known Q-values
|
|
agent.q_table[1] = np.array([0.5, 0.8, 0.3]) # Next state Q-values
|
|
old_value = agent.q_table[0, 1]
|
|
|
|
# Perform learning update
|
|
agent.learn(batch)
|
|
|
|
# Check if Q-value was updated correctly
|
|
# Q(s,a) = Q(s,a) + lr * (R + gamma * max(Q(s')) - Q(s,a))
|
|
expected_value = old_value + agent.lr * (1.0 + agent.gamma * 0.8 - old_value)
|
|
assert np.isclose(agent.q_table[0, 1], expected_value)
|
|
|
|
def test_save_and_load(agent, tmp_path):
|
|
"""Test model saving and loading"""
|
|
# Set some Q-values
|
|
agent.q_table[0] = np.array([1.0, 2.0, 3.0])
|
|
|
|
# Save model
|
|
save_path = tmp_path / "q_table.npy"
|
|
agent.save_model(save_path)
|
|
|
|
# Create new agent and load model
|
|
new_agent = QLearningAgent(agent_params(), state_size=4, action_size=3)
|
|
new_agent.load_q_table(save_path)
|
|
|
|
# Check if Q-values match
|
|
assert np.all(agent.q_table == new_agent.q_table)
|