Initial commit - Q-Table 협상 전략 강화학습 데모 프로젝트
commit
bbea033cdc
|
|
@ -0,0 +1,19 @@
|
|||
# 개발 환경 설정
|
||||
API_HOST=localhost
|
||||
API_PORT=8000
|
||||
FRONTEND_HOST=localhost
|
||||
FRONTEND_PORT=8501
|
||||
|
||||
# 강화학습 하이퍼파라미터
|
||||
DEFAULT_LEARNING_RATE=0.1
|
||||
DEFAULT_DISCOUNT_FACTOR=0.9
|
||||
DEFAULT_EPSILON=0.1
|
||||
|
||||
# 협상 환경 설정
|
||||
DEFAULT_ANCHOR_PRICE=100
|
||||
MAX_EPISODES=1000
|
||||
MAX_STEPS_PER_EPISODE=10
|
||||
|
||||
# 로깅 설정
|
||||
LOG_LEVEL=INFO
|
||||
LOG_FILE=app.log
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
# Python 3.9 기반 이미지
|
||||
FROM python:3.9-slim
|
||||
|
||||
# 작업 디렉토리 설정
|
||||
WORKDIR /app
|
||||
|
||||
# 시스템 의존성 설치
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Poetry 설치
|
||||
RUN pip install poetry
|
||||
|
||||
# Poetry 설정 (가상환경을 컨테이너 내부에 생성하지 않음)
|
||||
RUN poetry config virtualenvs.create false
|
||||
|
||||
# 의존성 파일 복사
|
||||
COPY pyproject.toml poetry.lock* ./
|
||||
|
||||
# 의존성 설치
|
||||
RUN poetry install --no-dev
|
||||
|
||||
# 애플리케이션 코드 복사
|
||||
COPY . .
|
||||
|
||||
# 포트 노출
|
||||
EXPOSE 8000 8501
|
||||
|
||||
# 환경 변수 설정
|
||||
ENV PYTHONPATH=/app
|
||||
ENV API_HOST=0.0.0.0
|
||||
ENV FRONTEND_HOST=0.0.0.0
|
||||
|
||||
# 기본 명령어 (API 서버 실행)
|
||||
CMD ["python", "run_api.py"]
|
||||
|
|
@ -0,0 +1,168 @@
|
|||
# 🚀 Q-Table 협상 전략 데모 시작 가이드
|
||||
|
||||
## 빠른 시작
|
||||
|
||||
### 1. 환경 설정
|
||||
|
||||
#### Poetry 사용 (권장)
|
||||
```bash
|
||||
# Poetry 설치 (미설치시)
|
||||
curl -sSL https://install.python-poetry.org | python3 -
|
||||
|
||||
# 의존성 설치
|
||||
poetry install
|
||||
|
||||
# 가상환경 활성화
|
||||
poetry shell
|
||||
```
|
||||
|
||||
#### pip 사용
|
||||
```bash
|
||||
# 가상환경 생성 (권장)
|
||||
python -m venv venv
|
||||
source venv/bin/activate # Linux/Mac
|
||||
# venv\Scripts\activate # Windows
|
||||
|
||||
# 의존성 설치
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 2. 환경 변수 설정
|
||||
```bash
|
||||
# .env 파일 생성
|
||||
cp .env.example .env
|
||||
|
||||
# 필요에 따라 .env 파일 수정
|
||||
nano .env
|
||||
```
|
||||
|
||||
### 3. 시스템 실행
|
||||
|
||||
#### 방법 1: 전체 시스템 동시 실행
|
||||
```bash
|
||||
python run_both.py
|
||||
```
|
||||
|
||||
#### 방법 2: 개별 실행
|
||||
```bash
|
||||
# 터미널 1 - API 서버
|
||||
python run_api.py
|
||||
|
||||
# 터미널 2 - 프론트엔드
|
||||
python run_frontend.py
|
||||
```
|
||||
|
||||
#### 방법 3: Make 사용
|
||||
```bash
|
||||
make run-both
|
||||
```
|
||||
|
||||
### 4. 접속
|
||||
- **프론트엔드**: http://localhost:8501
|
||||
- **API 문서**: http://localhost:8000/docs
|
||||
|
||||
## 주요 기능 사용법
|
||||
|
||||
### 1. 콜드 스타트 확인
|
||||
- 프론트엔드의 "콜드 스타트" 탭에서 초기 Q-Table 상태 확인
|
||||
- 모든 Q값이 0인 상태에서 시작
|
||||
|
||||
### 2. 데이터 수집
|
||||
- "데이터 수집" 탭에서 에피소드 생성
|
||||
- 탐험율과 에피소드 수를 조정하여 경험 데이터 수집
|
||||
|
||||
### 3. Q-Learning 학습
|
||||
- "Q-Learning" 탭에서 수집된 데이터로 Q-Table 업데이트
|
||||
- 학습률, 할인율 등 하이퍼파라미터 조정 가능
|
||||
|
||||
### 4. FQI+CQL 학습
|
||||
- "FQI+CQL" 탭에서 오프라인 학습 수행
|
||||
- 보수적 Q-Learning으로 안전한 정책 학습
|
||||
|
||||
### 5. 정책 비교
|
||||
- "학습된 정책" 탭에서 Q-Learning과 FQI+CQL 정책 비교
|
||||
- 실제 협상 시나리오에서 행동 추천 받기
|
||||
|
||||
## Docker 사용
|
||||
|
||||
### Docker Compose로 실행
|
||||
```bash
|
||||
# 이미지 빌드
|
||||
docker-compose build
|
||||
|
||||
# 시스템 실행
|
||||
docker-compose up
|
||||
|
||||
# 백그라운드 실행
|
||||
docker-compose up -d
|
||||
|
||||
# 중지
|
||||
docker-compose down
|
||||
```
|
||||
|
||||
## 개발 환경 설정
|
||||
|
||||
### 코드 품질 도구
|
||||
```bash
|
||||
# 코드 포맷팅
|
||||
make format
|
||||
|
||||
# 린팅
|
||||
make lint
|
||||
|
||||
# 테스트
|
||||
make test
|
||||
```
|
||||
|
||||
### 개발 서버 실행
|
||||
```bash
|
||||
# API 서버 (Hot Reload)
|
||||
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||
|
||||
# 프론트엔드 (Hot Reload)
|
||||
streamlit run frontend/app.py --server.port 8501
|
||||
```
|
||||
|
||||
## 문제 해결
|
||||
|
||||
### 포트 충돌
|
||||
- 8000번 포트나 8501번 포트가 사용 중인 경우 .env 파일에서 포트 변경
|
||||
|
||||
### 의존성 오류
|
||||
```bash
|
||||
# 의존성 재설치
|
||||
pip install --force-reinstall -r requirements.txt
|
||||
|
||||
# 또는 Poetry 사용시
|
||||
poetry install --no-cache
|
||||
```
|
||||
|
||||
### API 연결 오류
|
||||
- API 서버가 실행 중인지 확인
|
||||
- http://localhost:8000/api/v1/health 에서 상태 확인
|
||||
|
||||
## 추가 정보
|
||||
|
||||
### 프로젝트 구조
|
||||
```
|
||||
qtable_negotiation_demo/
|
||||
├── app/ # FastAPI 백엔드
|
||||
│ ├── api/ # API 엔드포인트
|
||||
│ ├── core/ # 설정 및 유틸리티
|
||||
│ ├── models/ # 데이터 모델
|
||||
│ └── services/ # 비즈니스 로직
|
||||
├── frontend/ # Streamlit 프론트엔드
|
||||
├── tests/ # 테스트 코드
|
||||
├── run_*.py # 실행 스크립트
|
||||
└── docker-compose.yml # Docker 설정
|
||||
```
|
||||
|
||||
### 핵심 개념
|
||||
- **보상함수**: R(s,a) = W × (A/P) + (1-W) × End
|
||||
- **상태공간**: (카드, 시나리오, 가격구간) 조합
|
||||
- **Q-Learning**: 온라인 강화학습
|
||||
- **FQI+CQL**: 오프라인 보수적 강화학습
|
||||
|
||||
### 지원 및 문의
|
||||
- 이슈 발생시 API 문서 참조: http://localhost:8000/docs
|
||||
- 로그 파일 확인: app.log
|
||||
Binary file not shown.
Binary file not shown.
|
|
@ -0,0 +1,89 @@
|
|||
# Q-Table 협상 전략 데모 Makefile
|
||||
|
||||
.PHONY: help install run-api run-frontend run-both test clean docker-build docker-run
|
||||
|
||||
help: ## 도움말 표시
|
||||
@echo "사용 가능한 명령어:"
|
||||
@awk 'BEGIN {FS = ":.*?## "} /^[a-zA-Z_-]+:.*?## / {printf " \033[36m%-15s\033[0m %s\n", $$1, $$2}' $(MAKEFILE_LIST)
|
||||
|
||||
install: ## 의존성 설치
|
||||
@echo "🔧 의존성을 설치합니다..."
|
||||
@if command -v poetry >/dev/null 2>&1; then \
|
||||
poetry install; \
|
||||
else \
|
||||
pip install -r requirements.txt; \
|
||||
fi
|
||||
|
||||
run-api: ## API 서버 실행
|
||||
@echo "🚀 API 서버를 시작합니다..."
|
||||
python run_api.py
|
||||
|
||||
run-frontend: ## 프론트엔드 실행
|
||||
@echo "🎯 프론트엔드를 시작합니다..."
|
||||
python run_frontend.py
|
||||
|
||||
run-both: ## API와 프론트엔드 동시 실행
|
||||
@echo "🚀 전체 시스템을 시작합니다..."
|
||||
python run_both.py
|
||||
|
||||
test: ## 테스트 실행
|
||||
@echo "🧪 테스트를 실행합니다..."
|
||||
@if command -v poetry >/dev/null 2>&1; then \
|
||||
poetry run pytest tests/ -v; \
|
||||
else \
|
||||
pytest tests/ -v; \
|
||||
fi
|
||||
|
||||
clean: ## 캐시 및 임시 파일 정리
|
||||
@echo "🧹 정리 작업을 수행합니다..."
|
||||
find . -type f -name "*.pyc" -delete
|
||||
find . -type d -name "__pycache__" -delete
|
||||
find . -type d -name ".pytest_cache" -exec rm -rf {} +
|
||||
find . -type d -name "*.egg-info" -exec rm -rf {} +
|
||||
|
||||
docker-build: ## Docker 이미지 빌드
|
||||
@echo "🐳 Docker 이미지를 빌드합니다..."
|
||||
docker-compose build
|
||||
|
||||
docker-run: ## Docker로 실행
|
||||
@echo "🐳 Docker로 시스템을 시작합니다..."
|
||||
docker-compose up
|
||||
|
||||
docker-stop: ## Docker 컨테이너 중지
|
||||
@echo "🛑 Docker 컨테이너를 중지합니다..."
|
||||
docker-compose down
|
||||
|
||||
format: ## 코드 포맷팅
|
||||
@echo "✨ 코드를 포맷팅합니다..."
|
||||
@if command -v poetry >/dev/null 2>&1; then \
|
||||
poetry run black app/ frontend/ tests/; \
|
||||
else \
|
||||
black app/ frontend/ tests/; \
|
||||
fi
|
||||
|
||||
lint: ## 코드 린팅
|
||||
@echo "🔍 코드를 검사합니다..."
|
||||
@if command -v poetry >/dev/null 2>&1; then \
|
||||
poetry run flake8 app/ frontend/ tests/; \
|
||||
else \
|
||||
flake8 app/ frontend/ tests/; \
|
||||
fi
|
||||
|
||||
setup-dev: ## 개발 환경 설정
|
||||
@echo "🔧 개발 환경을 설정합니다..."
|
||||
cp .env.example .env
|
||||
@echo "✅ .env 파일이 생성되었습니다. 필요에 따라 수정해주세요."
|
||||
|
||||
demo: ## 데모 데이터 생성
|
||||
@echo "🎲 데모 데이터를 생성합니다..."
|
||||
@echo "API 서버가 실행 중이어야 합니다."
|
||||
curl -X POST "http://localhost:8000/api/v1/episodes/generate" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"num_episodes": 10, "max_steps": 8, "anchor_price": 100, "exploration_rate": 0.4}'
|
||||
|
||||
status: ## 시스템 상태 확인
|
||||
@echo "📊 시스템 상태를 확인합니다..."
|
||||
@echo "API 서버 상태:"
|
||||
@curl -s http://localhost:8000/api/v1/health || echo "API 서버가 실행되지 않습니다."
|
||||
@echo "\n프론트엔드 상태:"
|
||||
@curl -s http://localhost:8501 >/dev/null && echo "프론트엔드가 실행 중입니다." || echo "프론트엔드가 실행되지 않습니다."
|
||||
|
|
@ -0,0 +1,134 @@
|
|||
# Q-Table 협상 전략 강화학습 데모
|
||||
|
||||
기업 간 협상 시뮬레이션을 위한 강화학습 에이전트의 Q-Table 기반 전략 학습 데모입니다.
|
||||
|
||||
## 주요 기능
|
||||
|
||||
- ✅ Q-Table 기반 협상 전략 시뮬레이션
|
||||
- ✅ 보상함수 R(s,a) = W × (A/P) + (1-W) × End 정확한 구현
|
||||
- ✅ FQI (Fitted Q-Iteration) + CQL (Conservative Q-Learning) 시뮬레이션
|
||||
- ✅ 실시간 경험 데이터 수집 및 축적
|
||||
- ✅ Q-Table 학습 과정 시각화
|
||||
- ✅ 콜드 스타트 문제부터 학습된 정책까지 전체 여정 관찰
|
||||
|
||||
## 시스템 구성
|
||||
|
||||
### 백엔드 (FastAPI)
|
||||
- RESTful API 서버
|
||||
- Q-Table 학습 엔진
|
||||
- 협상 환경 시뮬레이터
|
||||
- 경험 데이터 관리
|
||||
|
||||
### 프론트엔드 (Streamlit)
|
||||
- 대화형 웹 인터페이스
|
||||
- 실시간 시각화
|
||||
- 단계별 학습 과정 관찰
|
||||
|
||||
## 설치 및 실행
|
||||
|
||||
### 1. 의존성 설치
|
||||
```bash
|
||||
# Poetry를 사용한 설치
|
||||
poetry install
|
||||
|
||||
# 또는 pip 사용
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 2. 환경 설정
|
||||
```bash
|
||||
# .env 파일 복사 및 수정
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
### 3. 백엔드 서버 실행
|
||||
```bash
|
||||
# Poetry 사용
|
||||
poetry run uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
||||
|
||||
# 또는 직접 실행
|
||||
uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
||||
```
|
||||
|
||||
### 4. 프론트엔드 실행
|
||||
```bash
|
||||
# Poetry 사용
|
||||
poetry run streamlit run frontend/app.py --server.port 8501
|
||||
|
||||
# 또는 직접 실행
|
||||
streamlit run frontend/app.py --server.port 8501
|
||||
```
|
||||
|
||||
### 5. 데모 접속
|
||||
- 프론트엔드: http://localhost:8501
|
||||
- API 문서: http://localhost:8000/docs
|
||||
|
||||
## 핵심 개념
|
||||
|
||||
### 상태 공간 (State Space)
|
||||
상태는 (현재 카드, 시나리오, 가격 구간)의 조합으로 구성됩니다:
|
||||
- **카드**: C1, C2, C3, C4 (협상 전략 카드)
|
||||
- **시나리오**: A, B, C, D (협상 상황별 가중치)
|
||||
- **가격 구간**: PZ1, PZ2, PZ3 (목표가 대비 제안가 구간)
|
||||
|
||||
### 보상함수 (Reward Function)
|
||||
```
|
||||
R(s,a) = W × (A/P) + (1-W) × End
|
||||
```
|
||||
- **A**: 앵커링 목표가 (구매자 목표 가격)
|
||||
- **P**: 상대방 제안가 (판매자 제시 가격)
|
||||
- **End**: 협상 종료 여부 (0 또는 1)
|
||||
- **W**: 가중치 = (S_n + PZ_n) / 2
|
||||
|
||||
### 가중치 시스템
|
||||
- **시나리오별 가중치**: S_1=A, S_2=D, S_3=C, S_4=B
|
||||
- **가격 구간별 가중치**: 목표가와 제안가 차이에 따른 영향도
|
||||
|
||||
## 학습 알고리즘
|
||||
|
||||
### 1. Q-Learning
|
||||
전통적인 온라인 강화학습 알고리즘으로 실시간 Q-Table 업데이트
|
||||
|
||||
### 2. FQI (Fitted Q-Iteration)
|
||||
배치 기반 오프라인 강화학습으로 수집된 경험 데이터 활용
|
||||
|
||||
### 3. CQL (Conservative Q-Learning)
|
||||
분포 이동 문제를 해결하기 위한 보수적 Q-Learning
|
||||
|
||||
## 데모 흐름
|
||||
|
||||
1. **콜드 스타트**: 초기 Q-Table 상태 (모든 값 0)
|
||||
2. **데이터 수집**: 무작위 탐험을 통한 경험 데이터 축적
|
||||
3. **Q-Learning**: 수집된 데이터로 Q-Table 실시간 업데이트
|
||||
4. **FQI+CQL**: 오프라인 배치 학습 시뮬레이션
|
||||
5. **학습된 정책**: 최적화된 협상 전략 활용
|
||||
|
||||
## 기술 스택
|
||||
|
||||
- **Backend**: FastAPI, Uvicorn
|
||||
- **Frontend**: Streamlit
|
||||
- **Data**: Pandas, NumPy
|
||||
- **Visualization**: Matplotlib, Seaborn, Plotly
|
||||
- **Environment**: Poetry, python-dotenv
|
||||
- **Testing**: Pytest
|
||||
|
||||
## 프로젝트 구조
|
||||
|
||||
```
|
||||
qtable_negotiation_demo/
|
||||
├── app/ # FastAPI 백엔드
|
||||
│ ├── api/ # API 라우터
|
||||
│ ├── core/ # 핵심 설정
|
||||
│ ├── models/ # 데이터 모델
|
||||
│ └── services/ # 비즈니스 로직
|
||||
├── frontend/ # Streamlit 프론트엔드
|
||||
├── tests/ # 테스트 코드
|
||||
├── data/ # 데이터 파일
|
||||
├── pyproject.toml # Poetry 설정
|
||||
├── .env # 환경 변수
|
||||
└── README.md # 이 파일
|
||||
```
|
||||
|
||||
## 라이선스
|
||||
|
||||
이 프로젝트는 교육 및 데모 목적으로 제작되었습니다.
|
||||
|
|
@ -0,0 +1,161 @@
|
|||
# Q-Table 데모 프로젝트 검토 및 수정 보고서
|
||||
|
||||
## 검토 개요
|
||||
Q-Table 기반 협상 전략 강화학습 데모 프로젝트를 검토하고 실행 가능한 상태로 수정하였습니다.
|
||||
|
||||
## 발견된 주요 문제점
|
||||
|
||||
### 1. 프로젝트 구조 불일치
|
||||
**문제**: README.md에서 언급된 디렉토리 구조와 실제 파일 구조가 일치하지 않음
|
||||
- 모든 파일이 루트 디렉토리에 평면적으로 배치됨
|
||||
- app/, frontend/ 등의 디렉토리가 존재하지 않음
|
||||
|
||||
**해결**: 올바른 디렉토리 구조로 재구성
|
||||
```
|
||||
qtable_project/
|
||||
├── app/
|
||||
│ ├── __init__.py
|
||||
│ ├── main.py
|
||||
│ ├── models/
|
||||
│ │ ├── __init__.py
|
||||
│ │ └── schemas.py
|
||||
│ ├── services/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── demo_service.py
|
||||
│ │ ├── negotiation_env.py
|
||||
│ │ ├── qtable_learner.py
|
||||
│ │ └── fqi_cql.py
|
||||
│ ├── api/
|
||||
│ │ ├── __init__.py
|
||||
│ │ └── endpoints.py
|
||||
│ └── core/
|
||||
│ ├── __init__.py
|
||||
│ └── config.py
|
||||
├── frontend/
|
||||
│ ├── __init__.py
|
||||
│ └── app.py
|
||||
├── run_api.py
|
||||
├── run_frontend.py
|
||||
├── run_both.py
|
||||
├── test_basic.py
|
||||
├── requirements.txt
|
||||
├── pyproject.toml
|
||||
├── .env
|
||||
└── README.md
|
||||
```
|
||||
|
||||
### 2. Import 경로 오류
|
||||
**문제**: 모든 파일에서 `from app.models.schemas import` 등의 import 경로가 실제 파일 구조와 맞지 않음
|
||||
|
||||
**해결**: 모든 파일의 import 경로를 올바르게 수정
|
||||
- `app/services/demo_service.py`: import 경로 수정
|
||||
- `app/services/qtable_learner.py`: import 경로 수정
|
||||
- `app/services/negotiation_env.py`: import 경로 수정
|
||||
- `app/api/endpoints.py`: import 경로 수정
|
||||
- `run_api.py`, `run_frontend.py`, `run_both.py`: import 경로 수정
|
||||
|
||||
### 3. 파일명 중복 문제
|
||||
**문제**: `main(1).py`, `__init__(1).py` 등 중복된 파일명 존재
|
||||
|
||||
**해결**: 파일명을 올바르게 정리
|
||||
- `main(1).py` → `app/main.py`
|
||||
- `__init__(1).py` → `app/__init__.py`
|
||||
- `__init__(2).py` → `app/models/__init__.py`
|
||||
- 기타 __init__ 파일들을 적절한 위치로 이동
|
||||
|
||||
### 4. 환경 설정 파일 누락
|
||||
**문제**: `.env` 파일이 없어 환경 변수 로드 실패 가능성
|
||||
|
||||
**해결**: `env` 파일을 `.env`로 복사하여 환경 변수 설정 완료
|
||||
|
||||
## 수정 완료 사항
|
||||
|
||||
### ✅ 의존성 설치
|
||||
- `requirements.txt`의 모든 패키지 설치 완료
|
||||
- FastAPI, Streamlit, pandas, numpy 등 필수 패키지 정상 설치
|
||||
|
||||
### ✅ 프로젝트 구조 정리
|
||||
- 올바른 디렉토리 구조로 재구성
|
||||
- 모든 파일을 적절한 위치로 이동
|
||||
|
||||
### ✅ Import 경로 수정
|
||||
- 모든 Python 파일의 import 경로 수정 완료
|
||||
- 모듈 간 의존성 문제 해결
|
||||
|
||||
### ✅ 실행 테스트 성공
|
||||
- 모든 모듈 import 테스트 통과
|
||||
- FastAPI 서버 시작 테스트 성공
|
||||
- 기본 테스트 스크립트 실행 성공
|
||||
|
||||
## 실행 방법
|
||||
|
||||
### 1. API 서버 실행
|
||||
```bash
|
||||
cd qtable_project
|
||||
python3 run_api.py
|
||||
```
|
||||
- 주소: http://localhost:8000
|
||||
- API 문서: http://localhost:8000/docs
|
||||
|
||||
### 2. 프론트엔드 실행
|
||||
```bash
|
||||
cd qtable_project
|
||||
python3 run_frontend.py
|
||||
```
|
||||
- 주소: http://localhost:8501
|
||||
|
||||
### 3. 통합 실행 (API + 프론트엔드)
|
||||
```bash
|
||||
cd qtable_project
|
||||
python3 run_both.py
|
||||
```
|
||||
|
||||
## 핵심 기능 확인
|
||||
|
||||
### ✅ Q-Table 학습 엔진
|
||||
- `QTableLearner` 클래스 정상 작동
|
||||
- 경험 데이터 수집 및 관리 기능
|
||||
|
||||
### ✅ 협상 환경 시뮬레이터
|
||||
- `NegotiationEnvironment` 클래스 정상 작동
|
||||
- 보상함수 R(s,a) = W × (A/P) + (1-W) × End 구현
|
||||
|
||||
### ✅ FQI+CQL 학습
|
||||
- `FQICQLLearner` 클래스 정상 작동
|
||||
- 오프라인 강화학습 기능
|
||||
|
||||
### ✅ FastAPI 백엔드
|
||||
- RESTful API 서버 정상 시작
|
||||
- 모든 엔드포인트 정상 로드
|
||||
|
||||
### ✅ Streamlit 프론트엔드
|
||||
- 웹 인터페이스 모듈 정상 로드
|
||||
- 실시간 시각화 기능
|
||||
|
||||
## 테스트 결과
|
||||
|
||||
### Import 테스트
|
||||
- ✅ schemas 모듈 import 성공
|
||||
- ✅ config 모듈 import 성공
|
||||
- ✅ negotiation_env 모듈 import 성공
|
||||
- ✅ demo_service 모듈 import 성공
|
||||
- ✅ FastAPI app import 성공
|
||||
|
||||
### 서버 실행 테스트
|
||||
- ✅ API 서버 정상 시작 (10초 테스트)
|
||||
- ✅ Streamlit import 성공
|
||||
- ✅ 기본 테스트 스크립트 실행 성공
|
||||
|
||||
## 결론
|
||||
|
||||
**프로젝트가 성공적으로 수정되어 실행 가능한 상태가 되었습니다.**
|
||||
|
||||
주요 수정 사항:
|
||||
1. 프로젝트 구조를 README에 명시된 대로 정리
|
||||
2. 모든 import 경로 오류 수정
|
||||
3. 파일명 중복 문제 해결
|
||||
4. 환경 설정 파일 정리
|
||||
5. 의존성 설치 완료
|
||||
|
||||
이제 Q-Table 기반 협상 전략 강화학습 데모를 정상적으로 실행할 수 있습니다.
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -0,0 +1,172 @@
|
|||
"""
|
||||
FastAPI 엔드포인트 정의
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException, BackgroundTasks
|
||||
from typing import Dict, Any
|
||||
|
||||
from app.models.schemas import (
|
||||
RewardCalculationRequest, RewardCalculationResponse,
|
||||
EpisodeGenerationRequest, LearningUpdateRequest,
|
||||
FQICQLRequest, ActionRecommendationRequest,
|
||||
ActionRecommendationResponse, SystemStatus
|
||||
)
|
||||
from app.services.demo_service import demo_service
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health_check():
|
||||
"""헬스 체크"""
|
||||
return {"status": "healthy", "service": "Q-Table Negotiation Demo"}
|
||||
|
||||
|
||||
@router.post("/reward/calculate", response_model=RewardCalculationResponse)
|
||||
async def calculate_reward(request: RewardCalculationRequest):
|
||||
"""보상 계산"""
|
||||
try:
|
||||
return demo_service.calculate_reward(request)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/episodes/generate")
|
||||
async def generate_episodes(request: EpisodeGenerationRequest):
|
||||
"""에피소드 생성"""
|
||||
try:
|
||||
result = demo_service.generate_episodes(request)
|
||||
return {"success": True, "data": result}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/learning/q-learning")
|
||||
async def update_q_learning(request: LearningUpdateRequest):
|
||||
"""Q-Learning 업데이트"""
|
||||
try:
|
||||
result = demo_service.update_q_learning(
|
||||
learning_rate=request.learning_rate,
|
||||
discount_factor=request.discount_factor,
|
||||
batch_size=request.batch_size
|
||||
)
|
||||
return {"success": True, "data": result}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/learning/fqi-cql")
|
||||
async def run_fqi_cql(request: FQICQLRequest):
|
||||
"""FQI+CQL 학습 실행"""
|
||||
try:
|
||||
result = demo_service.run_fqi_cql(
|
||||
alpha=request.alpha,
|
||||
gamma=request.gamma,
|
||||
batch_size=request.batch_size,
|
||||
num_iterations=request.num_iterations
|
||||
)
|
||||
return {"success": True, "data": result}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/action/recommend", response_model=ActionRecommendationResponse)
|
||||
async def recommend_action(request: ActionRecommendationRequest):
|
||||
"""행동 추천"""
|
||||
try:
|
||||
return demo_service.get_action_recommendation(request)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/status", response_model=SystemStatus)
|
||||
async def get_system_status():
|
||||
"""시스템 상태 조회"""
|
||||
try:
|
||||
return demo_service.get_system_status()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/qtable")
|
||||
async def get_q_table():
|
||||
"""Q-Table 데이터 조회"""
|
||||
try:
|
||||
result = demo_service.get_q_table()
|
||||
return {"success": True, "data": result}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/fqi-cql")
|
||||
async def get_fqi_cql_results():
|
||||
"""FQI+CQL 결과 조회"""
|
||||
try:
|
||||
result = demo_service.get_fqi_cql_results()
|
||||
return {"success": True, "data": result}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/experiences")
|
||||
async def get_experience_data():
|
||||
"""경험 데이터 조회"""
|
||||
try:
|
||||
result = demo_service.get_experience_data()
|
||||
return {"success": True, "data": result}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/compare/{state}")
|
||||
async def compare_policies(state: str):
|
||||
"""정책 비교"""
|
||||
try:
|
||||
result = demo_service.compare_policies(state)
|
||||
return {"success": True, "data": result}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/reset")
|
||||
async def reset_system():
|
||||
"""시스템 초기화"""
|
||||
try:
|
||||
demo_service.reset_all()
|
||||
return {"success": True, "message": "System reset completed"}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/states")
|
||||
async def get_all_states():
|
||||
"""모든 상태 목록 조회"""
|
||||
try:
|
||||
states = demo_service.env.get_all_states()
|
||||
return {"success": True, "data": {"states": states, "count": len(states)}}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/actions")
|
||||
async def get_all_actions():
|
||||
"""모든 행동 목록 조회"""
|
||||
try:
|
||||
actions = [action.value for action in demo_service.env.get_all_actions()]
|
||||
return {"success": True, "data": {"actions": actions, "count": len(actions)}}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/config")
|
||||
async def get_configuration():
|
||||
"""환경 설정 조회"""
|
||||
try:
|
||||
config = {
|
||||
"scenario_weights": {k.value: v for k, v in demo_service.env.scenario_weights.items()},
|
||||
"price_zone_weights": {k.value: v for k, v in demo_service.env.price_zone_weights.items()},
|
||||
"card_effects": {k.value: v for k, v in demo_service.env.card_effects.items()},
|
||||
"scenario_difficulty": {k.value: v for k, v in demo_service.env.scenario_difficulty.items()}
|
||||
}
|
||||
return {"success": True, "data": config}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
|
@ -0,0 +1 @@
|
|||
"""Q-Table 협상 전략 데모 애플리케이션"""
|
||||
Binary file not shown.
Binary file not shown.
|
|
@ -0,0 +1,62 @@
|
|||
"""
|
||||
애플리케이션 설정 관리
|
||||
"""
|
||||
import os
|
||||
from typing import List
|
||||
from pydantic import BaseSettings
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# .env 파일 로드
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""애플리케이션 설정"""
|
||||
|
||||
# API 서버 설정
|
||||
api_host: str = "localhost"
|
||||
api_port: int = 8000
|
||||
|
||||
# 프론트엔드 설정
|
||||
frontend_host: str = "localhost"
|
||||
frontend_port: int = 8501
|
||||
|
||||
# 강화학습 하이퍼파라미터
|
||||
default_learning_rate: float = 0.1
|
||||
default_discount_factor: float = 0.9
|
||||
default_epsilon: float = 0.1
|
||||
|
||||
# 협상 환경 설정
|
||||
default_anchor_price: int = 100
|
||||
max_episodes: int = 1000
|
||||
max_steps_per_episode: int = 10
|
||||
|
||||
# 로깅 설정
|
||||
log_level: str = "INFO"
|
||||
log_file: str = "app.log"
|
||||
|
||||
# CORS 설정
|
||||
allowed_origins: List[str] = ["http://localhost:8501", "http://127.0.0.1:8501"]
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
# 환경 변수로부터 값 로드
|
||||
self.api_host = os.getenv("API_HOST", self.api_host)
|
||||
self.api_port = int(os.getenv("API_PORT", str(self.api_port)))
|
||||
self.frontend_host = os.getenv("FRONTEND_HOST", self.frontend_host)
|
||||
self.frontend_port = int(os.getenv("FRONTEND_PORT", str(self.frontend_port)))
|
||||
self.default_learning_rate = float(os.getenv("DEFAULT_LEARNING_RATE", str(self.default_learning_rate)))
|
||||
self.default_discount_factor = float(os.getenv("DEFAULT_DISCOUNT_FACTOR", str(self.default_discount_factor)))
|
||||
self.default_epsilon = float(os.getenv("DEFAULT_EPSILON", str(self.default_epsilon)))
|
||||
self.default_anchor_price = int(os.getenv("DEFAULT_ANCHOR_PRICE", str(self.default_anchor_price)))
|
||||
self.max_episodes = int(os.getenv("MAX_EPISODES", str(self.max_episodes)))
|
||||
self.max_steps_per_episode = int(os.getenv("MAX_STEPS_PER_EPISODE", str(self.max_steps_per_episode)))
|
||||
self.log_level = os.getenv("LOG_LEVEL", self.log_level)
|
||||
self.log_file = os.getenv("LOG_FILE", self.log_file)
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
|
||||
|
||||
# 전역 설정 인스턴스
|
||||
settings = Settings()
|
||||
|
|
@ -0,0 +1,130 @@
|
|||
"""
|
||||
FastAPI 메인 애플리케이션
|
||||
"""
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
from app.core.config import settings
|
||||
from app.api.endpoints import router
|
||||
|
||||
# FastAPI 앱 생성
|
||||
app = FastAPI(
|
||||
title="Q-Table 협상 전략 데모 API",
|
||||
description="기업 간 협상 시뮬레이션을 위한 강화학습 Q-Table 데모 API",
|
||||
version="1.0.0",
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc"
|
||||
)
|
||||
|
||||
# CORS 설정
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.allowed_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# API 라우터 등록
|
||||
app.include_router(router, prefix="/api/v1")
|
||||
|
||||
|
||||
@app.get("/", response_class=HTMLResponse)
|
||||
async def root():
|
||||
"""루트 페이지"""
|
||||
html_content = """
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Q-Table 협상 전략 데모 API</title>
|
||||
<style>
|
||||
body { font-family: Arial, sans-serif; margin: 40px; }
|
||||
.header { color: #2c3e50; }
|
||||
.section { margin: 20px 0; }
|
||||
.link { color: #3498db; text-decoration: none; }
|
||||
.link:hover { text-decoration: underline; }
|
||||
.code { background-color: #f8f9fa; padding: 2px 4px; border-radius: 3px; }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1 class="header">🎯 Q-Table 협상 전략 데모 API</h1>
|
||||
|
||||
<div class="section">
|
||||
<h2>📋 API 문서</h2>
|
||||
<ul>
|
||||
<li><a href="/docs" class="link">Swagger UI</a> - 대화형 API 문서</li>
|
||||
<li><a href="/redoc" class="link">ReDoc</a> - 깔끔한 API 문서</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>🚀 주요 기능</h2>
|
||||
<ul>
|
||||
<li><strong>보상 계산:</strong> <span class="code">POST /api/v1/reward/calculate</span></li>
|
||||
<li><strong>에피소드 생성:</strong> <span class="code">POST /api/v1/episodes/generate</span></li>
|
||||
<li><strong>Q-Learning 업데이트:</strong> <span class="code">POST /api/v1/learning/q-learning</span></li>
|
||||
<li><strong>FQI+CQL 학습:</strong> <span class="code">POST /api/v1/learning/fqi-cql</span></li>
|
||||
<li><strong>행동 추천:</strong> <span class="code">POST /api/v1/action/recommend</span></li>
|
||||
<li><strong>시스템 상태:</strong> <span class="code">GET /api/v1/status</span></li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>📊 데이터 조회</h2>
|
||||
<ul>
|
||||
<li><strong>Q-Table:</strong> <span class="code">GET /api/v1/qtable</span></li>
|
||||
<li><strong>경험 데이터:</strong> <span class="code">GET /api/v1/experiences</span></li>
|
||||
<li><strong>FQI+CQL 결과:</strong> <span class="code">GET /api/v1/fqi-cql</span></li>
|
||||
<li><strong>정책 비교:</strong> <span class="code">GET /api/v1/compare/{state}</span></li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>🔧 유틸리티</h2>
|
||||
<ul>
|
||||
<li><strong>헬스 체크:</strong> <span class="code">GET /api/v1/health</span></li>
|
||||
<li><strong>시스템 초기화:</strong> <span class="code">POST /api/v1/reset</span></li>
|
||||
<li><strong>설정 조회:</strong> <span class="code">GET /api/v1/config</span></li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>🎮 프론트엔드</h2>
|
||||
<p>Streamlit 기반 대화형 인터페이스: <a href="http://localhost:8501" class="link">http://localhost:8501</a></p>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<h2>📝 보상함수</h2>
|
||||
<p><strong>R(s,a) = W × (A/P) + (1-W) × End</strong></p>
|
||||
<ul>
|
||||
<li><strong>W:</strong> 가중치 = (S_n + PZ_n) / 2</li>
|
||||
<li><strong>A:</strong> 앵커링 목표가</li>
|
||||
<li><strong>P:</strong> 상대방 제안가</li>
|
||||
<li><strong>End:</strong> 협상 종료 여부</li>
|
||||
</ul>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
return html_content
|
||||
|
||||
|
||||
def start_api():
|
||||
"""API 서버 시작 (Poetry 스크립트용)"""
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host=settings.api_host,
|
||||
port=settings.api_port,
|
||||
reload=True
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host=settings.api_host,
|
||||
port=settings.api_port,
|
||||
reload=True
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
|
|
@ -0,0 +1,127 @@
|
|||
"""
|
||||
API 요청/응답 데이터 모델
|
||||
"""
|
||||
from typing import Optional, List, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class CardType(str, Enum):
|
||||
"""협상 카드 타입"""
|
||||
C1 = "C1"
|
||||
C2 = "C2"
|
||||
C3 = "C3"
|
||||
C4 = "C4"
|
||||
|
||||
|
||||
class ScenarioType(str, Enum):
|
||||
"""시나리오 타입"""
|
||||
A = "A" # S_1
|
||||
B = "B" # S_4
|
||||
C = "C" # S_3
|
||||
D = "D" # S_2
|
||||
|
||||
|
||||
class PriceZoneType(str, Enum):
|
||||
"""가격 구간 타입"""
|
||||
PZ1 = "PZ1" # P < A
|
||||
PZ2 = "PZ2" # A < P < T
|
||||
PZ3 = "PZ3" # T < P
|
||||
|
||||
|
||||
class ExperienceData(BaseModel):
|
||||
"""경험 데이터 모델"""
|
||||
state: str = Field(..., description="현재 상태")
|
||||
action: CardType = Field(..., description="선택한 행동")
|
||||
reward: float = Field(..., description="받은 보상")
|
||||
next_state: str = Field(..., description="다음 상태")
|
||||
done: bool = Field(..., description="에피소드 종료 여부")
|
||||
timestamp: float = Field(..., description="타임스탬프")
|
||||
metadata: Optional[Dict[str, Any]] = Field(default=None, description="추가 메타데이터")
|
||||
|
||||
|
||||
class NegotiationState(BaseModel):
|
||||
"""협상 상태 모델"""
|
||||
current_card: CardType = Field(..., description="현재 카드")
|
||||
scenario: ScenarioType = Field(..., description="시나리오")
|
||||
price_zone: PriceZoneType = Field(..., description="가격 구간")
|
||||
|
||||
@property
|
||||
def state_id(self) -> str:
|
||||
"""상태 ID 생성"""
|
||||
return f"{self.current_card.value}{self.scenario.value}{self.price_zone.value}"
|
||||
|
||||
|
||||
class RewardCalculationRequest(BaseModel):
|
||||
"""보상 계산 요청 모델"""
|
||||
scenario: ScenarioType = Field(..., description="시나리오")
|
||||
price_zone: PriceZoneType = Field(..., description="가격 구간")
|
||||
anchor_price: float = Field(..., gt=0, description="목표가 (A)")
|
||||
proposed_price: float = Field(..., gt=0, description="제안가 (P)")
|
||||
is_end: bool = Field(..., description="협상 종료 여부")
|
||||
|
||||
|
||||
class RewardCalculationResponse(BaseModel):
|
||||
"""보상 계산 응답 모델"""
|
||||
reward: float = Field(..., description="계산된 보상")
|
||||
weight: float = Field(..., description="가중치 W")
|
||||
scenario_weight: float = Field(..., description="시나리오 가중치 S_n")
|
||||
price_zone_weight: float = Field(..., description="가격구간 가중치 PZ_n")
|
||||
price_ratio: float = Field(..., description="가격 비율 A/P")
|
||||
formula_breakdown: str = Field(..., description="공식 분해")
|
||||
|
||||
|
||||
class QTableState(BaseModel):
|
||||
"""Q-Table 상태 모델"""
|
||||
q_table: Dict[str, Dict[str, float]] = Field(..., description="Q-Table 데이터")
|
||||
update_count: int = Field(..., description="업데이트 횟수")
|
||||
learning_rate: float = Field(..., description="학습률")
|
||||
discount_factor: float = Field(..., description="할인율")
|
||||
|
||||
|
||||
class EpisodeGenerationRequest(BaseModel):
|
||||
"""에피소드 생성 요청 모델"""
|
||||
num_episodes: int = Field(..., ge=1, le=100, description="생성할 에피소드 수")
|
||||
max_steps: int = Field(..., ge=1, le=20, description="에피소드당 최대 스텝")
|
||||
anchor_price: float = Field(..., gt=0, description="목표가")
|
||||
exploration_rate: float = Field(default=0.4, ge=0, le=1, description="탐험율")
|
||||
|
||||
|
||||
class LearningUpdateRequest(BaseModel):
|
||||
"""학습 업데이트 요청 모델"""
|
||||
learning_rate: float = Field(..., ge=0.001, le=1.0, description="학습률")
|
||||
discount_factor: float = Field(..., ge=0.1, le=0.99, description="할인율")
|
||||
batch_size: int = Field(default=32, ge=1, le=1000, description="배치 크기")
|
||||
|
||||
|
||||
class FQICQLRequest(BaseModel):
|
||||
"""FQI+CQL 학습 요청 모델"""
|
||||
alpha: float = Field(default=1.0, ge=0, description="CQL 보수성 파라미터")
|
||||
gamma: float = Field(default=0.95, ge=0.1, le=0.99, description="할인율")
|
||||
batch_size: int = Field(default=32, ge=1, le=1000, description="배치 크기")
|
||||
num_iterations: int = Field(default=10, ge=1, le=100, description="반복 횟수")
|
||||
|
||||
|
||||
class ActionRecommendationRequest(BaseModel):
|
||||
"""행동 추천 요청 모델"""
|
||||
current_state: str = Field(..., description="현재 상태")
|
||||
use_epsilon_greedy: bool = Field(default=False, description="엡실론 그리디 사용")
|
||||
epsilon: float = Field(default=0.1, ge=0, le=1, description="엡실론 값")
|
||||
|
||||
|
||||
class ActionRecommendationResponse(BaseModel):
|
||||
"""행동 추천 응답 모델"""
|
||||
recommended_action: CardType = Field(..., description="추천 행동")
|
||||
q_values: Dict[str, float] = Field(..., description="현재 상태의 Q값들")
|
||||
confidence: float = Field(..., description="추천 신뢰도")
|
||||
exploration: bool = Field(..., description="탐험 행동 여부")
|
||||
|
||||
|
||||
class SystemStatus(BaseModel):
|
||||
"""시스템 상태 모델"""
|
||||
total_experiences: int = Field(..., description="총 경험 데이터 수")
|
||||
q_table_updates: int = Field(..., description="Q-Table 업데이트 횟수")
|
||||
unique_states: int = Field(..., description="고유 상태 수")
|
||||
average_reward: float = Field(..., description="평균 보상")
|
||||
success_rate: float = Field(..., description="성공률")
|
||||
last_update: Optional[float] = Field(default=None, description="마지막 업데이트 시간")
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -0,0 +1,440 @@
|
|||
"""
|
||||
Q-Table 협상 전략 데모 메인 서비스
|
||||
"""
|
||||
import random
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
from app.models.schemas import (
|
||||
CardType, ScenarioType, PriceZoneType,
|
||||
ExperienceData, EpisodeGenerationRequest,
|
||||
RewardCalculationRequest, RewardCalculationResponse,
|
||||
ActionRecommendationRequest, ActionRecommendationResponse,
|
||||
SystemStatus
|
||||
)
|
||||
from app.services.negotiation_env import NegotiationEnvironment
|
||||
from app.services.qtable_learner import QTableLearner, ExperienceBuffer
|
||||
from app.services.fqi_cql import FQICQLLearner
|
||||
|
||||
|
||||
class DemoService:
|
||||
"""Q-Table 협상 전략 데모 메인 서비스"""
|
||||
|
||||
def __init__(self):
|
||||
# 환경 초기화
|
||||
self.env = NegotiationEnvironment()
|
||||
|
||||
# 상태 및 행동 공간
|
||||
self.states = self.env.get_all_states()
|
||||
self.actions = self.env.get_all_actions()
|
||||
|
||||
# 학습 엔진들
|
||||
self.experience_buffer = ExperienceBuffer(max_size=10000)
|
||||
self.q_learner = QTableLearner(
|
||||
states=self.states,
|
||||
actions=self.actions,
|
||||
learning_rate=0.1,
|
||||
discount_factor=0.9,
|
||||
epsilon=0.1
|
||||
)
|
||||
self.fqi_cql_learner = FQICQLLearner(
|
||||
states=self.states,
|
||||
actions=self.actions,
|
||||
alpha=1.0,
|
||||
gamma=0.95
|
||||
)
|
||||
|
||||
# 통계 정보
|
||||
self.episode_count = 0
|
||||
self.start_time = time.time()
|
||||
|
||||
def calculate_reward(self, request: RewardCalculationRequest) -> RewardCalculationResponse:
|
||||
"""보상 계산"""
|
||||
reward, weight = self.env.calculate_reward(
|
||||
scenario=request.scenario,
|
||||
price_zone=request.price_zone,
|
||||
anchor_price=request.anchor_price,
|
||||
proposed_price=request.proposed_price,
|
||||
is_end=request.is_end
|
||||
)
|
||||
|
||||
# 시나리오 및 가격구간 가중치
|
||||
scenario_weight = self.env.scenario_weights[request.scenario]
|
||||
price_zone_weight = self.env.price_zone_weights[request.price_zone]
|
||||
|
||||
# 가격 비율
|
||||
price_ratio = request.anchor_price / request.proposed_price if request.proposed_price > 0 else float('inf')
|
||||
|
||||
# 공식 분해
|
||||
formula_breakdown = (
|
||||
f"R(s,a) = W × (A/P) + (1-W) × End\n"
|
||||
f"W = (S_n + PZ_n) / 2 = ({scenario_weight} + {price_zone_weight}) / 2 = {weight:.3f}\n"
|
||||
f"A/P = {request.anchor_price}/{request.proposed_price} = {price_ratio:.3f}\n"
|
||||
f"End = {1 if request.is_end else 0}\n"
|
||||
f"R(s,a) = {weight:.3f} × {price_ratio:.3f} + {1-weight:.3f} × {1 if request.is_end else 0} = {reward:.3f}"
|
||||
)
|
||||
|
||||
return RewardCalculationResponse(
|
||||
reward=reward,
|
||||
weight=weight,
|
||||
scenario_weight=scenario_weight,
|
||||
price_zone_weight=price_zone_weight,
|
||||
price_ratio=price_ratio,
|
||||
formula_breakdown=formula_breakdown
|
||||
)
|
||||
|
||||
def generate_episodes(self, request: EpisodeGenerationRequest) -> Dict[str, Any]:
|
||||
"""에피소드 생성"""
|
||||
new_experiences = 0
|
||||
episode_results = []
|
||||
|
||||
for episode in range(request.num_episodes):
|
||||
episode_result = self._generate_single_episode(
|
||||
max_steps=request.max_steps,
|
||||
anchor_price=request.anchor_price,
|
||||
exploration_rate=request.exploration_rate,
|
||||
episode_id=self.episode_count + episode
|
||||
)
|
||||
episode_results.append(episode_result)
|
||||
new_experiences += episode_result['steps']
|
||||
|
||||
self.episode_count += request.num_episodes
|
||||
|
||||
return {
|
||||
"episodes_generated": request.num_episodes,
|
||||
"new_experiences": new_experiences,
|
||||
"episode_results": episode_results,
|
||||
"total_episodes": self.episode_count
|
||||
}
|
||||
|
||||
def _generate_single_episode(
|
||||
self,
|
||||
max_steps: int,
|
||||
anchor_price: float,
|
||||
exploration_rate: float,
|
||||
episode_id: int
|
||||
) -> Dict[str, Any]:
|
||||
"""단일 에피소드 생성"""
|
||||
# 초기 상태
|
||||
current_state = "C0S0P0"
|
||||
scenario = random.choice(list(ScenarioType))
|
||||
|
||||
episode_reward = 0.0
|
||||
steps = 0
|
||||
success = False
|
||||
|
||||
for step in range(max_steps):
|
||||
# 행동 선택 (epsilon-greedy)
|
||||
if random.random() < exploration_rate:
|
||||
action = random.choice(self.actions)
|
||||
is_exploration = True
|
||||
else:
|
||||
action = self.q_learner.get_optimal_action(current_state)
|
||||
is_exploration = False
|
||||
|
||||
# 환경 응답
|
||||
proposed_price = self.env.simulate_opponent_response(
|
||||
current_card=action,
|
||||
scenario=scenario,
|
||||
anchor_price=anchor_price,
|
||||
step=step
|
||||
)
|
||||
|
||||
# 가격 구간 결정
|
||||
price_zone = self.env.get_price_zone(proposed_price, anchor_price)
|
||||
|
||||
# 다음 상태
|
||||
next_state = f"{action.value}{scenario.value}{price_zone.value}"
|
||||
|
||||
# 종료 조건 확인
|
||||
is_done = self.env.is_negotiation_successful(proposed_price, anchor_price) or (step >= max_steps - 1)
|
||||
if self.env.is_negotiation_successful(proposed_price, anchor_price):
|
||||
success = True
|
||||
|
||||
# 보상 계산
|
||||
reward, weight = self.env.calculate_reward(
|
||||
scenario=scenario,
|
||||
price_zone=price_zone,
|
||||
anchor_price=anchor_price,
|
||||
proposed_price=proposed_price,
|
||||
is_end=is_done
|
||||
)
|
||||
|
||||
# 경험 저장
|
||||
metadata = {
|
||||
'episode': episode_id,
|
||||
'step': step,
|
||||
'scenario': scenario.value,
|
||||
'proposed_price': proposed_price,
|
||||
'weight': weight,
|
||||
'is_exploration': is_exploration,
|
||||
'anchor_price': anchor_price
|
||||
}
|
||||
|
||||
self.experience_buffer.add_experience(
|
||||
state=current_state,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_state,
|
||||
done=is_done,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
episode_reward += reward
|
||||
steps += 1
|
||||
current_state = next_state
|
||||
|
||||
if is_done:
|
||||
break
|
||||
|
||||
return {
|
||||
'episode_id': episode_id,
|
||||
'steps': steps,
|
||||
'total_reward': episode_reward,
|
||||
'success': success,
|
||||
'final_price': proposed_price if 'proposed_price' in locals() else anchor_price,
|
||||
'scenario': scenario.value
|
||||
}
|
||||
|
||||
def update_q_learning(self, learning_rate: float, discount_factor: float, batch_size: int) -> Dict[str, Any]:
|
||||
"""Q-Learning 업데이트"""
|
||||
# 하이퍼파라미터 설정
|
||||
self.q_learner.set_hyperparameters(
|
||||
learning_rate=learning_rate,
|
||||
discount_factor=discount_factor
|
||||
)
|
||||
|
||||
# 경험 데이터 가져오기
|
||||
experiences = self.experience_buffer.get_experiences()
|
||||
if not experiences:
|
||||
return {"message": "No experience data available", "updates": 0}
|
||||
|
||||
# 배치 샘플링
|
||||
if len(experiences) > batch_size:
|
||||
batch = self.experience_buffer.sample_batch(batch_size)
|
||||
else:
|
||||
batch = experiences
|
||||
|
||||
# 배치 업데이트
|
||||
result = self.q_learner.batch_update(batch)
|
||||
|
||||
return {
|
||||
"message": "Q-Learning update completed",
|
||||
"batch_size": len(batch),
|
||||
"updates": result["updates"],
|
||||
"avg_td_error": result["avg_td_error"],
|
||||
"total_updates": self.q_learner.update_count
|
||||
}
|
||||
|
||||
def run_fqi_cql(self, alpha: float, gamma: float, batch_size: int, num_iterations: int) -> Dict[str, Any]:
|
||||
"""FQI+CQL 실행"""
|
||||
# 하이퍼파라미터 설정
|
||||
self.fqi_cql_learner.set_hyperparameters(
|
||||
alpha=alpha,
|
||||
gamma=gamma
|
||||
)
|
||||
|
||||
# 경험 데이터 가져오기
|
||||
experiences = self.experience_buffer.get_experiences()
|
||||
if not experiences:
|
||||
return {"message": "No experience data available", "iterations": 0}
|
||||
|
||||
# 배치 샘플링
|
||||
if len(experiences) > batch_size:
|
||||
batch = self.experience_buffer.sample_batch(batch_size)
|
||||
else:
|
||||
batch = experiences
|
||||
|
||||
# FQI+CQL 학습
|
||||
result = self.fqi_cql_learner.train_multiple_iterations(
|
||||
experience_batch=batch,
|
||||
num_iterations=num_iterations
|
||||
)
|
||||
|
||||
# 정책 비교
|
||||
policy_comparison = self.fqi_cql_learner.compare_with_behavior_policy(batch)
|
||||
|
||||
return {
|
||||
"message": "FQI+CQL training completed",
|
||||
"training_result": result,
|
||||
"policy_comparison": policy_comparison,
|
||||
"batch_size": len(batch)
|
||||
}
|
||||
|
||||
def get_action_recommendation(self, request: ActionRecommendationRequest) -> ActionRecommendationResponse:
|
||||
"""행동 추천"""
|
||||
# Q값들 가져오기
|
||||
q_values = self.q_learner.get_state_q_values(request.current_state)
|
||||
|
||||
# 행동 선택
|
||||
if request.use_epsilon_greedy:
|
||||
action, is_exploration = self.q_learner.select_action(
|
||||
state=request.current_state,
|
||||
use_epsilon_greedy=True
|
||||
)
|
||||
# 임시로 epsilon 설정
|
||||
original_epsilon = self.q_learner.epsilon
|
||||
self.q_learner.epsilon = request.epsilon
|
||||
action, is_exploration = self.q_learner.select_action(
|
||||
state=request.current_state,
|
||||
use_epsilon_greedy=True
|
||||
)
|
||||
self.q_learner.epsilon = original_epsilon
|
||||
else:
|
||||
action = self.q_learner.get_optimal_action(request.current_state)
|
||||
is_exploration = False
|
||||
|
||||
# 신뢰도 계산 (Q값 분산 기반)
|
||||
if q_values and len(q_values) > 1:
|
||||
q_vals = list(q_values.values())
|
||||
max_q = max(q_vals)
|
||||
q_range = max(q_vals) - min(q_vals)
|
||||
confidence = max_q / (q_range + 1e-8) if q_range > 0 else 1.0
|
||||
confidence = min(confidence, 1.0)
|
||||
else:
|
||||
confidence = 0.0
|
||||
|
||||
return ActionRecommendationResponse(
|
||||
recommended_action=action,
|
||||
q_values=q_values,
|
||||
confidence=confidence,
|
||||
exploration=is_exploration
|
||||
)
|
||||
|
||||
def get_system_status(self) -> SystemStatus:
|
||||
"""시스템 상태 조회"""
|
||||
exp_df = self.experience_buffer.get_dataframe()
|
||||
|
||||
if not exp_df.empty:
|
||||
avg_reward = exp_df['reward'].mean()
|
||||
success_count = exp_df['done'].sum()
|
||||
success_rate = success_count / len(exp_df) if len(exp_df) > 0 else 0.0
|
||||
unique_states = exp_df['state'].nunique()
|
||||
else:
|
||||
avg_reward = 0.0
|
||||
success_rate = 0.0
|
||||
unique_states = 0
|
||||
|
||||
return SystemStatus(
|
||||
total_experiences=self.experience_buffer.size(),
|
||||
q_table_updates=self.q_learner.update_count,
|
||||
unique_states=unique_states,
|
||||
average_reward=avg_reward,
|
||||
success_rate=success_rate,
|
||||
last_update=time.time()
|
||||
)
|
||||
|
||||
def get_q_table(self) -> Dict[str, Any]:
|
||||
"""Q-Table 데이터 반환"""
|
||||
q_table_df = self.q_learner.get_q_table_copy()
|
||||
stats = self.q_learner.get_learning_statistics()
|
||||
|
||||
return {
|
||||
"q_table": q_table_df.to_dict(),
|
||||
"statistics": stats,
|
||||
"update_count": self.q_learner.update_count,
|
||||
"hyperparameters": {
|
||||
"learning_rate": self.q_learner.learning_rate,
|
||||
"discount_factor": self.q_learner.discount_factor,
|
||||
"epsilon": self.q_learner.epsilon
|
||||
}
|
||||
}
|
||||
|
||||
def get_fqi_cql_results(self) -> Dict[str, Any]:
|
||||
"""FQI+CQL 결과 반환"""
|
||||
q_network_df = self.fqi_cql_learner.get_q_network_copy()
|
||||
stats = self.fqi_cql_learner.get_training_statistics()
|
||||
|
||||
return {
|
||||
"q_network": q_network_df.to_dict(),
|
||||
"statistics": stats,
|
||||
"batch_count": self.fqi_cql_learner.batch_count,
|
||||
"hyperparameters": {
|
||||
"alpha": self.fqi_cql_learner.alpha,
|
||||
"gamma": self.fqi_cql_learner.gamma,
|
||||
"learning_rate": self.fqi_cql_learner.learning_rate
|
||||
}
|
||||
}
|
||||
|
||||
def get_experience_data(self) -> Dict[str, Any]:
|
||||
"""경험 데이터 반환"""
|
||||
exp_df = self.experience_buffer.get_dataframe()
|
||||
|
||||
if not exp_df.empty:
|
||||
# 기본 통계
|
||||
stats = {
|
||||
"total_count": len(exp_df),
|
||||
"avg_reward": exp_df['reward'].mean(),
|
||||
"reward_std": exp_df['reward'].std(),
|
||||
"success_rate": exp_df['done'].sum() / len(exp_df),
|
||||
"unique_states": exp_df['state'].nunique(),
|
||||
"unique_actions": exp_df['action'].nunique()
|
||||
}
|
||||
|
||||
# 최근 데이터
|
||||
recent_data = exp_df.tail(20).to_dict('records')
|
||||
else:
|
||||
stats = {
|
||||
"total_count": 0,
|
||||
"avg_reward": 0.0,
|
||||
"reward_std": 0.0,
|
||||
"success_rate": 0.0,
|
||||
"unique_states": 0,
|
||||
"unique_actions": 0
|
||||
}
|
||||
recent_data = []
|
||||
|
||||
return {
|
||||
"statistics": stats,
|
||||
"recent_data": recent_data,
|
||||
"buffer_size": self.experience_buffer.size(),
|
||||
"max_size": self.experience_buffer.max_size
|
||||
}
|
||||
|
||||
def reset_all(self):
|
||||
"""모든 학습 상태 초기화"""
|
||||
self.experience_buffer.clear()
|
||||
self.q_learner.reset()
|
||||
self.fqi_cql_learner.reset()
|
||||
self.episode_count = 0
|
||||
self.start_time = time.time()
|
||||
|
||||
def compare_policies(self, state: str) -> Dict[str, Any]:
|
||||
"""Q-Learning과 FQI+CQL 정책 비교"""
|
||||
# Q-Learning 정책
|
||||
q_learning_action = self.q_learner.get_optimal_action(state)
|
||||
q_learning_values = self.q_learner.get_state_q_values(state)
|
||||
|
||||
# FQI+CQL 정책
|
||||
fqi_cql_action = self.fqi_cql_learner.get_optimal_action(state)
|
||||
fqi_cql_values = self.fqi_cql_learner.get_state_q_values(state)
|
||||
|
||||
# 정책 일치 여부
|
||||
policy_agreement = (q_learning_action == fqi_cql_action)
|
||||
|
||||
# Q값 차이
|
||||
q_value_differences = {}
|
||||
for action_name in q_learning_values:
|
||||
diff = abs(q_learning_values[action_name] - fqi_cql_values.get(action_name, 0.0))
|
||||
q_value_differences[action_name] = diff
|
||||
|
||||
return {
|
||||
"state": state,
|
||||
"q_learning": {
|
||||
"action": q_learning_action.value,
|
||||
"q_values": q_learning_values
|
||||
},
|
||||
"fqi_cql": {
|
||||
"action": fqi_cql_action.value,
|
||||
"q_values": fqi_cql_values
|
||||
},
|
||||
"policy_agreement": policy_agreement,
|
||||
"q_value_differences": q_value_differences,
|
||||
"max_difference": max(q_value_differences.values()) if q_value_differences else 0.0
|
||||
}
|
||||
|
||||
|
||||
# 전역 서비스 인스턴스
|
||||
demo_service = DemoService()
|
||||
|
|
@ -0,0 +1,313 @@
|
|||
"""
|
||||
FQI (Fitted Q-Iteration) + CQL (Conservative Q-Learning) 서비스
|
||||
"""
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import time
|
||||
from typing import Dict, List, Any, Optional
|
||||
from app.models.schemas import CardType, ExperienceData
|
||||
|
||||
|
||||
class FQICQLLearner:
|
||||
"""FQI + CQL 학습 엔진"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
states: List[str],
|
||||
actions: List[CardType],
|
||||
alpha: float = 1.0, # CQL 보수성 파라미터
|
||||
gamma: float = 0.95, # 할인율
|
||||
learning_rate: float = 0.01
|
||||
):
|
||||
self.states = states
|
||||
self.actions = actions
|
||||
self.alpha = alpha
|
||||
self.gamma = gamma
|
||||
self.learning_rate = learning_rate
|
||||
|
||||
# Q-네트워크 시뮬레이션 (실제로는 신경망이지만 여기서는 테이블로 구현)
|
||||
action_names = [action.value for action in actions]
|
||||
self.q_network = pd.DataFrame(
|
||||
np.random.uniform(0, 0.1, (len(states), len(actions))),
|
||||
index=states,
|
||||
columns=action_names
|
||||
)
|
||||
|
||||
# 학습 기록
|
||||
self.training_history = []
|
||||
self.batch_count = 0
|
||||
|
||||
def fitted_q_iteration(self, experience_batch: List[ExperienceData]) -> Dict[str, float]:
|
||||
"""
|
||||
FQI 배치 학습 수행
|
||||
|
||||
Args:
|
||||
experience_batch: 경험 데이터 배치
|
||||
|
||||
Returns:
|
||||
학습 결과 통계
|
||||
"""
|
||||
if not experience_batch:
|
||||
return {"bellman_loss": 0.0, "cql_penalty": 0.0, "batch_size": 0}
|
||||
|
||||
bellman_losses = []
|
||||
cql_penalties = []
|
||||
|
||||
for exp in experience_batch:
|
||||
state = exp.state
|
||||
action = exp.action
|
||||
reward = exp.reward
|
||||
next_state = exp.next_state
|
||||
done = exp.done
|
||||
|
||||
if state not in self.q_network.index:
|
||||
continue
|
||||
|
||||
# Bellman Target 계산
|
||||
if done or next_state not in self.q_network.index:
|
||||
target = reward
|
||||
else:
|
||||
target = reward + self.gamma * self.q_network.loc[next_state].max()
|
||||
|
||||
current_q = self.q_network.loc[state, action.value]
|
||||
|
||||
# Bellman Error
|
||||
bellman_error = (current_q - target) ** 2
|
||||
bellman_losses.append(bellman_error)
|
||||
|
||||
# CQL Conservative Penalty 계산
|
||||
# 데이터셋에 있는 행동 vs 모든 가능한 행동의 Q값 차이
|
||||
all_q_values = self.q_network.loc[state]
|
||||
dataset_q = current_q
|
||||
|
||||
# 보수적 추정: 데이터에 없는 행동의 Q값을 낮게 유지
|
||||
ood_q_values = [] # Out-of-Distribution Q값들
|
||||
for other_action in self.actions:
|
||||
if other_action != action: # 현재 행동이 아닌 다른 행동들
|
||||
ood_q_values.append(all_q_values[other_action.value])
|
||||
|
||||
if ood_q_values:
|
||||
max_ood_q = max(ood_q_values)
|
||||
cql_penalty = self.alpha * max(0, max_ood_q - dataset_q)
|
||||
else:
|
||||
cql_penalty = 0.0
|
||||
|
||||
cql_penalties.append(cql_penalty)
|
||||
|
||||
# 네트워크 업데이트 (간단한 그래디언트 스텝)
|
||||
# 실제로는 신경망 역전파이지만, 여기서는 직접 업데이트
|
||||
gradient = self.learning_rate * (target - current_q)
|
||||
conservative_gradient = self.learning_rate * cql_penalty
|
||||
|
||||
# 벨만 오차 최소화 + CQL 페널티 적용
|
||||
update = gradient - conservative_gradient
|
||||
self.q_network.loc[state, action.value] += update
|
||||
|
||||
# 학습 기록 저장
|
||||
avg_bellman_loss = np.mean(bellman_losses) if bellman_losses else 0.0
|
||||
avg_cql_penalty = np.mean(cql_penalties) if cql_penalties else 0.0
|
||||
|
||||
self.training_history.append({
|
||||
'batch': self.batch_count,
|
||||
'avg_bellman_loss': avg_bellman_loss,
|
||||
'avg_cql_penalty': avg_cql_penalty,
|
||||
'batch_size': len(experience_batch),
|
||||
'timestamp': time.time()
|
||||
})
|
||||
|
||||
self.batch_count += 1
|
||||
|
||||
return {
|
||||
"bellman_loss": avg_bellman_loss,
|
||||
"cql_penalty": avg_cql_penalty,
|
||||
"batch_size": len(experience_batch)
|
||||
}
|
||||
|
||||
def train_multiple_iterations(
|
||||
self,
|
||||
experience_batch: List[ExperienceData],
|
||||
num_iterations: int = 10
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
여러 번의 FQI 반복 수행
|
||||
|
||||
Args:
|
||||
experience_batch: 경험 데이터 배치
|
||||
num_iterations: 반복 횟수
|
||||
|
||||
Returns:
|
||||
전체 학습 결과 통계
|
||||
"""
|
||||
iteration_results = []
|
||||
|
||||
for i in range(num_iterations):
|
||||
# 각 반복에서 배치를 셔플
|
||||
shuffled_batch = np.random.permutation(experience_batch).tolist()
|
||||
result = self.fitted_q_iteration(shuffled_batch)
|
||||
|
||||
iteration_results.append({
|
||||
'iteration': i,
|
||||
**result
|
||||
})
|
||||
|
||||
# 전체 통계 계산
|
||||
return {
|
||||
'total_iterations': num_iterations,
|
||||
'avg_bellman_loss': np.mean([r['bellman_loss'] for r in iteration_results]),
|
||||
'avg_cql_penalty': np.mean([r['cql_penalty'] for r in iteration_results]),
|
||||
'final_bellman_loss': iteration_results[-1]['bellman_loss'] if iteration_results else 0.0,
|
||||
'final_cql_penalty': iteration_results[-1]['cql_penalty'] if iteration_results else 0.0,
|
||||
'iteration_details': iteration_results
|
||||
}
|
||||
|
||||
def get_q_value(self, state: str, action: CardType) -> float:
|
||||
"""Q값 조회"""
|
||||
if state in self.q_network.index:
|
||||
return self.q_network.loc[state, action.value]
|
||||
return 0.0
|
||||
|
||||
def get_optimal_action(self, state: str) -> CardType:
|
||||
"""현재 상태에서 최적 행동 선택"""
|
||||
if state not in self.q_network.index:
|
||||
import random
|
||||
return random.choice(self.actions)
|
||||
|
||||
q_values = self.q_network.loc[state]
|
||||
best_action_name = q_values.idxmax()
|
||||
|
||||
for action in self.actions:
|
||||
if action.value == best_action_name:
|
||||
return action
|
||||
|
||||
import random
|
||||
return random.choice(self.actions)
|
||||
|
||||
def get_state_q_values(self, state: str) -> Dict[str, float]:
|
||||
"""특정 상태의 Q값들 반환"""
|
||||
if state not in self.q_network.index:
|
||||
return {action.value: 0.0 for action in self.actions}
|
||||
|
||||
return self.q_network.loc[state].to_dict()
|
||||
|
||||
def get_q_network_copy(self) -> pd.DataFrame:
|
||||
"""Q-네트워크 복사본 반환"""
|
||||
return self.q_network.copy()
|
||||
|
||||
def compare_with_behavior_policy(
|
||||
self,
|
||||
experience_batch: List[ExperienceData]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
학습된 정책과 행동 정책(데이터 수집 정책) 비교
|
||||
|
||||
Args:
|
||||
experience_batch: 경험 데이터 배치
|
||||
|
||||
Returns:
|
||||
정책 비교 결과
|
||||
"""
|
||||
if not experience_batch:
|
||||
return {"policy_divergence": 0.0, "action_agreement": 0.0}
|
||||
|
||||
agreements = 0
|
||||
total_comparisons = 0
|
||||
q_value_differences = []
|
||||
|
||||
for exp in experience_batch:
|
||||
state = exp.state
|
||||
behavior_action = exp.action # 데이터 수집 시 선택된 행동
|
||||
|
||||
if state not in self.q_network.index:
|
||||
continue
|
||||
|
||||
# 현재 학습된 정책의 최적 행동
|
||||
learned_action = self.get_optimal_action(state)
|
||||
|
||||
# 행동 일치 여부
|
||||
if behavior_action == learned_action:
|
||||
agreements += 1
|
||||
|
||||
# Q값 차이 계산
|
||||
behavior_q = self.get_q_value(state, behavior_action)
|
||||
learned_q = self.get_q_value(state, learned_action)
|
||||
q_value_differences.append(abs(learned_q - behavior_q))
|
||||
|
||||
total_comparisons += 1
|
||||
|
||||
if total_comparisons == 0:
|
||||
return {"policy_divergence": 0.0, "action_agreement": 0.0}
|
||||
|
||||
action_agreement = agreements / total_comparisons
|
||||
avg_q_difference = np.mean(q_value_differences)
|
||||
|
||||
return {
|
||||
"policy_divergence": avg_q_difference,
|
||||
"action_agreement": action_agreement,
|
||||
"total_comparisons": total_comparisons,
|
||||
"agreements": agreements
|
||||
}
|
||||
|
||||
def get_training_statistics(self) -> Dict[str, Any]:
|
||||
"""학습 통계 반환"""
|
||||
if not self.training_history:
|
||||
return {
|
||||
"total_batches": 0,
|
||||
"avg_bellman_loss": 0.0,
|
||||
"avg_cql_penalty": 0.0,
|
||||
"convergence_trend": "unknown"
|
||||
}
|
||||
|
||||
recent_history = self.training_history[-10:] # 최근 10개 배치
|
||||
|
||||
# 수렴 경향 분석
|
||||
if len(self.training_history) >= 5:
|
||||
recent_losses = [h['avg_bellman_loss'] for h in self.training_history[-5:]]
|
||||
if all(recent_losses[i] >= recent_losses[i+1] for i in range(len(recent_losses)-1)):
|
||||
convergence_trend = "improving"
|
||||
elif all(recent_losses[i] <= recent_losses[i+1] for i in range(len(recent_losses)-1)):
|
||||
convergence_trend = "deteriorating"
|
||||
else:
|
||||
convergence_trend = "fluctuating"
|
||||
else:
|
||||
convergence_trend = "insufficient_data"
|
||||
|
||||
return {
|
||||
"total_batches": self.batch_count,
|
||||
"avg_bellman_loss": np.mean([h['avg_bellman_loss'] for h in recent_history]),
|
||||
"avg_cql_penalty": np.mean([h['avg_cql_penalty'] for h in recent_history]),
|
||||
"convergence_trend": convergence_trend,
|
||||
"q_network_stats": {
|
||||
"min": float(self.q_network.min().min()),
|
||||
"max": float(self.q_network.max().max()),
|
||||
"mean": float(self.q_network.mean().mean()),
|
||||
"std": float(self.q_network.std().mean())
|
||||
}
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
"""학습 상태 초기화"""
|
||||
# Q-네트워크 재초기화
|
||||
action_names = [action.value for action in self.actions]
|
||||
self.q_network = pd.DataFrame(
|
||||
np.random.uniform(0, 0.1, (len(self.states), len(self.actions))),
|
||||
index=self.states,
|
||||
columns=action_names
|
||||
)
|
||||
|
||||
# 기록 초기화
|
||||
self.training_history.clear()
|
||||
self.batch_count = 0
|
||||
|
||||
def set_hyperparameters(
|
||||
self,
|
||||
alpha: Optional[float] = None,
|
||||
gamma: Optional[float] = None,
|
||||
learning_rate: Optional[float] = None
|
||||
):
|
||||
"""하이퍼파라미터 설정"""
|
||||
if alpha is not None:
|
||||
self.alpha = alpha
|
||||
if gamma is not None:
|
||||
self.gamma = gamma
|
||||
if learning_rate is not None:
|
||||
self.learning_rate = learning_rate
|
||||
|
|
@ -0,0 +1,222 @@
|
|||
"""
|
||||
협상 환경 시뮬레이터 서비스
|
||||
"""
|
||||
import random
|
||||
import numpy as np
|
||||
from typing import Dict, Tuple, Optional
|
||||
from app.models.schemas import ScenarioType, PriceZoneType, CardType
|
||||
|
||||
|
||||
class NegotiationEnvironment:
|
||||
"""협상 환경 시뮬레이터"""
|
||||
|
||||
def __init__(self):
|
||||
# 문서 기준 가중치 설정
|
||||
self.scenario_weights = {
|
||||
ScenarioType.A: 1.0, # S_1 = A
|
||||
ScenarioType.D: 0.75, # S_2 = D
|
||||
ScenarioType.C: 0.5, # S_3 = C
|
||||
ScenarioType.B: 0.25 # S_4 = B
|
||||
}
|
||||
|
||||
self.price_zone_weights = {
|
||||
PriceZoneType.PZ1: 0.1, # P < A (가장 좋은 구간)
|
||||
PriceZoneType.PZ2: 0.5, # A < P < T (중간 구간)
|
||||
PriceZoneType.PZ3: 1.0 # T < P (나쁜 구간)
|
||||
}
|
||||
|
||||
# 카드별 협상 효과 (시뮬레이션용)
|
||||
self.card_effects = {
|
||||
CardType.C1: {"price_multiplier": 1.2, "success_rate": 0.3},
|
||||
CardType.C2: {"price_multiplier": 1.1, "success_rate": 0.5},
|
||||
CardType.C3: {"price_multiplier": 1.0, "success_rate": 0.7},
|
||||
CardType.C4: {"price_multiplier": 0.9, "success_rate": 0.8}
|
||||
}
|
||||
|
||||
# 시나리오별 협상 난이도
|
||||
self.scenario_difficulty = {
|
||||
ScenarioType.A: 1.3, # 가장 어려운 협상
|
||||
ScenarioType.B: 1.1, # 보통 난이도
|
||||
ScenarioType.C: 0.95, # 쉬운 협상
|
||||
ScenarioType.D: 0.85 # 가장 쉬운 협상
|
||||
}
|
||||
|
||||
def calculate_reward(
|
||||
self,
|
||||
scenario: ScenarioType,
|
||||
price_zone: PriceZoneType,
|
||||
anchor_price: float,
|
||||
proposed_price: float,
|
||||
is_end: bool
|
||||
) -> Tuple[float, float]:
|
||||
"""
|
||||
보상함수 계산: R(s,a) = W × (A/P) + (1-W) × End
|
||||
|
||||
Args:
|
||||
scenario: 시나리오 타입
|
||||
price_zone: 가격 구간
|
||||
anchor_price: 목표가 (A)
|
||||
proposed_price: 제안가 (P)
|
||||
is_end: 협상 종료 여부
|
||||
|
||||
Returns:
|
||||
(reward, weight): 보상값과 가중치
|
||||
"""
|
||||
s_n = self.scenario_weights[scenario]
|
||||
pz_n = self.price_zone_weights[price_zone]
|
||||
|
||||
# 가중치 계산: W = (S_n + PZ_n) / 2
|
||||
w = (s_n + pz_n) / 2
|
||||
|
||||
# 가격 비율 계산 (0으로 나누기 방지)
|
||||
if proposed_price == 0:
|
||||
price_ratio = float('inf')
|
||||
else:
|
||||
price_ratio = anchor_price / proposed_price
|
||||
|
||||
# 보상 계산
|
||||
reward = w * price_ratio + (1 - w) * (1 if is_end else 0)
|
||||
|
||||
return reward, w
|
||||
|
||||
def get_price_zone(
|
||||
self,
|
||||
price: float,
|
||||
anchor_price: float,
|
||||
threshold_multiplier: float = 1.2
|
||||
) -> PriceZoneType:
|
||||
"""
|
||||
가격에 따른 구간 결정
|
||||
|
||||
Args:
|
||||
price: 제안 가격
|
||||
anchor_price: 목표가
|
||||
threshold_multiplier: 임계값 배수
|
||||
|
||||
Returns:
|
||||
가격 구간
|
||||
"""
|
||||
threshold = anchor_price * threshold_multiplier
|
||||
|
||||
if price <= anchor_price:
|
||||
return PriceZoneType.PZ1 # 목표가 이하 (좋음)
|
||||
elif price <= threshold:
|
||||
return PriceZoneType.PZ2 # 목표가와 임계값 사이 (보통)
|
||||
else:
|
||||
return PriceZoneType.PZ3 # 임계값 이상 (나쁨)
|
||||
|
||||
def simulate_opponent_response(
|
||||
self,
|
||||
current_card: CardType,
|
||||
scenario: ScenarioType,
|
||||
anchor_price: float,
|
||||
step: int = 0
|
||||
) -> float:
|
||||
"""
|
||||
상대방 응답 시뮬레이션
|
||||
|
||||
Args:
|
||||
current_card: 현재 사용한 카드
|
||||
scenario: 현재 시나리오
|
||||
anchor_price: 목표가
|
||||
step: 현재 협상 단계
|
||||
|
||||
Returns:
|
||||
상대방 제안 가격
|
||||
"""
|
||||
# 카드 효과
|
||||
card_effect = self.card_effects[current_card]["price_multiplier"]
|
||||
|
||||
# 시나리오 난이도
|
||||
scenario_difficulty = self.scenario_difficulty[scenario]
|
||||
|
||||
# 협상 진행에 따른 양보 (단계가 늘어날수록 가격 하락)
|
||||
step_discount = 1.0 - (step * 0.05)
|
||||
step_discount = max(step_discount, 0.7) # 최소 30% 할인
|
||||
|
||||
# 기본 가격 계산
|
||||
base_multiplier = card_effect * scenario_difficulty * step_discount
|
||||
|
||||
# 랜덤 노이즈 추가 (현실적 변동성)
|
||||
noise = np.random.uniform(0.85, 1.15)
|
||||
|
||||
# 최종 제안 가격
|
||||
proposed_price = anchor_price * base_multiplier * noise
|
||||
|
||||
# 최소 가격 보장 (목표가의 70% 이상)
|
||||
min_price = anchor_price * 0.7
|
||||
proposed_price = max(proposed_price, min_price)
|
||||
|
||||
return round(proposed_price, 2)
|
||||
|
||||
def is_negotiation_successful(
|
||||
self,
|
||||
proposed_price: float,
|
||||
anchor_price: float,
|
||||
tolerance: float = 0.05
|
||||
) -> bool:
|
||||
"""
|
||||
협상 성공 여부 판단
|
||||
|
||||
Args:
|
||||
proposed_price: 제안 가격
|
||||
anchor_price: 목표가
|
||||
tolerance: 허용 오차 (5%)
|
||||
|
||||
Returns:
|
||||
협상 성공 여부
|
||||
"""
|
||||
success_threshold = anchor_price * (1 + tolerance)
|
||||
return proposed_price <= success_threshold
|
||||
|
||||
def get_all_states(self) -> list[str]:
|
||||
"""모든 가능한 상태 목록 반환"""
|
||||
states = ["C0S0P0"] # 초기 상태
|
||||
|
||||
for card in CardType:
|
||||
for scenario in ScenarioType:
|
||||
for price_zone in PriceZoneType:
|
||||
state_id = f"{card.value}{scenario.value}{price_zone.value}"
|
||||
states.append(state_id)
|
||||
|
||||
return states
|
||||
|
||||
def get_all_actions(self) -> list[CardType]:
|
||||
"""모든 가능한 행동 목록 반환"""
|
||||
return list(CardType)
|
||||
|
||||
def parse_state(self, state_id: str) -> Optional[Dict[str, str]]:
|
||||
"""
|
||||
상태 ID를 파싱하여 구성 요소 반환
|
||||
|
||||
Args:
|
||||
state_id: 상태 ID (예: "C1APZ1")
|
||||
|
||||
Returns:
|
||||
상태 구성 요소 딕셔너리 또는 None
|
||||
"""
|
||||
if state_id == "C0S0P0":
|
||||
return {"card": "C0", "scenario": "S0", "price_zone": "P0"}
|
||||
|
||||
if len(state_id) != 6: # 예: C1APZ1 (6글자)
|
||||
return None
|
||||
|
||||
try:
|
||||
card = state_id[:2] # C1
|
||||
scenario = state_id[2] # A
|
||||
price_zone = state_id[3:] # PZ1
|
||||
|
||||
# 유효성 검사
|
||||
if (card in [c.value for c in CardType] and
|
||||
scenario in [s.value for s in ScenarioType] and
|
||||
price_zone in [pz.value for pz in PriceZoneType]):
|
||||
|
||||
return {
|
||||
"card": card,
|
||||
"scenario": scenario,
|
||||
"price_zone": price_zone
|
||||
}
|
||||
except:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
|
@ -0,0 +1,293 @@
|
|||
"""
|
||||
Q-Table 학습 엔진 서비스
|
||||
"""
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import random
|
||||
import time
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from collections import deque
|
||||
from app.models.schemas import CardType, ExperienceData
|
||||
|
||||
|
||||
class ExperienceBuffer:
|
||||
"""경험 데이터 저장 및 관리"""
|
||||
|
||||
def __init__(self, max_size: int = 10000):
|
||||
self.buffer = deque(maxlen=max_size)
|
||||
self.max_size = max_size
|
||||
|
||||
def add_experience(
|
||||
self,
|
||||
state: str,
|
||||
action: CardType,
|
||||
reward: float,
|
||||
next_state: str,
|
||||
done: bool,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""경험 데이터 추가"""
|
||||
experience = ExperienceData(
|
||||
state=state,
|
||||
action=action,
|
||||
reward=reward,
|
||||
next_state=next_state,
|
||||
done=done,
|
||||
timestamp=time.time(),
|
||||
metadata=metadata or {}
|
||||
)
|
||||
self.buffer.append(experience)
|
||||
|
||||
def get_experiences(self) -> List[ExperienceData]:
|
||||
"""모든 경험 데이터 반환"""
|
||||
return list(self.buffer)
|
||||
|
||||
def get_dataframe(self) -> pd.DataFrame:
|
||||
"""경험 데이터를 DataFrame으로 반환"""
|
||||
if not self.buffer:
|
||||
return pd.DataFrame()
|
||||
|
||||
data = []
|
||||
for exp in self.buffer:
|
||||
data.append({
|
||||
'state': exp.state,
|
||||
'action': exp.action.value,
|
||||
'reward': exp.reward,
|
||||
'next_state': exp.next_state,
|
||||
'done': exp.done,
|
||||
'timestamp': exp.timestamp,
|
||||
**exp.metadata
|
||||
})
|
||||
return pd.DataFrame(data)
|
||||
|
||||
def sample_batch(self, batch_size: int = 32) -> List[ExperienceData]:
|
||||
"""배치 샘플링"""
|
||||
if len(self.buffer) <= batch_size:
|
||||
return list(self.buffer)
|
||||
|
||||
indices = np.random.choice(len(self.buffer), batch_size, replace=False)
|
||||
return [self.buffer[i] for i in indices]
|
||||
|
||||
def clear(self):
|
||||
"""버퍼 초기화"""
|
||||
self.buffer.clear()
|
||||
|
||||
def size(self) -> int:
|
||||
"""버퍼 크기 반환"""
|
||||
return len(self.buffer)
|
||||
|
||||
|
||||
class QTableLearner:
|
||||
"""Q-Table 학습 엔진"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
states: List[str],
|
||||
actions: List[CardType],
|
||||
learning_rate: float = 0.1,
|
||||
discount_factor: float = 0.9,
|
||||
epsilon: float = 0.1
|
||||
):
|
||||
self.states = states
|
||||
self.actions = actions
|
||||
self.learning_rate = learning_rate
|
||||
self.discount_factor = discount_factor
|
||||
self.epsilon = epsilon
|
||||
|
||||
# Q-Table 초기화 (모든 값 0)
|
||||
action_names = [action.value for action in actions]
|
||||
self.q_table = pd.DataFrame(
|
||||
0.0,
|
||||
index=states,
|
||||
columns=action_names
|
||||
)
|
||||
|
||||
# 학습 기록
|
||||
self.learning_history = []
|
||||
self.update_count = 0
|
||||
self.total_reward = 0.0
|
||||
|
||||
def get_q_value(self, state: str, action: CardType) -> float:
|
||||
"""Q값 조회"""
|
||||
if state in self.q_table.index:
|
||||
return self.q_table.loc[state, action.value]
|
||||
return 0.0
|
||||
|
||||
def set_q_value(self, state: str, action: CardType, value: float):
|
||||
"""Q값 설정"""
|
||||
if state in self.q_table.index:
|
||||
self.q_table.loc[state, action.value] = value
|
||||
|
||||
def get_optimal_action(self, state: str) -> CardType:
|
||||
"""현재 상태에서 최적 행동 선택 (그리디)"""
|
||||
if state not in self.q_table.index:
|
||||
return random.choice(self.actions)
|
||||
|
||||
q_values = self.q_table.loc[state]
|
||||
best_action_name = q_values.idxmax()
|
||||
|
||||
# CardType으로 변환
|
||||
for action in self.actions:
|
||||
if action.value == best_action_name:
|
||||
return action
|
||||
|
||||
return random.choice(self.actions)
|
||||
|
||||
def select_action(self, state: str, use_epsilon_greedy: bool = True) -> Tuple[CardType, bool]:
|
||||
"""
|
||||
행동 선택 (엡실론 그리디 또는 그리디)
|
||||
|
||||
Returns:
|
||||
(action, is_exploration): 선택된 행동과 탐험 여부
|
||||
"""
|
||||
if use_epsilon_greedy and random.random() < self.epsilon:
|
||||
# 탐험: 무작위 행동
|
||||
return random.choice(self.actions), True
|
||||
else:
|
||||
# 활용: 최적 행동
|
||||
return self.get_optimal_action(state), False
|
||||
|
||||
def update_q_value(
|
||||
self,
|
||||
state: str,
|
||||
action: CardType,
|
||||
reward: float,
|
||||
next_state: str,
|
||||
done: bool
|
||||
) -> float:
|
||||
"""
|
||||
Q-Learning 업데이트 규칙 적용
|
||||
Q(s,a) ← Q(s,a) + α[r + γ max Q(s',a') - Q(s,a)]
|
||||
|
||||
Returns:
|
||||
TD 오차
|
||||
"""
|
||||
if state not in self.q_table.index:
|
||||
return 0.0
|
||||
|
||||
current_q = self.get_q_value(state, action)
|
||||
|
||||
if done or next_state not in self.q_table.index:
|
||||
target = reward
|
||||
else:
|
||||
max_next_q = self.q_table.loc[next_state].max()
|
||||
target = reward + self.discount_factor * max_next_q
|
||||
|
||||
# TD 오차 계산
|
||||
td_error = target - current_q
|
||||
|
||||
# Q값 업데이트
|
||||
new_q = current_q + self.learning_rate * td_error
|
||||
self.set_q_value(state, action, new_q)
|
||||
|
||||
# 학습 기록 저장
|
||||
self.learning_history.append({
|
||||
'update': self.update_count,
|
||||
'state': state,
|
||||
'action': action.value,
|
||||
'old_q': current_q,
|
||||
'new_q': new_q,
|
||||
'reward': reward,
|
||||
'target': target,
|
||||
'td_error': abs(td_error),
|
||||
'timestamp': time.time()
|
||||
})
|
||||
|
||||
self.update_count += 1
|
||||
self.total_reward += reward
|
||||
|
||||
return td_error
|
||||
|
||||
def batch_update(self, experiences: List[ExperienceData]) -> Dict[str, float]:
|
||||
"""배치 업데이트"""
|
||||
if not experiences:
|
||||
return {"avg_td_error": 0.0, "updates": 0}
|
||||
|
||||
td_errors = []
|
||||
updates = 0
|
||||
|
||||
for exp in experiences:
|
||||
td_error = self.update_q_value(
|
||||
exp.state,
|
||||
exp.action,
|
||||
exp.reward,
|
||||
exp.next_state,
|
||||
exp.done
|
||||
)
|
||||
if abs(td_error) > 1e-8: # 의미있는 업데이트만 카운트
|
||||
td_errors.append(abs(td_error))
|
||||
updates += 1
|
||||
|
||||
return {
|
||||
"avg_td_error": np.mean(td_errors) if td_errors else 0.0,
|
||||
"updates": updates,
|
||||
"total_experiences": len(experiences)
|
||||
}
|
||||
|
||||
def get_q_table_copy(self) -> pd.DataFrame:
|
||||
"""Q-Table 복사본 반환"""
|
||||
return self.q_table.copy()
|
||||
|
||||
def get_state_q_values(self, state: str) -> Dict[str, float]:
|
||||
"""특정 상태의 Q값들 반환"""
|
||||
if state not in self.q_table.index:
|
||||
return {action.value: 0.0 for action in self.actions}
|
||||
|
||||
return self.q_table.loc[state].to_dict()
|
||||
|
||||
def get_learning_statistics(self) -> Dict[str, Any]:
|
||||
"""학습 통계 반환"""
|
||||
if not self.learning_history:
|
||||
return {
|
||||
"total_updates": 0,
|
||||
"avg_td_error": 0.0,
|
||||
"avg_reward": 0.0,
|
||||
"q_table_sparsity": 1.0
|
||||
}
|
||||
|
||||
recent_history = self.learning_history[-100:] # 최근 100개
|
||||
|
||||
# Q-Table 희소성 계산 (0이 아닌 값의 비율)
|
||||
non_zero_values = (self.q_table != 0).sum().sum()
|
||||
total_values = self.q_table.size
|
||||
sparsity = 1.0 - (non_zero_values / total_values)
|
||||
|
||||
return {
|
||||
"total_updates": self.update_count,
|
||||
"avg_td_error": np.mean([h['td_error'] for h in recent_history]),
|
||||
"avg_reward": np.mean([h['reward'] for h in recent_history]),
|
||||
"q_table_sparsity": sparsity,
|
||||
"q_value_range": {
|
||||
"min": float(self.q_table.min().min()),
|
||||
"max": float(self.q_table.max().max()),
|
||||
"mean": float(self.q_table.mean().mean())
|
||||
}
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
"""학습 상태 초기화"""
|
||||
# Q-Table을 0으로 초기화
|
||||
self.q_table = pd.DataFrame(
|
||||
0.0,
|
||||
index=self.states,
|
||||
columns=[action.value for action in self.actions]
|
||||
)
|
||||
|
||||
# 기록 초기화
|
||||
self.learning_history.clear()
|
||||
self.update_count = 0
|
||||
self.total_reward = 0.0
|
||||
|
||||
def set_hyperparameters(
|
||||
self,
|
||||
learning_rate: Optional[float] = None,
|
||||
discount_factor: Optional[float] = None,
|
||||
epsilon: Optional[float] = None
|
||||
):
|
||||
"""하이퍼파라미터 설정"""
|
||||
if learning_rate is not None:
|
||||
self.learning_rate = learning_rate
|
||||
if discount_factor is not None:
|
||||
self.discount_factor = discount_factor
|
||||
if epsilon is not None:
|
||||
self.epsilon = epsilon
|
||||
|
|
@ -0,0 +1,35 @@
|
|||
version: '3.8'
|
||||
|
||||
services:
|
||||
api:
|
||||
build: .
|
||||
ports:
|
||||
- "8000:8000"
|
||||
environment:
|
||||
- API_HOST=0.0.0.0
|
||||
- API_PORT=8000
|
||||
command: python run_api.py
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8000/api/v1/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
|
||||
frontend:
|
||||
build: .
|
||||
ports:
|
||||
- "8501:8501"
|
||||
environment:
|
||||
- FRONTEND_HOST=0.0.0.0
|
||||
- FRONTEND_PORT=8501
|
||||
command: python run_frontend.py
|
||||
depends_on:
|
||||
- api
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8501"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
|
||||
volumes:
|
||||
data:
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Poetry
|
||||
poetry.lock
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Logs
|
||||
*.log
|
||||
logs/
|
||||
|
||||
# Data
|
||||
data/
|
||||
*.sqlite
|
||||
*.db
|
||||
|
||||
# Tests
|
||||
.pytest_cache/
|
||||
.coverage
|
||||
htmlcov/
|
||||
|
||||
# Git
|
||||
.git/
|
||||
.gitignore
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
# 개발 환경 설정
|
||||
API_HOST=localhost
|
||||
API_PORT=8000
|
||||
FRONTEND_HOST=localhost
|
||||
FRONTEND_PORT=8501
|
||||
|
||||
# 강화학습 하이퍼파라미터
|
||||
DEFAULT_LEARNING_RATE=0.1
|
||||
DEFAULT_DISCOUNT_FACTOR=0.9
|
||||
DEFAULT_EPSILON=0.1
|
||||
|
||||
# 협상 환경 설정
|
||||
DEFAULT_ANCHOR_PRICE=100
|
||||
MAX_EPISODES=1000
|
||||
MAX_STEPS_PER_EPISODE=10
|
||||
|
||||
# 로깅 설정
|
||||
LOG_LEVEL=INFO
|
||||
LOG_FILE=app.log
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
# 개발 환경 설정
|
||||
API_HOST=localhost
|
||||
API_PORT=8000
|
||||
FRONTEND_HOST=localhost
|
||||
FRONTEND_PORT=8501
|
||||
|
||||
# 강화학습 하이퍼파라미터
|
||||
DEFAULT_LEARNING_RATE=0.1
|
||||
DEFAULT_DISCOUNT_FACTOR=0.9
|
||||
DEFAULT_EPSILON=0.1
|
||||
|
||||
# 협상 환경 설정
|
||||
DEFAULT_ANCHOR_PRICE=100
|
||||
MAX_EPISODES=1000
|
||||
MAX_STEPS_PER_EPISODE=10
|
||||
|
||||
# 로깅 설정
|
||||
LOG_LEVEL=INFO
|
||||
LOG_FILE=app.log
|
||||
|
|
@ -0,0 +1,905 @@
|
|||
"""
|
||||
Streamlit 기반 Q-Table 협상 전략 데모 프론트엔드
|
||||
"""
|
||||
import streamlit as st
|
||||
import requests
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
import plotly.graph_objects as go
|
||||
from plotly.subplots import make_subplots
|
||||
import time
|
||||
import json
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
# 페이지 설정
|
||||
st.set_page_config(
|
||||
page_title="Q-Table 협상 전략 데모",
|
||||
page_icon="🎯",
|
||||
layout="wide",
|
||||
initial_sidebar_state="expanded"
|
||||
)
|
||||
|
||||
# API 기본 URL
|
||||
API_BASE_URL = "http://localhost:8000/api/v1"
|
||||
|
||||
# 세션 상태 초기화
|
||||
if 'current_state' not in st.session_state:
|
||||
st.session_state.current_state = "C0S0P0"
|
||||
if 'anchor_price' not in st.session_state:
|
||||
st.session_state.anchor_price = 100
|
||||
|
||||
|
||||
class APIClient:
|
||||
"""API 클라이언트"""
|
||||
|
||||
@staticmethod
|
||||
def get(endpoint: str) -> Optional[Dict[str, Any]]:
|
||||
"""GET 요청"""
|
||||
try:
|
||||
response = requests.get(f"{API_BASE_URL}{endpoint}")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
st.error(f"API 요청 오류: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def post(endpoint: str, data: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""POST 요청"""
|
||||
try:
|
||||
response = requests.post(f"{API_BASE_URL}{endpoint}", json=data)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except requests.exceptions.RequestException as e:
|
||||
st.error(f"API 요청 오류: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def display_header():
|
||||
"""헤더 제목 표시"""
|
||||
st.title("🎯 Q-Table 기반 협상 전략 데모")
|
||||
st.markdown("""
|
||||
### 강화학습으로 배우는 협상 전략의 진화
|
||||
|
||||
이 데모는 **콜드 스타트 문제**부터 **학습된 정책**까지 Q-Learning의 전체 여정을 보여줍니다.
|
||||
|
||||
**핵심 보상함수:** `R(s,a) = W × (A/P) + (1-W) × End`
|
||||
""")
|
||||
st.markdown("---")
|
||||
|
||||
|
||||
def display_sidebar():
|
||||
"""사이드바 설정"""
|
||||
st.sidebar.header("⚙️ 데모 설정")
|
||||
|
||||
# 시스템 상태 조회
|
||||
status_response = APIClient.get("/status")
|
||||
if status_response and status_response.get("success") is not False:
|
||||
status = status_response
|
||||
st.sidebar.metric("총 경험 데이터", status.get("total_experiences", 0))
|
||||
st.sidebar.metric("Q-Table 업데이트", status.get("q_table_updates", 0))
|
||||
st.sidebar.metric("고유 상태", status.get("unique_states", 0))
|
||||
st.sidebar.metric("평균 보상", f"{status.get('average_reward', 0):.3f}")
|
||||
st.sidebar.metric("성공률", f"{status.get('success_rate', 0)*100:.1f}%")
|
||||
|
||||
st.sidebar.markdown("---")
|
||||
|
||||
# 글로벌 설정
|
||||
st.session_state.anchor_price = st.sidebar.number_input(
|
||||
"목표가 (A)",
|
||||
value=st.session_state.anchor_price,
|
||||
min_value=50,
|
||||
max_value=300
|
||||
)
|
||||
|
||||
# 시스템 초기화
|
||||
if st.sidebar.button("🔄 시스템 초기화", type="secondary"):
|
||||
with st.spinner("시스템 초기화 중..."):
|
||||
reset_response = APIClient.post("/reset", {})
|
||||
if reset_response and reset_response.get("success"):
|
||||
st.sidebar.success("시스템이 초기화되었습니다!")
|
||||
st.rerun()
|
||||
else:
|
||||
st.sidebar.error("초기화 실패")
|
||||
|
||||
return status_response
|
||||
|
||||
|
||||
def tab_cold_start():
|
||||
"""콜드 스타트 탭"""
|
||||
st.header("🏁 콜드 스타트 문제")
|
||||
|
||||
st.markdown("""
|
||||
### 강화학습의 첫 번째 난관
|
||||
|
||||
새로운 강화학습 에이전트가 직면하는 가장 큰 문제는 **"아무것도 모른다"**는 것입니다.
|
||||
모든 Q값이 0으로 초기화되어 있어, 어떤 행동이 좋은지 전혀 알 수 없습니다.
|
||||
""")
|
||||
|
||||
# Q-Table 현재 상태 조회
|
||||
qtable_response = APIClient.get("/qtable")
|
||||
if qtable_response and qtable_response.get("success"):
|
||||
qtable_data = qtable_response["data"]
|
||||
q_table_dict = qtable_data["q_table"]
|
||||
|
||||
# DataFrame으로 변환
|
||||
q_table_df = pd.DataFrame(q_table_dict)
|
||||
|
||||
st.subheader("📋 현재 Q-Table 상태")
|
||||
|
||||
# 통계 표시
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
with col1:
|
||||
non_zero_count = (q_table_df != 0).sum().sum()
|
||||
st.metric("비어있지 않은 Q값", non_zero_count)
|
||||
with col2:
|
||||
total_count = q_table_df.size
|
||||
st.metric("전체 Q값", total_count)
|
||||
with col3:
|
||||
sparsity = (1 - non_zero_count / total_count) * 100
|
||||
st.metric("희소성", f"{sparsity:.1f}%")
|
||||
with col4:
|
||||
st.metric("업데이트 횟수", qtable_data["update_count"])
|
||||
|
||||
# Q-Table 표시 (상위 20개 상태만)
|
||||
display_rows = min(20, len(q_table_df))
|
||||
st.dataframe(
|
||||
q_table_df.head(display_rows).style.format("{:.3f}").highlight_max(axis=1),
|
||||
use_container_width=True
|
||||
)
|
||||
|
||||
if len(q_table_df) > 20:
|
||||
st.info(f"전체 {len(q_table_df)}개 상태 중 상위 20개만 표시됩니다.")
|
||||
|
||||
# 문제점과 해결방법
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
st.subheader("❌ 핵심 문제점")
|
||||
st.markdown("""
|
||||
- **무지상태**: 모든 Q값이 동일 (보통 0)
|
||||
- **행동 선택 불가**: 어떤 행동이 좋은지 모름
|
||||
- **무작위 탐험**: 비효율적인 학습 초기 단계
|
||||
- **데이터 부족**: 학습할 경험이 없음
|
||||
""")
|
||||
|
||||
with col2:
|
||||
st.subheader("✅ 해결 방법")
|
||||
st.markdown("""
|
||||
- **탐험 전략**: Epsilon-greedy로 무작위 탐험
|
||||
- **경험 수집**: (상태, 행동, 보상, 다음상태) 튜플 저장
|
||||
- **점진적 학습**: 수집된 경험으로 Q값 업데이트
|
||||
- **정책 개선**: 학습을 통한 점진적 성능 향상
|
||||
""")
|
||||
|
||||
# Q-Learning 공식 설명
|
||||
st.subheader("🧮 Q-Learning 업데이트 공식")
|
||||
st.latex(r"Q(s,a) \leftarrow Q(s,a) + lpha [r + \gamma \max_{a'} Q(s',a') - Q(s,a)]")
|
||||
|
||||
st.markdown("""
|
||||
**공식 설명:**
|
||||
- **Q(s,a)**: 상태 s에서 행동 a의 Q값
|
||||
- **α (알파)**: 학습률 (0 < α ≤ 1)
|
||||
- **r**: 즉시 보상
|
||||
- **γ (감마)**: 할인율 (0 ≤ γ < 1)
|
||||
- **max Q(s',a')**: 다음 상태에서의 최대 Q값
|
||||
""")
|
||||
|
||||
|
||||
def tab_data_collection():
|
||||
"""데이터 수집 탭"""
|
||||
st.header("📊 경험 데이터 수집")
|
||||
|
||||
st.markdown("""
|
||||
### 학습의 연료: 경험 데이터
|
||||
|
||||
강화학습 에이전트는 환경과 상호작용하면서 **경험 튜플**을 수집합니다.
|
||||
각 경험은 `(상태, 행동, 보상, 다음상태, 종료여부)` 형태로 저장됩니다.
|
||||
""")
|
||||
|
||||
# 설정 섹션
|
||||
col1, col2 = st.columns([1, 2])
|
||||
|
||||
with col1:
|
||||
st.subheader("⚙️ 에피소드 생성 설정")
|
||||
|
||||
num_episodes = st.slider("생성할 에피소드 수", 1, 50, 10)
|
||||
max_steps = st.slider("에피소드당 최대 스텝", 3, 15, 8)
|
||||
exploration_rate = st.slider("탐험율 (Epsilon)", 0.0, 1.0, 0.4, 0.1)
|
||||
|
||||
st.markdown(f"""
|
||||
**현재 설정:**
|
||||
- 목표가: {st.session_state.anchor_price}
|
||||
- 탐험율: {exploration_rate*100:.0f}%
|
||||
- 총 예상 경험: ~{num_episodes * max_steps}개
|
||||
""")
|
||||
|
||||
if st.button("🎲 자동 에피소드 생성", type="primary"):
|
||||
with st.spinner("에피소드 생성 중..."):
|
||||
request_data = {
|
||||
"num_episodes": num_episodes,
|
||||
"max_steps": max_steps,
|
||||
"anchor_price": st.session_state.anchor_price,
|
||||
"exploration_rate": exploration_rate
|
||||
}
|
||||
|
||||
response = APIClient.post("/episodes/generate", request_data)
|
||||
if response and response.get("success"):
|
||||
result = response["data"]
|
||||
st.success(f"✅ {result['new_experiences']}개의 새로운 경험 데이터 생성!")
|
||||
|
||||
# 에피소드 결과 표시
|
||||
episode_results = result["episode_results"]
|
||||
success_count = sum(1 for ep in episode_results if ep["success"])
|
||||
|
||||
col_a, col_b, col_c = st.columns(3)
|
||||
with col_a:
|
||||
st.metric("생성된 에피소드", result["episodes_generated"])
|
||||
with col_b:
|
||||
st.metric("성공한 협상", success_count)
|
||||
with col_c:
|
||||
success_rate = (success_count / len(episode_results)) * 100
|
||||
st.metric("성공률", f"{success_rate:.1f}%")
|
||||
|
||||
time.sleep(1) # UI 업데이트를 위한 잠시 대기
|
||||
st.rerun()
|
||||
else:
|
||||
st.error("에피소드 생성 실패")
|
||||
|
||||
with col2:
|
||||
st.subheader("📈 수집된 데이터 현황")
|
||||
|
||||
# 경험 데이터 조회
|
||||
exp_response = APIClient.get("/experiences")
|
||||
if exp_response and exp_response.get("success"):
|
||||
exp_data = exp_response["data"]
|
||||
stats = exp_data["statistics"]
|
||||
recent_data = exp_data["recent_data"]
|
||||
|
||||
# 통계 표시
|
||||
col_a, col_b, col_c, col_d = st.columns(4)
|
||||
with col_a:
|
||||
st.metric("총 경험 수", stats["total_count"])
|
||||
with col_b:
|
||||
st.metric("평균 보상", f"{stats['avg_reward']:.3f}")
|
||||
with col_c:
|
||||
st.metric("성공률", f"{stats['success_rate']*100:.1f}%")
|
||||
with col_d:
|
||||
st.metric("고유 상태", stats["unique_states"])
|
||||
|
||||
# 최근 경험 데이터 표시
|
||||
if recent_data:
|
||||
st.subheader("🔍 최근 경험 데이터")
|
||||
recent_df = pd.DataFrame(recent_data)
|
||||
|
||||
# 필요한 컬럼만 선택
|
||||
display_columns = ['state', 'action', 'reward', 'next_state', 'done']
|
||||
available_columns = [col for col in display_columns if col in recent_df.columns]
|
||||
|
||||
if available_columns:
|
||||
display_df = recent_df[available_columns].tail(10)
|
||||
st.dataframe(
|
||||
display_df.style.format({'reward': '{:.3f}'}),
|
||||
use_container_width=True
|
||||
)
|
||||
|
||||
# 데이터 분포 시각화
|
||||
if len(recent_df) > 5:
|
||||
st.subheader("📊 데이터 분포 분석")
|
||||
|
||||
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
|
||||
|
||||
# 보상 분포
|
||||
axes[0,0].hist(recent_df['reward'], bins=15, alpha=0.7, color='skyblue', edgecolor='black')
|
||||
axes[0,0].set_title('보상 분포')
|
||||
axes[0,0].set_xlabel('보상값')
|
||||
axes[0,0].set_ylabel('빈도')
|
||||
|
||||
# 행동 분포
|
||||
if 'action' in recent_df.columns:
|
||||
action_counts = recent_df['action'].value_counts()
|
||||
axes[0,1].bar(action_counts.index, action_counts.values, color='lightgreen', edgecolor='black')
|
||||
axes[0,1].set_title('행동 선택 빈도')
|
||||
axes[0,1].set_xlabel('행동')
|
||||
axes[0,1].set_ylabel('빈도')
|
||||
|
||||
# 상태 분포 (상위 10개)
|
||||
if 'state' in recent_df.columns:
|
||||
state_counts = recent_df['state'].value_counts().head(10)
|
||||
axes[1,0].bar(range(len(state_counts)), state_counts.values, color='orange', edgecolor='black')
|
||||
axes[1,0].set_title('상위 상태 빈도')
|
||||
axes[1,0].set_xlabel('상태 순위')
|
||||
axes[1,0].set_ylabel('빈도')
|
||||
axes[1,0].set_xticks(range(len(state_counts)))
|
||||
axes[1,0].set_xticklabels([f"{i+1}" for i in range(len(state_counts))])
|
||||
|
||||
# 성공/실패 분포
|
||||
if 'done' in recent_df.columns:
|
||||
done_counts = recent_df['done'].value_counts()
|
||||
labels = ['진행중' if not k else '완료' for k in done_counts.index]
|
||||
axes[1,1].pie(done_counts.values, labels=labels, autopct='%1.1f%%', colors=['lightcoral', 'lightblue'])
|
||||
axes[1,1].set_title('협상 완료 비율')
|
||||
|
||||
plt.tight_layout()
|
||||
st.pyplot(fig)
|
||||
else:
|
||||
st.info("아직 수집된 데이터가 없습니다. 왼쪽에서 에피소드를 생성해보세요!")
|
||||
else:
|
||||
st.warning("경험 데이터를 불러올 수 없습니다.")
|
||||
|
||||
# 경험 데이터 구조 설명
|
||||
st.subheader("📋 경험 데이터 구조")
|
||||
st.markdown("""
|
||||
각 경험 튜플은 다음 정보를 포함합니다:
|
||||
|
||||
| 항목 | 설명 | 예시 |
|
||||
|------|------|------|
|
||||
| **상태 (State)** | 현재 협상 상황 | "C1APZ2" (카드1, 시나리오A, 가격구간2) |
|
||||
| **행동 (Action)** | 선택한 협상 카드 | "C3" |
|
||||
| **보상 (Reward)** | 행동에 대한 평가 | 0.85 |
|
||||
| **다음상태 (Next State)** | 행동 후 새로운 상황 | "C3APZ1" |
|
||||
| **종료 (Done)** | 협상 완료 여부 | true/false |
|
||||
""")
|
||||
|
||||
|
||||
# 더 많은 탭 함수들을 다음 파트에서 계속...
|
||||
|
||||
|
||||
def tab_q_learning():
|
||||
"""Q-Learning 탭"""
|
||||
st.header("🔄 Q-Learning 실시간 학습")
|
||||
|
||||
st.markdown("""
|
||||
### 경험으로부터 학습하기
|
||||
|
||||
수집된 경험 데이터를 사용하여 Q-Table을 업데이트합니다.
|
||||
각 경험에서 **TD(Temporal Difference) 오차**를 계산하고 Q값을 조정합니다.
|
||||
""")
|
||||
|
||||
col1, col2 = st.columns([1, 2])
|
||||
|
||||
with col1:
|
||||
st.subheader("⚙️ 학습 설정")
|
||||
|
||||
learning_rate = st.slider("학습률 (α)", 0.01, 0.5, 0.1, 0.01)
|
||||
discount_factor = st.slider("할인율 (γ)", 0.8, 0.99, 0.9, 0.01)
|
||||
batch_size = st.slider("배치 크기", 16, 256, 32, 16)
|
||||
|
||||
st.markdown(f"""
|
||||
**하이퍼파라미터:**
|
||||
- **학습률 (α)**: {learning_rate} - 새로운 정보의 반영 정도
|
||||
- **할인율 (γ)**: {discount_factor} - 미래 보상의 중요도
|
||||
- **배치 크기**: {batch_size} - 한 번에 학습할 경험 수
|
||||
""")
|
||||
|
||||
if st.button("🚀 Q-Learning 실행", type="primary"):
|
||||
with st.spinner("Q-Learning 업데이트 중..."):
|
||||
request_data = {
|
||||
"learning_rate": learning_rate,
|
||||
"discount_factor": discount_factor,
|
||||
"batch_size": batch_size
|
||||
}
|
||||
|
||||
response = APIClient.post("/learning/q-learning", request_data)
|
||||
if response and response.get("success"):
|
||||
result = response["data"]
|
||||
st.success(f"✅ {result['updates']}개 Q값 업데이트 완료!")
|
||||
|
||||
col_a, col_b = st.columns(2)
|
||||
with col_a:
|
||||
st.metric("배치 크기", result["batch_size"])
|
||||
with col_b:
|
||||
st.metric("평균 TD 오차", f"{result.get('avg_td_error', 0):.4f}")
|
||||
|
||||
time.sleep(1)
|
||||
st.rerun()
|
||||
else:
|
||||
st.error("Q-Learning 업데이트 실패")
|
||||
|
||||
with col2:
|
||||
st.subheader("📊 Q-Table 현황")
|
||||
|
||||
# Q-Table 데이터 조회
|
||||
qtable_response = APIClient.get("/qtable")
|
||||
if qtable_response and qtable_response.get("success"):
|
||||
qtable_data = qtable_response["data"]
|
||||
statistics = qtable_data["statistics"]
|
||||
|
||||
# 통계 표시
|
||||
col_a, col_b, col_c, col_d = st.columns(4)
|
||||
with col_a:
|
||||
st.metric("총 업데이트", statistics.get("total_updates", 0))
|
||||
with col_b:
|
||||
st.metric("평균 TD 오차", f"{statistics.get('avg_td_error', 0):.4f}")
|
||||
with col_c:
|
||||
st.metric("평균 보상", f"{statistics.get('avg_reward', 0):.3f}")
|
||||
with col_d:
|
||||
sparsity = statistics.get('q_table_sparsity', 1.0) * 100
|
||||
st.metric("Q-Table 희소성", f"{sparsity:.1f}%")
|
||||
|
||||
# Q값 범위 표시
|
||||
q_range = statistics.get('q_value_range', {})
|
||||
if q_range:
|
||||
st.subheader("📈 Q값 분포")
|
||||
col_a, col_b, col_c = st.columns(3)
|
||||
with col_a:
|
||||
st.metric("최솟값", f"{q_range.get('min', 0):.3f}")
|
||||
with col_b:
|
||||
st.metric("평균값", f"{q_range.get('mean', 0):.3f}")
|
||||
with col_c:
|
||||
st.metric("최댓값", f"{q_range.get('max', 0):.3f}")
|
||||
|
||||
# Q-Table 표시 (비어있지 않은 상태들만)
|
||||
q_table_dict = qtable_data["q_table"]
|
||||
q_table_df = pd.DataFrame(q_table_dict)
|
||||
|
||||
# 0이 아닌 값이 있는 행만 필터링
|
||||
non_zero_rows = (q_table_df != 0).any(axis=1)
|
||||
if non_zero_rows.any():
|
||||
st.subheader("🎯 학습된 Q값들")
|
||||
learned_qtable = q_table_df[non_zero_rows].head(15)
|
||||
st.dataframe(
|
||||
learned_qtable.style.format("{:.3f}").highlight_max(axis=1),
|
||||
use_container_width=True
|
||||
)
|
||||
|
||||
learned_count = non_zero_rows.sum()
|
||||
total_count = len(q_table_df)
|
||||
st.info(f"전체 {total_count}개 상태 중 {learned_count}개 상태가 학습되었습니다.")
|
||||
else:
|
||||
st.info("아직 학습된 Q값이 없습니다. 위에서 Q-Learning을 실행해보세요!")
|
||||
|
||||
# TD 오차 설명
|
||||
st.subheader("🧮 TD(Temporal Difference) 오차")
|
||||
st.markdown("""
|
||||
**TD 오차**는 현재 Q값과 목표값의 차이입니다:
|
||||
|
||||
`TD 오차 = [r + γ max Q(s',a')] - Q(s,a)`
|
||||
|
||||
- **양수**: 현재 Q값이 너무 낮음 → Q값 증가
|
||||
- **음수**: 현재 Q값이 너무 높음 → Q값 감소
|
||||
- **0에 가까움**: Q값이 적절함 → 학습 수렴
|
||||
""")
|
||||
|
||||
|
||||
def tab_fqi_cql():
|
||||
"""FQI+CQL 탭"""
|
||||
st.header("🧠 FQI + CQL 오프라인 학습")
|
||||
|
||||
st.markdown("""
|
||||
### 오프라인 강화학습의 핵심
|
||||
|
||||
**FQI (Fitted Q-Iteration)**와 **CQL (Conservative Q-Learning)**을 결합한
|
||||
오프라인 강화학습 방법입니다. 수집된 데이터만으로 안전하고 보수적인 정책을 학습합니다.
|
||||
""")
|
||||
|
||||
col1, col2 = st.columns([1, 2])
|
||||
|
||||
with col1:
|
||||
st.subheader("⚙️ FQI+CQL 설정")
|
||||
|
||||
alpha = st.slider("CQL 보수성 파라미터 (α)", 0.0, 3.0, 1.0, 0.1)
|
||||
gamma = st.slider("할인율 (γ)", 0.8, 0.99, 0.95, 0.01)
|
||||
batch_size = st.slider("배치 크기", 16, 256, 32, 16)
|
||||
num_iterations = st.slider("반복 횟수", 1, 50, 10, 1)
|
||||
|
||||
st.markdown(f"""
|
||||
**설정값:**
|
||||
- **α (Alpha)**: {alpha} - 보수성 강도
|
||||
- **γ (Gamma)**: {gamma} - 미래 보상 할인
|
||||
- **배치 크기**: {batch_size}
|
||||
- **반복 횟수**: {num_iterations}
|
||||
""")
|
||||
|
||||
st.markdown("""
|
||||
**CQL 특징:**
|
||||
- 🛡️ **보수적 추정**: 불확실한 행동의 Q값을 낮게 유지
|
||||
- 📊 **데이터 기반**: 수집된 경험만 활용
|
||||
- 🎯 **안전한 정책**: 분포 이동 문제 해결
|
||||
""")
|
||||
|
||||
if st.button("🚀 FQI+CQL 실행", type="primary"):
|
||||
with st.spinner("FQI+CQL 학습 중..."):
|
||||
request_data = {
|
||||
"alpha": alpha,
|
||||
"gamma": gamma,
|
||||
"batch_size": batch_size,
|
||||
"num_iterations": num_iterations
|
||||
}
|
||||
|
||||
response = APIClient.post("/learning/fqi-cql", request_data)
|
||||
if response and response.get("success"):
|
||||
result = response["data"]
|
||||
training_result = result["training_result"]
|
||||
policy_comparison = result["policy_comparison"]
|
||||
|
||||
st.success(f"✅ {training_result['total_iterations']}회 반복 학습 완료!")
|
||||
|
||||
# 학습 결과 표시
|
||||
col_a, col_b = st.columns(2)
|
||||
with col_a:
|
||||
st.metric("평균 벨만 손실", f"{training_result['avg_bellman_loss']:.4f}")
|
||||
with col_b:
|
||||
st.metric("평균 CQL 페널티", f"{training_result['avg_cql_penalty']:.4f}")
|
||||
|
||||
# 정책 비교
|
||||
st.metric("행동 정책과의 일치율", f"{policy_comparison['action_agreement']*100:.1f}%")
|
||||
|
||||
time.sleep(1)
|
||||
st.rerun()
|
||||
else:
|
||||
st.error("FQI+CQL 학습 실패")
|
||||
|
||||
with col2:
|
||||
st.subheader("📊 FQI+CQL 결과")
|
||||
|
||||
# FQI+CQL 결과 조회
|
||||
fqi_response = APIClient.get("/fqi-cql")
|
||||
if fqi_response and fqi_response.get("success"):
|
||||
fqi_data = fqi_response["data"]
|
||||
statistics = fqi_data["statistics"]
|
||||
|
||||
# 통계 표시
|
||||
col_a, col_b, col_c = st.columns(3)
|
||||
with col_a:
|
||||
st.metric("학습 배치", statistics.get("total_batches", 0))
|
||||
with col_b:
|
||||
st.metric("벨만 손실", f"{statistics.get('avg_bellman_loss', 0):.4f}")
|
||||
with col_c:
|
||||
st.metric("CQL 페널티", f"{statistics.get('avg_cql_penalty', 0):.4f}")
|
||||
|
||||
# 수렴 경향
|
||||
convergence = statistics.get("convergence_trend", "unknown")
|
||||
convergence_color = {
|
||||
"improving": "🟢",
|
||||
"deteriorating": "🔴",
|
||||
"fluctuating": "🟡",
|
||||
"insufficient_data": "⚪"
|
||||
}
|
||||
st.info(f"수렴 경향: {convergence_color.get(convergence, '❓')} {convergence}")
|
||||
|
||||
# Q-Network 통계
|
||||
q_stats = statistics.get("q_network_stats", {})
|
||||
if q_stats:
|
||||
st.subheader("📈 Q-Network 분포")
|
||||
col_a, col_b, col_c, col_d = st.columns(4)
|
||||
with col_a:
|
||||
st.metric("최솟값", f"{q_stats.get('min', 0):.3f}")
|
||||
with col_b:
|
||||
st.metric("평균값", f"{q_stats.get('mean', 0):.3f}")
|
||||
with col_c:
|
||||
st.metric("최댓값", f"{q_stats.get('max', 0):.3f}")
|
||||
with col_d:
|
||||
st.metric("표준편차", f"{q_stats.get('std', 0):.3f}")
|
||||
|
||||
# Q-Network 표시 (상위 15개 상태)
|
||||
q_network_dict = fqi_data["q_network"]
|
||||
q_network_df = pd.DataFrame(q_network_dict)
|
||||
|
||||
st.subheader("🎯 학습된 Q-Network")
|
||||
display_df = q_network_df.head(15)
|
||||
st.dataframe(
|
||||
display_df.style.format("{:.3f}").highlight_max(axis=1),
|
||||
use_container_width=True
|
||||
)
|
||||
else:
|
||||
st.info("FQI+CQL을 먼저 실행해주세요!")
|
||||
|
||||
# FQI+CQL 알고리즘 설명
|
||||
st.subheader("🔬 FQI + CQL 알고리즘")
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
st.markdown("""
|
||||
**FQI (Fitted Q-Iteration)**
|
||||
- 배치 기반 Q-Learning
|
||||
- 전체 데이터셋을 한 번에 활용
|
||||
- 함수 근사 (신경망) 사용
|
||||
- 안정적인 학습 과정
|
||||
""")
|
||||
|
||||
with col2:
|
||||
st.markdown("""
|
||||
**CQL (Conservative Q-Learning)**
|
||||
- 보수적 Q값 추정
|
||||
- Out-of-Distribution 행동 억제
|
||||
- 데이터에 없는 행동의 Q값 하향 조정
|
||||
- 안전한 정책 학습
|
||||
""")
|
||||
|
||||
|
||||
def tab_learned_policy():
|
||||
"""학습된 정책 탭"""
|
||||
st.header("🎯 학습된 정책 비교 및 활용")
|
||||
|
||||
st.markdown("""
|
||||
### 학습 완료: 정책의 진화
|
||||
|
||||
Q-Learning과 FQI+CQL로 학습된 정책을 비교하고,
|
||||
실제 협상 상황에서 어떤 행동을 추천하는지 확인해보세요.
|
||||
""")
|
||||
|
||||
# 상태 선택
|
||||
col1, col2 = st.columns([1, 2])
|
||||
|
||||
with col1:
|
||||
st.subheader("🎮 협상 시뮬레이션")
|
||||
|
||||
# 상태 구성 요소 선택
|
||||
current_card = st.selectbox("현재 카드", ["C1", "C2", "C3", "C4"])
|
||||
scenario = st.selectbox("시나리오", ["A", "B", "C", "D"])
|
||||
price_zone = st.selectbox("가격 구간", ["PZ1", "PZ2", "PZ3"])
|
||||
|
||||
# 상태 ID 생성
|
||||
selected_state = f"{current_card}{scenario}{price_zone}"
|
||||
st.session_state.current_state = selected_state
|
||||
|
||||
st.info(f"선택된 상태: **{selected_state}**")
|
||||
|
||||
# 상태 해석
|
||||
state_interpretation = {
|
||||
"A": "어려운 협상 (높은 가중치)",
|
||||
"B": "쉬운 협상 (낮은 가중치)",
|
||||
"C": "보통 협상 (중간 가중치)",
|
||||
"D": "매우 쉬운 협상 (낮은 가중치)"
|
||||
}
|
||||
|
||||
price_interpretation = {
|
||||
"PZ1": "목표가 이하 (좋은 구간)",
|
||||
"PZ2": "목표가~임계값 (보통 구간)",
|
||||
"PZ3": "임계값 이상 (나쁜 구간)"
|
||||
}
|
||||
|
||||
st.markdown(f"""
|
||||
**상태 해석:**
|
||||
- **카드**: {current_card}
|
||||
- **시나리오**: {scenario} - {state_interpretation.get(scenario, "알 수 없음")}
|
||||
- **가격구간**: {price_zone} - {price_interpretation.get(price_zone, "알 수 없음")}
|
||||
""")
|
||||
|
||||
# 행동 추천 요청
|
||||
use_epsilon = st.checkbox("엡실론 그리디 사용", value=False)
|
||||
epsilon = 0.1
|
||||
if use_epsilon:
|
||||
epsilon = st.slider("엡실론 값", 0.0, 0.5, 0.1, 0.05)
|
||||
|
||||
if st.button("🎯 행동 추천 받기", type="primary"):
|
||||
request_data = {
|
||||
"current_state": selected_state,
|
||||
"use_epsilon_greedy": use_epsilon,
|
||||
"epsilon": epsilon
|
||||
}
|
||||
|
||||
response = APIClient.post("/action/recommend", request_data)
|
||||
if response and response.get("success") is not False:
|
||||
# response가 직접 ActionRecommendationResponse 형태인 경우
|
||||
recommendation = response
|
||||
|
||||
st.success(f"🎯 추천 행동: **{recommendation.get('recommended_action', 'N/A')}**")
|
||||
|
||||
if recommendation.get('exploration', False):
|
||||
st.warning("🎲 탐험 행동 (무작위 선택)")
|
||||
else:
|
||||
confidence = recommendation.get('confidence', 0) * 100
|
||||
st.info(f"🎯 활용 행동 (신뢰도: {confidence:.1f}%)")
|
||||
|
||||
# Q값들 표시
|
||||
q_values = recommendation.get('q_values', {})
|
||||
if q_values:
|
||||
st.subheader("📊 현재 상태의 Q값들")
|
||||
q_df = pd.DataFrame([q_values]).T
|
||||
q_df.columns = ['Q값']
|
||||
q_df = q_df.sort_values('Q값', ascending=False)
|
||||
|
||||
# 추천 행동 하이라이트
|
||||
def highlight_recommended(s):
|
||||
return ['background-color: lightgreen' if x == recommendation.get('recommended_action')
|
||||
else '' for x in s.index]
|
||||
|
||||
st.dataframe(
|
||||
q_df.style.format({'Q값': '{:.3f}'}).apply(highlight_recommended, axis=0),
|
||||
use_container_width=True
|
||||
)
|
||||
else:
|
||||
st.error("행동 추천 실패")
|
||||
|
||||
with col2:
|
||||
st.subheader("⚖️ 정책 비교")
|
||||
|
||||
# 정책 비교 요청
|
||||
compare_response = APIClient.get(f"/compare/{selected_state}")
|
||||
if compare_response and compare_response.get("success"):
|
||||
comparison = compare_response["data"]
|
||||
|
||||
# 정책 일치 여부
|
||||
agreement = comparison["policy_agreement"]
|
||||
if agreement:
|
||||
st.success("✅ Q-Learning과 FQI+CQL 정책이 일치합니다!")
|
||||
else:
|
||||
st.warning("⚠️ Q-Learning과 FQI+CQL 정책이 다릅니다.")
|
||||
|
||||
# 각 정책의 추천 행동
|
||||
col_a, col_b = st.columns(2)
|
||||
|
||||
with col_a:
|
||||
st.subheader("🔄 Q-Learning 정책")
|
||||
q_learning = comparison["q_learning"]
|
||||
st.metric("추천 행동", q_learning["action"])
|
||||
|
||||
# Q값들
|
||||
q_values_ql = q_learning["q_values"]
|
||||
if q_values_ql:
|
||||
q_df_ql = pd.DataFrame([q_values_ql]).T
|
||||
q_df_ql.columns = ['Q값']
|
||||
st.dataframe(q_df_ql.style.format({'Q값': '{:.3f}'}))
|
||||
|
||||
with col_b:
|
||||
st.subheader("🧠 FQI+CQL 정책")
|
||||
fqi_cql = comparison["fqi_cql"]
|
||||
st.metric("추천 행동", fqi_cql["action"])
|
||||
|
||||
# Q값들
|
||||
q_values_fqi = fqi_cql["q_values"]
|
||||
if q_values_fqi:
|
||||
q_df_fqi = pd.DataFrame([q_values_fqi]).T
|
||||
q_df_fqi.columns = ['Q값']
|
||||
st.dataframe(q_df_fqi.style.format({'Q값': '{:.3f}'}))
|
||||
|
||||
# Q값 차이 분석
|
||||
differences = comparison["q_value_differences"]
|
||||
max_diff = comparison["max_difference"]
|
||||
|
||||
st.subheader("📊 Q값 차이 분석")
|
||||
st.metric("최대 차이", f"{max_diff:.3f}")
|
||||
|
||||
if differences:
|
||||
diff_df = pd.DataFrame([differences]).T
|
||||
diff_df.columns = ['차이']
|
||||
st.dataframe(diff_df.style.format({'차이': '{:.3f}'}))
|
||||
|
||||
else:
|
||||
st.info("정책 비교를 위해 상태를 선택하고 학습을 진행해주세요.")
|
||||
|
||||
# 보상 계산 시뮬레이션
|
||||
st.subheader("🧮 보상 계산 시뮬레이션")
|
||||
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
st.subheader("📋 시뮬레이션 설정")
|
||||
proposed_price = st.number_input("상대방 제안가", value=120.0, min_value=50.0, max_value=500.0)
|
||||
is_negotiation_end = st.checkbox("협상 종료", value=False)
|
||||
|
||||
if st.button("💰 보상 계산", type="secondary"):
|
||||
request_data = {
|
||||
"scenario": scenario,
|
||||
"price_zone": price_zone,
|
||||
"anchor_price": st.session_state.anchor_price,
|
||||
"proposed_price": proposed_price,
|
||||
"is_end": is_negotiation_end
|
||||
}
|
||||
|
||||
response = APIClient.post("/reward/calculate", request_data)
|
||||
if response and response.get("success") is not False:
|
||||
# response가 직접 RewardCalculationResponse 형태인 경우
|
||||
reward_result = response
|
||||
|
||||
col_a, col_b, col_c = st.columns(3)
|
||||
with col_a:
|
||||
st.metric("보상", f"{reward_result.get('reward', 0):.3f}")
|
||||
with col_b:
|
||||
st.metric("가중치 (W)", f"{reward_result.get('weight', 0):.3f}")
|
||||
with col_c:
|
||||
st.metric("가격 비율 (A/P)", f"{reward_result.get('price_ratio', 0):.3f}")
|
||||
|
||||
# 공식 분해 표시
|
||||
formula = reward_result.get('formula_breakdown', '')
|
||||
if formula:
|
||||
st.subheader('📝 계산 과정')
|
||||
st.text(formula)
|
||||
else:
|
||||
st.error("보상 계산 실패")
|
||||
|
||||
with col2:
|
||||
st.subheader("📈 학습 진행 상황")
|
||||
|
||||
# 시스템 상태 조회
|
||||
status_response = APIClient.get("/status")
|
||||
if status_response and status_response.get("success") is not False:
|
||||
status = status_response
|
||||
|
||||
# 진행 상황 메트릭
|
||||
total_exp = status.get("total_experiences", 0)
|
||||
updates = status.get("q_table_updates", 0)
|
||||
success_rate = status.get("success_rate", 0) * 100
|
||||
|
||||
progress_metrics = [
|
||||
("데이터 수집", total_exp, 1000, "개"),
|
||||
("Q-Table 업데이트", updates, 500, "회"),
|
||||
("협상 성공률", success_rate, 100, "%")
|
||||
]
|
||||
|
||||
for name, value, target, unit in progress_metrics:
|
||||
progress = min(value / target, 1.0)
|
||||
st.metric(
|
||||
name,
|
||||
f"{value}{unit}",
|
||||
delta=f"목표: {target}{unit}"
|
||||
)
|
||||
st.progress(progress)
|
||||
|
||||
st.subheader("🎓 학습 완성도")
|
||||
|
||||
# Q-Table 완성도
|
||||
qtable_response = APIClient.get("/qtable")
|
||||
if qtable_response and qtable_response.get("success"):
|
||||
qtable_data = qtable_response["data"]
|
||||
statistics = qtable_data["statistics"]
|
||||
|
||||
sparsity = statistics.get('q_table_sparsity', 1.0)
|
||||
completeness = (1 - sparsity) * 100
|
||||
|
||||
st.metric("Q-Table 완성도", f"{completeness:.1f}%")
|
||||
st.progress(completeness / 100)
|
||||
|
||||
if completeness > 80:
|
||||
st.success("🎉 충분히 학습되었습니다!")
|
||||
elif completeness > 50:
|
||||
st.info("📖 적당히 학습되었습니다.")
|
||||
else:
|
||||
st.warning("📚 더 많은 학습이 필요합니다.")
|
||||
|
||||
|
||||
def main():
|
||||
"""메인 함수"""
|
||||
# 헤더 표시
|
||||
display_header()
|
||||
|
||||
# 사이드바 표시
|
||||
sidebar_status = display_sidebar()
|
||||
|
||||
# 탭 생성
|
||||
tab1, tab2, tab3, tab4, tab5 = st.tabs([
|
||||
"🏁 1. 콜드 스타트",
|
||||
"📊 2. 데이터 수집",
|
||||
"🔄 3. Q-Learning",
|
||||
"🧠 4. FQI+CQL",
|
||||
"🎯 5. 학습된 정책"
|
||||
])
|
||||
|
||||
with tab1:
|
||||
tab_cold_start()
|
||||
|
||||
with tab2:
|
||||
tab_data_collection()
|
||||
|
||||
with tab3:
|
||||
tab_q_learning()
|
||||
|
||||
with tab4:
|
||||
tab_fqi_cql()
|
||||
|
||||
with tab5:
|
||||
tab_learned_policy()
|
||||
|
||||
# 푸터
|
||||
st.markdown("---")
|
||||
st.markdown("""
|
||||
<div style='text-align: center; color: #666;'>
|
||||
<p>🎯 Q-Table 기반 협상 전략 데모 | 강화학습의 전체 여정을 경험해보세요</p>
|
||||
<p>💡 문의사항이 있으시면 API 문서를 참고해주세요: <a href="http://localhost:8000/docs" target="_blank">http://localhost:8000/docs</a></p>
|
||||
</div>
|
||||
""", unsafe_allow_html=True)
|
||||
|
||||
|
||||
def start_frontend():
|
||||
"""프론트엔드 시작 (Poetry 스크립트용)"""
|
||||
import subprocess
|
||||
subprocess.run(["streamlit", "run", "frontend/app.py", "--server.port", "8501"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -0,0 +1,162 @@
|
|||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
Pipfile.lock
|
||||
|
||||
# poetry
|
||||
poetry.lock
|
||||
|
||||
# pdm
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
.idea/
|
||||
|
||||
# VSCode
|
||||
.vscode/
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
.DS_Store?
|
||||
._*
|
||||
.Spotlight-V100
|
||||
.Trashes
|
||||
ehthumbs.db
|
||||
Thumbs.db
|
||||
|
||||
# Application specific
|
||||
data/
|
||||
logs/
|
||||
*.sqlite
|
||||
*.db
|
||||
temp/
|
||||
cache/
|
||||
|
||||
# Streamlit
|
||||
.streamlit/
|
||||
|
|
@ -0,0 +1,35 @@
|
|||
[tool.poetry]
|
||||
name = "qtable-negotiation-demo"
|
||||
version = "0.1.0"
|
||||
description = "Q-Table 기반 협상 전략 강화학습 데모"
|
||||
authors = ["Demo Author <demo@example.com>"]
|
||||
readme = "README.md"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.8"
|
||||
fastapi = "^0.104.1"
|
||||
uvicorn = {extras = ["standard"], version = "^0.24.0"}
|
||||
streamlit = "^1.28.0"
|
||||
pandas = "^2.1.0"
|
||||
numpy = "^1.24.0"
|
||||
matplotlib = "^3.7.0"
|
||||
seaborn = "^0.12.0"
|
||||
plotly = "^5.17.0"
|
||||
python-dotenv = "^1.0.0"
|
||||
pydantic = "^1.10.12"
|
||||
requests = "^2.31.0"
|
||||
scikit-learn = "^1.3.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pytest = "^7.4.0"
|
||||
black = "^23.9.0"
|
||||
flake8 = "^6.1.0"
|
||||
mypy = "^1.6.0"
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.poetry.scripts]
|
||||
start-api = "app.main:start_api"
|
||||
start-frontend = "frontend.app:start_frontend"
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
fastapi==0.104.1
|
||||
uvicorn[standard]==0.24.0
|
||||
streamlit==1.28.0
|
||||
pandas==2.1.0
|
||||
numpy==1.24.0
|
||||
matplotlib==3.7.0
|
||||
seaborn==0.12.0
|
||||
plotly==5.17.0
|
||||
python-dotenv==1.0.0
|
||||
pydantic==1.10.12
|
||||
requests==2.31.0
|
||||
scikit-learn==1.3.0
|
||||
pytest==7.4.0
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
API 서버 실행 스크립트
|
||||
"""
|
||||
import uvicorn
|
||||
from app.main import app
|
||||
from app.core.config import settings
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("🚀 Q-Table 협상 전략 데모 API 서버를 시작합니다...")
|
||||
print(f"📍 주소: http://{settings.api_host}:{settings.api_port}")
|
||||
print(f"📚 API 문서: http://{settings.api_host}:{settings.api_port}/docs")
|
||||
print("🛑 종료하려면 Ctrl+C를 누르세요")
|
||||
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host=settings.api_host,
|
||||
port=settings.api_port,
|
||||
reload=True,
|
||||
log_level="info"
|
||||
)
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
API와 프론트엔드 동시 실행 스크립트
|
||||
"""
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from threading import Thread
|
||||
from app.core.config import settings
|
||||
|
||||
def run_api():
|
||||
"""API 서버 실행"""
|
||||
subprocess.run([
|
||||
sys.executable, "run_api.py"
|
||||
])
|
||||
|
||||
def run_frontend():
|
||||
"""프론트엔드 실행"""
|
||||
# API 서버가 시작될 시간을 기다림
|
||||
time.sleep(3)
|
||||
subprocess.run([
|
||||
sys.executable, "run_frontend.py"
|
||||
])
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("🚀 Q-Table 협상 전략 데모 전체 시스템을 시작합니다...")
|
||||
print(f"🔧 API 서버: http://{settings.api_host}:{settings.api_port}")
|
||||
print(f"🎯 프론트엔드: http://{settings.frontend_host}:{settings.frontend_port}")
|
||||
print("🛑 종료하려면 Ctrl+C를 누르세요")
|
||||
|
||||
try:
|
||||
# API 서버를 별도 스레드에서 실행
|
||||
api_thread = Thread(target=run_api, daemon=True)
|
||||
api_thread.start()
|
||||
|
||||
# 프론트엔드 실행 (메인 스레드)
|
||||
run_frontend()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 시스템을 종료합니다.")
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
프론트엔드 실행 스크립트
|
||||
"""
|
||||
import subprocess
|
||||
import sys
|
||||
from app.core.config import settings
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("🎯 Q-Table 협상 전략 데모 프론트엔드를 시작합니다...")
|
||||
print(f"📍 주소: http://{settings.frontend_host}:{settings.frontend_port}")
|
||||
print("🛑 종료하려면 Ctrl+C를 누르세요")
|
||||
|
||||
try:
|
||||
subprocess.run([
|
||||
sys.executable, "-m", "streamlit", "run",
|
||||
"frontend/app.py",
|
||||
"--server.port", str(settings.frontend_port),
|
||||
"--server.address", settings.frontend_host
|
||||
])
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 프론트엔드를 종료합니다.")
|
||||
|
|
@ -0,0 +1,180 @@
|
|||
"""
|
||||
기본 테스트 모듈
|
||||
"""
|
||||
import pytest
|
||||
import requests
|
||||
from app.services.negotiation_env import NegotiationEnvironment
|
||||
from app.services.qtable_learner import QTableLearner, ExperienceBuffer
|
||||
from app.models.schemas import CardType, ScenarioType, PriceZoneType
|
||||
|
||||
|
||||
class TestNegotiationEnvironment:
|
||||
"""협상 환경 테스트"""
|
||||
|
||||
def setup_method(self):
|
||||
self.env = NegotiationEnvironment()
|
||||
|
||||
def test_reward_calculation(self):
|
||||
"""보상 계산 테스트"""
|
||||
reward, weight = self.env.calculate_reward(
|
||||
scenario=ScenarioType.A,
|
||||
price_zone=PriceZoneType.PZ1,
|
||||
anchor_price=100,
|
||||
proposed_price=95,
|
||||
is_end=True
|
||||
)
|
||||
|
||||
assert reward > 0
|
||||
assert 0 <= weight <= 1
|
||||
|
||||
def test_price_zone_determination(self):
|
||||
"""가격 구간 결정 테스트"""
|
||||
# 목표가 이하
|
||||
zone = self.env.get_price_zone(90, 100)
|
||||
assert zone == PriceZoneType.PZ1
|
||||
|
||||
# 중간 구간
|
||||
zone = self.env.get_price_zone(110, 100)
|
||||
assert zone == PriceZoneType.PZ2
|
||||
|
||||
# 높은 구간
|
||||
zone = self.env.get_price_zone(150, 100)
|
||||
assert zone == PriceZoneType.PZ3
|
||||
|
||||
def test_opponent_response_simulation(self):
|
||||
"""상대방 응답 시뮬레이션 테스트"""
|
||||
price = self.env.simulate_opponent_response(
|
||||
current_card=CardType.C1,
|
||||
scenario=ScenarioType.A,
|
||||
anchor_price=100,
|
||||
step=0
|
||||
)
|
||||
|
||||
assert price > 0
|
||||
assert isinstance(price, float)
|
||||
|
||||
|
||||
class TestQTableLearner:
|
||||
"""Q-Table 학습 테스트"""
|
||||
|
||||
def setup_method(self):
|
||||
states = ["S1", "S2", "S3"]
|
||||
actions = [CardType.C1, CardType.C2]
|
||||
self.learner = QTableLearner(states, actions)
|
||||
|
||||
def test_initialization(self):
|
||||
"""초기화 테스트"""
|
||||
assert self.learner.q_table.shape == (3, 2)
|
||||
assert (self.learner.q_table == 0).all().all()
|
||||
|
||||
def test_q_value_update(self):
|
||||
"""Q값 업데이트 테스트"""
|
||||
td_error = self.learner.update_q_value(
|
||||
state="S1",
|
||||
action=CardType.C1,
|
||||
reward=1.0,
|
||||
next_state="S2",
|
||||
done=False
|
||||
)
|
||||
|
||||
assert td_error != 0
|
||||
assert self.learner.get_q_value("S1", CardType.C1) != 0
|
||||
|
||||
def test_action_selection(self):
|
||||
"""행동 선택 테스트"""
|
||||
# 초기 상태에서는 무작위 선택
|
||||
action, is_exploration = self.learner.select_action("S1")
|
||||
assert action in [CardType.C1, CardType.C2]
|
||||
|
||||
# Q값 설정 후 최적 행동 선택
|
||||
self.learner.set_q_value("S1", CardType.C2, 1.0)
|
||||
optimal_action = self.learner.get_optimal_action("S1")
|
||||
assert optimal_action == CardType.C2
|
||||
|
||||
|
||||
class TestExperienceBuffer:
|
||||
"""경험 버퍼 테스트"""
|
||||
|
||||
def setup_method(self):
|
||||
self.buffer = ExperienceBuffer(max_size=10)
|
||||
|
||||
def test_add_experience(self):
|
||||
"""경험 추가 테스트"""
|
||||
self.buffer.add_experience(
|
||||
state="S1",
|
||||
action=CardType.C1,
|
||||
reward=1.0,
|
||||
next_state="S2",
|
||||
done=False
|
||||
)
|
||||
|
||||
assert self.buffer.size() == 1
|
||||
|
||||
def test_buffer_overflow(self):
|
||||
"""버퍼 오버플로우 테스트"""
|
||||
# 최대 크기보다 많이 추가
|
||||
for i in range(15):
|
||||
self.buffer.add_experience(
|
||||
state=f"S{i}",
|
||||
action=CardType.C1,
|
||||
reward=1.0,
|
||||
next_state=f"S{i+1}",
|
||||
done=False
|
||||
)
|
||||
|
||||
# 최대 크기 유지
|
||||
assert self.buffer.size() == 10
|
||||
|
||||
def test_sampling(self):
|
||||
"""샘플링 테스트"""
|
||||
# 경험 추가
|
||||
for i in range(5):
|
||||
self.buffer.add_experience(
|
||||
state=f"S{i}",
|
||||
action=CardType.C1,
|
||||
reward=1.0,
|
||||
next_state=f"S{i+1}",
|
||||
done=False
|
||||
)
|
||||
|
||||
# 배치 샘플링
|
||||
batch = self.buffer.sample_batch(3)
|
||||
assert len(batch) == 3
|
||||
|
||||
|
||||
# API 통합 테스트 (선택사항)
|
||||
class TestAPIIntegration:
|
||||
"""API 통합 테스트"""
|
||||
|
||||
def test_health_check(self):
|
||||
"""헬스 체크 테스트"""
|
||||
try:
|
||||
response = requests.get("http://localhost:8000/api/v1/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
except requests.exceptions.ConnectionError:
|
||||
pytest.skip("API 서버가 실행되지 않음")
|
||||
|
||||
def test_reward_calculation_endpoint(self):
|
||||
"""보상 계산 엔드포인트 테스트"""
|
||||
try:
|
||||
payload = {
|
||||
"scenario": "A",
|
||||
"price_zone": "PZ1",
|
||||
"anchor_price": 100,
|
||||
"proposed_price": 95,
|
||||
"is_end": True
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
"http://localhost:8000/api/v1/reward/calculate",
|
||||
json=payload
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "reward" in data
|
||||
assert "weight" in data
|
||||
except requests.exceptions.ConnectionError:
|
||||
pytest.skip("API 서버가 실행되지 않음")
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
# Q-Table 프로젝트 검토 진행 상황
|
||||
|
||||
## Phase 1: 프로젝트 파일 분석 및 구조 파악 ✅
|
||||
- [x] 프로젝트 파일 복사 및 구조 확인
|
||||
- [x] README.md 분석 - 프로젝트 개요 파악
|
||||
- [x] requirements.txt 및 pyproject.toml 확인
|
||||
- [x] 주요 코드 파일들 구조 분석
|
||||
- [x] 프로젝트 디렉토리 구조 정리
|
||||
|
||||
## Phase 2: 핵심 코드 검토 및 문제점 식별 ✅
|
||||
- [x] qtable_learner.py 검토 - import 문제 발견
|
||||
- [x] negotiation_env.py 검토 - import 문제 발견
|
||||
- [x] app.py (Streamlit 프론트엔드) 검토
|
||||
- [x] schemas.py 검토 - 정상
|
||||
- [x] main(1).py (FastAPI 백엔드) 검토 - import 문제 발견
|
||||
- [x] config.py 검토 - 정상
|
||||
- [x] endpoints.py 검토 - import 문제 발견
|
||||
- [x] demo_service.py 검토 - import 문제 발견
|
||||
- [x] 실행 스크립트들 검토 - import 문제 발견
|
||||
- [x] 문제점 정리
|
||||
|
||||
## 발견된 주요 문제점:
|
||||
1. **파일 구조 불일치**: README에서 언급된 app/ 디렉토리 구조가 실제와 다름 ✅ 해결
|
||||
2. **Import 경로 오류**: `from app.models.schemas import` 등의 경로가 잘못됨 ✅ 해결
|
||||
3. **파일명 중복**: main(1).py, __init__(1).py 등 중복된 파일명 ✅ 해결
|
||||
|
||||
## Phase 3: 의존성 및 환경 설정 검증 ✅
|
||||
- [x] 의존성 설치 테스트 - 성공
|
||||
- [x] 환경 변수 설정 확인 - .env 파일 생성
|
||||
- [x] 프로젝트 구조 수정 필요 여부 확인 - 수정 완료
|
||||
|
||||
## Phase 4: 코드 실행 테스트 및 오류 수정 ✅
|
||||
- [x] 프로젝트 구조 재구성
|
||||
- [x] Import 경로 수정
|
||||
- [x] 모듈 import 테스트 - 모두 성공
|
||||
- [x] API 서버 실행 테스트 - 성공
|
||||
- [x] 기본 테스트 실행 - 성공
|
||||
|
||||
## Phase 5: 수정된 프로젝트 결과 보고 ✅
|
||||
- [x] 수정 사항 정리
|
||||
- [x] 검토 보고서 작성
|
||||
- [x] 최종 실행 가능한 프로젝트 제공
|
||||
|
||||
## 🎉 프로젝트 수정 완료!
|
||||
모든 문제점이 해결되어 Q-Table 데모 프로젝트가 정상적으로 실행 가능한 상태가 되었습니다.
|
||||
|
||||
Loading…
Reference in New Issue