173 lines
5.4 KiB
Python
173 lines
5.4 KiB
Python
"""
|
|
FastAPI 엔드포인트 정의
|
|
"""
|
|
from fastapi import APIRouter, HTTPException, BackgroundTasks
|
|
from typing import Dict, Any
|
|
|
|
from app.models.schemas import (
|
|
RewardCalculationRequest, RewardCalculationResponse,
|
|
EpisodeGenerationRequest, LearningUpdateRequest,
|
|
FQICQLRequest, ActionRecommendationRequest,
|
|
ActionRecommendationResponse, SystemStatus
|
|
)
|
|
from app.services.demo_service import demo_service
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.get("/health")
|
|
async def health_check():
|
|
"""헬스 체크"""
|
|
return {"status": "healthy", "service": "Q-Table Negotiation Demo"}
|
|
|
|
|
|
@router.post("/reward/calculate", response_model=RewardCalculationResponse)
|
|
async def calculate_reward(request: RewardCalculationRequest):
|
|
"""보상 계산"""
|
|
try:
|
|
return demo_service.calculate_reward(request)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
@router.post("/episodes/generate")
|
|
async def generate_episodes(request: EpisodeGenerationRequest):
|
|
"""에피소드 생성"""
|
|
try:
|
|
result = demo_service.generate_episodes(request)
|
|
return {"success": True, "data": result}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
@router.post("/learning/q-learning")
|
|
async def update_q_learning(request: LearningUpdateRequest):
|
|
"""Q-Learning 업데이트"""
|
|
try:
|
|
result = demo_service.update_q_learning(
|
|
learning_rate=request.learning_rate,
|
|
discount_factor=request.discount_factor,
|
|
batch_size=request.batch_size
|
|
)
|
|
return {"success": True, "data": result}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
@router.post("/learning/fqi-cql")
|
|
async def run_fqi_cql(request: FQICQLRequest):
|
|
"""FQI+CQL 학습 실행"""
|
|
try:
|
|
result = demo_service.run_fqi_cql(
|
|
alpha=request.alpha,
|
|
gamma=request.gamma,
|
|
batch_size=request.batch_size,
|
|
num_iterations=request.num_iterations
|
|
)
|
|
return {"success": True, "data": result}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
@router.post("/action/recommend", response_model=ActionRecommendationResponse)
|
|
async def recommend_action(request: ActionRecommendationRequest):
|
|
"""행동 추천"""
|
|
try:
|
|
return demo_service.get_action_recommendation(request)
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
@router.get("/status", response_model=SystemStatus)
|
|
async def get_system_status():
|
|
"""시스템 상태 조회"""
|
|
try:
|
|
return demo_service.get_system_status()
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.get("/qtable")
|
|
async def get_q_table():
|
|
"""Q-Table 데이터 조회"""
|
|
try:
|
|
result = demo_service.get_q_table()
|
|
return {"success": True, "data": result}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.get("/fqi-cql")
|
|
async def get_fqi_cql_results():
|
|
"""FQI+CQL 결과 조회"""
|
|
try:
|
|
result = demo_service.get_fqi_cql_results()
|
|
return {"success": True, "data": result}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.get("/experiences")
|
|
async def get_experience_data():
|
|
"""경험 데이터 조회"""
|
|
try:
|
|
result = demo_service.get_experience_data()
|
|
return {"success": True, "data": result}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.get("/compare/{state}")
|
|
async def compare_policies(state: str):
|
|
"""정책 비교"""
|
|
try:
|
|
result = demo_service.compare_policies(state)
|
|
return {"success": True, "data": result}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
|
@router.post("/reset")
|
|
async def reset_system():
|
|
"""시스템 초기화"""
|
|
try:
|
|
demo_service.reset_all()
|
|
return {"success": True, "message": "System reset completed"}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.get("/states")
|
|
async def get_all_states():
|
|
"""모든 상태 목록 조회"""
|
|
try:
|
|
states = demo_service.env.get_all_states()
|
|
return {"success": True, "data": {"states": states, "count": len(states)}}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.get("/actions")
|
|
async def get_all_actions():
|
|
"""모든 행동 목록 조회"""
|
|
try:
|
|
actions = [action.value for action in demo_service.env.get_all_actions()]
|
|
return {"success": True, "data": {"actions": actions, "count": len(actions)}}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
@router.get("/config")
|
|
async def get_configuration():
|
|
"""환경 설정 조회"""
|
|
try:
|
|
config = {
|
|
"scenario_weights": {k.value: v for k, v in demo_service.env.scenario_weights.items()},
|
|
"price_zone_weights": {k.value: v for k, v in demo_service.env.price_zone_weights.items()},
|
|
"card_effects": {k.value: v for k, v in demo_service.env.card_effects.items()},
|
|
"scenario_difficulty": {k.value: v for k, v in demo_service.env.scenario_difficulty.items()}
|
|
}
|
|
return {"success": True, "data": config}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=str(e))
|