q_table_demo/frontend/app.py

906 lines
34 KiB
Python
Raw Permalink Blame History

This file contains invisible Unicode characters!

This file contains invisible Unicode characters that may be processed differently from what appears below. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to reveal hidden characters.

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
Streamlit 기반 Q-Table 협상 전략 데모 프론트엔드
"""
import streamlit as st
import requests
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import time
import json
from typing import Dict, Any, Optional
# 페이지 설정
st.set_page_config(
page_title="Q-Table 협상 전략 데모",
page_icon="🎯",
layout="wide",
initial_sidebar_state="expanded"
)
# API 기본 URL
API_BASE_URL = "http://localhost:8000/api/v1"
# 세션 상태 초기화
if 'current_state' not in st.session_state:
st.session_state.current_state = "C0S0P0"
if 'anchor_price' not in st.session_state:
st.session_state.anchor_price = 100
class APIClient:
"""API 클라이언트"""
@staticmethod
def get(endpoint: str) -> Optional[Dict[str, Any]]:
"""GET 요청"""
try:
response = requests.get(f"{API_BASE_URL}{endpoint}")
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
st.error(f"API 요청 오류: {e}")
return None
@staticmethod
def post(endpoint: str, data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""POST 요청"""
try:
response = requests.post(f"{API_BASE_URL}{endpoint}", json=data)
response.raise_for_status()
return response.json()
except requests.exceptions.RequestException as e:
st.error(f"API 요청 오류: {e}")
return None
def display_header():
"""헤더 제목 표시"""
st.title("🎯 Q-Table 기반 협상 전략 데모")
st.markdown("""
### 강화학습으로 배우는 협상 전략의 진화
이 데모는 **콜드 스타트 문제**부터 **학습된 정책**까지 Q-Learning의 전체 여정을 보여줍니다.
**핵심 보상함수:** `R(s,a) = W × (A/P) + (1-W) × End`
""")
st.markdown("---")
def display_sidebar():
"""사이드바 설정"""
st.sidebar.header("⚙️ 데모 설정")
# 시스템 상태 조회
status_response = APIClient.get("/status")
if status_response and status_response.get("success") is not False:
status = status_response
st.sidebar.metric("총 경험 데이터", status.get("total_experiences", 0))
st.sidebar.metric("Q-Table 업데이트", status.get("q_table_updates", 0))
st.sidebar.metric("고유 상태", status.get("unique_states", 0))
st.sidebar.metric("평균 보상", f"{status.get('average_reward', 0):.3f}")
st.sidebar.metric("성공률", f"{status.get('success_rate', 0)*100:.1f}%")
st.sidebar.markdown("---")
# 글로벌 설정
st.session_state.anchor_price = st.sidebar.number_input(
"목표가 (A)",
value=st.session_state.anchor_price,
min_value=50,
max_value=300
)
# 시스템 초기화
if st.sidebar.button("🔄 시스템 초기화", type="secondary"):
with st.spinner("시스템 초기화 중..."):
reset_response = APIClient.post("/reset", {})
if reset_response and reset_response.get("success"):
st.sidebar.success("시스템이 초기화되었습니다!")
st.rerun()
else:
st.sidebar.error("초기화 실패")
return status_response
def tab_cold_start():
"""콜드 스타트 탭"""
st.header("🏁 콜드 스타트 문제")
st.markdown("""
### 강화학습의 첫 번째 난관
새로운 강화학습 에이전트가 직면하는 가장 큰 문제는 **"아무것도 모른다"**는 것입니다.
모든 Q값이 0으로 초기화되어 있어, 어떤 행동이 좋은지 전혀 알 수 없습니다.
""")
# Q-Table 현재 상태 조회
qtable_response = APIClient.get("/qtable")
if qtable_response and qtable_response.get("success"):
qtable_data = qtable_response["data"]
q_table_dict = qtable_data["q_table"]
# DataFrame으로 변환
q_table_df = pd.DataFrame(q_table_dict)
st.subheader("📋 현재 Q-Table 상태")
# 통계 표시
col1, col2, col3, col4 = st.columns(4)
with col1:
non_zero_count = (q_table_df != 0).sum().sum()
st.metric("비어있지 않은 Q값", non_zero_count)
with col2:
total_count = q_table_df.size
st.metric("전체 Q값", total_count)
with col3:
sparsity = (1 - non_zero_count / total_count) * 100
st.metric("희소성", f"{sparsity:.1f}%")
with col4:
st.metric("업데이트 횟수", qtable_data["update_count"])
# Q-Table 표시 (상위 20개 상태만)
display_rows = min(20, len(q_table_df))
st.dataframe(
q_table_df.head(display_rows).style.format("{:.3f}").highlight_max(axis=1),
use_container_width=True
)
if len(q_table_df) > 20:
st.info(f"전체 {len(q_table_df)}개 상태 중 상위 20개만 표시됩니다.")
# 문제점과 해결방법
col1, col2 = st.columns(2)
with col1:
st.subheader("❌ 핵심 문제점")
st.markdown("""
- **무지상태**: 모든 Q값이 동일 (보통 0)
- **행동 선택 불가**: 어떤 행동이 좋은지 모름
- **무작위 탐험**: 비효율적인 학습 초기 단계
- **데이터 부족**: 학습할 경험이 없음
""")
with col2:
st.subheader("✅ 해결 방법")
st.markdown("""
- **탐험 전략**: Epsilon-greedy로 무작위 탐험
- **경험 수집**: (상태, 행동, 보상, 다음상태) 튜플 저장
- **점진적 학습**: 수집된 경험으로 Q값 업데이트
- **정책 개선**: 학습을 통한 점진적 성능 향상
""")
# Q-Learning 공식 설명
st.subheader("🧮 Q-Learning 업데이트 공식")
st.latex(r"Q(s,a) \leftarrow Q(s,a) + lpha [r + \gamma \max_{a'} Q(s',a') - Q(s,a)]")
st.markdown("""
**공식 설명:**
- **Q(s,a)**: 상태 s에서 행동 a의 Q값
- **α (알파)**: 학습률 (0 < α ≤ 1)
- **r**: 즉시 보상
- **γ (감마)**: 할인율 (0 ≤ γ < 1)
- **max Q(s',a')**: 다음 상태에서의 최대 Q값
""")
def tab_data_collection():
"""데이터 수집 탭"""
st.header("📊 경험 데이터 수집")
st.markdown("""
### 학습의 연료: 경험 데이터
강화학습 에이전트는 환경과 상호작용하면서 **경험 튜플**을 수집합니다.
각 경험은 `(상태, 행동, 보상, 다음상태, 종료여부)` 형태로 저장됩니다.
""")
# 설정 섹션
col1, col2 = st.columns([1, 2])
with col1:
st.subheader("⚙️ 에피소드 생성 설정")
num_episodes = st.slider("생성할 에피소드 수", 1, 50, 10)
max_steps = st.slider("에피소드당 최대 스텝", 3, 15, 8)
exploration_rate = st.slider("탐험율 (Epsilon)", 0.0, 1.0, 0.4, 0.1)
st.markdown(f"""
**현재 설정:**
- 목표가: {st.session_state.anchor_price}
- 탐험율: {exploration_rate*100:.0f}%
- 총 예상 경험: ~{num_episodes * max_steps}
""")
if st.button("🎲 자동 에피소드 생성", type="primary"):
with st.spinner("에피소드 생성 중..."):
request_data = {
"num_episodes": num_episodes,
"max_steps": max_steps,
"anchor_price": st.session_state.anchor_price,
"exploration_rate": exploration_rate
}
response = APIClient.post("/episodes/generate", request_data)
if response and response.get("success"):
result = response["data"]
st.success(f"{result['new_experiences']}개의 새로운 경험 데이터 생성!")
# 에피소드 결과 표시
episode_results = result["episode_results"]
success_count = sum(1 for ep in episode_results if ep["success"])
col_a, col_b, col_c = st.columns(3)
with col_a:
st.metric("생성된 에피소드", result["episodes_generated"])
with col_b:
st.metric("성공한 협상", success_count)
with col_c:
success_rate = (success_count / len(episode_results)) * 100
st.metric("성공률", f"{success_rate:.1f}%")
time.sleep(1) # UI 업데이트를 위한 잠시 대기
st.rerun()
else:
st.error("에피소드 생성 실패")
with col2:
st.subheader("📈 수집된 데이터 현황")
# 경험 데이터 조회
exp_response = APIClient.get("/experiences")
if exp_response and exp_response.get("success"):
exp_data = exp_response["data"]
stats = exp_data["statistics"]
recent_data = exp_data["recent_data"]
# 통계 표시
col_a, col_b, col_c, col_d = st.columns(4)
with col_a:
st.metric("총 경험 수", stats["total_count"])
with col_b:
st.metric("평균 보상", f"{stats['avg_reward']:.3f}")
with col_c:
st.metric("성공률", f"{stats['success_rate']*100:.1f}%")
with col_d:
st.metric("고유 상태", stats["unique_states"])
# 최근 경험 데이터 표시
if recent_data:
st.subheader("🔍 최근 경험 데이터")
recent_df = pd.DataFrame(recent_data)
# 필요한 컬럼만 선택
display_columns = ['state', 'action', 'reward', 'next_state', 'done']
available_columns = [col for col in display_columns if col in recent_df.columns]
if available_columns:
display_df = recent_df[available_columns].tail(10)
st.dataframe(
display_df.style.format({'reward': '{:.3f}'}),
use_container_width=True
)
# 데이터 분포 시각화
if len(recent_df) > 5:
st.subheader("📊 데이터 분포 분석")
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
# 보상 분포
axes[0,0].hist(recent_df['reward'], bins=15, alpha=0.7, color='skyblue', edgecolor='black')
axes[0,0].set_title('보상 분포')
axes[0,0].set_xlabel('보상값')
axes[0,0].set_ylabel('빈도')
# 행동 분포
if 'action' in recent_df.columns:
action_counts = recent_df['action'].value_counts()
axes[0,1].bar(action_counts.index, action_counts.values, color='lightgreen', edgecolor='black')
axes[0,1].set_title('행동 선택 빈도')
axes[0,1].set_xlabel('행동')
axes[0,1].set_ylabel('빈도')
# 상태 분포 (상위 10개)
if 'state' in recent_df.columns:
state_counts = recent_df['state'].value_counts().head(10)
axes[1,0].bar(range(len(state_counts)), state_counts.values, color='orange', edgecolor='black')
axes[1,0].set_title('상위 상태 빈도')
axes[1,0].set_xlabel('상태 순위')
axes[1,0].set_ylabel('빈도')
axes[1,0].set_xticks(range(len(state_counts)))
axes[1,0].set_xticklabels([f"{i+1}" for i in range(len(state_counts))])
# 성공/실패 분포
if 'done' in recent_df.columns:
done_counts = recent_df['done'].value_counts()
labels = ['진행중' if not k else '완료' for k in done_counts.index]
axes[1,1].pie(done_counts.values, labels=labels, autopct='%1.1f%%', colors=['lightcoral', 'lightblue'])
axes[1,1].set_title('협상 완료 비율')
plt.tight_layout()
st.pyplot(fig)
else:
st.info("아직 수집된 데이터가 없습니다. 왼쪽에서 에피소드를 생성해보세요!")
else:
st.warning("경험 데이터를 불러올 수 없습니다.")
# 경험 데이터 구조 설명
st.subheader("📋 경험 데이터 구조")
st.markdown("""
각 경험 튜플은 다음 정보를 포함합니다:
| 항목 | 설명 | 예시 |
|------|------|------|
| **상태 (State)** | 현재 협상 상황 | "C1APZ2" (카드1, 시나리오A, 가격구간2) |
| **행동 (Action)** | 선택한 협상 카드 | "C3" |
| **보상 (Reward)** | 행동에 대한 평가 | 0.85 |
| **다음상태 (Next State)** | 행동 후 새로운 상황 | "C3APZ1" |
| **종료 (Done)** | 협상 완료 여부 | true/false |
""")
# 더 많은 탭 함수들을 다음 파트에서 계속...
def tab_q_learning():
"""Q-Learning 탭"""
st.header("🔄 Q-Learning 실시간 학습")
st.markdown("""
### 경험으로부터 학습하기
수집된 경험 데이터를 사용하여 Q-Table을 업데이트합니다.
각 경험에서 **TD(Temporal Difference) 오차**를 계산하고 Q값을 조정합니다.
""")
col1, col2 = st.columns([1, 2])
with col1:
st.subheader("⚙️ 학습 설정")
learning_rate = st.slider("학습률 (α)", 0.01, 0.5, 0.1, 0.01)
discount_factor = st.slider("할인율 (γ)", 0.8, 0.99, 0.9, 0.01)
batch_size = st.slider("배치 크기", 16, 256, 32, 16)
st.markdown(f"""
**하이퍼파라미터:**
- **학습률 (α)**: {learning_rate} - 새로운 정보의 반영 정도
- **할인율 (γ)**: {discount_factor} - 미래 보상의 중요도
- **배치 크기**: {batch_size} - 한 번에 학습할 경험 수
""")
if st.button("🚀 Q-Learning 실행", type="primary"):
with st.spinner("Q-Learning 업데이트 중..."):
request_data = {
"learning_rate": learning_rate,
"discount_factor": discount_factor,
"batch_size": batch_size
}
response = APIClient.post("/learning/q-learning", request_data)
if response and response.get("success"):
result = response["data"]
st.success(f"{result['updates']}개 Q값 업데이트 완료!")
col_a, col_b = st.columns(2)
with col_a:
st.metric("배치 크기", result["batch_size"])
with col_b:
st.metric("평균 TD 오차", f"{result.get('avg_td_error', 0):.4f}")
time.sleep(1)
st.rerun()
else:
st.error("Q-Learning 업데이트 실패")
with col2:
st.subheader("📊 Q-Table 현황")
# Q-Table 데이터 조회
qtable_response = APIClient.get("/qtable")
if qtable_response and qtable_response.get("success"):
qtable_data = qtable_response["data"]
statistics = qtable_data["statistics"]
# 통계 표시
col_a, col_b, col_c, col_d = st.columns(4)
with col_a:
st.metric("총 업데이트", statistics.get("total_updates", 0))
with col_b:
st.metric("평균 TD 오차", f"{statistics.get('avg_td_error', 0):.4f}")
with col_c:
st.metric("평균 보상", f"{statistics.get('avg_reward', 0):.3f}")
with col_d:
sparsity = statistics.get('q_table_sparsity', 1.0) * 100
st.metric("Q-Table 희소성", f"{sparsity:.1f}%")
# Q값 범위 표시
q_range = statistics.get('q_value_range', {})
if q_range:
st.subheader("📈 Q값 분포")
col_a, col_b, col_c = st.columns(3)
with col_a:
st.metric("최솟값", f"{q_range.get('min', 0):.3f}")
with col_b:
st.metric("평균값", f"{q_range.get('mean', 0):.3f}")
with col_c:
st.metric("최댓값", f"{q_range.get('max', 0):.3f}")
# Q-Table 표시 (비어있지 않은 상태들만)
q_table_dict = qtable_data["q_table"]
q_table_df = pd.DataFrame(q_table_dict)
# 0이 아닌 값이 있는 행만 필터링
non_zero_rows = (q_table_df != 0).any(axis=1)
if non_zero_rows.any():
st.subheader("🎯 학습된 Q값들")
learned_qtable = q_table_df[non_zero_rows].head(15)
st.dataframe(
learned_qtable.style.format("{:.3f}").highlight_max(axis=1),
use_container_width=True
)
learned_count = non_zero_rows.sum()
total_count = len(q_table_df)
st.info(f"전체 {total_count}개 상태 중 {learned_count}개 상태가 학습되었습니다.")
else:
st.info("아직 학습된 Q값이 없습니다. 위에서 Q-Learning을 실행해보세요!")
# TD 오차 설명
st.subheader("🧮 TD(Temporal Difference) 오차")
st.markdown("""
**TD 오차**는 현재 Q값과 목표값의 차이입니다:
`TD 오차 = [r + γ max Q(s',a')] - Q(s,a)`
- **양수**: 현재 Q값이 너무 낮음 → Q값 증가
- **음수**: 현재 Q값이 너무 높음 → Q값 감소
- **0에 가까움**: Q값이 적절함 → 학습 수렴
""")
def tab_fqi_cql():
"""FQI+CQL 탭"""
st.header("🧠 FQI + CQL 오프라인 학습")
st.markdown("""
### 오프라인 강화학습의 핵심
**FQI (Fitted Q-Iteration)**와 **CQL (Conservative Q-Learning)**을 결합한
오프라인 강화학습 방법입니다. 수집된 데이터만으로 안전하고 보수적인 정책을 학습합니다.
""")
col1, col2 = st.columns([1, 2])
with col1:
st.subheader("⚙️ FQI+CQL 설정")
alpha = st.slider("CQL 보수성 파라미터 (α)", 0.0, 3.0, 1.0, 0.1)
gamma = st.slider("할인율 (γ)", 0.8, 0.99, 0.95, 0.01)
batch_size = st.slider("배치 크기", 16, 256, 32, 16)
num_iterations = st.slider("반복 횟수", 1, 50, 10, 1)
st.markdown(f"""
**설정값:**
- **α (Alpha)**: {alpha} - 보수성 강도
- **γ (Gamma)**: {gamma} - 미래 보상 할인
- **배치 크기**: {batch_size}
- **반복 횟수**: {num_iterations}
""")
st.markdown("""
**CQL 특징:**
- 🛡️ **보수적 추정**: 불확실한 행동의 Q값을 낮게 유지
- 📊 **데이터 기반**: 수집된 경험만 활용
- 🎯 **안전한 정책**: 분포 이동 문제 해결
""")
if st.button("🚀 FQI+CQL 실행", type="primary"):
with st.spinner("FQI+CQL 학습 중..."):
request_data = {
"alpha": alpha,
"gamma": gamma,
"batch_size": batch_size,
"num_iterations": num_iterations
}
response = APIClient.post("/learning/fqi-cql", request_data)
if response and response.get("success"):
result = response["data"]
training_result = result["training_result"]
policy_comparison = result["policy_comparison"]
st.success(f"{training_result['total_iterations']}회 반복 학습 완료!")
# 학습 결과 표시
col_a, col_b = st.columns(2)
with col_a:
st.metric("평균 벨만 손실", f"{training_result['avg_bellman_loss']:.4f}")
with col_b:
st.metric("평균 CQL 페널티", f"{training_result['avg_cql_penalty']:.4f}")
# 정책 비교
st.metric("행동 정책과의 일치율", f"{policy_comparison['action_agreement']*100:.1f}%")
time.sleep(1)
st.rerun()
else:
st.error("FQI+CQL 학습 실패")
with col2:
st.subheader("📊 FQI+CQL 결과")
# FQI+CQL 결과 조회
fqi_response = APIClient.get("/fqi-cql")
if fqi_response and fqi_response.get("success"):
fqi_data = fqi_response["data"]
statistics = fqi_data["statistics"]
# 통계 표시
col_a, col_b, col_c = st.columns(3)
with col_a:
st.metric("학습 배치", statistics.get("total_batches", 0))
with col_b:
st.metric("벨만 손실", f"{statistics.get('avg_bellman_loss', 0):.4f}")
with col_c:
st.metric("CQL 페널티", f"{statistics.get('avg_cql_penalty', 0):.4f}")
# 수렴 경향
convergence = statistics.get("convergence_trend", "unknown")
convergence_color = {
"improving": "🟢",
"deteriorating": "🔴",
"fluctuating": "🟡",
"insufficient_data": ""
}
st.info(f"수렴 경향: {convergence_color.get(convergence, '')} {convergence}")
# Q-Network 통계
q_stats = statistics.get("q_network_stats", {})
if q_stats:
st.subheader("📈 Q-Network 분포")
col_a, col_b, col_c, col_d = st.columns(4)
with col_a:
st.metric("최솟값", f"{q_stats.get('min', 0):.3f}")
with col_b:
st.metric("평균값", f"{q_stats.get('mean', 0):.3f}")
with col_c:
st.metric("최댓값", f"{q_stats.get('max', 0):.3f}")
with col_d:
st.metric("표준편차", f"{q_stats.get('std', 0):.3f}")
# Q-Network 표시 (상위 15개 상태)
q_network_dict = fqi_data["q_network"]
q_network_df = pd.DataFrame(q_network_dict)
st.subheader("🎯 학습된 Q-Network")
display_df = q_network_df.head(15)
st.dataframe(
display_df.style.format("{:.3f}").highlight_max(axis=1),
use_container_width=True
)
else:
st.info("FQI+CQL을 먼저 실행해주세요!")
# FQI+CQL 알고리즘 설명
st.subheader("🔬 FQI + CQL 알고리즘")
col1, col2 = st.columns(2)
with col1:
st.markdown("""
**FQI (Fitted Q-Iteration)**
- 배치 기반 Q-Learning
- 전체 데이터셋을 한 번에 활용
- 함수 근사 (신경망) 사용
- 안정적인 학습 과정
""")
with col2:
st.markdown("""
**CQL (Conservative Q-Learning)**
- 보수적 Q값 추정
- Out-of-Distribution 행동 억제
- 데이터에 없는 행동의 Q값 하향 조정
- 안전한 정책 학습
""")
def tab_learned_policy():
"""학습된 정책 탭"""
st.header("🎯 학습된 정책 비교 및 활용")
st.markdown("""
### 학습 완료: 정책의 진화
Q-Learning과 FQI+CQL로 학습된 정책을 비교하고,
실제 협상 상황에서 어떤 행동을 추천하는지 확인해보세요.
""")
# 상태 선택
col1, col2 = st.columns([1, 2])
with col1:
st.subheader("🎮 협상 시뮬레이션")
# 상태 구성 요소 선택
current_card = st.selectbox("현재 카드", ["C1", "C2", "C3", "C4"])
scenario = st.selectbox("시나리오", ["A", "B", "C", "D"])
price_zone = st.selectbox("가격 구간", ["PZ1", "PZ2", "PZ3"])
# 상태 ID 생성
selected_state = f"{current_card}{scenario}{price_zone}"
st.session_state.current_state = selected_state
st.info(f"선택된 상태: **{selected_state}**")
# 상태 해석
state_interpretation = {
"A": "어려운 협상 (높은 가중치)",
"B": "쉬운 협상 (낮은 가중치)",
"C": "보통 협상 (중간 가중치)",
"D": "매우 쉬운 협상 (낮은 가중치)"
}
price_interpretation = {
"PZ1": "목표가 이하 (좋은 구간)",
"PZ2": "목표가~임계값 (보통 구간)",
"PZ3": "임계값 이상 (나쁜 구간)"
}
st.markdown(f"""
**상태 해석:**
- **카드**: {current_card}
- **시나리오**: {scenario} - {state_interpretation.get(scenario, "알 수 없음")}
- **가격구간**: {price_zone} - {price_interpretation.get(price_zone, "알 수 없음")}
""")
# 행동 추천 요청
use_epsilon = st.checkbox("엡실론 그리디 사용", value=False)
epsilon = 0.1
if use_epsilon:
epsilon = st.slider("엡실론 값", 0.0, 0.5, 0.1, 0.05)
if st.button("🎯 행동 추천 받기", type="primary"):
request_data = {
"current_state": selected_state,
"use_epsilon_greedy": use_epsilon,
"epsilon": epsilon
}
response = APIClient.post("/action/recommend", request_data)
if response and response.get("success") is not False:
# response가 직접 ActionRecommendationResponse 형태인 경우
recommendation = response
st.success(f"🎯 추천 행동: **{recommendation.get('recommended_action', 'N/A')}**")
if recommendation.get('exploration', False):
st.warning("🎲 탐험 행동 (무작위 선택)")
else:
confidence = recommendation.get('confidence', 0) * 100
st.info(f"🎯 활용 행동 (신뢰도: {confidence:.1f}%)")
# Q값들 표시
q_values = recommendation.get('q_values', {})
if q_values:
st.subheader("📊 현재 상태의 Q값들")
q_df = pd.DataFrame([q_values]).T
q_df.columns = ['Q값']
q_df = q_df.sort_values('Q값', ascending=False)
# 추천 행동 하이라이트
def highlight_recommended(s):
return ['background-color: lightgreen' if x == recommendation.get('recommended_action')
else '' for x in s.index]
st.dataframe(
q_df.style.format({'Q값': '{:.3f}'}).apply(highlight_recommended, axis=0),
use_container_width=True
)
else:
st.error("행동 추천 실패")
with col2:
st.subheader("⚖️ 정책 비교")
# 정책 비교 요청
compare_response = APIClient.get(f"/compare/{selected_state}")
if compare_response and compare_response.get("success"):
comparison = compare_response["data"]
# 정책 일치 여부
agreement = comparison["policy_agreement"]
if agreement:
st.success("✅ Q-Learning과 FQI+CQL 정책이 일치합니다!")
else:
st.warning("⚠️ Q-Learning과 FQI+CQL 정책이 다릅니다.")
# 각 정책의 추천 행동
col_a, col_b = st.columns(2)
with col_a:
st.subheader("🔄 Q-Learning 정책")
q_learning = comparison["q_learning"]
st.metric("추천 행동", q_learning["action"])
# Q값들
q_values_ql = q_learning["q_values"]
if q_values_ql:
q_df_ql = pd.DataFrame([q_values_ql]).T
q_df_ql.columns = ['Q값']
st.dataframe(q_df_ql.style.format({'Q값': '{:.3f}'}))
with col_b:
st.subheader("🧠 FQI+CQL 정책")
fqi_cql = comparison["fqi_cql"]
st.metric("추천 행동", fqi_cql["action"])
# Q값들
q_values_fqi = fqi_cql["q_values"]
if q_values_fqi:
q_df_fqi = pd.DataFrame([q_values_fqi]).T
q_df_fqi.columns = ['Q값']
st.dataframe(q_df_fqi.style.format({'Q값': '{:.3f}'}))
# Q값 차이 분석
differences = comparison["q_value_differences"]
max_diff = comparison["max_difference"]
st.subheader("📊 Q값 차이 분석")
st.metric("최대 차이", f"{max_diff:.3f}")
if differences:
diff_df = pd.DataFrame([differences]).T
diff_df.columns = ['차이']
st.dataframe(diff_df.style.format({'차이': '{:.3f}'}))
else:
st.info("정책 비교를 위해 상태를 선택하고 학습을 진행해주세요.")
# 보상 계산 시뮬레이션
st.subheader("🧮 보상 계산 시뮬레이션")
col1, col2 = st.columns(2)
with col1:
st.subheader("📋 시뮬레이션 설정")
proposed_price = st.number_input("상대방 제안가", value=120.0, min_value=50.0, max_value=500.0)
is_negotiation_end = st.checkbox("협상 종료", value=False)
if st.button("💰 보상 계산", type="secondary"):
request_data = {
"scenario": scenario,
"price_zone": price_zone,
"anchor_price": st.session_state.anchor_price,
"proposed_price": proposed_price,
"is_end": is_negotiation_end
}
response = APIClient.post("/reward/calculate", request_data)
if response and response.get("success") is not False:
# response가 직접 RewardCalculationResponse 형태인 경우
reward_result = response
col_a, col_b, col_c = st.columns(3)
with col_a:
st.metric("보상", f"{reward_result.get('reward', 0):.3f}")
with col_b:
st.metric("가중치 (W)", f"{reward_result.get('weight', 0):.3f}")
with col_c:
st.metric("가격 비율 (A/P)", f"{reward_result.get('price_ratio', 0):.3f}")
# 공식 분해 표시
formula = reward_result.get('formula_breakdown', '')
if formula:
st.subheader('📝 계산 과정')
st.text(formula)
else:
st.error("보상 계산 실패")
with col2:
st.subheader("📈 학습 진행 상황")
# 시스템 상태 조회
status_response = APIClient.get("/status")
if status_response and status_response.get("success") is not False:
status = status_response
# 진행 상황 메트릭
total_exp = status.get("total_experiences", 0)
updates = status.get("q_table_updates", 0)
success_rate = status.get("success_rate", 0) * 100
progress_metrics = [
("데이터 수집", total_exp, 1000, ""),
("Q-Table 업데이트", updates, 500, ""),
("협상 성공률", success_rate, 100, "%")
]
for name, value, target, unit in progress_metrics:
progress = min(value / target, 1.0)
st.metric(
name,
f"{value}{unit}",
delta=f"목표: {target}{unit}"
)
st.progress(progress)
st.subheader("🎓 학습 완성도")
# Q-Table 완성도
qtable_response = APIClient.get("/qtable")
if qtable_response and qtable_response.get("success"):
qtable_data = qtable_response["data"]
statistics = qtable_data["statistics"]
sparsity = statistics.get('q_table_sparsity', 1.0)
completeness = (1 - sparsity) * 100
st.metric("Q-Table 완성도", f"{completeness:.1f}%")
st.progress(completeness / 100)
if completeness > 80:
st.success("🎉 충분히 학습되었습니다!")
elif completeness > 50:
st.info("📖 적당히 학습되었습니다.")
else:
st.warning("📚 더 많은 학습이 필요합니다.")
def main():
"""메인 함수"""
# 헤더 표시
display_header()
# 사이드바 표시
sidebar_status = display_sidebar()
# 탭 생성
tab1, tab2, tab3, tab4, tab5 = st.tabs([
"🏁 1. 콜드 스타트",
"📊 2. 데이터 수집",
"🔄 3. Q-Learning",
"🧠 4. FQI+CQL",
"🎯 5. 학습된 정책"
])
with tab1:
tab_cold_start()
with tab2:
tab_data_collection()
with tab3:
tab_q_learning()
with tab4:
tab_fqi_cql()
with tab5:
tab_learned_policy()
# 푸터
st.markdown("---")
st.markdown("""
<div style='text-align: center; color: #666;'>
<p>🎯 Q-Table 기반 협상 전략 데모 | 강화학습의 전체 여정을 경험해보세요</p>
<p>💡 문의사항이 있으시면 API 문서를 참고해주세요: <a href="http://localhost:8000/docs" target="_blank">http://localhost:8000/docs</a></p>
</div>
""", unsafe_allow_html=True)
def start_frontend():
"""프론트엔드 시작 (Poetry 스크립트용)"""
import subprocess
subprocess.run(["streamlit", "run", "frontend/app.py", "--server.port", "8501"])
if __name__ == "__main__":
main()