181 lines
5.2 KiB
Python
181 lines
5.2 KiB
Python
"""
|
|
기본 테스트 모듈
|
|
"""
|
|
import pytest
|
|
import requests
|
|
from app.services.negotiation_env import NegotiationEnvironment
|
|
from app.services.qtable_learner import QTableLearner, ExperienceBuffer
|
|
from app.models.schemas import CardType, ScenarioType, PriceZoneType
|
|
|
|
|
|
class TestNegotiationEnvironment:
|
|
"""협상 환경 테스트"""
|
|
|
|
def setup_method(self):
|
|
self.env = NegotiationEnvironment()
|
|
|
|
def test_reward_calculation(self):
|
|
"""보상 계산 테스트"""
|
|
reward, weight = self.env.calculate_reward(
|
|
scenario=ScenarioType.A,
|
|
price_zone=PriceZoneType.PZ1,
|
|
anchor_price=100,
|
|
proposed_price=95,
|
|
is_end=True
|
|
)
|
|
|
|
assert reward > 0
|
|
assert 0 <= weight <= 1
|
|
|
|
def test_price_zone_determination(self):
|
|
"""가격 구간 결정 테스트"""
|
|
# 목표가 이하
|
|
zone = self.env.get_price_zone(90, 100)
|
|
assert zone == PriceZoneType.PZ1
|
|
|
|
# 중간 구간
|
|
zone = self.env.get_price_zone(110, 100)
|
|
assert zone == PriceZoneType.PZ2
|
|
|
|
# 높은 구간
|
|
zone = self.env.get_price_zone(150, 100)
|
|
assert zone == PriceZoneType.PZ3
|
|
|
|
def test_opponent_response_simulation(self):
|
|
"""상대방 응답 시뮬레이션 테스트"""
|
|
price = self.env.simulate_opponent_response(
|
|
current_card=CardType.C1,
|
|
scenario=ScenarioType.A,
|
|
anchor_price=100,
|
|
step=0
|
|
)
|
|
|
|
assert price > 0
|
|
assert isinstance(price, float)
|
|
|
|
|
|
class TestQTableLearner:
|
|
"""Q-Table 학습 테스트"""
|
|
|
|
def setup_method(self):
|
|
states = ["S1", "S2", "S3"]
|
|
actions = [CardType.C1, CardType.C2]
|
|
self.learner = QTableLearner(states, actions)
|
|
|
|
def test_initialization(self):
|
|
"""초기화 테스트"""
|
|
assert self.learner.q_table.shape == (3, 2)
|
|
assert (self.learner.q_table == 0).all().all()
|
|
|
|
def test_q_value_update(self):
|
|
"""Q값 업데이트 테스트"""
|
|
td_error = self.learner.update_q_value(
|
|
state="S1",
|
|
action=CardType.C1,
|
|
reward=1.0,
|
|
next_state="S2",
|
|
done=False
|
|
)
|
|
|
|
assert td_error != 0
|
|
assert self.learner.get_q_value("S1", CardType.C1) != 0
|
|
|
|
def test_action_selection(self):
|
|
"""행동 선택 테스트"""
|
|
# 초기 상태에서는 무작위 선택
|
|
action, is_exploration = self.learner.select_action("S1")
|
|
assert action in [CardType.C1, CardType.C2]
|
|
|
|
# Q값 설정 후 최적 행동 선택
|
|
self.learner.set_q_value("S1", CardType.C2, 1.0)
|
|
optimal_action = self.learner.get_optimal_action("S1")
|
|
assert optimal_action == CardType.C2
|
|
|
|
|
|
class TestExperienceBuffer:
|
|
"""경험 버퍼 테스트"""
|
|
|
|
def setup_method(self):
|
|
self.buffer = ExperienceBuffer(max_size=10)
|
|
|
|
def test_add_experience(self):
|
|
"""경험 추가 테스트"""
|
|
self.buffer.add_experience(
|
|
state="S1",
|
|
action=CardType.C1,
|
|
reward=1.0,
|
|
next_state="S2",
|
|
done=False
|
|
)
|
|
|
|
assert self.buffer.size() == 1
|
|
|
|
def test_buffer_overflow(self):
|
|
"""버퍼 오버플로우 테스트"""
|
|
# 최대 크기보다 많이 추가
|
|
for i in range(15):
|
|
self.buffer.add_experience(
|
|
state=f"S{i}",
|
|
action=CardType.C1,
|
|
reward=1.0,
|
|
next_state=f"S{i+1}",
|
|
done=False
|
|
)
|
|
|
|
# 최대 크기 유지
|
|
assert self.buffer.size() == 10
|
|
|
|
def test_sampling(self):
|
|
"""샘플링 테스트"""
|
|
# 경험 추가
|
|
for i in range(5):
|
|
self.buffer.add_experience(
|
|
state=f"S{i}",
|
|
action=CardType.C1,
|
|
reward=1.0,
|
|
next_state=f"S{i+1}",
|
|
done=False
|
|
)
|
|
|
|
# 배치 샘플링
|
|
batch = self.buffer.sample_batch(3)
|
|
assert len(batch) == 3
|
|
|
|
|
|
# API 통합 테스트 (선택사항)
|
|
class TestAPIIntegration:
|
|
"""API 통합 테스트"""
|
|
|
|
def test_health_check(self):
|
|
"""헬스 체크 테스트"""
|
|
try:
|
|
response = requests.get("http://localhost:8000/api/v1/health")
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert data["status"] == "healthy"
|
|
except requests.exceptions.ConnectionError:
|
|
pytest.skip("API 서버가 실행되지 않음")
|
|
|
|
def test_reward_calculation_endpoint(self):
|
|
"""보상 계산 엔드포인트 테스트"""
|
|
try:
|
|
payload = {
|
|
"scenario": "A",
|
|
"price_zone": "PZ1",
|
|
"anchor_price": 100,
|
|
"proposed_price": 95,
|
|
"is_end": True
|
|
}
|
|
|
|
response = requests.post(
|
|
"http://localhost:8000/api/v1/reward/calculate",
|
|
json=payload
|
|
)
|
|
|
|
assert response.status_code == 200
|
|
data = response.json()
|
|
assert "reward" in data
|
|
assert "weight" in data
|
|
except requests.exceptions.ConnectionError:
|
|
pytest.skip("API 서버가 실행되지 않음")
|