KT_Q_Table/tests/test_episode_policy.py

87 lines
3.0 KiB
Python

import pytest
import numpy as np
from agents.policy import EpisodePolicy
@pytest.fixture
def policy():
return EpisodePolicy(epsilon=0.0) # epsilon=0 for deterministic testing
@pytest.fixture
def q_values():
return np.array([1.0, 2.0, 3.0, 0.5]) # Example Q-values
def test_select_action_without_mask(policy, q_values):
"""Test action selection without any mask"""
# First action should be the highest Q-value
action = policy.select_action(q_values)
assert action == 2 # index 2 has highest value (3.0)
# Second action should exclude the previous one
action = policy.select_action(q_values)
assert action == 1 # index 1 has second highest value (2.0)
def test_select_action_with_mask(policy, q_values):
"""Test action selection with explicit action mask"""
action_mask = np.array([1, 1, 0, 1]) # Mask out action 2
action = policy.select_action(q_values, action_mask)
assert action == 1 # index 1 has highest value among unmasked actions
def test_episode_tracking(policy, q_values):
"""Test if actions are properly tracked within an episode"""
# Take some actions
policy.select_action(q_values)
policy.select_action(q_values)
policy.select_action(q_values)
# Check if actions were tracked
assert len(policy.episode_actions) == 3
def test_reset_episode(policy, q_values):
"""Test episode reset functionality"""
# Take some actions
policy.select_action(q_values)
policy.select_action(q_values)
# Reset episode
policy.reset_episode()
# Check if actions were cleared
assert len(policy.episode_actions) == 0
def test_all_actions_taken(policy, q_values):
"""Test behavior when all actions have been taken"""
# Take all possible actions
actions_taken = []
for _ in range(len(q_values)):
action = policy.select_action(q_values)
assert action is not None
actions_taken.append(action)
# Verify all actions were unique
assert len(set(actions_taken)) == len(q_values)
# Try to take one more action
action = policy.select_action(q_values)
assert action is None # Should return None when no actions are available
# Check if episode was automatically reset
assert len(policy.episode_actions) == 0
@pytest.mark.parametrize("epsilon,min_unique_actions", [
(0.0, 1), # Deterministic - should always take best action first
(1.0, 3) # Random - should see multiple different actions
])
def test_epsilon_greedy(q_values, epsilon, min_unique_actions):
"""Test epsilon-greedy behavior with different epsilon values"""
policy = EpisodePolicy(epsilon=epsilon)
actions = set()
# Take multiple actions and verify they're appropriate for the epsilon value
for _ in range(50): # Run multiple times to ensure statistical significance
action = policy.select_action(q_values)
if action is not None:
actions.add(action)
policy.reset_episode()
assert len(actions) >= min_unique_actions