906 lines
34 KiB
Python
906 lines
34 KiB
Python
"""
|
||
Streamlit 기반 Q-Table 협상 전략 데모 프론트엔드
|
||
"""
|
||
import streamlit as st
|
||
import requests
|
||
import pandas as pd
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
import seaborn as sns
|
||
import plotly.graph_objects as go
|
||
from plotly.subplots import make_subplots
|
||
import time
|
||
import json
|
||
from typing import Dict, Any, Optional
|
||
|
||
# 페이지 설정
|
||
st.set_page_config(
|
||
page_title="Q-Table 협상 전략 데모",
|
||
page_icon="🎯",
|
||
layout="wide",
|
||
initial_sidebar_state="expanded"
|
||
)
|
||
|
||
# API 기본 URL
|
||
API_BASE_URL = "http://localhost:8000/api/v1"
|
||
|
||
# 세션 상태 초기화
|
||
if 'current_state' not in st.session_state:
|
||
st.session_state.current_state = "C0S0P0"
|
||
if 'anchor_price' not in st.session_state:
|
||
st.session_state.anchor_price = 100
|
||
|
||
|
||
class APIClient:
|
||
"""API 클라이언트"""
|
||
|
||
@staticmethod
|
||
def get(endpoint: str) -> Optional[Dict[str, Any]]:
|
||
"""GET 요청"""
|
||
try:
|
||
response = requests.get(f"{API_BASE_URL}{endpoint}")
|
||
response.raise_for_status()
|
||
return response.json()
|
||
except requests.exceptions.RequestException as e:
|
||
st.error(f"API 요청 오류: {e}")
|
||
return None
|
||
|
||
@staticmethod
|
||
def post(endpoint: str, data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||
"""POST 요청"""
|
||
try:
|
||
response = requests.post(f"{API_BASE_URL}{endpoint}", json=data)
|
||
response.raise_for_status()
|
||
return response.json()
|
||
except requests.exceptions.RequestException as e:
|
||
st.error(f"API 요청 오류: {e}")
|
||
return None
|
||
|
||
|
||
def display_header():
|
||
"""헤더 제목 표시"""
|
||
st.title("🎯 Q-Table 기반 협상 전략 데모")
|
||
st.markdown("""
|
||
### 강화학습으로 배우는 협상 전략의 진화
|
||
|
||
이 데모는 **콜드 스타트 문제**부터 **학습된 정책**까지 Q-Learning의 전체 여정을 보여줍니다.
|
||
|
||
**핵심 보상함수:** `R(s,a) = W × (A/P) + (1-W) × End`
|
||
""")
|
||
st.markdown("---")
|
||
|
||
|
||
def display_sidebar():
|
||
"""사이드바 설정"""
|
||
st.sidebar.header("⚙️ 데모 설정")
|
||
|
||
# 시스템 상태 조회
|
||
status_response = APIClient.get("/status")
|
||
if status_response and status_response.get("success") is not False:
|
||
status = status_response
|
||
st.sidebar.metric("총 경험 데이터", status.get("total_experiences", 0))
|
||
st.sidebar.metric("Q-Table 업데이트", status.get("q_table_updates", 0))
|
||
st.sidebar.metric("고유 상태", status.get("unique_states", 0))
|
||
st.sidebar.metric("평균 보상", f"{status.get('average_reward', 0):.3f}")
|
||
st.sidebar.metric("성공률", f"{status.get('success_rate', 0)*100:.1f}%")
|
||
|
||
st.sidebar.markdown("---")
|
||
|
||
# 글로벌 설정
|
||
st.session_state.anchor_price = st.sidebar.number_input(
|
||
"목표가 (A)",
|
||
value=st.session_state.anchor_price,
|
||
min_value=50,
|
||
max_value=300
|
||
)
|
||
|
||
# 시스템 초기화
|
||
if st.sidebar.button("🔄 시스템 초기화", type="secondary"):
|
||
with st.spinner("시스템 초기화 중..."):
|
||
reset_response = APIClient.post("/reset", {})
|
||
if reset_response and reset_response.get("success"):
|
||
st.sidebar.success("시스템이 초기화되었습니다!")
|
||
st.rerun()
|
||
else:
|
||
st.sidebar.error("초기화 실패")
|
||
|
||
return status_response
|
||
|
||
|
||
def tab_cold_start():
|
||
"""콜드 스타트 탭"""
|
||
st.header("🏁 콜드 스타트 문제")
|
||
|
||
st.markdown("""
|
||
### 강화학습의 첫 번째 난관
|
||
|
||
새로운 강화학습 에이전트가 직면하는 가장 큰 문제는 **"아무것도 모른다"**는 것입니다.
|
||
모든 Q값이 0으로 초기화되어 있어, 어떤 행동이 좋은지 전혀 알 수 없습니다.
|
||
""")
|
||
|
||
# Q-Table 현재 상태 조회
|
||
qtable_response = APIClient.get("/qtable")
|
||
if qtable_response and qtable_response.get("success"):
|
||
qtable_data = qtable_response["data"]
|
||
q_table_dict = qtable_data["q_table"]
|
||
|
||
# DataFrame으로 변환
|
||
q_table_df = pd.DataFrame(q_table_dict)
|
||
|
||
st.subheader("📋 현재 Q-Table 상태")
|
||
|
||
# 통계 표시
|
||
col1, col2, col3, col4 = st.columns(4)
|
||
with col1:
|
||
non_zero_count = (q_table_df != 0).sum().sum()
|
||
st.metric("비어있지 않은 Q값", non_zero_count)
|
||
with col2:
|
||
total_count = q_table_df.size
|
||
st.metric("전체 Q값", total_count)
|
||
with col3:
|
||
sparsity = (1 - non_zero_count / total_count) * 100
|
||
st.metric("희소성", f"{sparsity:.1f}%")
|
||
with col4:
|
||
st.metric("업데이트 횟수", qtable_data["update_count"])
|
||
|
||
# Q-Table 표시 (상위 20개 상태만)
|
||
display_rows = min(20, len(q_table_df))
|
||
st.dataframe(
|
||
q_table_df.head(display_rows).style.format("{:.3f}").highlight_max(axis=1),
|
||
use_container_width=True
|
||
)
|
||
|
||
if len(q_table_df) > 20:
|
||
st.info(f"전체 {len(q_table_df)}개 상태 중 상위 20개만 표시됩니다.")
|
||
|
||
# 문제점과 해결방법
|
||
col1, col2 = st.columns(2)
|
||
|
||
with col1:
|
||
st.subheader("❌ 핵심 문제점")
|
||
st.markdown("""
|
||
- **무지상태**: 모든 Q값이 동일 (보통 0)
|
||
- **행동 선택 불가**: 어떤 행동이 좋은지 모름
|
||
- **무작위 탐험**: 비효율적인 학습 초기 단계
|
||
- **데이터 부족**: 학습할 경험이 없음
|
||
""")
|
||
|
||
with col2:
|
||
st.subheader("✅ 해결 방법")
|
||
st.markdown("""
|
||
- **탐험 전략**: Epsilon-greedy로 무작위 탐험
|
||
- **경험 수집**: (상태, 행동, 보상, 다음상태) 튜플 저장
|
||
- **점진적 학습**: 수집된 경험으로 Q값 업데이트
|
||
- **정책 개선**: 학습을 통한 점진적 성능 향상
|
||
""")
|
||
|
||
# Q-Learning 공식 설명
|
||
st.subheader("🧮 Q-Learning 업데이트 공식")
|
||
st.latex(r"Q(s,a) \leftarrow Q(s,a) + lpha [r + \gamma \max_{a'} Q(s',a') - Q(s,a)]")
|
||
|
||
st.markdown("""
|
||
**공식 설명:**
|
||
- **Q(s,a)**: 상태 s에서 행동 a의 Q값
|
||
- **α (알파)**: 학습률 (0 < α ≤ 1)
|
||
- **r**: 즉시 보상
|
||
- **γ (감마)**: 할인율 (0 ≤ γ < 1)
|
||
- **max Q(s',a')**: 다음 상태에서의 최대 Q값
|
||
""")
|
||
|
||
|
||
def tab_data_collection():
|
||
"""데이터 수집 탭"""
|
||
st.header("📊 경험 데이터 수집")
|
||
|
||
st.markdown("""
|
||
### 학습의 연료: 경험 데이터
|
||
|
||
강화학습 에이전트는 환경과 상호작용하면서 **경험 튜플**을 수집합니다.
|
||
각 경험은 `(상태, 행동, 보상, 다음상태, 종료여부)` 형태로 저장됩니다.
|
||
""")
|
||
|
||
# 설정 섹션
|
||
col1, col2 = st.columns([1, 2])
|
||
|
||
with col1:
|
||
st.subheader("⚙️ 에피소드 생성 설정")
|
||
|
||
num_episodes = st.slider("생성할 에피소드 수", 1, 50, 10)
|
||
max_steps = st.slider("에피소드당 최대 스텝", 3, 15, 8)
|
||
exploration_rate = st.slider("탐험율 (Epsilon)", 0.0, 1.0, 0.4, 0.1)
|
||
|
||
st.markdown(f"""
|
||
**현재 설정:**
|
||
- 목표가: {st.session_state.anchor_price}
|
||
- 탐험율: {exploration_rate*100:.0f}%
|
||
- 총 예상 경험: ~{num_episodes * max_steps}개
|
||
""")
|
||
|
||
if st.button("🎲 자동 에피소드 생성", type="primary"):
|
||
with st.spinner("에피소드 생성 중..."):
|
||
request_data = {
|
||
"num_episodes": num_episodes,
|
||
"max_steps": max_steps,
|
||
"anchor_price": st.session_state.anchor_price,
|
||
"exploration_rate": exploration_rate
|
||
}
|
||
|
||
response = APIClient.post("/episodes/generate", request_data)
|
||
if response and response.get("success"):
|
||
result = response["data"]
|
||
st.success(f"✅ {result['new_experiences']}개의 새로운 경험 데이터 생성!")
|
||
|
||
# 에피소드 결과 표시
|
||
episode_results = result["episode_results"]
|
||
success_count = sum(1 for ep in episode_results if ep["success"])
|
||
|
||
col_a, col_b, col_c = st.columns(3)
|
||
with col_a:
|
||
st.metric("생성된 에피소드", result["episodes_generated"])
|
||
with col_b:
|
||
st.metric("성공한 협상", success_count)
|
||
with col_c:
|
||
success_rate = (success_count / len(episode_results)) * 100
|
||
st.metric("성공률", f"{success_rate:.1f}%")
|
||
|
||
time.sleep(1) # UI 업데이트를 위한 잠시 대기
|
||
st.rerun()
|
||
else:
|
||
st.error("에피소드 생성 실패")
|
||
|
||
with col2:
|
||
st.subheader("📈 수집된 데이터 현황")
|
||
|
||
# 경험 데이터 조회
|
||
exp_response = APIClient.get("/experiences")
|
||
if exp_response and exp_response.get("success"):
|
||
exp_data = exp_response["data"]
|
||
stats = exp_data["statistics"]
|
||
recent_data = exp_data["recent_data"]
|
||
|
||
# 통계 표시
|
||
col_a, col_b, col_c, col_d = st.columns(4)
|
||
with col_a:
|
||
st.metric("총 경험 수", stats["total_count"])
|
||
with col_b:
|
||
st.metric("평균 보상", f"{stats['avg_reward']:.3f}")
|
||
with col_c:
|
||
st.metric("성공률", f"{stats['success_rate']*100:.1f}%")
|
||
with col_d:
|
||
st.metric("고유 상태", stats["unique_states"])
|
||
|
||
# 최근 경험 데이터 표시
|
||
if recent_data:
|
||
st.subheader("🔍 최근 경험 데이터")
|
||
recent_df = pd.DataFrame(recent_data)
|
||
|
||
# 필요한 컬럼만 선택
|
||
display_columns = ['state', 'action', 'reward', 'next_state', 'done']
|
||
available_columns = [col for col in display_columns if col in recent_df.columns]
|
||
|
||
if available_columns:
|
||
display_df = recent_df[available_columns].tail(10)
|
||
st.dataframe(
|
||
display_df.style.format({'reward': '{:.3f}'}),
|
||
use_container_width=True
|
||
)
|
||
|
||
# 데이터 분포 시각화
|
||
if len(recent_df) > 5:
|
||
st.subheader("📊 데이터 분포 분석")
|
||
|
||
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
|
||
|
||
# 보상 분포
|
||
axes[0,0].hist(recent_df['reward'], bins=15, alpha=0.7, color='skyblue', edgecolor='black')
|
||
axes[0,0].set_title('보상 분포')
|
||
axes[0,0].set_xlabel('보상값')
|
||
axes[0,0].set_ylabel('빈도')
|
||
|
||
# 행동 분포
|
||
if 'action' in recent_df.columns:
|
||
action_counts = recent_df['action'].value_counts()
|
||
axes[0,1].bar(action_counts.index, action_counts.values, color='lightgreen', edgecolor='black')
|
||
axes[0,1].set_title('행동 선택 빈도')
|
||
axes[0,1].set_xlabel('행동')
|
||
axes[0,1].set_ylabel('빈도')
|
||
|
||
# 상태 분포 (상위 10개)
|
||
if 'state' in recent_df.columns:
|
||
state_counts = recent_df['state'].value_counts().head(10)
|
||
axes[1,0].bar(range(len(state_counts)), state_counts.values, color='orange', edgecolor='black')
|
||
axes[1,0].set_title('상위 상태 빈도')
|
||
axes[1,0].set_xlabel('상태 순위')
|
||
axes[1,0].set_ylabel('빈도')
|
||
axes[1,0].set_xticks(range(len(state_counts)))
|
||
axes[1,0].set_xticklabels([f"{i+1}" for i in range(len(state_counts))])
|
||
|
||
# 성공/실패 분포
|
||
if 'done' in recent_df.columns:
|
||
done_counts = recent_df['done'].value_counts()
|
||
labels = ['진행중' if not k else '완료' for k in done_counts.index]
|
||
axes[1,1].pie(done_counts.values, labels=labels, autopct='%1.1f%%', colors=['lightcoral', 'lightblue'])
|
||
axes[1,1].set_title('협상 완료 비율')
|
||
|
||
plt.tight_layout()
|
||
st.pyplot(fig)
|
||
else:
|
||
st.info("아직 수집된 데이터가 없습니다. 왼쪽에서 에피소드를 생성해보세요!")
|
||
else:
|
||
st.warning("경험 데이터를 불러올 수 없습니다.")
|
||
|
||
# 경험 데이터 구조 설명
|
||
st.subheader("📋 경험 데이터 구조")
|
||
st.markdown("""
|
||
각 경험 튜플은 다음 정보를 포함합니다:
|
||
|
||
| 항목 | 설명 | 예시 |
|
||
|------|------|------|
|
||
| **상태 (State)** | 현재 협상 상황 | "C1APZ2" (카드1, 시나리오A, 가격구간2) |
|
||
| **행동 (Action)** | 선택한 협상 카드 | "C3" |
|
||
| **보상 (Reward)** | 행동에 대한 평가 | 0.85 |
|
||
| **다음상태 (Next State)** | 행동 후 새로운 상황 | "C3APZ1" |
|
||
| **종료 (Done)** | 협상 완료 여부 | true/false |
|
||
""")
|
||
|
||
|
||
# 더 많은 탭 함수들을 다음 파트에서 계속...
|
||
|
||
|
||
def tab_q_learning():
|
||
"""Q-Learning 탭"""
|
||
st.header("🔄 Q-Learning 실시간 학습")
|
||
|
||
st.markdown("""
|
||
### 경험으로부터 학습하기
|
||
|
||
수집된 경험 데이터를 사용하여 Q-Table을 업데이트합니다.
|
||
각 경험에서 **TD(Temporal Difference) 오차**를 계산하고 Q값을 조정합니다.
|
||
""")
|
||
|
||
col1, col2 = st.columns([1, 2])
|
||
|
||
with col1:
|
||
st.subheader("⚙️ 학습 설정")
|
||
|
||
learning_rate = st.slider("학습률 (α)", 0.01, 0.5, 0.1, 0.01)
|
||
discount_factor = st.slider("할인율 (γ)", 0.8, 0.99, 0.9, 0.01)
|
||
batch_size = st.slider("배치 크기", 16, 256, 32, 16)
|
||
|
||
st.markdown(f"""
|
||
**하이퍼파라미터:**
|
||
- **학습률 (α)**: {learning_rate} - 새로운 정보의 반영 정도
|
||
- **할인율 (γ)**: {discount_factor} - 미래 보상의 중요도
|
||
- **배치 크기**: {batch_size} - 한 번에 학습할 경험 수
|
||
""")
|
||
|
||
if st.button("🚀 Q-Learning 실행", type="primary"):
|
||
with st.spinner("Q-Learning 업데이트 중..."):
|
||
request_data = {
|
||
"learning_rate": learning_rate,
|
||
"discount_factor": discount_factor,
|
||
"batch_size": batch_size
|
||
}
|
||
|
||
response = APIClient.post("/learning/q-learning", request_data)
|
||
if response and response.get("success"):
|
||
result = response["data"]
|
||
st.success(f"✅ {result['updates']}개 Q값 업데이트 완료!")
|
||
|
||
col_a, col_b = st.columns(2)
|
||
with col_a:
|
||
st.metric("배치 크기", result["batch_size"])
|
||
with col_b:
|
||
st.metric("평균 TD 오차", f"{result.get('avg_td_error', 0):.4f}")
|
||
|
||
time.sleep(1)
|
||
st.rerun()
|
||
else:
|
||
st.error("Q-Learning 업데이트 실패")
|
||
|
||
with col2:
|
||
st.subheader("📊 Q-Table 현황")
|
||
|
||
# Q-Table 데이터 조회
|
||
qtable_response = APIClient.get("/qtable")
|
||
if qtable_response and qtable_response.get("success"):
|
||
qtable_data = qtable_response["data"]
|
||
statistics = qtable_data["statistics"]
|
||
|
||
# 통계 표시
|
||
col_a, col_b, col_c, col_d = st.columns(4)
|
||
with col_a:
|
||
st.metric("총 업데이트", statistics.get("total_updates", 0))
|
||
with col_b:
|
||
st.metric("평균 TD 오차", f"{statistics.get('avg_td_error', 0):.4f}")
|
||
with col_c:
|
||
st.metric("평균 보상", f"{statistics.get('avg_reward', 0):.3f}")
|
||
with col_d:
|
||
sparsity = statistics.get('q_table_sparsity', 1.0) * 100
|
||
st.metric("Q-Table 희소성", f"{sparsity:.1f}%")
|
||
|
||
# Q값 범위 표시
|
||
q_range = statistics.get('q_value_range', {})
|
||
if q_range:
|
||
st.subheader("📈 Q값 분포")
|
||
col_a, col_b, col_c = st.columns(3)
|
||
with col_a:
|
||
st.metric("최솟값", f"{q_range.get('min', 0):.3f}")
|
||
with col_b:
|
||
st.metric("평균값", f"{q_range.get('mean', 0):.3f}")
|
||
with col_c:
|
||
st.metric("최댓값", f"{q_range.get('max', 0):.3f}")
|
||
|
||
# Q-Table 표시 (비어있지 않은 상태들만)
|
||
q_table_dict = qtable_data["q_table"]
|
||
q_table_df = pd.DataFrame(q_table_dict)
|
||
|
||
# 0이 아닌 값이 있는 행만 필터링
|
||
non_zero_rows = (q_table_df != 0).any(axis=1)
|
||
if non_zero_rows.any():
|
||
st.subheader("🎯 학습된 Q값들")
|
||
learned_qtable = q_table_df[non_zero_rows].head(15)
|
||
st.dataframe(
|
||
learned_qtable.style.format("{:.3f}").highlight_max(axis=1),
|
||
use_container_width=True
|
||
)
|
||
|
||
learned_count = non_zero_rows.sum()
|
||
total_count = len(q_table_df)
|
||
st.info(f"전체 {total_count}개 상태 중 {learned_count}개 상태가 학습되었습니다.")
|
||
else:
|
||
st.info("아직 학습된 Q값이 없습니다. 위에서 Q-Learning을 실행해보세요!")
|
||
|
||
# TD 오차 설명
|
||
st.subheader("🧮 TD(Temporal Difference) 오차")
|
||
st.markdown("""
|
||
**TD 오차**는 현재 Q값과 목표값의 차이입니다:
|
||
|
||
`TD 오차 = [r + γ max Q(s',a')] - Q(s,a)`
|
||
|
||
- **양수**: 현재 Q값이 너무 낮음 → Q값 증가
|
||
- **음수**: 현재 Q값이 너무 높음 → Q값 감소
|
||
- **0에 가까움**: Q값이 적절함 → 학습 수렴
|
||
""")
|
||
|
||
|
||
def tab_fqi_cql():
|
||
"""FQI+CQL 탭"""
|
||
st.header("🧠 FQI + CQL 오프라인 학습")
|
||
|
||
st.markdown("""
|
||
### 오프라인 강화학습의 핵심
|
||
|
||
**FQI (Fitted Q-Iteration)**와 **CQL (Conservative Q-Learning)**을 결합한
|
||
오프라인 강화학습 방법입니다. 수집된 데이터만으로 안전하고 보수적인 정책을 학습합니다.
|
||
""")
|
||
|
||
col1, col2 = st.columns([1, 2])
|
||
|
||
with col1:
|
||
st.subheader("⚙️ FQI+CQL 설정")
|
||
|
||
alpha = st.slider("CQL 보수성 파라미터 (α)", 0.0, 3.0, 1.0, 0.1)
|
||
gamma = st.slider("할인율 (γ)", 0.8, 0.99, 0.95, 0.01)
|
||
batch_size = st.slider("배치 크기", 16, 256, 32, 16)
|
||
num_iterations = st.slider("반복 횟수", 1, 50, 10, 1)
|
||
|
||
st.markdown(f"""
|
||
**설정값:**
|
||
- **α (Alpha)**: {alpha} - 보수성 강도
|
||
- **γ (Gamma)**: {gamma} - 미래 보상 할인
|
||
- **배치 크기**: {batch_size}
|
||
- **반복 횟수**: {num_iterations}
|
||
""")
|
||
|
||
st.markdown("""
|
||
**CQL 특징:**
|
||
- 🛡️ **보수적 추정**: 불확실한 행동의 Q값을 낮게 유지
|
||
- 📊 **데이터 기반**: 수집된 경험만 활용
|
||
- 🎯 **안전한 정책**: 분포 이동 문제 해결
|
||
""")
|
||
|
||
if st.button("🚀 FQI+CQL 실행", type="primary"):
|
||
with st.spinner("FQI+CQL 학습 중..."):
|
||
request_data = {
|
||
"alpha": alpha,
|
||
"gamma": gamma,
|
||
"batch_size": batch_size,
|
||
"num_iterations": num_iterations
|
||
}
|
||
|
||
response = APIClient.post("/learning/fqi-cql", request_data)
|
||
if response and response.get("success"):
|
||
result = response["data"]
|
||
training_result = result["training_result"]
|
||
policy_comparison = result["policy_comparison"]
|
||
|
||
st.success(f"✅ {training_result['total_iterations']}회 반복 학습 완료!")
|
||
|
||
# 학습 결과 표시
|
||
col_a, col_b = st.columns(2)
|
||
with col_a:
|
||
st.metric("평균 벨만 손실", f"{training_result['avg_bellman_loss']:.4f}")
|
||
with col_b:
|
||
st.metric("평균 CQL 페널티", f"{training_result['avg_cql_penalty']:.4f}")
|
||
|
||
# 정책 비교
|
||
st.metric("행동 정책과의 일치율", f"{policy_comparison['action_agreement']*100:.1f}%")
|
||
|
||
time.sleep(1)
|
||
st.rerun()
|
||
else:
|
||
st.error("FQI+CQL 학습 실패")
|
||
|
||
with col2:
|
||
st.subheader("📊 FQI+CQL 결과")
|
||
|
||
# FQI+CQL 결과 조회
|
||
fqi_response = APIClient.get("/fqi-cql")
|
||
if fqi_response and fqi_response.get("success"):
|
||
fqi_data = fqi_response["data"]
|
||
statistics = fqi_data["statistics"]
|
||
|
||
# 통계 표시
|
||
col_a, col_b, col_c = st.columns(3)
|
||
with col_a:
|
||
st.metric("학습 배치", statistics.get("total_batches", 0))
|
||
with col_b:
|
||
st.metric("벨만 손실", f"{statistics.get('avg_bellman_loss', 0):.4f}")
|
||
with col_c:
|
||
st.metric("CQL 페널티", f"{statistics.get('avg_cql_penalty', 0):.4f}")
|
||
|
||
# 수렴 경향
|
||
convergence = statistics.get("convergence_trend", "unknown")
|
||
convergence_color = {
|
||
"improving": "🟢",
|
||
"deteriorating": "🔴",
|
||
"fluctuating": "🟡",
|
||
"insufficient_data": "⚪"
|
||
}
|
||
st.info(f"수렴 경향: {convergence_color.get(convergence, '❓')} {convergence}")
|
||
|
||
# Q-Network 통계
|
||
q_stats = statistics.get("q_network_stats", {})
|
||
if q_stats:
|
||
st.subheader("📈 Q-Network 분포")
|
||
col_a, col_b, col_c, col_d = st.columns(4)
|
||
with col_a:
|
||
st.metric("최솟값", f"{q_stats.get('min', 0):.3f}")
|
||
with col_b:
|
||
st.metric("평균값", f"{q_stats.get('mean', 0):.3f}")
|
||
with col_c:
|
||
st.metric("최댓값", f"{q_stats.get('max', 0):.3f}")
|
||
with col_d:
|
||
st.metric("표준편차", f"{q_stats.get('std', 0):.3f}")
|
||
|
||
# Q-Network 표시 (상위 15개 상태)
|
||
q_network_dict = fqi_data["q_network"]
|
||
q_network_df = pd.DataFrame(q_network_dict)
|
||
|
||
st.subheader("🎯 학습된 Q-Network")
|
||
display_df = q_network_df.head(15)
|
||
st.dataframe(
|
||
display_df.style.format("{:.3f}").highlight_max(axis=1),
|
||
use_container_width=True
|
||
)
|
||
else:
|
||
st.info("FQI+CQL을 먼저 실행해주세요!")
|
||
|
||
# FQI+CQL 알고리즘 설명
|
||
st.subheader("🔬 FQI + CQL 알고리즘")
|
||
|
||
col1, col2 = st.columns(2)
|
||
|
||
with col1:
|
||
st.markdown("""
|
||
**FQI (Fitted Q-Iteration)**
|
||
- 배치 기반 Q-Learning
|
||
- 전체 데이터셋을 한 번에 활용
|
||
- 함수 근사 (신경망) 사용
|
||
- 안정적인 학습 과정
|
||
""")
|
||
|
||
with col2:
|
||
st.markdown("""
|
||
**CQL (Conservative Q-Learning)**
|
||
- 보수적 Q값 추정
|
||
- Out-of-Distribution 행동 억제
|
||
- 데이터에 없는 행동의 Q값 하향 조정
|
||
- 안전한 정책 학습
|
||
""")
|
||
|
||
|
||
def tab_learned_policy():
|
||
"""학습된 정책 탭"""
|
||
st.header("🎯 학습된 정책 비교 및 활용")
|
||
|
||
st.markdown("""
|
||
### 학습 완료: 정책의 진화
|
||
|
||
Q-Learning과 FQI+CQL로 학습된 정책을 비교하고,
|
||
실제 협상 상황에서 어떤 행동을 추천하는지 확인해보세요.
|
||
""")
|
||
|
||
# 상태 선택
|
||
col1, col2 = st.columns([1, 2])
|
||
|
||
with col1:
|
||
st.subheader("🎮 협상 시뮬레이션")
|
||
|
||
# 상태 구성 요소 선택
|
||
current_card = st.selectbox("현재 카드", ["C1", "C2", "C3", "C4"])
|
||
scenario = st.selectbox("시나리오", ["A", "B", "C", "D"])
|
||
price_zone = st.selectbox("가격 구간", ["PZ1", "PZ2", "PZ3"])
|
||
|
||
# 상태 ID 생성
|
||
selected_state = f"{current_card}{scenario}{price_zone}"
|
||
st.session_state.current_state = selected_state
|
||
|
||
st.info(f"선택된 상태: **{selected_state}**")
|
||
|
||
# 상태 해석
|
||
state_interpretation = {
|
||
"A": "어려운 협상 (높은 가중치)",
|
||
"B": "쉬운 협상 (낮은 가중치)",
|
||
"C": "보통 협상 (중간 가중치)",
|
||
"D": "매우 쉬운 협상 (낮은 가중치)"
|
||
}
|
||
|
||
price_interpretation = {
|
||
"PZ1": "목표가 이하 (좋은 구간)",
|
||
"PZ2": "목표가~임계값 (보통 구간)",
|
||
"PZ3": "임계값 이상 (나쁜 구간)"
|
||
}
|
||
|
||
st.markdown(f"""
|
||
**상태 해석:**
|
||
- **카드**: {current_card}
|
||
- **시나리오**: {scenario} - {state_interpretation.get(scenario, "알 수 없음")}
|
||
- **가격구간**: {price_zone} - {price_interpretation.get(price_zone, "알 수 없음")}
|
||
""")
|
||
|
||
# 행동 추천 요청
|
||
use_epsilon = st.checkbox("엡실론 그리디 사용", value=False)
|
||
epsilon = 0.1
|
||
if use_epsilon:
|
||
epsilon = st.slider("엡실론 값", 0.0, 0.5, 0.1, 0.05)
|
||
|
||
if st.button("🎯 행동 추천 받기", type="primary"):
|
||
request_data = {
|
||
"current_state": selected_state,
|
||
"use_epsilon_greedy": use_epsilon,
|
||
"epsilon": epsilon
|
||
}
|
||
|
||
response = APIClient.post("/action/recommend", request_data)
|
||
if response and response.get("success") is not False:
|
||
# response가 직접 ActionRecommendationResponse 형태인 경우
|
||
recommendation = response
|
||
|
||
st.success(f"🎯 추천 행동: **{recommendation.get('recommended_action', 'N/A')}**")
|
||
|
||
if recommendation.get('exploration', False):
|
||
st.warning("🎲 탐험 행동 (무작위 선택)")
|
||
else:
|
||
confidence = recommendation.get('confidence', 0) * 100
|
||
st.info(f"🎯 활용 행동 (신뢰도: {confidence:.1f}%)")
|
||
|
||
# Q값들 표시
|
||
q_values = recommendation.get('q_values', {})
|
||
if q_values:
|
||
st.subheader("📊 현재 상태의 Q값들")
|
||
q_df = pd.DataFrame([q_values]).T
|
||
q_df.columns = ['Q값']
|
||
q_df = q_df.sort_values('Q값', ascending=False)
|
||
|
||
# 추천 행동 하이라이트
|
||
def highlight_recommended(s):
|
||
return ['background-color: lightgreen' if x == recommendation.get('recommended_action')
|
||
else '' for x in s.index]
|
||
|
||
st.dataframe(
|
||
q_df.style.format({'Q값': '{:.3f}'}).apply(highlight_recommended, axis=0),
|
||
use_container_width=True
|
||
)
|
||
else:
|
||
st.error("행동 추천 실패")
|
||
|
||
with col2:
|
||
st.subheader("⚖️ 정책 비교")
|
||
|
||
# 정책 비교 요청
|
||
compare_response = APIClient.get(f"/compare/{selected_state}")
|
||
if compare_response and compare_response.get("success"):
|
||
comparison = compare_response["data"]
|
||
|
||
# 정책 일치 여부
|
||
agreement = comparison["policy_agreement"]
|
||
if agreement:
|
||
st.success("✅ Q-Learning과 FQI+CQL 정책이 일치합니다!")
|
||
else:
|
||
st.warning("⚠️ Q-Learning과 FQI+CQL 정책이 다릅니다.")
|
||
|
||
# 각 정책의 추천 행동
|
||
col_a, col_b = st.columns(2)
|
||
|
||
with col_a:
|
||
st.subheader("🔄 Q-Learning 정책")
|
||
q_learning = comparison["q_learning"]
|
||
st.metric("추천 행동", q_learning["action"])
|
||
|
||
# Q값들
|
||
q_values_ql = q_learning["q_values"]
|
||
if q_values_ql:
|
||
q_df_ql = pd.DataFrame([q_values_ql]).T
|
||
q_df_ql.columns = ['Q값']
|
||
st.dataframe(q_df_ql.style.format({'Q값': '{:.3f}'}))
|
||
|
||
with col_b:
|
||
st.subheader("🧠 FQI+CQL 정책")
|
||
fqi_cql = comparison["fqi_cql"]
|
||
st.metric("추천 행동", fqi_cql["action"])
|
||
|
||
# Q값들
|
||
q_values_fqi = fqi_cql["q_values"]
|
||
if q_values_fqi:
|
||
q_df_fqi = pd.DataFrame([q_values_fqi]).T
|
||
q_df_fqi.columns = ['Q값']
|
||
st.dataframe(q_df_fqi.style.format({'Q값': '{:.3f}'}))
|
||
|
||
# Q값 차이 분석
|
||
differences = comparison["q_value_differences"]
|
||
max_diff = comparison["max_difference"]
|
||
|
||
st.subheader("📊 Q값 차이 분석")
|
||
st.metric("최대 차이", f"{max_diff:.3f}")
|
||
|
||
if differences:
|
||
diff_df = pd.DataFrame([differences]).T
|
||
diff_df.columns = ['차이']
|
||
st.dataframe(diff_df.style.format({'차이': '{:.3f}'}))
|
||
|
||
else:
|
||
st.info("정책 비교를 위해 상태를 선택하고 학습을 진행해주세요.")
|
||
|
||
# 보상 계산 시뮬레이션
|
||
st.subheader("🧮 보상 계산 시뮬레이션")
|
||
|
||
col1, col2 = st.columns(2)
|
||
|
||
with col1:
|
||
st.subheader("📋 시뮬레이션 설정")
|
||
proposed_price = st.number_input("상대방 제안가", value=120.0, min_value=50.0, max_value=500.0)
|
||
is_negotiation_end = st.checkbox("협상 종료", value=False)
|
||
|
||
if st.button("💰 보상 계산", type="secondary"):
|
||
request_data = {
|
||
"scenario": scenario,
|
||
"price_zone": price_zone,
|
||
"anchor_price": st.session_state.anchor_price,
|
||
"proposed_price": proposed_price,
|
||
"is_end": is_negotiation_end
|
||
}
|
||
|
||
response = APIClient.post("/reward/calculate", request_data)
|
||
if response and response.get("success") is not False:
|
||
# response가 직접 RewardCalculationResponse 형태인 경우
|
||
reward_result = response
|
||
|
||
col_a, col_b, col_c = st.columns(3)
|
||
with col_a:
|
||
st.metric("보상", f"{reward_result.get('reward', 0):.3f}")
|
||
with col_b:
|
||
st.metric("가중치 (W)", f"{reward_result.get('weight', 0):.3f}")
|
||
with col_c:
|
||
st.metric("가격 비율 (A/P)", f"{reward_result.get('price_ratio', 0):.3f}")
|
||
|
||
# 공식 분해 표시
|
||
formula = reward_result.get('formula_breakdown', '')
|
||
if formula:
|
||
st.subheader('📝 계산 과정')
|
||
st.text(formula)
|
||
else:
|
||
st.error("보상 계산 실패")
|
||
|
||
with col2:
|
||
st.subheader("📈 학습 진행 상황")
|
||
|
||
# 시스템 상태 조회
|
||
status_response = APIClient.get("/status")
|
||
if status_response and status_response.get("success") is not False:
|
||
status = status_response
|
||
|
||
# 진행 상황 메트릭
|
||
total_exp = status.get("total_experiences", 0)
|
||
updates = status.get("q_table_updates", 0)
|
||
success_rate = status.get("success_rate", 0) * 100
|
||
|
||
progress_metrics = [
|
||
("데이터 수집", total_exp, 1000, "개"),
|
||
("Q-Table 업데이트", updates, 500, "회"),
|
||
("협상 성공률", success_rate, 100, "%")
|
||
]
|
||
|
||
for name, value, target, unit in progress_metrics:
|
||
progress = min(value / target, 1.0)
|
||
st.metric(
|
||
name,
|
||
f"{value}{unit}",
|
||
delta=f"목표: {target}{unit}"
|
||
)
|
||
st.progress(progress)
|
||
|
||
st.subheader("🎓 학습 완성도")
|
||
|
||
# Q-Table 완성도
|
||
qtable_response = APIClient.get("/qtable")
|
||
if qtable_response and qtable_response.get("success"):
|
||
qtable_data = qtable_response["data"]
|
||
statistics = qtable_data["statistics"]
|
||
|
||
sparsity = statistics.get('q_table_sparsity', 1.0)
|
||
completeness = (1 - sparsity) * 100
|
||
|
||
st.metric("Q-Table 완성도", f"{completeness:.1f}%")
|
||
st.progress(completeness / 100)
|
||
|
||
if completeness > 80:
|
||
st.success("🎉 충분히 학습되었습니다!")
|
||
elif completeness > 50:
|
||
st.info("📖 적당히 학습되었습니다.")
|
||
else:
|
||
st.warning("📚 더 많은 학습이 필요합니다.")
|
||
|
||
|
||
def main():
|
||
"""메인 함수"""
|
||
# 헤더 표시
|
||
display_header()
|
||
|
||
# 사이드바 표시
|
||
sidebar_status = display_sidebar()
|
||
|
||
# 탭 생성
|
||
tab1, tab2, tab3, tab4, tab5 = st.tabs([
|
||
"🏁 1. 콜드 스타트",
|
||
"📊 2. 데이터 수집",
|
||
"🔄 3. Q-Learning",
|
||
"🧠 4. FQI+CQL",
|
||
"🎯 5. 학습된 정책"
|
||
])
|
||
|
||
with tab1:
|
||
tab_cold_start()
|
||
|
||
with tab2:
|
||
tab_data_collection()
|
||
|
||
with tab3:
|
||
tab_q_learning()
|
||
|
||
with tab4:
|
||
tab_fqi_cql()
|
||
|
||
with tab5:
|
||
tab_learned_policy()
|
||
|
||
# 푸터
|
||
st.markdown("---")
|
||
st.markdown("""
|
||
<div style='text-align: center; color: #666;'>
|
||
<p>🎯 Q-Table 기반 협상 전략 데모 | 강화학습의 전체 여정을 경험해보세요</p>
|
||
<p>💡 문의사항이 있으시면 API 문서를 참고해주세요: <a href="http://localhost:8000/docs" target="_blank">http://localhost:8000/docs</a></p>
|
||
</div>
|
||
""", unsafe_allow_html=True)
|
||
|
||
|
||
def start_frontend():
|
||
"""프론트엔드 시작 (Poetry 스크립트용)"""
|
||
import subprocess
|
||
subprocess.run(["streamlit", "run", "frontend/app.py", "--server.port", "8501"])
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|