87 lines
3.0 KiB
Python
87 lines
3.0 KiB
Python
import pytest
|
|
import numpy as np
|
|
from agents.policy import EpisodePolicy
|
|
|
|
@pytest.fixture
|
|
def policy():
|
|
return EpisodePolicy(epsilon=0.0) # epsilon=0 for deterministic testing
|
|
|
|
@pytest.fixture
|
|
def q_values():
|
|
return np.array([1.0, 2.0, 3.0, 0.5]) # Example Q-values
|
|
|
|
def test_select_action_without_mask(policy, q_values):
|
|
"""Test action selection without any mask"""
|
|
# First action should be the highest Q-value
|
|
action = policy.select_action(q_values)
|
|
assert action == 2 # index 2 has highest value (3.0)
|
|
|
|
# Second action should exclude the previous one
|
|
action = policy.select_action(q_values)
|
|
assert action == 1 # index 1 has second highest value (2.0)
|
|
|
|
def test_select_action_with_mask(policy, q_values):
|
|
"""Test action selection with explicit action mask"""
|
|
action_mask = np.array([1, 1, 0, 1]) # Mask out action 2
|
|
action = policy.select_action(q_values, action_mask)
|
|
assert action == 1 # index 1 has highest value among unmasked actions
|
|
|
|
def test_episode_tracking(policy, q_values):
|
|
"""Test if actions are properly tracked within an episode"""
|
|
# Take some actions
|
|
policy.select_action(q_values)
|
|
policy.select_action(q_values)
|
|
policy.select_action(q_values)
|
|
|
|
# Check if actions were tracked
|
|
assert len(policy.episode_actions) == 3
|
|
|
|
def test_reset_episode(policy, q_values):
|
|
"""Test episode reset functionality"""
|
|
# Take some actions
|
|
policy.select_action(q_values)
|
|
policy.select_action(q_values)
|
|
|
|
# Reset episode
|
|
policy.reset_episode()
|
|
|
|
# Check if actions were cleared
|
|
assert len(policy.episode_actions) == 0
|
|
|
|
def test_all_actions_taken(policy, q_values):
|
|
"""Test behavior when all actions have been taken"""
|
|
# Take all possible actions
|
|
actions_taken = []
|
|
for _ in range(len(q_values)):
|
|
action = policy.select_action(q_values)
|
|
assert action is not None
|
|
actions_taken.append(action)
|
|
|
|
# Verify all actions were unique
|
|
assert len(set(actions_taken)) == len(q_values)
|
|
|
|
# Try to take one more action
|
|
action = policy.select_action(q_values)
|
|
assert action is None # Should return None when no actions are available
|
|
|
|
# Check if episode was automatically reset
|
|
assert len(policy.episode_actions) == 0
|
|
|
|
@pytest.mark.parametrize("epsilon,min_unique_actions", [
|
|
(0.0, 1), # Deterministic - should always take best action first
|
|
(1.0, 3) # Random - should see multiple different actions
|
|
])
|
|
def test_epsilon_greedy(q_values, epsilon, min_unique_actions):
|
|
"""Test epsilon-greedy behavior with different epsilon values"""
|
|
policy = EpisodePolicy(epsilon=epsilon)
|
|
actions = set()
|
|
|
|
# Take multiple actions and verify they're appropriate for the epsilon value
|
|
for _ in range(50): # Run multiple times to ensure statistical significance
|
|
action = policy.select_action(q_values)
|
|
if action is not None:
|
|
actions.add(action)
|
|
policy.reset_episode()
|
|
|
|
assert len(actions) >= min_unique_actions
|