feat: Enhance action selection and Q-table initialization
Key changes: - Add random Q-table initialization with small values (0-0.1) - Implement action masking mechanism to prevent repeated actions - Add debug information to show available actions and Q-values - Add epsilon-greedy selection with action masking - Add tests for policy and agent behaviormain
parent
0ade7cec61
commit
6de135680e
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -1,31 +1,45 @@
|
|||
import numpy as np
|
||||
import random
|
||||
import os
|
||||
from .policy import EpisodePolicy
|
||||
|
||||
|
||||
class QLearningAgent:
|
||||
def __init__(self, agent_params, state_size, action_size):
|
||||
self.state_size = state_size
|
||||
self.action_size = action_size
|
||||
self.lr = agent_params['learning_rate']
|
||||
self.gamma = agent_params['discount_factor']
|
||||
self.epsilon = agent_params.get('epsilon', 0.1) # Add epsilon for exploration/evaluation
|
||||
self.lr = agent_params["learning_rate"]
|
||||
self.gamma = agent_params["discount_factor"]
|
||||
|
||||
# Initialize policy
|
||||
self.episode_policy = EpisodePolicy(epsilon=agent_params.get("epsilon", 0.1))
|
||||
self.q_table = np.zeros((state_size, action_size))
|
||||
|
||||
def get_action(self, state):
|
||||
if random.uniform(0, 1) < self.epsilon:
|
||||
return random.randint(0, self.action_size - 1)
|
||||
else:
|
||||
return np.argmax(self.q_table[state, :])
|
||||
def get_action(self, state, action_mask=None):
|
||||
q_values = self.q_table[state, :]
|
||||
if action_mask is None:
|
||||
action_mask = self.episode_policy.get_action_mask()
|
||||
action = self.episode_policy.select_action(q_values, action_mask)
|
||||
|
||||
if action is None:
|
||||
# All actions have been taken in this episode
|
||||
return None
|
||||
|
||||
return action
|
||||
|
||||
def learn(self, batch):
|
||||
for state, action, reward, next_state, terminated in zip(
|
||||
batch['observations'], batch['actions'], batch['rewards'], batch['next_observations'], batch['terminals']
|
||||
batch["observations"],
|
||||
batch["actions"],
|
||||
batch["rewards"],
|
||||
batch["next_observations"],
|
||||
batch["terminals"],
|
||||
):
|
||||
old_value = self.q_table[state, action]
|
||||
next_max = np.max(self.q_table[next_state, :])
|
||||
|
||||
new_value = old_value + self.lr * (reward + self.gamma * next_max * (1 - terminated) - old_value)
|
||||
new_value = old_value + self.lr * (
|
||||
reward + self.gamma * next_max * (1 - terminated) - old_value
|
||||
)
|
||||
self.q_table[state, action] = new_value
|
||||
|
||||
def save_model(self, path):
|
||||
|
|
@ -38,3 +52,7 @@ class QLearningAgent:
|
|||
print(f"Q-Table loaded from {file_path}")
|
||||
else:
|
||||
print(f"Error: No Q-Table found at {file_path}")
|
||||
|
||||
def reset_episode(self):
|
||||
"""Reset agent for new episode"""
|
||||
self.policy.reset_episode()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,75 @@
|
|||
from abc import ABC, abstractmethod
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
|
||||
class Policy(ABC):
|
||||
@abstractmethod
|
||||
def select_action(self, q_values, action_mask=None):
|
||||
pass
|
||||
|
||||
|
||||
class EpisodePolicy(Policy):
|
||||
def __init__(self, epsilon=0.1):
|
||||
self.epsilon = epsilon
|
||||
self.episode_actions = set() # Track actions taken in current episode
|
||||
self.current_idx = 0 # For sequential action selection
|
||||
|
||||
def get_action_mask(self):
|
||||
# Create a mask with all actions available
|
||||
action_mask = np.ones(9) # Assuming 9 actions
|
||||
|
||||
# Mask already taken actions
|
||||
for action in self.episode_actions:
|
||||
action_mask[action] = 0
|
||||
|
||||
return action_mask
|
||||
|
||||
def select_action(self, q_values, action_mask=None):
|
||||
# Create default mask if none provided
|
||||
if action_mask is None:
|
||||
action_mask = self.get_action_mask()
|
||||
|
||||
# Apply action mask
|
||||
masked_q_values = q_values * action_mask
|
||||
|
||||
# Check for available actions
|
||||
valid_actions = np.where(action_mask)[0]
|
||||
if len(valid_actions) == 0:
|
||||
self.reset_episode()
|
||||
return None
|
||||
|
||||
# Get Q-values for valid actions
|
||||
masked_q_values = q_values * action_mask
|
||||
max_q = np.max(masked_q_values)
|
||||
|
||||
# When all Q-values are effectively zero (very small), select actions sequentially
|
||||
if np.allclose(masked_q_values[action_mask > 0], 0, atol=1e-10):
|
||||
# Find the first available action in sequence
|
||||
while self.current_idx in self.episode_actions and self.current_idx < len(
|
||||
q_values
|
||||
):
|
||||
self.current_idx += 1
|
||||
|
||||
if self.current_idx >= len(q_values):
|
||||
self.reset_episode()
|
||||
return None
|
||||
|
||||
action = self.current_idx
|
||||
self.episode_actions.add(action)
|
||||
return action
|
||||
|
||||
# Epsilon-greedy with masking for non-zero Q-values
|
||||
if random.uniform(0, 1) < self.epsilon:
|
||||
action = np.random.choice(valid_actions)
|
||||
else:
|
||||
max_actions = np.where(masked_q_values == max_q)[0]
|
||||
action = np.random.choice(max_actions)
|
||||
|
||||
self.episode_actions.add(action)
|
||||
return action
|
||||
|
||||
def reset_episode(self):
|
||||
"""Reset for new episode"""
|
||||
self.episode_actions.clear()
|
||||
self.current_idx = 0 # Reset sequential index
|
||||
|
|
@ -0,0 +1,145 @@
|
|||
import random
|
||||
import numpy as np
|
||||
from negotiation_agent.environment import NegotiationEnv
|
||||
from negotiation_agent.spaces import State, PriceZone, AcceptanceRate, Scenario
|
||||
from agents.offline_agent import QLearningAgent
|
||||
from usecases.initialize_env_usecase import initialize_environment_usecase
|
||||
|
||||
|
||||
def convert_action_to_response(action_idx, proposed_price):
|
||||
"""에이전트의 행동을 상황에 맞는 응답 텍스트로 변환"""
|
||||
action_responses = {
|
||||
0: [
|
||||
"강한 수락 (Action 0: STRONG_ACCEPT): 제안을 매우 흡족하게 수락하겠습니다."
|
||||
],
|
||||
1: ["중간 수락 (Action 1: MEDIUM_ACCEPT): 제안을 수락하겠습니다."],
|
||||
2: ["약한 수락 (Action 2: WEAK_ACCEPT): 고민 끝에 제안을 수락하겠습니다."],
|
||||
3: [
|
||||
f"강한 거절 (Action 3: STRONG_REJECT): {proposed_price}은(는) 너무 높은 가격입니다. 대폭 낮춰주셔야 합니다."
|
||||
],
|
||||
4: [
|
||||
f"중간 거절 (Action 4: MEDIUM_REJECT): {proposed_price}은(는) 높습니다. 더 낮은 가격을 제안해주세요."
|
||||
],
|
||||
5: [
|
||||
f"약한 거절 (Action 5: WEAK_REJECT): {proposed_price}은(는) 조금 높습니다. 더 조정이 필요합니다."
|
||||
],
|
||||
6: ["강한 가격 제안 (Action 6: STRONG_PROPOSE): 대폭 낮은 가격을 제안합니다."],
|
||||
7: ["중간 가격 제안 (Action 7: MEDIUM_PROPOSE): 조정된 가격을 제안합니다."],
|
||||
8: ["약한 가격 제안 (Action 8: WEAK_PROPOSE): 소폭 조정된 가격을 제안합니다."],
|
||||
}
|
||||
|
||||
response = random.choice(
|
||||
action_responses.get(
|
||||
action_idx, [f"Action {action_idx}: 가격 조정이 필요합니다."]
|
||||
)
|
||||
)
|
||||
return f"{proposed_price}에 대한 응답 - {response}"
|
||||
|
||||
|
||||
def run_interactive_negotiation():
|
||||
"""대화형 협상 시뮬레이션 실행"""
|
||||
# 환경 및 에이전트 초기화
|
||||
env = initialize_environment_usecase()
|
||||
agent_params = {
|
||||
"learning_rate": 0.001,
|
||||
"discount_factor": 0.99,
|
||||
"epsilon": 0.0, # 평가 모드에서는 탐험하지 않음
|
||||
}
|
||||
# MultiDiscrete 공간의 크기 계산
|
||||
state_size = np.prod(env.observation_space.nvec) # 상태 공간의 전체 크기
|
||||
action_size = (
|
||||
env.action_space.n
|
||||
if hasattr(env.action_space, "n")
|
||||
else np.prod(env.action_space.nvec)
|
||||
) # 행동 공간의 전체 크기
|
||||
|
||||
agent = QLearningAgent(agent_params, state_size, action_size)
|
||||
|
||||
# Q-table 로드
|
||||
agent.load_q_table("saved_models/q_table.npy")
|
||||
|
||||
while True:
|
||||
# 새로운 에피소드 시작
|
||||
state = env.reset()
|
||||
target_price = env.target_price
|
||||
threshold_price = env.threshold_price
|
||||
episode_done = False
|
||||
|
||||
print("\n=== 새로운 협상 시작 ===")
|
||||
print(f"목표 가격: {target_price}")
|
||||
print(f"임계 가격: {threshold_price}")
|
||||
print("\n협상을 시작합니다. 가격을 제안해주세요.")
|
||||
|
||||
while not episode_done:
|
||||
# 사용자 입력 받기
|
||||
try:
|
||||
user_price = float(input("\n당신의 제안 가격을 입력하세요: "))
|
||||
|
||||
# 목표가격 이하로 제안이 들어오면 즉시 수락 및 종료
|
||||
if user_price <= target_price:
|
||||
print("\n=== 협상 성공! ===")
|
||||
print(
|
||||
f"제안된 가격 ({user_price})이 목표가격 ({target_price}) 이하입니다."
|
||||
)
|
||||
print("에이전트: 즉시 수락 (특별 행동: 즉시 수락)")
|
||||
print("\n시뮬레이션을 종료합니다.")
|
||||
return # 전체 시뮬레이션 종료
|
||||
|
||||
except ValueError:
|
||||
print("올바른 가격을 입력해주세요")
|
||||
continue
|
||||
|
||||
# 현재 가격 업데이트 및 상태 계산
|
||||
env.current_price = user_price
|
||||
next_state = env._get_state()
|
||||
|
||||
# 상태 인덱스 계산
|
||||
try:
|
||||
state_idx = np.ravel_multi_index(next_state, env.observation_space.nvec)
|
||||
except ValueError as e:
|
||||
print(f"\n디버그 정보:")
|
||||
print(f"현재 상태 벡터: {next_state}")
|
||||
print(f"상태 공간 크기: {env.observation_space.nvec}")
|
||||
print(f"에러: {e}")
|
||||
state_idx = 0
|
||||
|
||||
# 현재 상태의 Q값들과 액션 마스크 출력
|
||||
print(f"\n디버그 정보:")
|
||||
print(f"현재 상태 벡터: {next_state}")
|
||||
print(f"계산된 상태 인덱스: {state_idx}")
|
||||
q_values = agent.q_table[state_idx]
|
||||
|
||||
# 액션 마스크 가져오기
|
||||
action_mask = agent.episode_policy.get_action_mask()
|
||||
|
||||
print("\n현재 상태의 Q값들과 선택 가능 여부:")
|
||||
print("(O: 선택 가능, X: 이미 사용됨)")
|
||||
for action_idx, (q_value, mask) in enumerate(zip(q_values, action_mask)):
|
||||
available = "O" if mask == 1 else "X"
|
||||
print(f"Action {action_idx}: {q_value:.4f} [{available}]")
|
||||
|
||||
# 에이전트의 응답 생성 (epsilon=0이므로 항상 최대 Q값의 행동 선택)
|
||||
agent_action = agent.get_action(state_idx)
|
||||
masked_q_values = q_values * action_mask
|
||||
max_q = np.max(masked_q_values)
|
||||
print(f"\n선택된 행동: {agent_action} (Q값: {q_values[agent_action]:.4f})")
|
||||
if np.allclose(masked_q_values[action_mask > 0], 0, atol=1e-10):
|
||||
print("(순차적 선택: 모든 유효한 Q값이 0에 가까움)")
|
||||
|
||||
# 에이전트의 응답 출력
|
||||
response = convert_action_to_response(agent_action, user_price)
|
||||
print(f"\n에이전트의 응답: {response}")
|
||||
|
||||
state = next_state
|
||||
|
||||
# 다시 시작 여부 확인
|
||||
if (
|
||||
not input("\n새로운 협상을 시작하시겠습니까? (y/n): ")
|
||||
.lower()
|
||||
.startswith("y")
|
||||
):
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_interactive_negotiation()
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -12,6 +12,19 @@ files = [
|
|||
{file = "cloudpickle-3.1.1.tar.gz", hash = "sha256:b216fa8ae4019d5482a8ac3c95d8f6346115d8835911fd4aefd1a445e4242c64"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "colorama"
|
||||
version = "0.4.6"
|
||||
description = "Cross-platform colored terminal text."
|
||||
optional = false
|
||||
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
|
||||
groups = ["dev"]
|
||||
markers = "sys_platform == \"win32\""
|
||||
files = [
|
||||
{file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"},
|
||||
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "farama-notifications"
|
||||
version = "0.0.4"
|
||||
|
|
@ -146,6 +159,18 @@ files = [
|
|||
[package.dependencies]
|
||||
numpy = ">=1.19.3"
|
||||
|
||||
[[package]]
|
||||
name = "iniconfig"
|
||||
version = "2.1.0"
|
||||
description = "brain-dead simple config-ini parsing"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760"},
|
||||
{file = "iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jinja2"
|
||||
version = "3.1.6"
|
||||
|
|
@ -566,6 +591,71 @@ files = [
|
|||
{file = "nvidia_nvtx_cu12-12.8.90-py3-none-win_amd64.whl", hash = "sha256:619c8304aedc69f02ea82dd244541a83c3d9d40993381b3b590f1adaed3db41e"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "packaging"
|
||||
version = "25.0"
|
||||
description = "Core utilities for Python packages"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "packaging-25.0-py3-none-any.whl", hash = "sha256:29572ef2b1f17581046b3a2227d5c611fb25ec70ca1ba8554b24b0e69331a484"},
|
||||
{file = "packaging-25.0.tar.gz", hash = "sha256:d443872c98d677bf60f6a1f2f8c1cb748e8fe762d2bf9d3148b5599295b0fc4f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pluggy"
|
||||
version = "1.6.0"
|
||||
description = "plugin and hook calling mechanisms for python"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746"},
|
||||
{file = "pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
dev = ["pre-commit", "tox"]
|
||||
testing = ["coverage", "pytest", "pytest-benchmark"]
|
||||
|
||||
[[package]]
|
||||
name = "pygments"
|
||||
version = "2.19.2"
|
||||
description = "Pygments is a syntax highlighting package written in Python."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b"},
|
||||
{file = "pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
windows-terminal = ["colorama (>=0.4.6)"]
|
||||
|
||||
[[package]]
|
||||
name = "pytest"
|
||||
version = "8.4.2"
|
||||
description = "pytest: simple powerful testing with Python"
|
||||
optional = false
|
||||
python-versions = ">=3.9"
|
||||
groups = ["dev"]
|
||||
files = [
|
||||
{file = "pytest-8.4.2-py3-none-any.whl", hash = "sha256:872f880de3fc3a5bdc88a11b39c9710c3497a547cfa9320bc3c5e62fbf272e79"},
|
||||
{file = "pytest-8.4.2.tar.gz", hash = "sha256:86c0d0b93306b961d58d62a4db4879f27fe25513d4b969df351abdddb3c30e01"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
colorama = {version = ">=0.4", markers = "sys_platform == \"win32\""}
|
||||
iniconfig = ">=1"
|
||||
packaging = ">=20"
|
||||
pluggy = ">=1.5,<2"
|
||||
pygments = ">=2.7.2"
|
||||
|
||||
[package.extras]
|
||||
dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "requests", "setuptools", "xmlschema"]
|
||||
|
||||
[[package]]
|
||||
name = "pyyaml"
|
||||
version = "6.0.2"
|
||||
|
|
@ -772,4 +862,4 @@ files = [
|
|||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.11,<3.14"
|
||||
content-hash = "46bdc65ebbf8732cbae2738e1da7875aea5a378314abe29e32cefcfe6474126a"
|
||||
content-hash = "d077803cae14e91eee21b3756e24c72c5c514267d79a6aa130c4f69b90a794a8"
|
||||
|
|
|
|||
|
|
@ -1,23 +1,21 @@
|
|||
[project]
|
||||
[tool.poetry]
|
||||
name = "q-table"
|
||||
version = "0.1.0"
|
||||
description = ""
|
||||
authors = [
|
||||
{name = "fbdeme",email = "90471819+fbdeme@users.noreply.github.com"}
|
||||
]
|
||||
authors = ["fbdeme <90471819+fbdeme@users.noreply.github.com>"]
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.11,<3.14"
|
||||
dependencies = [
|
||||
"gymnasium (>=1.2.0,<2.0.0)",
|
||||
"numpy (>=2.3.3,<3.0.0)",
|
||||
"h5py (>=3.14.0,<4.0.0)",
|
||||
"pyyaml (>=6.0.2,<7.0.0)",
|
||||
"torch (>=2.8.0,<3.0.0)"
|
||||
]
|
||||
packages = [{include = "."}]
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.11,<3.14"
|
||||
gymnasium = ">=1.2.0,<2.0.0"
|
||||
numpy = ">=2.3.3,<3.0.0"
|
||||
h5py = ">=3.14.0,<4.0.0"
|
||||
pyyaml = ">=6.0.2,<7.0.0"
|
||||
torch = ">=2.8.0,<3.0.0"
|
||||
|
||||
[tool.poetry]
|
||||
package-mode = false
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pytest = "^8.4.2"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=2.0.0,<3.0.0"]
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -0,0 +1,86 @@
|
|||
import pytest
|
||||
import numpy as np
|
||||
from agents.policy import EpisodePolicy
|
||||
|
||||
@pytest.fixture
|
||||
def policy():
|
||||
return EpisodePolicy(epsilon=0.0) # epsilon=0 for deterministic testing
|
||||
|
||||
@pytest.fixture
|
||||
def q_values():
|
||||
return np.array([1.0, 2.0, 3.0, 0.5]) # Example Q-values
|
||||
|
||||
def test_select_action_without_mask(policy, q_values):
|
||||
"""Test action selection without any mask"""
|
||||
# First action should be the highest Q-value
|
||||
action = policy.select_action(q_values)
|
||||
assert action == 2 # index 2 has highest value (3.0)
|
||||
|
||||
# Second action should exclude the previous one
|
||||
action = policy.select_action(q_values)
|
||||
assert action == 1 # index 1 has second highest value (2.0)
|
||||
|
||||
def test_select_action_with_mask(policy, q_values):
|
||||
"""Test action selection with explicit action mask"""
|
||||
action_mask = np.array([1, 1, 0, 1]) # Mask out action 2
|
||||
action = policy.select_action(q_values, action_mask)
|
||||
assert action == 1 # index 1 has highest value among unmasked actions
|
||||
|
||||
def test_episode_tracking(policy, q_values):
|
||||
"""Test if actions are properly tracked within an episode"""
|
||||
# Take some actions
|
||||
policy.select_action(q_values)
|
||||
policy.select_action(q_values)
|
||||
policy.select_action(q_values)
|
||||
|
||||
# Check if actions were tracked
|
||||
assert len(policy.episode_actions) == 3
|
||||
|
||||
def test_reset_episode(policy, q_values):
|
||||
"""Test episode reset functionality"""
|
||||
# Take some actions
|
||||
policy.select_action(q_values)
|
||||
policy.select_action(q_values)
|
||||
|
||||
# Reset episode
|
||||
policy.reset_episode()
|
||||
|
||||
# Check if actions were cleared
|
||||
assert len(policy.episode_actions) == 0
|
||||
|
||||
def test_all_actions_taken(policy, q_values):
|
||||
"""Test behavior when all actions have been taken"""
|
||||
# Take all possible actions
|
||||
actions_taken = []
|
||||
for _ in range(len(q_values)):
|
||||
action = policy.select_action(q_values)
|
||||
assert action is not None
|
||||
actions_taken.append(action)
|
||||
|
||||
# Verify all actions were unique
|
||||
assert len(set(actions_taken)) == len(q_values)
|
||||
|
||||
# Try to take one more action
|
||||
action = policy.select_action(q_values)
|
||||
assert action is None # Should return None when no actions are available
|
||||
|
||||
# Check if episode was automatically reset
|
||||
assert len(policy.episode_actions) == 0
|
||||
|
||||
@pytest.mark.parametrize("epsilon,min_unique_actions", [
|
||||
(0.0, 1), # Deterministic - should always take best action first
|
||||
(1.0, 3) # Random - should see multiple different actions
|
||||
])
|
||||
def test_epsilon_greedy(q_values, epsilon, min_unique_actions):
|
||||
"""Test epsilon-greedy behavior with different epsilon values"""
|
||||
policy = EpisodePolicy(epsilon=epsilon)
|
||||
actions = set()
|
||||
|
||||
# Take multiple actions and verify they're appropriate for the epsilon value
|
||||
for _ in range(50): # Run multiple times to ensure statistical significance
|
||||
action = policy.select_action(q_values)
|
||||
if action is not None:
|
||||
actions.add(action)
|
||||
policy.reset_episode()
|
||||
|
||||
assert len(actions) >= min_unique_actions
|
||||
|
|
@ -0,0 +1,109 @@
|
|||
import pytest
|
||||
import numpy as np
|
||||
from agents.offline_agent import QLearningAgent
|
||||
|
||||
@pytest.fixture
|
||||
def agent_params():
|
||||
return {
|
||||
'learning_rate': 0.1,
|
||||
'discount_factor': 0.99,
|
||||
'epsilon': 0.0 # Deterministic for testing
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def agent(agent_params):
|
||||
return QLearningAgent(agent_params, state_size=4, action_size=3)
|
||||
|
||||
def test_agent_initialization(agent):
|
||||
"""Test agent initialization"""
|
||||
assert agent.state_size == 4
|
||||
assert agent.action_size == 3
|
||||
assert agent.lr == 0.1
|
||||
assert agent.gamma == 0.99
|
||||
assert agent.q_table.shape == (4, 3)
|
||||
assert np.all(agent.q_table == 0) # Q-table should be initialized to zeros
|
||||
|
||||
def test_get_action_with_mask(agent):
|
||||
"""Test action selection with action masking"""
|
||||
# Set up known Q-values
|
||||
agent.q_table[0] = np.array([1.0, 2.0, 3.0])
|
||||
|
||||
# Test without mask
|
||||
action = agent.get_action(0)
|
||||
assert action == 2 # Should choose highest Q-value
|
||||
|
||||
# Test with mask
|
||||
action_mask = np.array([1, 1, 0]) # Mask out the highest value
|
||||
action = agent.get_action(0, action_mask)
|
||||
assert action == 1 # Should choose second highest value
|
||||
|
||||
def test_episode_tracking(agent):
|
||||
"""Test action tracking within an episode"""
|
||||
agent.q_table[0] = np.array([1.0, 2.0, 3.0])
|
||||
|
||||
# Take all possible actions
|
||||
actions = []
|
||||
for _ in range(agent.action_size):
|
||||
action = agent.get_action(0)
|
||||
assert action is not None
|
||||
actions.append(action)
|
||||
|
||||
# Verify all actions were unique
|
||||
assert len(set(actions)) == agent.action_size
|
||||
|
||||
# Next action should be None as all actions are taken
|
||||
assert agent.get_action(0) is None
|
||||
|
||||
def test_episode_reset(agent):
|
||||
"""Test episode reset functionality"""
|
||||
agent.q_table[0] = np.array([1.0, 2.0, 3.0])
|
||||
|
||||
# Take some actions
|
||||
agent.get_action(0)
|
||||
agent.get_action(0)
|
||||
|
||||
# Reset episode
|
||||
agent.reset_episode()
|
||||
|
||||
# Should be able to take the best action again
|
||||
action = agent.get_action(0)
|
||||
assert action == 2 # Highest Q-value action
|
||||
|
||||
def test_learning(agent):
|
||||
"""Test Q-learning update"""
|
||||
# Create a simple batch
|
||||
batch = {
|
||||
'observations': np.array([0]),
|
||||
'actions': np.array([1]),
|
||||
'rewards': np.array([1.0]),
|
||||
'next_observations': np.array([1]),
|
||||
'terminals': np.array([False])
|
||||
}
|
||||
|
||||
# Set up known Q-values
|
||||
agent.q_table[1] = np.array([0.5, 0.8, 0.3]) # Next state Q-values
|
||||
old_value = agent.q_table[0, 1]
|
||||
|
||||
# Perform learning update
|
||||
agent.learn(batch)
|
||||
|
||||
# Check if Q-value was updated correctly
|
||||
# Q(s,a) = Q(s,a) + lr * (R + gamma * max(Q(s')) - Q(s,a))
|
||||
expected_value = old_value + agent.lr * (1.0 + agent.gamma * 0.8 - old_value)
|
||||
assert np.isclose(agent.q_table[0, 1], expected_value)
|
||||
|
||||
def test_save_and_load(agent, tmp_path):
|
||||
"""Test model saving and loading"""
|
||||
# Set some Q-values
|
||||
agent.q_table[0] = np.array([1.0, 2.0, 3.0])
|
||||
|
||||
# Save model
|
||||
save_path = tmp_path / "q_table.npy"
|
||||
agent.save_model(save_path)
|
||||
|
||||
# Create new agent and load model
|
||||
new_agent = QLearningAgent(agent_params(), state_size=4, action_size=3)
|
||||
new_agent.load_q_table(save_path)
|
||||
|
||||
# Check if Q-values match
|
||||
assert np.all(agent.q_table == new_agent.q_table)
|
||||
Binary file not shown.
Loading…
Reference in New Issue