76 lines
2.4 KiB
Python
76 lines
2.4 KiB
Python
from abc import ABC, abstractmethod
|
|
import numpy as np
|
|
import random
|
|
|
|
|
|
class Policy(ABC):
|
|
@abstractmethod
|
|
def select_action(self, q_values, action_mask=None):
|
|
pass
|
|
|
|
|
|
class EpisodePolicy(Policy):
|
|
def __init__(self, epsilon=0.1):
|
|
self.epsilon = epsilon
|
|
self.episode_actions = set() # Track actions taken in current episode
|
|
self.current_idx = 0 # For sequential action selection
|
|
|
|
def get_action_mask(self):
|
|
# Create a mask with all actions available
|
|
action_mask = np.ones(9) # Assuming 9 actions
|
|
|
|
# Mask already taken actions
|
|
for action in self.episode_actions:
|
|
action_mask[action] = 0
|
|
|
|
return action_mask
|
|
|
|
def select_action(self, q_values, action_mask=None):
|
|
# Create default mask if none provided
|
|
if action_mask is None:
|
|
action_mask = self.get_action_mask()
|
|
|
|
# Apply action mask
|
|
masked_q_values = q_values * action_mask
|
|
|
|
# Check for available actions
|
|
valid_actions = np.where(action_mask)[0]
|
|
if len(valid_actions) == 0:
|
|
self.reset_episode()
|
|
return None
|
|
|
|
# Get Q-values for valid actions
|
|
masked_q_values = q_values * action_mask
|
|
max_q = np.max(masked_q_values)
|
|
|
|
# When all Q-values are effectively zero (very small), select actions sequentially
|
|
if np.allclose(masked_q_values[action_mask > 0], 0, atol=1e-10):
|
|
# Find the first available action in sequence
|
|
while self.current_idx in self.episode_actions and self.current_idx < len(
|
|
q_values
|
|
):
|
|
self.current_idx += 1
|
|
|
|
if self.current_idx >= len(q_values):
|
|
self.reset_episode()
|
|
return None
|
|
|
|
action = self.current_idx
|
|
self.episode_actions.add(action)
|
|
return action
|
|
|
|
# Epsilon-greedy with masking for non-zero Q-values
|
|
if random.uniform(0, 1) < self.epsilon:
|
|
action = np.random.choice(valid_actions)
|
|
else:
|
|
max_actions = np.where(masked_q_values == max_q)[0]
|
|
action = np.random.choice(max_actions)
|
|
|
|
self.episode_actions.add(action)
|
|
return action
|
|
|
|
def reset_episode(self):
|
|
"""Reset for new episode"""
|
|
self.episode_actions.clear()
|
|
self.current_idx = 0 # Reset sequential index
|