q_table_demo/app/services/qtable_learner.py

294 lines
8.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
Q-Table 학습 엔진 서비스
"""
import pandas as pd
import numpy as np
import random
import time
from typing import Dict, List, Optional, Tuple, Any
from collections import deque
from app.models.schemas import CardType, ExperienceData
class ExperienceBuffer:
"""경험 데이터 저장 및 관리"""
def __init__(self, max_size: int = 10000):
self.buffer = deque(maxlen=max_size)
self.max_size = max_size
def add_experience(
self,
state: str,
action: CardType,
reward: float,
next_state: str,
done: bool,
metadata: Optional[Dict[str, Any]] = None
):
"""경험 데이터 추가"""
experience = ExperienceData(
state=state,
action=action,
reward=reward,
next_state=next_state,
done=done,
timestamp=time.time(),
metadata=metadata or {}
)
self.buffer.append(experience)
def get_experiences(self) -> List[ExperienceData]:
"""모든 경험 데이터 반환"""
return list(self.buffer)
def get_dataframe(self) -> pd.DataFrame:
"""경험 데이터를 DataFrame으로 반환"""
if not self.buffer:
return pd.DataFrame()
data = []
for exp in self.buffer:
data.append({
'state': exp.state,
'action': exp.action.value,
'reward': exp.reward,
'next_state': exp.next_state,
'done': exp.done,
'timestamp': exp.timestamp,
**exp.metadata
})
return pd.DataFrame(data)
def sample_batch(self, batch_size: int = 32) -> List[ExperienceData]:
"""배치 샘플링"""
if len(self.buffer) <= batch_size:
return list(self.buffer)
indices = np.random.choice(len(self.buffer), batch_size, replace=False)
return [self.buffer[i] for i in indices]
def clear(self):
"""버퍼 초기화"""
self.buffer.clear()
def size(self) -> int:
"""버퍼 크기 반환"""
return len(self.buffer)
class QTableLearner:
"""Q-Table 학습 엔진"""
def __init__(
self,
states: List[str],
actions: List[CardType],
learning_rate: float = 0.1,
discount_factor: float = 0.9,
epsilon: float = 0.1
):
self.states = states
self.actions = actions
self.learning_rate = learning_rate
self.discount_factor = discount_factor
self.epsilon = epsilon
# Q-Table 초기화 (모든 값 0)
action_names = [action.value for action in actions]
self.q_table = pd.DataFrame(
0.0,
index=states,
columns=action_names
)
# 학습 기록
self.learning_history = []
self.update_count = 0
self.total_reward = 0.0
def get_q_value(self, state: str, action: CardType) -> float:
"""Q값 조회"""
if state in self.q_table.index:
return self.q_table.loc[state, action.value]
return 0.0
def set_q_value(self, state: str, action: CardType, value: float):
"""Q값 설정"""
if state in self.q_table.index:
self.q_table.loc[state, action.value] = value
def get_optimal_action(self, state: str) -> CardType:
"""현재 상태에서 최적 행동 선택 (그리디)"""
if state not in self.q_table.index:
return random.choice(self.actions)
q_values = self.q_table.loc[state]
best_action_name = q_values.idxmax()
# CardType으로 변환
for action in self.actions:
if action.value == best_action_name:
return action
return random.choice(self.actions)
def select_action(self, state: str, use_epsilon_greedy: bool = True) -> Tuple[CardType, bool]:
"""
행동 선택 (엡실론 그리디 또는 그리디)
Returns:
(action, is_exploration): 선택된 행동과 탐험 여부
"""
if use_epsilon_greedy and random.random() < self.epsilon:
# 탐험: 무작위 행동
return random.choice(self.actions), True
else:
# 활용: 최적 행동
return self.get_optimal_action(state), False
def update_q_value(
self,
state: str,
action: CardType,
reward: float,
next_state: str,
done: bool
) -> float:
"""
Q-Learning 업데이트 규칙 적용
Q(s,a) ← Q(s,a) + α[r + γ max Q(s',a') - Q(s,a)]
Returns:
TD 오차
"""
if state not in self.q_table.index:
return 0.0
current_q = self.get_q_value(state, action)
if done or next_state not in self.q_table.index:
target = reward
else:
max_next_q = self.q_table.loc[next_state].max()
target = reward + self.discount_factor * max_next_q
# TD 오차 계산
td_error = target - current_q
# Q값 업데이트
new_q = current_q + self.learning_rate * td_error
self.set_q_value(state, action, new_q)
# 학습 기록 저장
self.learning_history.append({
'update': self.update_count,
'state': state,
'action': action.value,
'old_q': current_q,
'new_q': new_q,
'reward': reward,
'target': target,
'td_error': abs(td_error),
'timestamp': time.time()
})
self.update_count += 1
self.total_reward += reward
return td_error
def batch_update(self, experiences: List[ExperienceData]) -> Dict[str, float]:
"""배치 업데이트"""
if not experiences:
return {"avg_td_error": 0.0, "updates": 0}
td_errors = []
updates = 0
for exp in experiences:
td_error = self.update_q_value(
exp.state,
exp.action,
exp.reward,
exp.next_state,
exp.done
)
if abs(td_error) > 1e-8: # 의미있는 업데이트만 카운트
td_errors.append(abs(td_error))
updates += 1
return {
"avg_td_error": np.mean(td_errors) if td_errors else 0.0,
"updates": updates,
"total_experiences": len(experiences)
}
def get_q_table_copy(self) -> pd.DataFrame:
"""Q-Table 복사본 반환"""
return self.q_table.copy()
def get_state_q_values(self, state: str) -> Dict[str, float]:
"""특정 상태의 Q값들 반환"""
if state not in self.q_table.index:
return {action.value: 0.0 for action in self.actions}
return self.q_table.loc[state].to_dict()
def get_learning_statistics(self) -> Dict[str, Any]:
"""학습 통계 반환"""
if not self.learning_history:
return {
"total_updates": 0,
"avg_td_error": 0.0,
"avg_reward": 0.0,
"q_table_sparsity": 1.0
}
recent_history = self.learning_history[-100:] # 최근 100개
# Q-Table 희소성 계산 (0이 아닌 값의 비율)
non_zero_values = (self.q_table != 0).sum().sum()
total_values = self.q_table.size
sparsity = 1.0 - (non_zero_values / total_values)
return {
"total_updates": self.update_count,
"avg_td_error": np.mean([h['td_error'] for h in recent_history]),
"avg_reward": np.mean([h['reward'] for h in recent_history]),
"q_table_sparsity": sparsity,
"q_value_range": {
"min": float(self.q_table.min().min()),
"max": float(self.q_table.max().max()),
"mean": float(self.q_table.mean().mean())
}
}
def reset(self):
"""학습 상태 초기화"""
# Q-Table을 0으로 초기화
self.q_table = pd.DataFrame(
0.0,
index=self.states,
columns=[action.value for action in self.actions]
)
# 기록 초기화
self.learning_history.clear()
self.update_count = 0
self.total_reward = 0.0
def set_hyperparameters(
self,
learning_rate: Optional[float] = None,
discount_factor: Optional[float] = None,
epsilon: Optional[float] = None
):
"""하이퍼파라미터 설정"""
if learning_rate is not None:
self.learning_rate = learning_rate
if discount_factor is not None:
self.discount_factor = discount_factor
if epsilon is not None:
self.epsilon = epsilon