import pytest import numpy as np from agents.policy import EpisodePolicy @pytest.fixture def policy(): return EpisodePolicy(epsilon=0.0) # epsilon=0 for deterministic testing @pytest.fixture def q_values(): return np.array([1.0, 2.0, 3.0, 0.5]) # Example Q-values def test_select_action_without_mask(policy, q_values): """Test action selection without any mask""" # First action should be the highest Q-value action = policy.select_action(q_values) assert action == 2 # index 2 has highest value (3.0) # Second action should exclude the previous one action = policy.select_action(q_values) assert action == 1 # index 1 has second highest value (2.0) def test_select_action_with_mask(policy, q_values): """Test action selection with explicit action mask""" action_mask = np.array([1, 1, 0, 1]) # Mask out action 2 action = policy.select_action(q_values, action_mask) assert action == 1 # index 1 has highest value among unmasked actions def test_episode_tracking(policy, q_values): """Test if actions are properly tracked within an episode""" # Take some actions policy.select_action(q_values) policy.select_action(q_values) policy.select_action(q_values) # Check if actions were tracked assert len(policy.episode_actions) == 3 def test_reset_episode(policy, q_values): """Test episode reset functionality""" # Take some actions policy.select_action(q_values) policy.select_action(q_values) # Reset episode policy.reset_episode() # Check if actions were cleared assert len(policy.episode_actions) == 0 def test_all_actions_taken(policy, q_values): """Test behavior when all actions have been taken""" # Take all possible actions actions_taken = [] for _ in range(len(q_values)): action = policy.select_action(q_values) assert action is not None actions_taken.append(action) # Verify all actions were unique assert len(set(actions_taken)) == len(q_values) # Try to take one more action action = policy.select_action(q_values) assert action is None # Should return None when no actions are available # Check if episode was automatically reset assert len(policy.episode_actions) == 0 @pytest.mark.parametrize("epsilon,min_unique_actions", [ (0.0, 1), # Deterministic - should always take best action first (1.0, 3) # Random - should see multiple different actions ]) def test_epsilon_greedy(q_values, epsilon, min_unique_actions): """Test epsilon-greedy behavior with different epsilon values""" policy = EpisodePolicy(epsilon=epsilon) actions = set() # Take multiple actions and verify they're appropriate for the epsilon value for _ in range(50): # Run multiple times to ensure statistical significance action = policy.select_action(q_values) if action is not None: actions.add(action) policy.reset_episode() assert len(actions) >= min_unique_actions