441 lines
15 KiB
Python
441 lines
15 KiB
Python
"""
|
||
Q-Table 협상 전략 데모 메인 서비스
|
||
"""
|
||
import random
|
||
import time
|
||
from typing import Dict, List, Optional, Any, Tuple
|
||
import pandas as pd
|
||
import numpy as np
|
||
|
||
from app.models.schemas import (
|
||
CardType, ScenarioType, PriceZoneType,
|
||
ExperienceData, EpisodeGenerationRequest,
|
||
RewardCalculationRequest, RewardCalculationResponse,
|
||
ActionRecommendationRequest, ActionRecommendationResponse,
|
||
SystemStatus
|
||
)
|
||
from app.services.negotiation_env import NegotiationEnvironment
|
||
from app.services.qtable_learner import QTableLearner, ExperienceBuffer
|
||
from app.services.fqi_cql import FQICQLLearner
|
||
|
||
|
||
class DemoService:
|
||
"""Q-Table 협상 전략 데모 메인 서비스"""
|
||
|
||
def __init__(self):
|
||
# 환경 초기화
|
||
self.env = NegotiationEnvironment()
|
||
|
||
# 상태 및 행동 공간
|
||
self.states = self.env.get_all_states()
|
||
self.actions = self.env.get_all_actions()
|
||
|
||
# 학습 엔진들
|
||
self.experience_buffer = ExperienceBuffer(max_size=10000)
|
||
self.q_learner = QTableLearner(
|
||
states=self.states,
|
||
actions=self.actions,
|
||
learning_rate=0.1,
|
||
discount_factor=0.9,
|
||
epsilon=0.1
|
||
)
|
||
self.fqi_cql_learner = FQICQLLearner(
|
||
states=self.states,
|
||
actions=self.actions,
|
||
alpha=1.0,
|
||
gamma=0.95
|
||
)
|
||
|
||
# 통계 정보
|
||
self.episode_count = 0
|
||
self.start_time = time.time()
|
||
|
||
def calculate_reward(self, request: RewardCalculationRequest) -> RewardCalculationResponse:
|
||
"""보상 계산"""
|
||
reward, weight = self.env.calculate_reward(
|
||
scenario=request.scenario,
|
||
price_zone=request.price_zone,
|
||
anchor_price=request.anchor_price,
|
||
proposed_price=request.proposed_price,
|
||
is_end=request.is_end
|
||
)
|
||
|
||
# 시나리오 및 가격구간 가중치
|
||
scenario_weight = self.env.scenario_weights[request.scenario]
|
||
price_zone_weight = self.env.price_zone_weights[request.price_zone]
|
||
|
||
# 가격 비율
|
||
price_ratio = request.anchor_price / request.proposed_price if request.proposed_price > 0 else float('inf')
|
||
|
||
# 공식 분해
|
||
formula_breakdown = (
|
||
f"R(s,a) = W × (A/P) + (1-W) × End\n"
|
||
f"W = (S_n + PZ_n) / 2 = ({scenario_weight} + {price_zone_weight}) / 2 = {weight:.3f}\n"
|
||
f"A/P = {request.anchor_price}/{request.proposed_price} = {price_ratio:.3f}\n"
|
||
f"End = {1 if request.is_end else 0}\n"
|
||
f"R(s,a) = {weight:.3f} × {price_ratio:.3f} + {1-weight:.3f} × {1 if request.is_end else 0} = {reward:.3f}"
|
||
)
|
||
|
||
return RewardCalculationResponse(
|
||
reward=reward,
|
||
weight=weight,
|
||
scenario_weight=scenario_weight,
|
||
price_zone_weight=price_zone_weight,
|
||
price_ratio=price_ratio,
|
||
formula_breakdown=formula_breakdown
|
||
)
|
||
|
||
def generate_episodes(self, request: EpisodeGenerationRequest) -> Dict[str, Any]:
|
||
"""에피소드 생성"""
|
||
new_experiences = 0
|
||
episode_results = []
|
||
|
||
for episode in range(request.num_episodes):
|
||
episode_result = self._generate_single_episode(
|
||
max_steps=request.max_steps,
|
||
anchor_price=request.anchor_price,
|
||
exploration_rate=request.exploration_rate,
|
||
episode_id=self.episode_count + episode
|
||
)
|
||
episode_results.append(episode_result)
|
||
new_experiences += episode_result['steps']
|
||
|
||
self.episode_count += request.num_episodes
|
||
|
||
return {
|
||
"episodes_generated": request.num_episodes,
|
||
"new_experiences": new_experiences,
|
||
"episode_results": episode_results,
|
||
"total_episodes": self.episode_count
|
||
}
|
||
|
||
def _generate_single_episode(
|
||
self,
|
||
max_steps: int,
|
||
anchor_price: float,
|
||
exploration_rate: float,
|
||
episode_id: int
|
||
) -> Dict[str, Any]:
|
||
"""단일 에피소드 생성"""
|
||
# 초기 상태
|
||
current_state = "C0S0P0"
|
||
scenario = random.choice(list(ScenarioType))
|
||
|
||
episode_reward = 0.0
|
||
steps = 0
|
||
success = False
|
||
|
||
for step in range(max_steps):
|
||
# 행동 선택 (epsilon-greedy)
|
||
if random.random() < exploration_rate:
|
||
action = random.choice(self.actions)
|
||
is_exploration = True
|
||
else:
|
||
action = self.q_learner.get_optimal_action(current_state)
|
||
is_exploration = False
|
||
|
||
# 환경 응답
|
||
proposed_price = self.env.simulate_opponent_response(
|
||
current_card=action,
|
||
scenario=scenario,
|
||
anchor_price=anchor_price,
|
||
step=step
|
||
)
|
||
|
||
# 가격 구간 결정
|
||
price_zone = self.env.get_price_zone(proposed_price, anchor_price)
|
||
|
||
# 다음 상태
|
||
next_state = f"{action.value}{scenario.value}{price_zone.value}"
|
||
|
||
# 종료 조건 확인
|
||
is_done = self.env.is_negotiation_successful(proposed_price, anchor_price) or (step >= max_steps - 1)
|
||
if self.env.is_negotiation_successful(proposed_price, anchor_price):
|
||
success = True
|
||
|
||
# 보상 계산
|
||
reward, weight = self.env.calculate_reward(
|
||
scenario=scenario,
|
||
price_zone=price_zone,
|
||
anchor_price=anchor_price,
|
||
proposed_price=proposed_price,
|
||
is_end=is_done
|
||
)
|
||
|
||
# 경험 저장
|
||
metadata = {
|
||
'episode': episode_id,
|
||
'step': step,
|
||
'scenario': scenario.value,
|
||
'proposed_price': proposed_price,
|
||
'weight': weight,
|
||
'is_exploration': is_exploration,
|
||
'anchor_price': anchor_price
|
||
}
|
||
|
||
self.experience_buffer.add_experience(
|
||
state=current_state,
|
||
action=action,
|
||
reward=reward,
|
||
next_state=next_state,
|
||
done=is_done,
|
||
metadata=metadata
|
||
)
|
||
|
||
episode_reward += reward
|
||
steps += 1
|
||
current_state = next_state
|
||
|
||
if is_done:
|
||
break
|
||
|
||
return {
|
||
'episode_id': episode_id,
|
||
'steps': steps,
|
||
'total_reward': episode_reward,
|
||
'success': success,
|
||
'final_price': proposed_price if 'proposed_price' in locals() else anchor_price,
|
||
'scenario': scenario.value
|
||
}
|
||
|
||
def update_q_learning(self, learning_rate: float, discount_factor: float, batch_size: int) -> Dict[str, Any]:
|
||
"""Q-Learning 업데이트"""
|
||
# 하이퍼파라미터 설정
|
||
self.q_learner.set_hyperparameters(
|
||
learning_rate=learning_rate,
|
||
discount_factor=discount_factor
|
||
)
|
||
|
||
# 경험 데이터 가져오기
|
||
experiences = self.experience_buffer.get_experiences()
|
||
if not experiences:
|
||
return {"message": "No experience data available", "updates": 0}
|
||
|
||
# 배치 샘플링
|
||
if len(experiences) > batch_size:
|
||
batch = self.experience_buffer.sample_batch(batch_size)
|
||
else:
|
||
batch = experiences
|
||
|
||
# 배치 업데이트
|
||
result = self.q_learner.batch_update(batch)
|
||
|
||
return {
|
||
"message": "Q-Learning update completed",
|
||
"batch_size": len(batch),
|
||
"updates": result["updates"],
|
||
"avg_td_error": result["avg_td_error"],
|
||
"total_updates": self.q_learner.update_count
|
||
}
|
||
|
||
def run_fqi_cql(self, alpha: float, gamma: float, batch_size: int, num_iterations: int) -> Dict[str, Any]:
|
||
"""FQI+CQL 실행"""
|
||
# 하이퍼파라미터 설정
|
||
self.fqi_cql_learner.set_hyperparameters(
|
||
alpha=alpha,
|
||
gamma=gamma
|
||
)
|
||
|
||
# 경험 데이터 가져오기
|
||
experiences = self.experience_buffer.get_experiences()
|
||
if not experiences:
|
||
return {"message": "No experience data available", "iterations": 0}
|
||
|
||
# 배치 샘플링
|
||
if len(experiences) > batch_size:
|
||
batch = self.experience_buffer.sample_batch(batch_size)
|
||
else:
|
||
batch = experiences
|
||
|
||
# FQI+CQL 학습
|
||
result = self.fqi_cql_learner.train_multiple_iterations(
|
||
experience_batch=batch,
|
||
num_iterations=num_iterations
|
||
)
|
||
|
||
# 정책 비교
|
||
policy_comparison = self.fqi_cql_learner.compare_with_behavior_policy(batch)
|
||
|
||
return {
|
||
"message": "FQI+CQL training completed",
|
||
"training_result": result,
|
||
"policy_comparison": policy_comparison,
|
||
"batch_size": len(batch)
|
||
}
|
||
|
||
def get_action_recommendation(self, request: ActionRecommendationRequest) -> ActionRecommendationResponse:
|
||
"""행동 추천"""
|
||
# Q값들 가져오기
|
||
q_values = self.q_learner.get_state_q_values(request.current_state)
|
||
|
||
# 행동 선택
|
||
if request.use_epsilon_greedy:
|
||
action, is_exploration = self.q_learner.select_action(
|
||
state=request.current_state,
|
||
use_epsilon_greedy=True
|
||
)
|
||
# 임시로 epsilon 설정
|
||
original_epsilon = self.q_learner.epsilon
|
||
self.q_learner.epsilon = request.epsilon
|
||
action, is_exploration = self.q_learner.select_action(
|
||
state=request.current_state,
|
||
use_epsilon_greedy=True
|
||
)
|
||
self.q_learner.epsilon = original_epsilon
|
||
else:
|
||
action = self.q_learner.get_optimal_action(request.current_state)
|
||
is_exploration = False
|
||
|
||
# 신뢰도 계산 (Q값 분산 기반)
|
||
if q_values and len(q_values) > 1:
|
||
q_vals = list(q_values.values())
|
||
max_q = max(q_vals)
|
||
q_range = max(q_vals) - min(q_vals)
|
||
confidence = max_q / (q_range + 1e-8) if q_range > 0 else 1.0
|
||
confidence = min(confidence, 1.0)
|
||
else:
|
||
confidence = 0.0
|
||
|
||
return ActionRecommendationResponse(
|
||
recommended_action=action,
|
||
q_values=q_values,
|
||
confidence=confidence,
|
||
exploration=is_exploration
|
||
)
|
||
|
||
def get_system_status(self) -> SystemStatus:
|
||
"""시스템 상태 조회"""
|
||
exp_df = self.experience_buffer.get_dataframe()
|
||
|
||
if not exp_df.empty:
|
||
avg_reward = exp_df['reward'].mean()
|
||
success_count = exp_df['done'].sum()
|
||
success_rate = success_count / len(exp_df) if len(exp_df) > 0 else 0.0
|
||
unique_states = exp_df['state'].nunique()
|
||
else:
|
||
avg_reward = 0.0
|
||
success_rate = 0.0
|
||
unique_states = 0
|
||
|
||
return SystemStatus(
|
||
total_experiences=self.experience_buffer.size(),
|
||
q_table_updates=self.q_learner.update_count,
|
||
unique_states=unique_states,
|
||
average_reward=avg_reward,
|
||
success_rate=success_rate,
|
||
last_update=time.time()
|
||
)
|
||
|
||
def get_q_table(self) -> Dict[str, Any]:
|
||
"""Q-Table 데이터 반환"""
|
||
q_table_df = self.q_learner.get_q_table_copy()
|
||
stats = self.q_learner.get_learning_statistics()
|
||
|
||
return {
|
||
"q_table": q_table_df.to_dict(),
|
||
"statistics": stats,
|
||
"update_count": self.q_learner.update_count,
|
||
"hyperparameters": {
|
||
"learning_rate": self.q_learner.learning_rate,
|
||
"discount_factor": self.q_learner.discount_factor,
|
||
"epsilon": self.q_learner.epsilon
|
||
}
|
||
}
|
||
|
||
def get_fqi_cql_results(self) -> Dict[str, Any]:
|
||
"""FQI+CQL 결과 반환"""
|
||
q_network_df = self.fqi_cql_learner.get_q_network_copy()
|
||
stats = self.fqi_cql_learner.get_training_statistics()
|
||
|
||
return {
|
||
"q_network": q_network_df.to_dict(),
|
||
"statistics": stats,
|
||
"batch_count": self.fqi_cql_learner.batch_count,
|
||
"hyperparameters": {
|
||
"alpha": self.fqi_cql_learner.alpha,
|
||
"gamma": self.fqi_cql_learner.gamma,
|
||
"learning_rate": self.fqi_cql_learner.learning_rate
|
||
}
|
||
}
|
||
|
||
def get_experience_data(self) -> Dict[str, Any]:
|
||
"""경험 데이터 반환"""
|
||
exp_df = self.experience_buffer.get_dataframe()
|
||
|
||
if not exp_df.empty:
|
||
# 기본 통계
|
||
stats = {
|
||
"total_count": len(exp_df),
|
||
"avg_reward": exp_df['reward'].mean(),
|
||
"reward_std": exp_df['reward'].std(),
|
||
"success_rate": exp_df['done'].sum() / len(exp_df),
|
||
"unique_states": exp_df['state'].nunique(),
|
||
"unique_actions": exp_df['action'].nunique()
|
||
}
|
||
|
||
# 최근 데이터
|
||
recent_data = exp_df.tail(20).to_dict('records')
|
||
else:
|
||
stats = {
|
||
"total_count": 0,
|
||
"avg_reward": 0.0,
|
||
"reward_std": 0.0,
|
||
"success_rate": 0.0,
|
||
"unique_states": 0,
|
||
"unique_actions": 0
|
||
}
|
||
recent_data = []
|
||
|
||
return {
|
||
"statistics": stats,
|
||
"recent_data": recent_data,
|
||
"buffer_size": self.experience_buffer.size(),
|
||
"max_size": self.experience_buffer.max_size
|
||
}
|
||
|
||
def reset_all(self):
|
||
"""모든 학습 상태 초기화"""
|
||
self.experience_buffer.clear()
|
||
self.q_learner.reset()
|
||
self.fqi_cql_learner.reset()
|
||
self.episode_count = 0
|
||
self.start_time = time.time()
|
||
|
||
def compare_policies(self, state: str) -> Dict[str, Any]:
|
||
"""Q-Learning과 FQI+CQL 정책 비교"""
|
||
# Q-Learning 정책
|
||
q_learning_action = self.q_learner.get_optimal_action(state)
|
||
q_learning_values = self.q_learner.get_state_q_values(state)
|
||
|
||
# FQI+CQL 정책
|
||
fqi_cql_action = self.fqi_cql_learner.get_optimal_action(state)
|
||
fqi_cql_values = self.fqi_cql_learner.get_state_q_values(state)
|
||
|
||
# 정책 일치 여부
|
||
policy_agreement = (q_learning_action == fqi_cql_action)
|
||
|
||
# Q값 차이
|
||
q_value_differences = {}
|
||
for action_name in q_learning_values:
|
||
diff = abs(q_learning_values[action_name] - fqi_cql_values.get(action_name, 0.0))
|
||
q_value_differences[action_name] = diff
|
||
|
||
return {
|
||
"state": state,
|
||
"q_learning": {
|
||
"action": q_learning_action.value,
|
||
"q_values": q_learning_values
|
||
},
|
||
"fqi_cql": {
|
||
"action": fqi_cql_action.value,
|
||
"q_values": fqi_cql_values
|
||
},
|
||
"policy_agreement": policy_agreement,
|
||
"q_value_differences": q_value_differences,
|
||
"max_difference": max(q_value_differences.values()) if q_value_differences else 0.0
|
||
}
|
||
|
||
|
||
# 전역 서비스 인스턴스
|
||
demo_service = DemoService()
|