KT_Q_Table/agents/policy.py

76 lines
2.4 KiB
Python

from abc import ABC, abstractmethod
import numpy as np
import random
class Policy(ABC):
@abstractmethod
def select_action(self, q_values, action_mask=None):
pass
class EpisodePolicy(Policy):
def __init__(self, epsilon=0.1):
self.epsilon = epsilon
self.episode_actions = set() # Track actions taken in current episode
self.current_idx = 0 # For sequential action selection
def get_action_mask(self):
# Create a mask with all actions available
action_mask = np.ones(9) # Assuming 9 actions
# Mask already taken actions
for action in self.episode_actions:
action_mask[action] = 0
return action_mask
def select_action(self, q_values, action_mask=None):
# Create default mask if none provided
if action_mask is None:
action_mask = self.get_action_mask()
# Apply action mask
masked_q_values = q_values * action_mask
# Check for available actions
valid_actions = np.where(action_mask)[0]
if len(valid_actions) == 0:
self.reset_episode()
return None
# Get Q-values for valid actions
masked_q_values = q_values * action_mask
max_q = np.max(masked_q_values)
# When all Q-values are effectively zero (very small), select actions sequentially
if np.allclose(masked_q_values[action_mask > 0], 0, atol=1e-10):
# Find the first available action in sequence
while self.current_idx in self.episode_actions and self.current_idx < len(
q_values
):
self.current_idx += 1
if self.current_idx >= len(q_values):
self.reset_episode()
return None
action = self.current_idx
self.episode_actions.add(action)
return action
# Epsilon-greedy with masking for non-zero Q-values
if random.uniform(0, 1) < self.epsilon:
action = np.random.choice(valid_actions)
else:
max_actions = np.where(masked_q_values == max_q)[0]
action = np.random.choice(max_actions)
self.episode_actions.add(action)
return action
def reset_episode(self):
"""Reset for new episode"""
self.episode_actions.clear()
self.current_idx = 0 # Reset sequential index