314 lines
11 KiB
Python
314 lines
11 KiB
Python
"""
|
|
FQI (Fitted Q-Iteration) + CQL (Conservative Q-Learning) 서비스
|
|
"""
|
|
import pandas as pd
|
|
import numpy as np
|
|
import time
|
|
from typing import Dict, List, Any, Optional
|
|
from app.models.schemas import CardType, ExperienceData
|
|
|
|
|
|
class FQICQLLearner:
|
|
"""FQI + CQL 학습 엔진"""
|
|
|
|
def __init__(
|
|
self,
|
|
states: List[str],
|
|
actions: List[CardType],
|
|
alpha: float = 1.0, # CQL 보수성 파라미터
|
|
gamma: float = 0.95, # 할인율
|
|
learning_rate: float = 0.01
|
|
):
|
|
self.states = states
|
|
self.actions = actions
|
|
self.alpha = alpha
|
|
self.gamma = gamma
|
|
self.learning_rate = learning_rate
|
|
|
|
# Q-네트워크 시뮬레이션 (실제로는 신경망이지만 여기서는 테이블로 구현)
|
|
action_names = [action.value for action in actions]
|
|
self.q_network = pd.DataFrame(
|
|
np.random.uniform(0, 0.1, (len(states), len(actions))),
|
|
index=states,
|
|
columns=action_names
|
|
)
|
|
|
|
# 학습 기록
|
|
self.training_history = []
|
|
self.batch_count = 0
|
|
|
|
def fitted_q_iteration(self, experience_batch: List[ExperienceData]) -> Dict[str, float]:
|
|
"""
|
|
FQI 배치 학습 수행
|
|
|
|
Args:
|
|
experience_batch: 경험 데이터 배치
|
|
|
|
Returns:
|
|
학습 결과 통계
|
|
"""
|
|
if not experience_batch:
|
|
return {"bellman_loss": 0.0, "cql_penalty": 0.0, "batch_size": 0}
|
|
|
|
bellman_losses = []
|
|
cql_penalties = []
|
|
|
|
for exp in experience_batch:
|
|
state = exp.state
|
|
action = exp.action
|
|
reward = exp.reward
|
|
next_state = exp.next_state
|
|
done = exp.done
|
|
|
|
if state not in self.q_network.index:
|
|
continue
|
|
|
|
# Bellman Target 계산
|
|
if done or next_state not in self.q_network.index:
|
|
target = reward
|
|
else:
|
|
target = reward + self.gamma * self.q_network.loc[next_state].max()
|
|
|
|
current_q = self.q_network.loc[state, action.value]
|
|
|
|
# Bellman Error
|
|
bellman_error = (current_q - target) ** 2
|
|
bellman_losses.append(bellman_error)
|
|
|
|
# CQL Conservative Penalty 계산
|
|
# 데이터셋에 있는 행동 vs 모든 가능한 행동의 Q값 차이
|
|
all_q_values = self.q_network.loc[state]
|
|
dataset_q = current_q
|
|
|
|
# 보수적 추정: 데이터에 없는 행동의 Q값을 낮게 유지
|
|
ood_q_values = [] # Out-of-Distribution Q값들
|
|
for other_action in self.actions:
|
|
if other_action != action: # 현재 행동이 아닌 다른 행동들
|
|
ood_q_values.append(all_q_values[other_action.value])
|
|
|
|
if ood_q_values:
|
|
max_ood_q = max(ood_q_values)
|
|
cql_penalty = self.alpha * max(0, max_ood_q - dataset_q)
|
|
else:
|
|
cql_penalty = 0.0
|
|
|
|
cql_penalties.append(cql_penalty)
|
|
|
|
# 네트워크 업데이트 (간단한 그래디언트 스텝)
|
|
# 실제로는 신경망 역전파이지만, 여기서는 직접 업데이트
|
|
gradient = self.learning_rate * (target - current_q)
|
|
conservative_gradient = self.learning_rate * cql_penalty
|
|
|
|
# 벨만 오차 최소화 + CQL 페널티 적용
|
|
update = gradient - conservative_gradient
|
|
self.q_network.loc[state, action.value] += update
|
|
|
|
# 학습 기록 저장
|
|
avg_bellman_loss = np.mean(bellman_losses) if bellman_losses else 0.0
|
|
avg_cql_penalty = np.mean(cql_penalties) if cql_penalties else 0.0
|
|
|
|
self.training_history.append({
|
|
'batch': self.batch_count,
|
|
'avg_bellman_loss': avg_bellman_loss,
|
|
'avg_cql_penalty': avg_cql_penalty,
|
|
'batch_size': len(experience_batch),
|
|
'timestamp': time.time()
|
|
})
|
|
|
|
self.batch_count += 1
|
|
|
|
return {
|
|
"bellman_loss": avg_bellman_loss,
|
|
"cql_penalty": avg_cql_penalty,
|
|
"batch_size": len(experience_batch)
|
|
}
|
|
|
|
def train_multiple_iterations(
|
|
self,
|
|
experience_batch: List[ExperienceData],
|
|
num_iterations: int = 10
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
여러 번의 FQI 반복 수행
|
|
|
|
Args:
|
|
experience_batch: 경험 데이터 배치
|
|
num_iterations: 반복 횟수
|
|
|
|
Returns:
|
|
전체 학습 결과 통계
|
|
"""
|
|
iteration_results = []
|
|
|
|
for i in range(num_iterations):
|
|
# 각 반복에서 배치를 셔플
|
|
shuffled_batch = np.random.permutation(experience_batch).tolist()
|
|
result = self.fitted_q_iteration(shuffled_batch)
|
|
|
|
iteration_results.append({
|
|
'iteration': i,
|
|
**result
|
|
})
|
|
|
|
# 전체 통계 계산
|
|
return {
|
|
'total_iterations': num_iterations,
|
|
'avg_bellman_loss': np.mean([r['bellman_loss'] for r in iteration_results]),
|
|
'avg_cql_penalty': np.mean([r['cql_penalty'] for r in iteration_results]),
|
|
'final_bellman_loss': iteration_results[-1]['bellman_loss'] if iteration_results else 0.0,
|
|
'final_cql_penalty': iteration_results[-1]['cql_penalty'] if iteration_results else 0.0,
|
|
'iteration_details': iteration_results
|
|
}
|
|
|
|
def get_q_value(self, state: str, action: CardType) -> float:
|
|
"""Q값 조회"""
|
|
if state in self.q_network.index:
|
|
return self.q_network.loc[state, action.value]
|
|
return 0.0
|
|
|
|
def get_optimal_action(self, state: str) -> CardType:
|
|
"""현재 상태에서 최적 행동 선택"""
|
|
if state not in self.q_network.index:
|
|
import random
|
|
return random.choice(self.actions)
|
|
|
|
q_values = self.q_network.loc[state]
|
|
best_action_name = q_values.idxmax()
|
|
|
|
for action in self.actions:
|
|
if action.value == best_action_name:
|
|
return action
|
|
|
|
import random
|
|
return random.choice(self.actions)
|
|
|
|
def get_state_q_values(self, state: str) -> Dict[str, float]:
|
|
"""특정 상태의 Q값들 반환"""
|
|
if state not in self.q_network.index:
|
|
return {action.value: 0.0 for action in self.actions}
|
|
|
|
return self.q_network.loc[state].to_dict()
|
|
|
|
def get_q_network_copy(self) -> pd.DataFrame:
|
|
"""Q-네트워크 복사본 반환"""
|
|
return self.q_network.copy()
|
|
|
|
def compare_with_behavior_policy(
|
|
self,
|
|
experience_batch: List[ExperienceData]
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
학습된 정책과 행동 정책(데이터 수집 정책) 비교
|
|
|
|
Args:
|
|
experience_batch: 경험 데이터 배치
|
|
|
|
Returns:
|
|
정책 비교 결과
|
|
"""
|
|
if not experience_batch:
|
|
return {"policy_divergence": 0.0, "action_agreement": 0.0}
|
|
|
|
agreements = 0
|
|
total_comparisons = 0
|
|
q_value_differences = []
|
|
|
|
for exp in experience_batch:
|
|
state = exp.state
|
|
behavior_action = exp.action # 데이터 수집 시 선택된 행동
|
|
|
|
if state not in self.q_network.index:
|
|
continue
|
|
|
|
# 현재 학습된 정책의 최적 행동
|
|
learned_action = self.get_optimal_action(state)
|
|
|
|
# 행동 일치 여부
|
|
if behavior_action == learned_action:
|
|
agreements += 1
|
|
|
|
# Q값 차이 계산
|
|
behavior_q = self.get_q_value(state, behavior_action)
|
|
learned_q = self.get_q_value(state, learned_action)
|
|
q_value_differences.append(abs(learned_q - behavior_q))
|
|
|
|
total_comparisons += 1
|
|
|
|
if total_comparisons == 0:
|
|
return {"policy_divergence": 0.0, "action_agreement": 0.0}
|
|
|
|
action_agreement = agreements / total_comparisons
|
|
avg_q_difference = np.mean(q_value_differences)
|
|
|
|
return {
|
|
"policy_divergence": avg_q_difference,
|
|
"action_agreement": action_agreement,
|
|
"total_comparisons": total_comparisons,
|
|
"agreements": agreements
|
|
}
|
|
|
|
def get_training_statistics(self) -> Dict[str, Any]:
|
|
"""학습 통계 반환"""
|
|
if not self.training_history:
|
|
return {
|
|
"total_batches": 0,
|
|
"avg_bellman_loss": 0.0,
|
|
"avg_cql_penalty": 0.0,
|
|
"convergence_trend": "unknown"
|
|
}
|
|
|
|
recent_history = self.training_history[-10:] # 최근 10개 배치
|
|
|
|
# 수렴 경향 분석
|
|
if len(self.training_history) >= 5:
|
|
recent_losses = [h['avg_bellman_loss'] for h in self.training_history[-5:]]
|
|
if all(recent_losses[i] >= recent_losses[i+1] for i in range(len(recent_losses)-1)):
|
|
convergence_trend = "improving"
|
|
elif all(recent_losses[i] <= recent_losses[i+1] for i in range(len(recent_losses)-1)):
|
|
convergence_trend = "deteriorating"
|
|
else:
|
|
convergence_trend = "fluctuating"
|
|
else:
|
|
convergence_trend = "insufficient_data"
|
|
|
|
return {
|
|
"total_batches": self.batch_count,
|
|
"avg_bellman_loss": np.mean([h['avg_bellman_loss'] for h in recent_history]),
|
|
"avg_cql_penalty": np.mean([h['avg_cql_penalty'] for h in recent_history]),
|
|
"convergence_trend": convergence_trend,
|
|
"q_network_stats": {
|
|
"min": float(self.q_network.min().min()),
|
|
"max": float(self.q_network.max().max()),
|
|
"mean": float(self.q_network.mean().mean()),
|
|
"std": float(self.q_network.std().mean())
|
|
}
|
|
}
|
|
|
|
def reset(self):
|
|
"""학습 상태 초기화"""
|
|
# Q-네트워크 재초기화
|
|
action_names = [action.value for action in self.actions]
|
|
self.q_network = pd.DataFrame(
|
|
np.random.uniform(0, 0.1, (len(self.states), len(self.actions))),
|
|
index=self.states,
|
|
columns=action_names
|
|
)
|
|
|
|
# 기록 초기화
|
|
self.training_history.clear()
|
|
self.batch_count = 0
|
|
|
|
def set_hyperparameters(
|
|
self,
|
|
alpha: Optional[float] = None,
|
|
gamma: Optional[float] = None,
|
|
learning_rate: Optional[float] = None
|
|
):
|
|
"""하이퍼파라미터 설정"""
|
|
if alpha is not None:
|
|
self.alpha = alpha
|
|
if gamma is not None:
|
|
self.gamma = gamma
|
|
if learning_rate is not None:
|
|
self.learning_rate = learning_rate
|