q_table_demo/test_basic.py

181 lines
5.2 KiB
Python

"""
기본 테스트 모듈
"""
import pytest
import requests
from app.services.negotiation_env import NegotiationEnvironment
from app.services.qtable_learner import QTableLearner, ExperienceBuffer
from app.models.schemas import CardType, ScenarioType, PriceZoneType
class TestNegotiationEnvironment:
"""협상 환경 테스트"""
def setup_method(self):
self.env = NegotiationEnvironment()
def test_reward_calculation(self):
"""보상 계산 테스트"""
reward, weight = self.env.calculate_reward(
scenario=ScenarioType.A,
price_zone=PriceZoneType.PZ1,
anchor_price=100,
proposed_price=95,
is_end=True
)
assert reward > 0
assert 0 <= weight <= 1
def test_price_zone_determination(self):
"""가격 구간 결정 테스트"""
# 목표가 이하
zone = self.env.get_price_zone(90, 100)
assert zone == PriceZoneType.PZ1
# 중간 구간
zone = self.env.get_price_zone(110, 100)
assert zone == PriceZoneType.PZ2
# 높은 구간
zone = self.env.get_price_zone(150, 100)
assert zone == PriceZoneType.PZ3
def test_opponent_response_simulation(self):
"""상대방 응답 시뮬레이션 테스트"""
price = self.env.simulate_opponent_response(
current_card=CardType.C1,
scenario=ScenarioType.A,
anchor_price=100,
step=0
)
assert price > 0
assert isinstance(price, float)
class TestQTableLearner:
"""Q-Table 학습 테스트"""
def setup_method(self):
states = ["S1", "S2", "S3"]
actions = [CardType.C1, CardType.C2]
self.learner = QTableLearner(states, actions)
def test_initialization(self):
"""초기화 테스트"""
assert self.learner.q_table.shape == (3, 2)
assert (self.learner.q_table == 0).all().all()
def test_q_value_update(self):
"""Q값 업데이트 테스트"""
td_error = self.learner.update_q_value(
state="S1",
action=CardType.C1,
reward=1.0,
next_state="S2",
done=False
)
assert td_error != 0
assert self.learner.get_q_value("S1", CardType.C1) != 0
def test_action_selection(self):
"""행동 선택 테스트"""
# 초기 상태에서는 무작위 선택
action, is_exploration = self.learner.select_action("S1")
assert action in [CardType.C1, CardType.C2]
# Q값 설정 후 최적 행동 선택
self.learner.set_q_value("S1", CardType.C2, 1.0)
optimal_action = self.learner.get_optimal_action("S1")
assert optimal_action == CardType.C2
class TestExperienceBuffer:
"""경험 버퍼 테스트"""
def setup_method(self):
self.buffer = ExperienceBuffer(max_size=10)
def test_add_experience(self):
"""경험 추가 테스트"""
self.buffer.add_experience(
state="S1",
action=CardType.C1,
reward=1.0,
next_state="S2",
done=False
)
assert self.buffer.size() == 1
def test_buffer_overflow(self):
"""버퍼 오버플로우 테스트"""
# 최대 크기보다 많이 추가
for i in range(15):
self.buffer.add_experience(
state=f"S{i}",
action=CardType.C1,
reward=1.0,
next_state=f"S{i+1}",
done=False
)
# 최대 크기 유지
assert self.buffer.size() == 10
def test_sampling(self):
"""샘플링 테스트"""
# 경험 추가
for i in range(5):
self.buffer.add_experience(
state=f"S{i}",
action=CardType.C1,
reward=1.0,
next_state=f"S{i+1}",
done=False
)
# 배치 샘플링
batch = self.buffer.sample_batch(3)
assert len(batch) == 3
# API 통합 테스트 (선택사항)
class TestAPIIntegration:
"""API 통합 테스트"""
def test_health_check(self):
"""헬스 체크 테스트"""
try:
response = requests.get("http://localhost:8000/api/v1/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
except requests.exceptions.ConnectionError:
pytest.skip("API 서버가 실행되지 않음")
def test_reward_calculation_endpoint(self):
"""보상 계산 엔드포인트 테스트"""
try:
payload = {
"scenario": "A",
"price_zone": "PZ1",
"anchor_price": 100,
"proposed_price": 95,
"is_end": True
}
response = requests.post(
"http://localhost:8000/api/v1/reward/calculate",
json=payload
)
assert response.status_code == 200
data = response.json()
assert "reward" in data
assert "weight" in data
except requests.exceptions.ConnectionError:
pytest.skip("API 서버가 실행되지 않음")