from abc import ABC, abstractmethod import numpy as np import random class Policy(ABC): @abstractmethod def select_action(self, q_values, action_mask=None): pass class EpisodePolicy(Policy): def __init__(self, epsilon=0.1): self.epsilon = epsilon self.episode_actions = set() # Track actions taken in current episode self.current_idx = 0 # For sequential action selection def get_action_mask(self): # Create a mask with all actions available action_mask = np.ones(9) # Assuming 9 actions # Mask already taken actions for action in self.episode_actions: action_mask[action] = 0 return action_mask def select_action(self, q_values, action_mask=None): # Create default mask if none provided if action_mask is None: action_mask = self.get_action_mask() # Apply action mask masked_q_values = q_values * action_mask # Check for available actions valid_actions = np.where(action_mask)[0] if len(valid_actions) == 0: self.reset_episode() return None # Get Q-values for valid actions masked_q_values = q_values * action_mask max_q = np.max(masked_q_values) # When all Q-values are effectively zero (very small), select actions sequentially if np.allclose(masked_q_values[action_mask > 0], 0, atol=1e-10): # Find the first available action in sequence while self.current_idx in self.episode_actions and self.current_idx < len( q_values ): self.current_idx += 1 if self.current_idx >= len(q_values): self.reset_episode() return None action = self.current_idx self.episode_actions.add(action) return action # Epsilon-greedy with masking for non-zero Q-values if random.uniform(0, 1) < self.epsilon: action = np.random.choice(valid_actions) else: max_actions = np.where(masked_q_values == max_q)[0] action = np.random.choice(max_actions) self.episode_actions.add(action) return action def reset_episode(self): """Reset for new episode""" self.episode_actions.clear() self.current_idx = 0 # Reset sequential index