294 lines
8.7 KiB
Python
294 lines
8.7 KiB
Python
"""
|
||
Q-Table 학습 엔진 서비스
|
||
"""
|
||
import pandas as pd
|
||
import numpy as np
|
||
import random
|
||
import time
|
||
from typing import Dict, List, Optional, Tuple, Any
|
||
from collections import deque
|
||
from app.models.schemas import CardType, ExperienceData
|
||
|
||
|
||
class ExperienceBuffer:
|
||
"""경험 데이터 저장 및 관리"""
|
||
|
||
def __init__(self, max_size: int = 10000):
|
||
self.buffer = deque(maxlen=max_size)
|
||
self.max_size = max_size
|
||
|
||
def add_experience(
|
||
self,
|
||
state: str,
|
||
action: CardType,
|
||
reward: float,
|
||
next_state: str,
|
||
done: bool,
|
||
metadata: Optional[Dict[str, Any]] = None
|
||
):
|
||
"""경험 데이터 추가"""
|
||
experience = ExperienceData(
|
||
state=state,
|
||
action=action,
|
||
reward=reward,
|
||
next_state=next_state,
|
||
done=done,
|
||
timestamp=time.time(),
|
||
metadata=metadata or {}
|
||
)
|
||
self.buffer.append(experience)
|
||
|
||
def get_experiences(self) -> List[ExperienceData]:
|
||
"""모든 경험 데이터 반환"""
|
||
return list(self.buffer)
|
||
|
||
def get_dataframe(self) -> pd.DataFrame:
|
||
"""경험 데이터를 DataFrame으로 반환"""
|
||
if not self.buffer:
|
||
return pd.DataFrame()
|
||
|
||
data = []
|
||
for exp in self.buffer:
|
||
data.append({
|
||
'state': exp.state,
|
||
'action': exp.action.value,
|
||
'reward': exp.reward,
|
||
'next_state': exp.next_state,
|
||
'done': exp.done,
|
||
'timestamp': exp.timestamp,
|
||
**exp.metadata
|
||
})
|
||
return pd.DataFrame(data)
|
||
|
||
def sample_batch(self, batch_size: int = 32) -> List[ExperienceData]:
|
||
"""배치 샘플링"""
|
||
if len(self.buffer) <= batch_size:
|
||
return list(self.buffer)
|
||
|
||
indices = np.random.choice(len(self.buffer), batch_size, replace=False)
|
||
return [self.buffer[i] for i in indices]
|
||
|
||
def clear(self):
|
||
"""버퍼 초기화"""
|
||
self.buffer.clear()
|
||
|
||
def size(self) -> int:
|
||
"""버퍼 크기 반환"""
|
||
return len(self.buffer)
|
||
|
||
|
||
class QTableLearner:
|
||
"""Q-Table 학습 엔진"""
|
||
|
||
def __init__(
|
||
self,
|
||
states: List[str],
|
||
actions: List[CardType],
|
||
learning_rate: float = 0.1,
|
||
discount_factor: float = 0.9,
|
||
epsilon: float = 0.1
|
||
):
|
||
self.states = states
|
||
self.actions = actions
|
||
self.learning_rate = learning_rate
|
||
self.discount_factor = discount_factor
|
||
self.epsilon = epsilon
|
||
|
||
# Q-Table 초기화 (모든 값 0)
|
||
action_names = [action.value for action in actions]
|
||
self.q_table = pd.DataFrame(
|
||
0.0,
|
||
index=states,
|
||
columns=action_names
|
||
)
|
||
|
||
# 학습 기록
|
||
self.learning_history = []
|
||
self.update_count = 0
|
||
self.total_reward = 0.0
|
||
|
||
def get_q_value(self, state: str, action: CardType) -> float:
|
||
"""Q값 조회"""
|
||
if state in self.q_table.index:
|
||
return self.q_table.loc[state, action.value]
|
||
return 0.0
|
||
|
||
def set_q_value(self, state: str, action: CardType, value: float):
|
||
"""Q값 설정"""
|
||
if state in self.q_table.index:
|
||
self.q_table.loc[state, action.value] = value
|
||
|
||
def get_optimal_action(self, state: str) -> CardType:
|
||
"""현재 상태에서 최적 행동 선택 (그리디)"""
|
||
if state not in self.q_table.index:
|
||
return random.choice(self.actions)
|
||
|
||
q_values = self.q_table.loc[state]
|
||
best_action_name = q_values.idxmax()
|
||
|
||
# CardType으로 변환
|
||
for action in self.actions:
|
||
if action.value == best_action_name:
|
||
return action
|
||
|
||
return random.choice(self.actions)
|
||
|
||
def select_action(self, state: str, use_epsilon_greedy: bool = True) -> Tuple[CardType, bool]:
|
||
"""
|
||
행동 선택 (엡실론 그리디 또는 그리디)
|
||
|
||
Returns:
|
||
(action, is_exploration): 선택된 행동과 탐험 여부
|
||
"""
|
||
if use_epsilon_greedy and random.random() < self.epsilon:
|
||
# 탐험: 무작위 행동
|
||
return random.choice(self.actions), True
|
||
else:
|
||
# 활용: 최적 행동
|
||
return self.get_optimal_action(state), False
|
||
|
||
def update_q_value(
|
||
self,
|
||
state: str,
|
||
action: CardType,
|
||
reward: float,
|
||
next_state: str,
|
||
done: bool
|
||
) -> float:
|
||
"""
|
||
Q-Learning 업데이트 규칙 적용
|
||
Q(s,a) ← Q(s,a) + α[r + γ max Q(s',a') - Q(s,a)]
|
||
|
||
Returns:
|
||
TD 오차
|
||
"""
|
||
if state not in self.q_table.index:
|
||
return 0.0
|
||
|
||
current_q = self.get_q_value(state, action)
|
||
|
||
if done or next_state not in self.q_table.index:
|
||
target = reward
|
||
else:
|
||
max_next_q = self.q_table.loc[next_state].max()
|
||
target = reward + self.discount_factor * max_next_q
|
||
|
||
# TD 오차 계산
|
||
td_error = target - current_q
|
||
|
||
# Q값 업데이트
|
||
new_q = current_q + self.learning_rate * td_error
|
||
self.set_q_value(state, action, new_q)
|
||
|
||
# 학습 기록 저장
|
||
self.learning_history.append({
|
||
'update': self.update_count,
|
||
'state': state,
|
||
'action': action.value,
|
||
'old_q': current_q,
|
||
'new_q': new_q,
|
||
'reward': reward,
|
||
'target': target,
|
||
'td_error': abs(td_error),
|
||
'timestamp': time.time()
|
||
})
|
||
|
||
self.update_count += 1
|
||
self.total_reward += reward
|
||
|
||
return td_error
|
||
|
||
def batch_update(self, experiences: List[ExperienceData]) -> Dict[str, float]:
|
||
"""배치 업데이트"""
|
||
if not experiences:
|
||
return {"avg_td_error": 0.0, "updates": 0}
|
||
|
||
td_errors = []
|
||
updates = 0
|
||
|
||
for exp in experiences:
|
||
td_error = self.update_q_value(
|
||
exp.state,
|
||
exp.action,
|
||
exp.reward,
|
||
exp.next_state,
|
||
exp.done
|
||
)
|
||
if abs(td_error) > 1e-8: # 의미있는 업데이트만 카운트
|
||
td_errors.append(abs(td_error))
|
||
updates += 1
|
||
|
||
return {
|
||
"avg_td_error": np.mean(td_errors) if td_errors else 0.0,
|
||
"updates": updates,
|
||
"total_experiences": len(experiences)
|
||
}
|
||
|
||
def get_q_table_copy(self) -> pd.DataFrame:
|
||
"""Q-Table 복사본 반환"""
|
||
return self.q_table.copy()
|
||
|
||
def get_state_q_values(self, state: str) -> Dict[str, float]:
|
||
"""특정 상태의 Q값들 반환"""
|
||
if state not in self.q_table.index:
|
||
return {action.value: 0.0 for action in self.actions}
|
||
|
||
return self.q_table.loc[state].to_dict()
|
||
|
||
def get_learning_statistics(self) -> Dict[str, Any]:
|
||
"""학습 통계 반환"""
|
||
if not self.learning_history:
|
||
return {
|
||
"total_updates": 0,
|
||
"avg_td_error": 0.0,
|
||
"avg_reward": 0.0,
|
||
"q_table_sparsity": 1.0
|
||
}
|
||
|
||
recent_history = self.learning_history[-100:] # 최근 100개
|
||
|
||
# Q-Table 희소성 계산 (0이 아닌 값의 비율)
|
||
non_zero_values = (self.q_table != 0).sum().sum()
|
||
total_values = self.q_table.size
|
||
sparsity = 1.0 - (non_zero_values / total_values)
|
||
|
||
return {
|
||
"total_updates": self.update_count,
|
||
"avg_td_error": np.mean([h['td_error'] for h in recent_history]),
|
||
"avg_reward": np.mean([h['reward'] for h in recent_history]),
|
||
"q_table_sparsity": sparsity,
|
||
"q_value_range": {
|
||
"min": float(self.q_table.min().min()),
|
||
"max": float(self.q_table.max().max()),
|
||
"mean": float(self.q_table.mean().mean())
|
||
}
|
||
}
|
||
|
||
def reset(self):
|
||
"""학습 상태 초기화"""
|
||
# Q-Table을 0으로 초기화
|
||
self.q_table = pd.DataFrame(
|
||
0.0,
|
||
index=self.states,
|
||
columns=[action.value for action in self.actions]
|
||
)
|
||
|
||
# 기록 초기화
|
||
self.learning_history.clear()
|
||
self.update_count = 0
|
||
self.total_reward = 0.0
|
||
|
||
def set_hyperparameters(
|
||
self,
|
||
learning_rate: Optional[float] = None,
|
||
discount_factor: Optional[float] = None,
|
||
epsilon: Optional[float] = None
|
||
):
|
||
"""하이퍼파라미터 설정"""
|
||
if learning_rate is not None:
|
||
self.learning_rate = learning_rate
|
||
if discount_factor is not None:
|
||
self.discount_factor = discount_factor
|
||
if epsilon is not None:
|
||
self.epsilon = epsilon
|