Initial commit
commit
94bbc309fd
|
|
@ -0,0 +1,5 @@
|
|||
.venv/
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
.DS_Store
|
||||
.env
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
aiohappyeyeballs==2.6.1
|
||||
aiohttp==3.13.2
|
||||
aiosignal==1.4.0
|
||||
aiosqlite==0.21.0
|
||||
annotated-doc==0.0.4
|
||||
annotated-types==0.7.0
|
||||
anyio==4.11.0
|
||||
asyncpg==0.31.0
|
||||
attrs==25.4.0
|
||||
certifi==2025.11.12
|
||||
charset-normalizer==3.4.4
|
||||
click==8.3.1
|
||||
cloudpickle==3.1.2
|
||||
contourpy==1.3.3
|
||||
cycler==0.12.1
|
||||
distro==1.9.0
|
||||
Farama-Notifications==0.0.4
|
||||
fastapi==0.122.0
|
||||
fonttools==4.60.1
|
||||
frozenlist==1.8.0
|
||||
greenlet==3.2.4
|
||||
gunicorn==23.0.0
|
||||
gymnasium==1.2.2
|
||||
h11==0.16.0
|
||||
h5py==3.15.1
|
||||
httpcore==1.0.9
|
||||
httpx==0.28.1
|
||||
idna==3.11
|
||||
iniconfig==2.3.0
|
||||
JayDeBeApi==1.2.3
|
||||
Jinja2==3.1.6
|
||||
jiter==0.12.0
|
||||
jpype1==1.6.0
|
||||
jsonpatch==1.33
|
||||
jsonpointer==3.0.0
|
||||
kiwisolver==1.4.9
|
||||
langchain-core==1.1.0
|
||||
langchain-openai==1.1.0
|
||||
langgraph==1.0.4
|
||||
langgraph-checkpoint==3.0.1
|
||||
langgraph-checkpoint-sqlite==3.0.0
|
||||
langgraph-prebuilt==1.0.5
|
||||
langgraph-sdk==0.2.10
|
||||
langsmith==0.4.49
|
||||
MarkupSafe==3.0.3
|
||||
matplotlib==3.10.7
|
||||
multidict==6.7.0
|
||||
numpy==2.3.5
|
||||
openai==2.8.1
|
||||
orjson==3.11.4
|
||||
ormsgpack==1.12.0
|
||||
packaging==25.0
|
||||
passlib==1.7.4
|
||||
pillow==12.0.0
|
||||
pluggy==1.6.0
|
||||
propcache==0.4.1
|
||||
psycopg2-binary==2.9.11
|
||||
pydantic==2.12.5
|
||||
pydantic-settings==2.12.0
|
||||
pydantic_core==2.41.5
|
||||
Pygments==2.19.2
|
||||
PyJWT==2.10.1
|
||||
pyparsing==3.2.5
|
||||
pypdf==6.1.3
|
||||
pytest==9.0.1
|
||||
pytest-asyncio==1.3.0
|
||||
python-dateutil==2.9.0.post0
|
||||
python-dotenv==1.2.1
|
||||
python-multipart==0.0.20
|
||||
pytz==2025.2
|
||||
PyYAML==6.0.3
|
||||
regex==2025.11.3
|
||||
requests==2.32.5
|
||||
requests-toolbelt==1.0.0
|
||||
six==1.17.0
|
||||
sniffio==1.3.1
|
||||
SQLAlchemy==2.0.44
|
||||
sqlite-vec==0.1.6
|
||||
starlette==0.50.0
|
||||
tenacity==9.1.2
|
||||
tiktoken==0.12.0
|
||||
tqdm==4.67.1
|
||||
typing-inspection==0.4.2
|
||||
typing_extensions==4.15.0
|
||||
urllib3==2.5.0
|
||||
uvicorn==0.38.0
|
||||
xxhash==3.6.0
|
||||
yarl==1.22.0
|
||||
zstandard==0.25.0
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
#!/bin/bash
|
||||
python3 -m venv .venv
|
||||
source .venv/bin/activate
|
||||
pip install -r requirements.txt
|
||||
echo "Setup complete. Activate with 'source .venv/bin/activate'"
|
||||
|
|
@ -0,0 +1,63 @@
|
|||
# 변경 사항 (Changelog)
|
||||
|
||||
## Version 3.0 - 비즈니스 용어 적용 (최종)
|
||||
|
||||
### 주요 변경 사항
|
||||
|
||||
#### 1. State 변수명 최종 확정 (03_state_design.tex)
|
||||
|
||||
**Version 2.0 (이전):**
|
||||
|
||||
- 견적 금액 구간, 제조사 구분, 파트너사 구조, 가격 수용률 분위, 현재 가격 구간
|
||||
|
||||
**Version 3.0 (최종):**
|
||||
|
||||
- 매출액 가격구간, 유통 구조, 파트너사 종류, 가격 수용률 구간, 입력 금액 구간
|
||||
|
||||
**변경 이유:**
|
||||
|
||||
- 실제 비즈니스에서 사용하는 용어로 통일하여 가독성 및 이해도 향상
|
||||
|
||||
**최종 State 변수명:**
|
||||
|
||||
| 순번 | 변수 | 구분 | 설명 |
|
||||
| :--- | :------------------- | :--- | :--------------------------------------------------------- |
|
||||
| 1 | **매출액 가격구간** | 3개 | Low (≤1,000만원), Mid (1,000~3,000만원), High (>3,000만원) |
|
||||
| 2 | **유통 구조** | 3개 | 제조, 총판, 유통 |
|
||||
| 3 | **파트너사 종류** | 3개 | Single (단독), Multiple (다수), None (없음) |
|
||||
| 4 | **가격 수용률 구간** | 3개 | Low (<30%), Mid (30~90%), High (>90%) |
|
||||
| 5 | **입력 금액 구간** | 2개 | PZ1 (A ≤ P ≤ T), PZ2 (P > T) |
|
||||
|
||||
#### 2. 동적 가중치 (W) 설계 업데이트 (04_reward_function.tex)
|
||||
|
||||
- 변수명 변경에 따라 수식 및 테이블의 변수명도 모두 업데이트
|
||||
- `S_manu` → `S_dist`
|
||||
- 계산 로직 및 가중치 값은 동일하게 유지
|
||||
|
||||
```
|
||||
W_raw = w1×S_amount + w2×S_dist + w3×S_partner + w4×S_accept + w5×S_pricezone
|
||||
```
|
||||
|
||||
**새로운 유통 구조 가중치 설계 의도:**
|
||||
|
||||
| 유통 구조 | 정규화 값 | 설계 의도 |
|
||||
| :-------- | :-------- | :----------------------------------------- |
|
||||
| 제조 | 0.2 | 제조사 직공급으로 가격 협상 여지 가장 적음 |
|
||||
| 총판 | 0.5 | 중간 유통 단계로 협상 여지 중간 |
|
||||
| 유통 | 1.0 | 복잡한 유통 단계로 가격 협상 여지 가장 큼 |
|
||||
|
||||
### 파일 변경 내역
|
||||
|
||||
- `sections/03_state_design.tex`: 변수명 변경
|
||||
- `sections/04_reward_function.tex`: 변수명 변경
|
||||
- `EXAMPLE_CALCULATION.md`: 변수명 변경
|
||||
- `CHANGELOG.md`: 이 파일 (Version 4.0 반영)
|
||||
|
||||
### 최종 검증
|
||||
|
||||
- [x] State 변수명 5개 모두 비즈니스 용어로 변경 완료
|
||||
- [x] 총 State 수 162개 유지
|
||||
- [x] 동적 가중치 W 계산식 변수명 업데이트 완료
|
||||
- [x] 계산 예시 변수명 업데이트 완료
|
||||
|
||||
이 버전은 기능적 변경 없이, 문서의 가독성과 현업 적용성을 높이는 데 중점을 둔 최종 버전입니다.
|
||||
|
|
@ -0,0 +1,173 @@
|
|||
# Q-Table Version 3.0 Migration Guide
|
||||
|
||||
## 개요
|
||||
|
||||
Q-Table 프로젝트가 Version 2.0 (36 states)에서 Version 3.0 (162 states)으로 업그레이드되었습니다.
|
||||
|
||||
## 주요 변경사항
|
||||
|
||||
### 1. State 구조 변경
|
||||
|
||||
#### Before (Version 2.0 - 36 states)
|
||||
|
||||
```python
|
||||
from negotiation_agent.Q_Table.domain.model.state import (
|
||||
State, Scenario, PriceZone, AcceptanceRate
|
||||
)
|
||||
|
||||
state = State(
|
||||
scenario=Scenario.PRICE_FIRST, # 4개 값
|
||||
price_zone=PriceZone.AT_OR_BELOW_ANCHOR, # 3개 값
|
||||
acceptance_rate=AcceptanceRate.MEDIUM # 3개 값
|
||||
)
|
||||
# 총 4 × 3 × 3 = 36 states
|
||||
```
|
||||
|
||||
#### After (Version 3.0 - 162 states)
|
||||
|
||||
```python
|
||||
from negotiation_agent.Q_Table.domain.model.state import (
|
||||
State,
|
||||
RevenueRange,
|
||||
DistributionStructure,
|
||||
PartnerType,
|
||||
AcceptanceRate,
|
||||
InputPriceZone,
|
||||
)
|
||||
|
||||
state = State(
|
||||
revenue_range=RevenueRange.MID, # 3개 값
|
||||
distribution=DistributionStructure.WHOLESALER, # 3개 값
|
||||
partner_type=PartnerType.SINGLE, # 3개 값
|
||||
acceptance_rate=AcceptanceRate.MID, # 3개 값
|
||||
input_price_zone=InputPriceZone.PZ1, # 2개 값
|
||||
)
|
||||
# 총 3 × 3 × 3 × 3 × 2 = 162 states
|
||||
```
|
||||
|
||||
### 2. State Builder 변경
|
||||
|
||||
#### Before (Version 2.0)
|
||||
|
||||
```python
|
||||
from negotiation_agent.Q_Table.domain.service.state_calculator import (
|
||||
NegotiationSnapshot, build_state
|
||||
)
|
||||
|
||||
snapshot = NegotiationSnapshot(
|
||||
scenario_code="A",
|
||||
anchor_price=10000,
|
||||
target_price=12000,
|
||||
seller_initial_price=15000,
|
||||
current_price=11000,
|
||||
)
|
||||
state = build_state(snapshot)
|
||||
```
|
||||
|
||||
#### After (Version 3.0)
|
||||
|
||||
```python
|
||||
from negotiation_agent.Q_Table.domain.service.state_calculator import (
|
||||
NegotiationSnapshot, build_state
|
||||
)
|
||||
|
||||
snapshot = NegotiationSnapshot(
|
||||
revenue_amount=2000, # 매출액 (만원)
|
||||
distribution_code="W", # "M": 제조, "W": 총판, "R": 유통
|
||||
partner_count=1, # 파트너사 수
|
||||
anchor_price=10000,
|
||||
target_price=12000,
|
||||
input_price=11000,
|
||||
acceptance_ratio=0.5, # 0~1 사이 값
|
||||
)
|
||||
state = build_state(snapshot)
|
||||
```
|
||||
|
||||
### 3. Reward 계산 변경
|
||||
|
||||
#### Before (Version 2.0)
|
||||
|
||||
```python
|
||||
from negotiation_agent.Q_Table.domain.service.reward_calculator import (
|
||||
calculate_reward, NegotiationOutcome
|
||||
)
|
||||
|
||||
breakdown = calculate_reward(
|
||||
scenario=state.scenario,
|
||||
price_zone=state.price_zone,
|
||||
current_price=11000,
|
||||
anchor_price=10000,
|
||||
target_price=12000,
|
||||
round_number=3,
|
||||
outcome=NegotiationOutcome.ONGOING,
|
||||
)
|
||||
```
|
||||
|
||||
#### After (Version 3.0)
|
||||
|
||||
```python
|
||||
from negotiation_agent.Q_Table.domain.service.reward_calculator import (
|
||||
calculate_reward, NegotiationOutcome
|
||||
)
|
||||
|
||||
breakdown = calculate_reward(
|
||||
revenue_range=state.revenue_range,
|
||||
distribution=state.distribution,
|
||||
partner_type=state.partner_type,
|
||||
acceptance_rate=state.acceptance_rate,
|
||||
input_price_zone=state.input_price_zone,
|
||||
current_price=11000,
|
||||
anchor_price=10000,
|
||||
target_price=12000,
|
||||
round_number=3,
|
||||
outcome=NegotiationOutcome.ONGOING,
|
||||
)
|
||||
```
|
||||
|
||||
### 4. Q-Table 크기 변경
|
||||
|
||||
#### Before
|
||||
|
||||
```python
|
||||
q_table = QTable(state_space_size=36, action_space_size=21)
|
||||
visit_table = VisitTable(state_space_size=36, action_space_size=21)
|
||||
```
|
||||
|
||||
#### After
|
||||
|
||||
```python
|
||||
q_table = QTable(state_space_size=162, action_space_size=21)
|
||||
visit_table = VisitTable(state_space_size=162, action_space_size=21)
|
||||
```
|
||||
|
||||
## 마이그레이션 체크리스트
|
||||
|
||||
- [ ] State 생성 코드를 새로운 5-변수 구조로 변경
|
||||
- [ ] NegotiationSnapshot 생성 코드 업데이트
|
||||
- [ ] calculate_reward() 호출부 매개변수 변경
|
||||
- [ ] Q-Table, VisitTable 초기화 시 state_space_size를 162로 변경
|
||||
- [ ] 기존 학습된 모델 파일(.npy) 재학습 필요 (36→162 차원 불일치)
|
||||
- [ ] 단위 테스트 업데이트
|
||||
|
||||
## 호환성 주의사항
|
||||
|
||||
⚠️ **기존 학습 모델 사용 불가**
|
||||
|
||||
- Version 2.0에서 학습된 Q-Table (36 × N)은 Version 3.0 (162 × N)과 호환되지 않습니다.
|
||||
- 기존 모델을 사용하려면 재학습이 필요합니다.
|
||||
|
||||
⚠️ **Experience 데이터 재수집 권장**
|
||||
|
||||
- 기존 Experience에는 새로운 State 변수 정보가 없습니다.
|
||||
- 새로운 State 정의에 맞춰 Experience를 재수집하는 것을 권장합니다.
|
||||
|
||||
## 변경 이유
|
||||
|
||||
1. **비즈니스 용어 적용**: 실제 협상 도메인에서 사용하는 용어로 통일
|
||||
2. **더 세밀한 상태 표현**: 매출액, 유통 구조, 파트너사 정보 등을 명시적으로 포함
|
||||
3. **확장성 향상**: 새로운 비즈니스 변수 추가 시 유연한 대응 가능
|
||||
4. **협상 전략 다양화**: 162개의 상태로 더 정교한 협상 전략 학습 가능
|
||||
|
||||
## 문의
|
||||
|
||||
변경사항에 대한 문의는 팀 리드에게 연락해주세요.
|
||||
|
|
@ -0,0 +1,364 @@
|
|||
# Q_Table 리팩토링 완료 요약
|
||||
|
||||
## 📅 작업 일자
|
||||
|
||||
- 2025-10-29: Version 2.0 (36 states)
|
||||
- 2025-11-10: **Version 3.0 (162 states)** - 비즈니스 용어 적용 및 State 재설계
|
||||
|
||||
## ✅ Version 3.0 주요 변경사항
|
||||
|
||||
### State 공간 재설계 (36 → 162 states)
|
||||
|
||||
**기존 Version 2.0:**
|
||||
|
||||
- State = (Scenario, PriceZone, AcceptanceRate)
|
||||
- 총 상태 수: 4 × 3 × 3 = **36 states**
|
||||
|
||||
**신규 Version 3.0:**
|
||||
|
||||
- State = (매출액 가격구간, 유통 구조, 파트너사 종류, 가격 수용률 구간, 입력 금액 구간)
|
||||
- 총 상태 수: 3 × 3 × 3 × 3 × 2 = **162 states**
|
||||
|
||||
| 순번 | 변수 | 구분 | 설명 |
|
||||
| :--- | :------------------- | :--- | :--------------------------------------------------------- |
|
||||
| 1 | **매출액 가격구간** | 3개 | Low (≤1,000만원), Mid (1,000~3,000만원), High (>3,000만원) |
|
||||
| 2 | **유통 구조** | 3개 | 제조, 총판, 유통 |
|
||||
| 3 | **파트너사 종류** | 3개 | Single (단독), Multiple (다수), None (없음) |
|
||||
| 4 | **가격 수용률 구간** | 3개 | Low (<30%), Mid (30~90%), High (>90%) |
|
||||
| 5 | **입력 금액 구간** | 2개 | PZ1 (A ≤ P ≤ T), PZ2 (P > T) |
|
||||
|
||||
### 동적 가중치 (W) 설계 업데이트
|
||||
|
||||
**기존 Version 2.0:**
|
||||
|
||||
```
|
||||
W = (scenario.weight + price_zone.weight) / 2.0
|
||||
```
|
||||
|
||||
**신규 Version 3.0:**
|
||||
|
||||
```
|
||||
W_raw = w1×S_amount + w2×S_dist + w3×S_partner + w4×S_accept + w5×S_pricezone
|
||||
W = clip(W_raw, 0.2, 0.8)
|
||||
```
|
||||
|
||||
기본 가중치 계수:
|
||||
|
||||
- w1 = 0.20 (매출액 가격구간)
|
||||
- w2 = 0.25 (유통 구조)
|
||||
- w3 = 0.20 (파트너사 종류)
|
||||
- w4 = 0.25 (가격 수용률)
|
||||
- w5 = 0.10 (입력 금액 구간)
|
||||
|
||||
## ✅ 완료된 작업 (Version 2.0)
|
||||
|
||||
### 1. 새로운 핵심 컴포넌트 구현
|
||||
|
||||
#### **QTable** (`domain/model/q_table.py`)
|
||||
|
||||
- ✅ 동적 `action_space_size` 지원 (21개 카드)
|
||||
- ✅ Q-Learning 업데이트 메서드
|
||||
- ✅ 직렬화/역직렬화 지원
|
||||
- ✅ 인덱스 유효성 검증
|
||||
|
||||
#### **EpisodePolicy** (`domain/agents/policy.py`)
|
||||
|
||||
- ✅ 하드코딩된 크기(9) → 동적 크기로 변경
|
||||
- ✅ 생성자에서 `action_space_size` 주입
|
||||
- ✅ `get_action_mask()` 동적 크기 적용
|
||||
|
||||
#### **State** (`domain/model/state.py`)
|
||||
|
||||
- ✅ `to_index()` 메서드 추가 (State → 1D 인덱스)
|
||||
- ✅ `from_index()` 메서드 추가 (1D 인덱스 → State)
|
||||
|
||||
#### **ActionCardMapper** (`integration/action_card_mapper.py`) 🆕
|
||||
|
||||
- ✅ action_id ↔ card_id 양방향 매핑
|
||||
- ✅ JSON 파일 기반 설정
|
||||
- ✅ 유틸리티 메서드 제공
|
||||
|
||||
#### **action_card_mapping.json** (`integration/data/`) 🆕
|
||||
|
||||
- ✅ 21개 카드 매핑 (0-20 → "no_0" ~ "no_20")
|
||||
|
||||
#### **GetBestActionUsecase** (`usecase/get_best_action_usecase.py`)
|
||||
|
||||
- ✅ State → action_id → card_id 전체 플로우
|
||||
- ✅ Policy 사용 여부 선택 가능
|
||||
- ✅ Top-K 추천 기능
|
||||
- ✅ 사용 가능한 액션/카드 조회
|
||||
|
||||
#### **CollectExperienceUsecase** (`usecase/collect_experience_usecase.py`) 🆕
|
||||
|
||||
- ✅ 협상 중 Experience 수집
|
||||
- ✅ 에피소드 단위 관리
|
||||
- ✅ JSONL 형식 자동 저장
|
||||
- ✅ 수집 정보 조회
|
||||
|
||||
#### **TrainOfflineUsecase** (`usecase/train_offline_usecase.py`) 🆕
|
||||
|
||||
- ✅ 저장된 Experience로 Q-Table 학습
|
||||
- ✅ 배치 학습 및 에포크 반복
|
||||
- ✅ 특정 에피소드 선택 학습
|
||||
- ✅ 성능 평가 기능
|
||||
|
||||
### 2. 레거시 코드 정리
|
||||
|
||||
#### 제거된 파일
|
||||
|
||||
```
|
||||
❌ domain/action_space.py
|
||||
❌ domain/model/action.py
|
||||
❌ domain/constants.py
|
||||
❌ domain/spaces.py
|
||||
```
|
||||
|
||||
#### Legacy로 이동된 파일 (참고용 보관)
|
||||
|
||||
```
|
||||
📦 legacy/environment.py
|
||||
📦 legacy/calculate_reward_usecase.py
|
||||
📦 legacy/evaluate_agent_usecase.py
|
||||
📦 legacy/execute_step_usecase.py
|
||||
📦 legacy/get_q_value_usecase.py
|
||||
📦 legacy/get_state_info_usecase.py
|
||||
📦 legacy/initialize_env_usecase.py
|
||||
📦 legacy/load_q_table_usecase.py
|
||||
📦 legacy/train_agent_usecase.py
|
||||
📦 legacy/update_q_table_usecase.py
|
||||
```
|
||||
|
||||
## 📁 최종 디렉토리 구조
|
||||
|
||||
```
|
||||
negotiation_agent/
|
||||
│
|
||||
├── card_management/ # 카드 관리 (DB 기반)
|
||||
│ ├── domain/
|
||||
│ │ ├── model/
|
||||
│ │ │ └── nego_card.py ✅ DB 엔티티
|
||||
│ │ ├── repository/
|
||||
│ │ │ ├── nego_card_repository.py
|
||||
│ │ │ └── nego_card_script_repository.py ✅ JSON CRUD
|
||||
│ │ └── value/
|
||||
│ │ └── nego_card_types.py ✅ 6가지 atomic elements
|
||||
│ └── data/
|
||||
│ └── nego_card_scripts.json ✅ 21개 카드 스크립트
|
||||
│
|
||||
├── Q_Table/ # Q-Learning (추상 action_id)
|
||||
│ ├── domain/
|
||||
│ │ ├── model/
|
||||
│ │ │ ├── state.py ✅ to_index() 추가
|
||||
│ │ │ ├── q_table.py ✅ 새로 생성
|
||||
│ │ │ └── experience.py ✅ 새로 생성
|
||||
│ │ ├── agents/
|
||||
│ │ │ ├── policy.py ✅ 동적 크기 수정
|
||||
│ │ │ └── offline_agent.py
|
||||
│ │ ├── repository/
|
||||
│ │ │ ├── action_repository.py
|
||||
│ │ │ └── experience_repository.py ✅ 새로 생성
|
||||
│ │ └── service/
|
||||
│ │ └── state_calculator.py
|
||||
│ ├── usecase/
|
||||
│ │ ├── get_best_action_usecase.py ✅ 완전 재작성
|
||||
│ │ ├── collect_experience_usecase.py ✅ 새로 생성
|
||||
│ │ └── train_offline_usecase.py ✅ 새로 생성
|
||||
│ ├── infra/
|
||||
│ │ ├── data_collector.py
|
||||
│ │ └── gym/
|
||||
│ │ └── env_wrapper.py
|
||||
│ ├── data/ 🆕 데이터 저장소
|
||||
│ │ └── experiences/
|
||||
│ │ └── *.jsonl
|
||||
│ └── legacy/ 📦 레거시 코드 보관
|
||||
│ ├── README.md
|
||||
│ └── ... (10개 파일)
|
||||
│
|
||||
└── integration/ 🆕 매핑 레이어
|
||||
├── action_card_mapper.py ✅
|
||||
└── data/
|
||||
└── action_card_mapping.json ✅
|
||||
```
|
||||
|
||||
## 🎯 핵심 설계 원칙
|
||||
|
||||
### 1. 명확한 책임 분리
|
||||
|
||||
- **Q_Table**: action_id (0~20)만 다룸
|
||||
- **Card_Management**: card_id ("no_0"~"no_20")만 다룸
|
||||
- **ActionCardMapper**: 둘을 연결하는 단일 책임
|
||||
|
||||
### 2. 동적 확장성
|
||||
|
||||
- 카드 추가/삭제 시 `action_card_mapping.json`만 수정
|
||||
- Q-Table과 Policy가 자동으로 새 크기에 대응
|
||||
|
||||
### 3. 유지보수성
|
||||
|
||||
- 레거시 코드는 legacy/ 폴더에 보관
|
||||
- 새로운 코드와 명확히 분리
|
||||
|
||||
## 🔄 사용 예시
|
||||
|
||||
### 1. 추론 (Inference)
|
||||
|
||||
```python
|
||||
from negotiation_agent.Q_Table.domain.model.q_table import QTable
|
||||
from negotiation_agent.Q_Table.domain.model.visit_table import VisitTable
|
||||
from negotiation_agent.Q_Table.domain.agents.policy import UCBPolicy
|
||||
from negotiation_agent.integration.action_card_mapper import ActionCardMapper
|
||||
from negotiation_agent.Q_Table.usecase.get_best_action_usecase import GetBestActionUsecase
|
||||
|
||||
# 매퍼 로드
|
||||
mapper = ActionCardMapper()
|
||||
action_space_size = mapper.get_action_space_size() # 21
|
||||
|
||||
# Q-Table 생성 (Version 3.0: 162 states)
|
||||
q_table = QTable(
|
||||
state_space_size=162, # 3 x 3 x 3 x 3 x 2
|
||||
action_space_size=action_space_size # 21
|
||||
)
|
||||
|
||||
# Policy 생성
|
||||
visit_table = VisitTable(state_space_size=162, action_space_size=action_space_size)
|
||||
policy = UCBPolicy(visit_table=visit_table)
|
||||
|
||||
# Usecase 생성
|
||||
usecase = GetBestActionUsecase(
|
||||
q_table=q_table,
|
||||
policy=policy,
|
||||
action_card_mapper=mapper
|
||||
)
|
||||
|
||||
# 협상 추론 (Version 3.0)
|
||||
from negotiation_agent.Q_Table.domain.model.state import (
|
||||
State,
|
||||
RevenueRange,
|
||||
DistributionStructure,
|
||||
PartnerType,
|
||||
AcceptanceRate,
|
||||
InputPriceZone,
|
||||
)
|
||||
|
||||
state = State(
|
||||
revenue_range=RevenueRange.MID, # 1,000~3,000만원
|
||||
distribution=DistributionStructure.WHOLESALER, # 총판
|
||||
partner_type=PartnerType.SINGLE, # 단독 파트너
|
||||
acceptance_rate=AcceptanceRate.MID, # 30~90%
|
||||
input_price_zone=InputPriceZone.PZ1, # A≤P≤T
|
||||
)
|
||||
|
||||
result = usecase.execute(state)
|
||||
print(result)
|
||||
# Output:
|
||||
# {
|
||||
# 'action_id': 5,
|
||||
# 'card_id': 'no_5',
|
||||
# 'q_value': 0.0,
|
||||
# 'state_index': 1
|
||||
# }
|
||||
```
|
||||
|
||||
### 2. Experience 수집
|
||||
|
||||
```python
|
||||
from negotiation_agent.Q_Table.domain.repository.experience_repository import ExperienceRepository
|
||||
from negotiation_agent.Q_Table.usecase.collect_experience_usecase import CollectExperienceUsecase
|
||||
|
||||
# Repository 생성
|
||||
exp_repo = ExperienceRepository()
|
||||
|
||||
# Usecase 생성
|
||||
collect_usecase = CollectExperienceUsecase(exp_repo)
|
||||
|
||||
# 에피소드 시작
|
||||
collect_usecase.start_episode("ep_001")
|
||||
|
||||
# 협상 진행 중...
|
||||
for step in range(10):
|
||||
# 현재 상태
|
||||
current_state = State(...)
|
||||
|
||||
# 액션 선택 (GetBestActionUsecase 사용)
|
||||
result = usecase.execute(current_state)
|
||||
action_id = result['action_id']
|
||||
|
||||
# 액션 실행 후 보상과 다음 상태 받음
|
||||
reward = calculate_reward(...) # 보상 계산
|
||||
next_state = get_next_state(...) # 다음 상태
|
||||
done = check_done(...) # 종료 여부
|
||||
|
||||
# Experience 수집
|
||||
collect_usecase.collect(
|
||||
state=current_state,
|
||||
action_id=action_id,
|
||||
reward=reward,
|
||||
next_state=next_state,
|
||||
done=done
|
||||
)
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
# 에피소드 종료
|
||||
collect_usecase.end_episode()
|
||||
|
||||
# 수집 정보 확인
|
||||
info = collect_usecase.get_collection_info()
|
||||
print(info)
|
||||
# {
|
||||
# 'total_experiences': 10,
|
||||
# 'episodes': 1,
|
||||
# 'current_episode': None
|
||||
# }
|
||||
```
|
||||
|
||||
### 3. 오프라인 학습
|
||||
|
||||
```python
|
||||
from negotiation_agent.Q_Table.usecase.train_offline_usecase import TrainOfflineUsecase
|
||||
|
||||
# Usecase 생성
|
||||
train_usecase = TrainOfflineUsecase(
|
||||
q_table=q_table,
|
||||
experience_repository=exp_repo
|
||||
)
|
||||
|
||||
# 학습 실행
|
||||
result = train_usecase.train(
|
||||
filename="experiences.jsonl",
|
||||
epochs=10,
|
||||
batch_size=32
|
||||
)
|
||||
|
||||
print(result)
|
||||
# {
|
||||
# 'total_experiences': 1000,
|
||||
# 'epochs': 10,
|
||||
# 'updates': 10000,
|
||||
# 'avg_loss': 0.05
|
||||
# }
|
||||
|
||||
# 성능 평가
|
||||
eval_result = train_usecase.evaluate(filename="experiences.jsonl")
|
||||
print(eval_result)
|
||||
# {
|
||||
# 'avg_q_value': 0.5,
|
||||
# 'avg_reward': 1.0,
|
||||
# 'total_samples': 1000
|
||||
# }
|
||||
```
|
||||
|
||||
## 📋 향후 작업 (선택사항)
|
||||
|
||||
- [x] Experience 수집 Usecase 구현 ✅
|
||||
- [x] 오프라인 학습 Usecase 구현 ✅
|
||||
- [ ] Q-Table Repository 구현 (영속성)
|
||||
- [ ] 카드 추가 시 Q-Table 자동 확장 기능
|
||||
- [ ] 통계적 Q-value 초기화 기능
|
||||
|
||||
## 📚 참고 문서
|
||||
|
||||
- 설계 문서: `REFACTORING_GUIDE.md`
|
||||
- 레거시 코드: `legacy/README.md`
|
||||
|
|
@ -0,0 +1,279 @@
|
|||
# Q-Table Version 3.0 변경사항 정리
|
||||
|
||||
## 📋 변경 개요
|
||||
|
||||
CHANGELOG.md에 기술된 Version 3.0 업데이트를 코드에 반영했습니다.
|
||||
- **State 공간**: 36 states → **162 states**
|
||||
- **변수 구조**: 3개 변수 → **5개 변수** (비즈니스 용어 적용)
|
||||
- **동적 가중치**: 2변수 평균 → **5변수 가중합**
|
||||
|
||||
---
|
||||
|
||||
## 🔄 주요 변경 파일
|
||||
|
||||
### 1. State 모델 (`domain/model/state.py`)
|
||||
|
||||
#### 기존 (Version 2.0)
|
||||
```python
|
||||
class Scenario(IntEnum): ... # 4개 값
|
||||
class PriceZone(IntEnum): ... # 3개 값
|
||||
class AcceptanceRate(IntEnum): ... # 3개 값
|
||||
|
||||
State = (Scenario, PriceZone, AcceptanceRate) # 4×3×3 = 36
|
||||
```
|
||||
|
||||
#### 변경 (Version 3.0)
|
||||
```python
|
||||
class RevenueRange(IntEnum): ... # 3개 값 (매출액 가격구간)
|
||||
class DistributionStructure(IntEnum): ... # 3개 값 (유통 구조)
|
||||
class PartnerType(IntEnum): ... # 3개 값 (파트너사 종류)
|
||||
class AcceptanceRate(IntEnum): ... # 3개 값 (가격 수용률 구간)
|
||||
class InputPriceZone(IntEnum): ... # 2개 값 (입력 금액 구간)
|
||||
|
||||
State = (RevenueRange, DistributionStructure, PartnerType,
|
||||
AcceptanceRate, InputPriceZone) # 3×3×3×3×2 = 162
|
||||
```
|
||||
|
||||
**주요 특징:**
|
||||
- 모든 클래스에 `.weight` 속성 추가 (동적 가중치 계산용)
|
||||
- 비즈니스 용어로 명명 (매출액, 유통 구조, 파트너사 등)
|
||||
- 각 변수마다 `from_*()` 클래스 메서드로 분류 로직 제공
|
||||
|
||||
---
|
||||
|
||||
### 2. Reward Calculator (`domain/service/reward_calculator.py`)
|
||||
|
||||
#### 기존 (Version 2.0)
|
||||
```python
|
||||
def calculate_reward(
|
||||
scenario: Scenario,
|
||||
price_zone: PriceZone,
|
||||
...
|
||||
)
|
||||
|
||||
def _calculate_weight(scenario, price_zone, cfg):
|
||||
raw = (scenario.priority_weight + price_zone.zone_weight) / 2.0
|
||||
return clip(raw, 0.2, 0.8)
|
||||
```
|
||||
|
||||
#### 변경 (Version 3.0)
|
||||
```python
|
||||
def calculate_reward(
|
||||
revenue_range: RevenueRange,
|
||||
distribution: DistributionStructure,
|
||||
partner_type: PartnerType,
|
||||
acceptance_rate: AcceptanceRate,
|
||||
input_price_zone: InputPriceZone,
|
||||
...
|
||||
)
|
||||
|
||||
def _calculate_weight(...):
|
||||
W_raw = (
|
||||
config.w1 * revenue_range.weight +
|
||||
config.w2 * distribution.weight +
|
||||
config.w3 * partner_type.weight +
|
||||
config.w4 * acceptance_rate.weight +
|
||||
config.w5 * input_price_zone.weight
|
||||
)
|
||||
return clip(W_raw, 0.2, 0.8)
|
||||
```
|
||||
|
||||
**기본 가중치 계수:**
|
||||
- w1 = 0.20 (매출액)
|
||||
- w2 = 0.25 (유통 구조)
|
||||
- w3 = 0.20 (파트너사)
|
||||
- w4 = 0.25 (수용률)
|
||||
- w5 = 0.10 (입력 금액)
|
||||
|
||||
---
|
||||
|
||||
### 3. State Calculator (`domain/service/state_calculator.py`)
|
||||
|
||||
#### 기존 (Version 2.0)
|
||||
```python
|
||||
@dataclass
|
||||
class NegotiationSnapshot:
|
||||
scenario_code: str
|
||||
anchor_price: float
|
||||
target_price: float
|
||||
seller_initial_price: float
|
||||
current_price: float
|
||||
```
|
||||
|
||||
#### 변경 (Version 3.0)
|
||||
```python
|
||||
@dataclass
|
||||
class NegotiationSnapshot:
|
||||
revenue_amount: float # 매출액 (만원)
|
||||
distribution_code: str # "M", "W", "R"
|
||||
partner_count: int # 파트너사 수
|
||||
anchor_price: float
|
||||
target_price: float
|
||||
input_price: float
|
||||
acceptance_ratio: Optional[float]
|
||||
initial_price: Optional[float]
|
||||
```
|
||||
|
||||
**변경 이유:**
|
||||
- 실제 비즈니스 데이터에 맞춰 필드 재구성
|
||||
- 5개 State 변수를 도출할 수 있는 정보 제공
|
||||
|
||||
---
|
||||
|
||||
### 4. 모듈 Export (`domain/model/__init__.py`)
|
||||
|
||||
```python
|
||||
# Version 3.0
|
||||
from .state import (
|
||||
AcceptanceRate,
|
||||
DistributionStructure,
|
||||
InputPriceZone,
|
||||
PartnerType,
|
||||
RevenueRange,
|
||||
State,
|
||||
)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📊 State 변수 상세 명세
|
||||
|
||||
| 변수 | 클래스 | 값 | 설명 | 가중치 |
|
||||
|------|--------|---|------|--------|
|
||||
| 매출액 가격구간 | `RevenueRange` | LOW (≤1,000만원) | 낮은 매출액 | 0.3 |
|
||||
| | | MID (1,000~3,000만원) | 중간 매출액 | 0.6 |
|
||||
| | | HIGH (>3,000만원) | 높은 매출액 | 1.0 |
|
||||
| 유통 구조 | `DistributionStructure` | MANUFACTURER | 제조사 직공급 | 0.2 |
|
||||
| | | WHOLESALER | 총판 경유 | 0.5 |
|
||||
| | | RETAILER | 유통 경유 | 1.0 |
|
||||
| 파트너사 종류 | `PartnerType` | NONE | 파트너 없음 | 0.3 |
|
||||
| | | SINGLE | 단독 파트너 | 0.5 |
|
||||
| | | MULTIPLE | 다수 파트너 | 1.0 |
|
||||
| 가격 수용률 | `AcceptanceRate` | LOW (<30%) | 낮은 수용률 | 0.3 |
|
||||
| | | MID (30~90%) | 중간 수용률 | 0.6 |
|
||||
| | | HIGH (>90%) | 높은 수용률 | 1.0 |
|
||||
| 입력 금액 구간 | `InputPriceZone` | PZ1 (A≤P≤T) | 목표 범위 내 | 1.0 |
|
||||
| | | PZ2 (P>T or P<A) | 목표 범위 밖 | 0.3 |
|
||||
|
||||
---
|
||||
|
||||
## 🧪 새로운 테스트 파일
|
||||
|
||||
### `tests/Q_Table/test_state_v3.py`
|
||||
- State 변수 분류 로직 검증
|
||||
- State ↔ index 변환 정합성 테스트
|
||||
- 162개 state 전체 순회 검증
|
||||
- 가중치 값 검증
|
||||
|
||||
### `tests/Q_Table/test_reward_v3.py`
|
||||
- NegotiationSnapshot → State 변환 테스트
|
||||
- 동적 가중치 계산 검증
|
||||
- 성공/실패/진행중 보상 계산 검증
|
||||
- 최소/최대 가중치 경계 케이스 테스트
|
||||
|
||||
---
|
||||
|
||||
## 📚 문서 업데이트
|
||||
|
||||
### `REFACTORING_SUMMARY.md`
|
||||
- Version 3.0 변경사항 추가
|
||||
- 사용 예시 업데이트 (36 → 162 states)
|
||||
- State 구조 비교표 추가
|
||||
|
||||
### `MIGRATION_V3.md` (신규)
|
||||
- Version 2.0 → 3.0 마이그레이션 가이드
|
||||
- Before/After 코드 비교
|
||||
- 호환성 주의사항
|
||||
- 체크리스트 제공
|
||||
|
||||
### `CHANGELOG.md` (기존 파일 참조)
|
||||
- Version 3.0 변경사항 문서화
|
||||
- 비즈니스 용어 적용 배경 설명
|
||||
|
||||
---
|
||||
|
||||
## ⚠️ Breaking Changes
|
||||
|
||||
### 1. Q-Table 크기 변경
|
||||
```python
|
||||
# Before
|
||||
QTable(state_space_size=36, action_space_size=21)
|
||||
|
||||
# After
|
||||
QTable(state_space_size=162, action_space_size=21)
|
||||
```
|
||||
|
||||
### 2. State 생성 방식 변경
|
||||
```python
|
||||
# Before
|
||||
State(
|
||||
scenario=Scenario.PRICE_FIRST,
|
||||
price_zone=PriceZone.AT_OR_BELOW_ANCHOR,
|
||||
acceptance_rate=AcceptanceRate.MEDIUM,
|
||||
)
|
||||
|
||||
# After
|
||||
State(
|
||||
revenue_range=RevenueRange.MID,
|
||||
distribution=DistributionStructure.WHOLESALER,
|
||||
partner_type=PartnerType.SINGLE,
|
||||
acceptance_rate=AcceptanceRate.MID,
|
||||
input_price_zone=InputPriceZone.PZ1,
|
||||
)
|
||||
```
|
||||
|
||||
### 3. 기존 학습 모델 호환 불가
|
||||
- Version 2.0 모델 파일(.npy)은 36×N 크기
|
||||
- Version 3.0은 162×N 크기로 로드 불가
|
||||
- **재학습 필수**
|
||||
|
||||
---
|
||||
|
||||
## ✅ 검증 항목
|
||||
|
||||
- [x] State 클래스 5개 변수로 재정의
|
||||
- [x] 각 변수의 `.weight` 속성 구현
|
||||
- [x] State.to_index() / from_index() 162 states 대응
|
||||
- [x] RewardConfig에 w1~w5 가중치 계수 추가
|
||||
- [x] _calculate_weight() 5변수 가중합으로 변경
|
||||
- [x] NegotiationSnapshot 필드 재구성
|
||||
- [x] build_state() 5변수 매핑 로직 구현
|
||||
- [x] __init__.py export 업데이트
|
||||
- [x] REFACTORING_SUMMARY.md 업데이트
|
||||
- [x] MIGRATION_V3.md 작성
|
||||
- [x] 단위 테스트 작성 (test_state_v3.py, test_reward_v3.py)
|
||||
|
||||
---
|
||||
|
||||
## 🚀 다음 단계
|
||||
|
||||
1. **Python 환경 설정**
|
||||
```bash
|
||||
poetry install
|
||||
# or
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. **테스트 실행**
|
||||
```bash
|
||||
pytest tests/Q_Table/test_state_v3.py -v
|
||||
pytest tests/Q_Table/test_reward_v3.py -v
|
||||
```
|
||||
|
||||
3. **기존 코드 마이그레이션**
|
||||
- `MIGRATION_V3.md` 가이드 참조
|
||||
- State 생성 코드 일괄 변경
|
||||
- Q-Table 초기화 크기 변경
|
||||
|
||||
4. **모델 재학습**
|
||||
- 새로운 162-state 공간으로 학습 데이터 재수집
|
||||
- Q-Table 재학습 실행
|
||||
|
||||
---
|
||||
|
||||
## 📞 문의
|
||||
|
||||
Version 3.0 적용 중 문제 발생 시:
|
||||
1. `MIGRATION_V3.md` 참조
|
||||
2. 테스트 코드 참조 (`test_state_v3.py`, `test_reward_v3.py`)
|
||||
3. 팀 리드에게 문의
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
"""Q-Table package exports."""
|
||||
|
||||
from . import domain, usecase # noqa: F401
|
||||
File diff suppressed because it is too large
Load Diff
Binary file not shown.
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,120 @@
|
|||
"""Offline Q-learning agent backed by a UCB policy."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..model.visit_table import VisitTable
|
||||
from ..model.q_table import QTable
|
||||
from .policy import UCBPolicy
|
||||
|
||||
|
||||
class QLearningAgent:
|
||||
def __init__(
|
||||
self,
|
||||
agent_params,
|
||||
state_size: int,
|
||||
action_size: int,
|
||||
visit_table: Optional[VisitTable] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
agent_params: 에이전트 파라미터 (learning_rate, discount_factor 등)
|
||||
state_size: 상태 공간 크기
|
||||
action_size: 액션 공간 크기
|
||||
visit_table: 방문 기록 테이블 (None이면 새로 생성)
|
||||
"""
|
||||
self.state_size = state_size
|
||||
self.action_size = action_size
|
||||
|
||||
# Q-Table 객체 생성 (Composition)
|
||||
self.q_table = QTable(
|
||||
state_space_size=state_size,
|
||||
action_space_size=action_size,
|
||||
learning_rate=agent_params["learning_rate"],
|
||||
discount_factor=agent_params["discount_factor"]
|
||||
)
|
||||
|
||||
self.visit_table = visit_table or VisitTable(state_size, action_size)
|
||||
self.policy = UCBPolicy(
|
||||
visit_table=self.visit_table,
|
||||
exploration_constant=agent_params.get("exploration_constant", np.sqrt(2.0)),
|
||||
)
|
||||
|
||||
def get_action(self, state: int, action_mask=None):
|
||||
"""
|
||||
현재 상태에서 액션 선택 (UCB 정책 사용)
|
||||
|
||||
Args:
|
||||
state: 현재 상태 인덱스
|
||||
action_mask: 가능한 액션 마스크 (Optional)
|
||||
|
||||
Returns:
|
||||
선택된 액션 ID
|
||||
"""
|
||||
# QTable 객체에서 Q-value 조회
|
||||
q_values = self.q_table.get_q_values_for_state(state)
|
||||
mask = None if action_mask is None else np.asarray(action_mask, dtype=bool)
|
||||
return self.policy.select_action(state, q_values, available_mask=mask)
|
||||
|
||||
def learn(self, batch):
|
||||
"""
|
||||
배치 데이터를 사용하여 Q-Table 업데이트
|
||||
|
||||
Args:
|
||||
batch: 학습 데이터 배치 (observations, actions, rewards, next_observations, terminals)
|
||||
"""
|
||||
for state, action, reward, next_state, terminated in zip(
|
||||
batch["observations"],
|
||||
batch["actions"],
|
||||
batch["rewards"],
|
||||
batch["next_observations"],
|
||||
batch["terminals"],
|
||||
):
|
||||
# QTable 객체의 update 메서드 사용
|
||||
# terminated가 True이면 next_state는 의미가 없거나 None이어야 함
|
||||
# QTable.update는 next_state_index가 None이면 종료 상태로 처리
|
||||
|
||||
next_s = next_state if not terminated else None
|
||||
|
||||
self.q_table.update(
|
||||
state_index=state,
|
||||
action_id=action,
|
||||
reward=reward,
|
||||
next_state_index=next_s,
|
||||
done=terminated
|
||||
)
|
||||
|
||||
self.visit_table.increment(state, action)
|
||||
|
||||
def save_model(self, path):
|
||||
"""
|
||||
Q-Table을 파일로 저장
|
||||
|
||||
Args:
|
||||
path: 저장할 파일 경로
|
||||
"""
|
||||
# QTable 내부의 numpy array를 저장 (기존 호환성 유지)
|
||||
np.save(path, self.q_table.q_values)
|
||||
print(f"Q-Table saved to {path}")
|
||||
|
||||
def load_q_table(self, file_path):
|
||||
"""
|
||||
파일에서 Q-Table 로드
|
||||
|
||||
Args:
|
||||
file_path: 로드할 파일 경로
|
||||
"""
|
||||
if os.path.exists(file_path):
|
||||
# QTable 객체의 q_values 속성에 직접 할당
|
||||
self.q_table.q_values = np.load(file_path)
|
||||
print(f"Q-Table loaded from {file_path}")
|
||||
else:
|
||||
print(f"Error: No Q-Table found at {file_path}")
|
||||
|
||||
def reset_episode(self):
|
||||
"""에피소드 초기화 (정책 상태 등 리셋)"""
|
||||
self.policy.reset_episode()
|
||||
|
|
@ -0,0 +1,123 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..model.visit_table import VisitTable
|
||||
|
||||
|
||||
class Policy(ABC):
|
||||
@abstractmethod
|
||||
def select_action(
|
||||
self,
|
||||
state_index: int,
|
||||
q_values: np.ndarray,
|
||||
available_mask: Optional[np.ndarray] = None,
|
||||
) -> Optional[int]:
|
||||
"""
|
||||
주어진 상태와 Q-value를 기반으로 액션 선택
|
||||
|
||||
Args:
|
||||
state_index: 현재 상태 인덱스
|
||||
q_values: 해당 상태의 Q-value 배열
|
||||
available_mask: 선택 가능한 액션 마스크 (True=선택가능)
|
||||
|
||||
Returns:
|
||||
선택된 액션 ID (선택 불가 시 None)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def reset_episode(self) -> None:
|
||||
"""에피소드 시작 시 상태 초기화"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_action_mask(self) -> np.ndarray:
|
||||
"""
|
||||
현재 정책 상태에 따른 액션 마스크 반환
|
||||
|
||||
Returns:
|
||||
액션 마스크 배열 (1=선택가능, 0=선택불가)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class UCBPolicy(Policy):
|
||||
"""Upper Confidence Bound policy with per-episode action masking."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
visit_table: VisitTable,
|
||||
exploration_constant: float = np.sqrt(2.0),
|
||||
rng: Optional[np.random.Generator] = None,
|
||||
) -> None:
|
||||
self.visit_table = visit_table
|
||||
self.action_space_size = visit_table.action_space_size
|
||||
self.exploration_constant = exploration_constant
|
||||
self._rng = rng or np.random.default_rng()
|
||||
self._episode_actions: set[int] = set()
|
||||
|
||||
def select_action(
|
||||
self,
|
||||
state_index: int,
|
||||
q_values: np.ndarray,
|
||||
available_mask: Optional[np.ndarray] = None,
|
||||
) -> Optional[int]:
|
||||
"""
|
||||
UCB 알고리즘을 사용하여 액션 선택
|
||||
|
||||
UCB = Q(s,a) + c * sqrt(ln(N(s)) / N(s,a))
|
||||
"""
|
||||
mask = self._prepare_mask(available_mask)
|
||||
if not mask.any():
|
||||
self.reset_episode()
|
||||
mask = self._prepare_mask(available_mask)
|
||||
if not mask.any():
|
||||
return None
|
||||
|
||||
counts = self.visit_table.get_state_counts(state_index)
|
||||
masked_counts = np.where(mask, counts, 0)
|
||||
|
||||
zero_visit_candidates = np.where((counts == 0) & mask)[0]
|
||||
if zero_visit_candidates.size > 0:
|
||||
best = zero_visit_candidates[np.argmax(q_values[zero_visit_candidates])]
|
||||
action = int(best)
|
||||
else:
|
||||
total = masked_counts.sum()
|
||||
# Avoid division by zero while keeping exploration pressure.
|
||||
denom = counts.astype(float) + 1e-9
|
||||
bonus = self.exploration_constant * np.sqrt(np.log(total + 1.0) / denom)
|
||||
scores = q_values + bonus
|
||||
scores[~mask] = -np.inf
|
||||
action = int(np.argmax(scores))
|
||||
|
||||
self.visit_table.increment(state_index, action)
|
||||
self._episode_actions.add(action)
|
||||
return action
|
||||
|
||||
def reset_episode(self) -> None:
|
||||
"""에피소드 내 사용된 액션 기록 초기화"""
|
||||
self._episode_actions.clear()
|
||||
|
||||
def get_action_mask(self) -> np.ndarray:
|
||||
"""
|
||||
이미 사용된 액션을 제외한 마스크 반환
|
||||
|
||||
Returns:
|
||||
마스크 배열 (1=선택가능, 0=선택불가)
|
||||
"""
|
||||
mask = np.ones(self.action_space_size, dtype=int)
|
||||
if self._episode_actions:
|
||||
for action in self._episode_actions:
|
||||
mask[action] = 0
|
||||
return mask
|
||||
|
||||
def _prepare_mask(self, available_mask: Optional[np.ndarray]) -> np.ndarray:
|
||||
"""입력 마스크와 이미 사용된 액션을 결합하여 최종 마스크 생성"""
|
||||
if available_mask is None:
|
||||
mask = np.ones(self.action_space_size, dtype=bool)
|
||||
else:
|
||||
mask = np.asarray(available_mask, dtype=bool).copy()
|
||||
if self._episode_actions:
|
||||
for action in self._episode_actions:
|
||||
mask[action] = False
|
||||
return mask
|
||||
|
|
@ -0,0 +1,11 @@
|
|||
"""Domain model exports for Q-Table (Version 3.0 - 162 states)."""
|
||||
|
||||
from .state import ( # noqa: F401
|
||||
AcceptanceRate,
|
||||
DistributionStructure,
|
||||
InputPriceZone,
|
||||
PartnerType,
|
||||
RevenueRange,
|
||||
State,
|
||||
)
|
||||
from .visit_table import VisitTable # noqa: F401
|
||||
|
|
@ -0,0 +1,94 @@
|
|||
"""
|
||||
Experience 모델
|
||||
Q-Learning을 위한 경험 데이터 (SARS: State, Action, Reward, next_State)
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from .state import State
|
||||
|
||||
|
||||
@dataclass
|
||||
class Experience:
|
||||
"""
|
||||
Q-Learning에서 사용하는 경험 데이터
|
||||
|
||||
SARS 형식:
|
||||
- State: 현재 상태
|
||||
- Action: 선택한 액션 (action_id)
|
||||
- Reward: 받은 보상
|
||||
- next_State: 다음 상태 (종료 시 None)
|
||||
"""
|
||||
|
||||
state: State
|
||||
action_id: int
|
||||
reward: float
|
||||
next_state: Optional[State]
|
||||
done: bool # 에피소드 종료 여부
|
||||
|
||||
# 메타데이터 (선택사항)
|
||||
episode_id: Optional[str] = None
|
||||
step: Optional[int] = None
|
||||
timestamp: Optional[str] = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""
|
||||
Experience를 딕셔너리로 직렬화
|
||||
|
||||
Returns:
|
||||
{
|
||||
'state': [scenario, price_zone, acceptance_rate],
|
||||
'action_id': 5,
|
||||
'reward': 1.0,
|
||||
'next_state': [scenario, price_zone, acceptance_rate] or None,
|
||||
'done': False,
|
||||
'episode_id': 'ep_001',
|
||||
'step': 3,
|
||||
'timestamp': '2025-10-29T16:30:00'
|
||||
}
|
||||
"""
|
||||
return {
|
||||
"state": self.state.to_array(),
|
||||
"action_id": self.action_id,
|
||||
"reward": self.reward,
|
||||
"next_state": self.next_state.to_array() if self.next_state else None,
|
||||
"done": self.done,
|
||||
"episode_id": self.episode_id,
|
||||
"step": self.step,
|
||||
"timestamp": self.timestamp,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> "Experience":
|
||||
"""
|
||||
딕셔너리에서 Experience 복원
|
||||
|
||||
Args:
|
||||
data: 직렬화된 Experience 데이터
|
||||
|
||||
Returns:
|
||||
Experience 인스턴스
|
||||
"""
|
||||
return cls(
|
||||
state=State.from_array(data["state"]),
|
||||
action_id=data["action_id"],
|
||||
reward=data["reward"],
|
||||
next_state=(
|
||||
State.from_array(data["next_state"]) if data["next_state"] else None
|
||||
),
|
||||
done=data["done"],
|
||||
episode_id=data.get("episode_id"),
|
||||
step=data.get("step"),
|
||||
timestamp=data.get("timestamp"),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"Experience("
|
||||
f"state={self.state.to_array()}, "
|
||||
f"action_id={self.action_id}, "
|
||||
f"reward={self.reward}, "
|
||||
f"next_state={self.next_state.to_array() if self.next_state else None}, "
|
||||
f"done={self.done}"
|
||||
f")"
|
||||
)
|
||||
|
|
@ -0,0 +1,209 @@
|
|||
"""
|
||||
Q-Table 모델
|
||||
동적 action_space_size를 지원하여 협상 카드 추가/삭제에 대응
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from typing import Optional, Dict
|
||||
|
||||
|
||||
class QTable:
|
||||
"""
|
||||
Q-Learning을 위한 Q-Table
|
||||
|
||||
Q-Table은 추상적인 action_id만 다루며,
|
||||
실제 협상 카드(card_id)는 ActionCardMapper에서 매핑
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
state_space_size: int,
|
||||
action_space_size: int,
|
||||
learning_rate: float = 0.1,
|
||||
discount_factor: float = 0.95,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
state_space_size: 상태 공간 크기
|
||||
(Scenario=4 x PriceZone=3 x AcceptanceRate=3 = 36)
|
||||
action_space_size: 액션 공간 크기 (협상 카드 개수, 현재 21개)
|
||||
learning_rate: 학습률 (alpha)
|
||||
discount_factor: 할인 인자 (gamma)
|
||||
"""
|
||||
self.state_space_size = state_space_size
|
||||
self.action_space_size = action_space_size
|
||||
self.learning_rate = learning_rate
|
||||
self.discount_factor = discount_factor
|
||||
|
||||
# Q-Table 초기화: (state_size, action_size) 형태
|
||||
self.q_values = np.zeros((state_space_size, action_space_size))
|
||||
|
||||
def get_q_value(self, state_index: int, action_id: int) -> float:
|
||||
"""
|
||||
특정 (state, action)의 Q-value 조회
|
||||
|
||||
Args:
|
||||
state_index: State의 인덱스 (0 ~ state_space_size-1)
|
||||
action_id: Action ID (0 ~ action_space_size-1)
|
||||
|
||||
Returns:
|
||||
Q-value
|
||||
"""
|
||||
self._validate_indices(state_index, action_id)
|
||||
return float(self.q_values[state_index, action_id])
|
||||
|
||||
def get_q_values_for_state(self, state_index: int) -> np.ndarray:
|
||||
"""
|
||||
특정 state의 모든 action에 대한 Q-values
|
||||
|
||||
Args:
|
||||
state_index: State의 인덱스
|
||||
|
||||
Returns:
|
||||
shape (action_space_size,)의 Q-values 배열
|
||||
"""
|
||||
if state_index < 0 or state_index >= self.state_space_size:
|
||||
raise ValueError(
|
||||
f"state_index {state_index} out of range [0, {self.state_space_size})"
|
||||
)
|
||||
return self.q_values[state_index, :].copy()
|
||||
|
||||
def set_q_value(self, state_index: int, action_id: int, value: float):
|
||||
"""
|
||||
특정 (state, action)의 Q-value 직접 설정
|
||||
|
||||
Args:
|
||||
state_index: State의 인덱스
|
||||
action_id: Action ID
|
||||
value: 설정할 Q-value
|
||||
"""
|
||||
self._validate_indices(state_index, action_id)
|
||||
self.q_values[state_index, action_id] = value
|
||||
|
||||
def update(
|
||||
self,
|
||||
state_index: int,
|
||||
action_id: int,
|
||||
reward: float,
|
||||
next_state_index: Optional[int] = None,
|
||||
done: bool = False,
|
||||
):
|
||||
"""
|
||||
Q-Learning 업데이트
|
||||
|
||||
Q(s,a) ← Q(s,a) + α[r + γ·max_a'Q(s',a') - Q(s,a)]
|
||||
|
||||
Args:
|
||||
state_index: 현재 상태 인덱스
|
||||
action_id: 선택한 액션 ID
|
||||
reward: 받은 보상
|
||||
next_state_index: 다음 상태 인덱스 (종료 시 None)
|
||||
done: 에피소드 종료 여부
|
||||
"""
|
||||
self._validate_indices(state_index, action_id)
|
||||
|
||||
current_q = self.q_values[state_index, action_id]
|
||||
|
||||
if done or next_state_index is None:
|
||||
# 종료 상태: target = reward만 사용
|
||||
target = reward
|
||||
else:
|
||||
# 비종료 상태: target = reward + γ·max Q(s',a')
|
||||
if next_state_index < 0 or next_state_index >= self.state_space_size:
|
||||
raise ValueError(f"next_state_index {next_state_index} out of range")
|
||||
max_next_q = np.max(self.q_values[next_state_index, :])
|
||||
target = reward + self.discount_factor * max_next_q
|
||||
|
||||
# Q-Learning 업데이트
|
||||
self.q_values[state_index, action_id] += self.learning_rate * (
|
||||
target - current_q
|
||||
)
|
||||
|
||||
def get_best_action(self, state_index: int) -> int:
|
||||
"""
|
||||
해당 state에서 최고 Q-value를 가진 action_id 반환
|
||||
|
||||
Args:
|
||||
state_index: State의 인덱스
|
||||
|
||||
Returns:
|
||||
action_id (Q-value가 최대인 액션)
|
||||
"""
|
||||
if state_index < 0 or state_index >= self.state_space_size:
|
||||
raise ValueError(
|
||||
f"state_index {state_index} out of range [0, {self.state_space_size})"
|
||||
)
|
||||
return int(np.argmax(self.q_values[state_index, :]))
|
||||
|
||||
def get_best_actions_with_ties(self, state_index: int) -> list:
|
||||
"""
|
||||
동일한 최대 Q-value를 가진 모든 action_id 반환
|
||||
|
||||
Args:
|
||||
state_index: State의 인덱스
|
||||
|
||||
Returns:
|
||||
최대 Q-value를 가진 action_id들의 리스트
|
||||
"""
|
||||
if state_index < 0 or state_index >= self.state_space_size:
|
||||
raise ValueError(
|
||||
f"state_index {state_index} out of range [0, {self.state_space_size})"
|
||||
)
|
||||
|
||||
max_q = np.max(self.q_values[state_index, :])
|
||||
max_actions = np.where(self.q_values[state_index, :] == max_q)[0]
|
||||
return max_actions.tolist()
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""
|
||||
Q-Table을 딕셔너리로 직렬화
|
||||
|
||||
Returns:
|
||||
직렬화된 Q-Table 데이터
|
||||
"""
|
||||
return {
|
||||
"state_space_size": self.state_space_size,
|
||||
"action_space_size": self.action_space_size,
|
||||
"learning_rate": self.learning_rate,
|
||||
"discount_factor": self.discount_factor,
|
||||
"q_values": self.q_values.tolist(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> "QTable":
|
||||
"""
|
||||
딕셔너리에서 Q-Table 복원
|
||||
|
||||
Args:
|
||||
data: 직렬화된 Q-Table 데이터
|
||||
|
||||
Returns:
|
||||
복원된 QTable 인스턴스
|
||||
"""
|
||||
q_table = cls(
|
||||
state_space_size=data["state_space_size"],
|
||||
action_space_size=data["action_space_size"],
|
||||
learning_rate=data["learning_rate"],
|
||||
discount_factor=data["discount_factor"],
|
||||
)
|
||||
q_table.q_values = np.array(data["q_values"])
|
||||
return q_table
|
||||
|
||||
def _validate_indices(self, state_index: int, action_id: int):
|
||||
"""인덱스 유효성 검증"""
|
||||
if state_index < 0 or state_index >= self.state_space_size:
|
||||
raise ValueError(
|
||||
f"state_index {state_index} out of range [0, {self.state_space_size})"
|
||||
)
|
||||
if action_id < 0 or action_id >= self.action_space_size:
|
||||
raise ValueError(
|
||||
f"action_id {action_id} out of range [0, {self.action_space_size})"
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"QTable(state_space_size={self.state_space_size}, "
|
||||
f"action_space_size={self.action_space_size}, "
|
||||
f"learning_rate={self.learning_rate}, "
|
||||
f"discount_factor={self.discount_factor})"
|
||||
)
|
||||
|
|
@ -0,0 +1,258 @@
|
|||
from dataclasses import dataclass
|
||||
from enum import IntEnum, nonmember
|
||||
from typing import List
|
||||
|
||||
|
||||
class RevenueRange(IntEnum):
|
||||
"""매출액 가격구간 (Revenue Price Range)"""
|
||||
|
||||
LOW = 0 # ≤ 1,000만원
|
||||
MID = 1 # 1,000만원 ~ 3,000만원
|
||||
HIGH = 2 # > 3,000만원
|
||||
|
||||
_DESCRIPTIONS = nonmember(("Low (≤1,000만원)", "Mid (1,000~3,000만원)", "High (>3,000만원)"))
|
||||
_WEIGHTS = nonmember((0.3, 0.6, 1.0)) # 매출액이 높을수록 협상 여지 큼
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._DESCRIPTIONS[self.value]
|
||||
|
||||
@property
|
||||
def weight(self) -> float:
|
||||
return self._WEIGHTS[self.value]
|
||||
|
||||
@classmethod
|
||||
def from_amount(cls, amount: float) -> "RevenueRange":
|
||||
"""매출액(만원 단위)으로부터 구간 결정"""
|
||||
if amount <= 1000:
|
||||
return cls.LOW
|
||||
elif amount <= 3000:
|
||||
return cls.MID
|
||||
else:
|
||||
return cls.HIGH
|
||||
|
||||
|
||||
class DistributionStructure(IntEnum):
|
||||
"""유통 구조 (Distribution Structure)"""
|
||||
|
||||
MANUFACTURER = 0 # 제조
|
||||
WHOLESALER = 1 # 총판
|
||||
RETAILER = 2 # 유통
|
||||
|
||||
_DESCRIPTIONS = nonmember(("제조", "총판", "유통"))
|
||||
_WEIGHTS = nonmember((0.2, 0.5, 1.0)) # 유통 단계가 복잡할수록 협상 여지 큼
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._DESCRIPTIONS[self.value]
|
||||
|
||||
@property
|
||||
def weight(self) -> float:
|
||||
return self._WEIGHTS[self.value]
|
||||
|
||||
@classmethod
|
||||
def from_code(cls, code: str) -> "DistributionStructure":
|
||||
"""코드로부터 유통 구조 결정"""
|
||||
normalized = (code or "").strip().upper()
|
||||
mapping = {"M": cls.MANUFACTURER, "W": cls.WHOLESALER, "R": cls.RETAILER}
|
||||
if normalized not in mapping:
|
||||
raise ValueError(f"unknown distribution code: {code}")
|
||||
return mapping[normalized]
|
||||
|
||||
|
||||
class PartnerType(IntEnum):
|
||||
"""파트너사 종류 (Partner Type)"""
|
||||
|
||||
SINGLE = 0 # 단독
|
||||
MULTIPLE = 1 # 다수
|
||||
NONE = 2 # 없음
|
||||
|
||||
_DESCRIPTIONS = nonmember(("Single (단독)", "Multiple (다수)", "None (없음)"))
|
||||
_WEIGHTS = nonmember((0.5, 1.0, 0.3)) # 다수 파트너사일수록 협상 복잡도 증가
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._DESCRIPTIONS[self.value]
|
||||
|
||||
@property
|
||||
def weight(self) -> float:
|
||||
return self._WEIGHTS[self.value]
|
||||
|
||||
@classmethod
|
||||
def from_count(cls, count: int) -> "PartnerType":
|
||||
"""파트너사 수로부터 구간 결정"""
|
||||
if count == 0:
|
||||
return cls.NONE
|
||||
elif count == 1:
|
||||
return cls.SINGLE
|
||||
else:
|
||||
return cls.MULTIPLE
|
||||
|
||||
|
||||
class AcceptanceRate(IntEnum):
|
||||
"""가격 수용률 구간 (Price Acceptance Rate)"""
|
||||
|
||||
LOW = 0 # < 30%
|
||||
MID = 1 # 30% ~ 90%
|
||||
HIGH = 2 # > 90%
|
||||
|
||||
_DESCRIPTIONS = nonmember(("Low (<30%)", "Mid (30~90%)", "High (>90%)"))
|
||||
_WEIGHTS = nonmember((0.3, 0.6, 1.0)) # 수용률이 높을수록 협상 성공 가능성 높음
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._DESCRIPTIONS[self.value]
|
||||
|
||||
@property
|
||||
def weight(self) -> float:
|
||||
return self._WEIGHTS[self.value]
|
||||
|
||||
@classmethod
|
||||
def from_ratio(cls, ratio: float) -> "AcceptanceRate":
|
||||
"""수용률(0~1)로부터 구간 결정"""
|
||||
normalized = max(0.0, min(1.0, ratio))
|
||||
if normalized < 0.30:
|
||||
return cls.LOW
|
||||
elif normalized <= 0.90:
|
||||
return cls.MID
|
||||
else:
|
||||
return cls.HIGH
|
||||
|
||||
|
||||
class InputPriceZone(IntEnum):
|
||||
"""입력 금액 구간 (Input Price Zone)"""
|
||||
|
||||
PZ1 = 0 # A ≤ P ≤ T (앵커가격 ~ 목표가격)
|
||||
PZ2 = 1 # P > T (목표가격 초과)
|
||||
|
||||
_DESCRIPTIONS = nonmember(("PZ1 (A≤P≤T)", "PZ2 (P>T)"))
|
||||
_WEIGHTS = nonmember((1.0, 0.3)) # 목표가격 이내일수록 협상 유리
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._DESCRIPTIONS[self.value]
|
||||
|
||||
@property
|
||||
def weight(self) -> float:
|
||||
return self._WEIGHTS[self.value]
|
||||
|
||||
@classmethod
|
||||
def from_prices(
|
||||
cls,
|
||||
input_price: float,
|
||||
anchor_price: float,
|
||||
target_price: float,
|
||||
) -> "InputPriceZone":
|
||||
"""입력가격, 앵커가격, 목표가격으로부터 구간 결정"""
|
||||
if anchor_price <= 0 or target_price <= 0:
|
||||
raise ValueError("anchor_price and target_price must be positive")
|
||||
if anchor_price > target_price:
|
||||
raise ValueError("anchor_price must not exceed target_price")
|
||||
|
||||
if anchor_price <= input_price <= target_price:
|
||||
return cls.PZ1
|
||||
else:
|
||||
return cls.PZ2
|
||||
|
||||
|
||||
@dataclass
|
||||
class State:
|
||||
"""
|
||||
Q-Learning State 표현 (Version 3.0 - 162 states)
|
||||
|
||||
State = (매출액 가격구간, 유통 구조, 파트너사 종류, 가격 수용률 구간, 입력 금액 구간)
|
||||
Total: 3 × 3 × 3 × 3 × 2 = 162 states
|
||||
"""
|
||||
|
||||
revenue_range: RevenueRange
|
||||
distribution: DistributionStructure
|
||||
partner_type: PartnerType
|
||||
acceptance_rate: AcceptanceRate
|
||||
input_price_zone: InputPriceZone
|
||||
|
||||
def to_array(self) -> List[int]:
|
||||
"""State를 배열로 변환"""
|
||||
return [
|
||||
self.revenue_range.value,
|
||||
self.distribution.value,
|
||||
self.partner_type.value,
|
||||
self.acceptance_rate.value,
|
||||
self.input_price_zone.value,
|
||||
]
|
||||
|
||||
def to_index(self) -> int:
|
||||
"""
|
||||
State를 1차원 인덱스로 변환 (Q-Table 인덱싱용)
|
||||
|
||||
State space: 3 × 3 × 3 × 3 × 2 = 162
|
||||
|
||||
Returns:
|
||||
0 ~ 161 사이의 인덱스
|
||||
"""
|
||||
# 5D to 1D index conversion
|
||||
# index = revenue * (3*3*3*2) + dist * (3*3*2) + partner * (3*2) + accept * 2 + pricezone
|
||||
return (
|
||||
self.revenue_range.value * 54
|
||||
+ self.distribution.value * 18
|
||||
+ self.partner_type.value * 6
|
||||
+ self.acceptance_rate.value * 2
|
||||
+ self.input_price_zone.value
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_index(cls, index: int) -> "State":
|
||||
"""
|
||||
1차원 인덱스에서 State 복원
|
||||
|
||||
Args:
|
||||
index: 0 ~ 161 사이의 인덱스
|
||||
|
||||
Returns:
|
||||
State 객체
|
||||
"""
|
||||
if index < 0 or index >= 162:
|
||||
raise ValueError(f"index {index} out of range [0, 162)")
|
||||
|
||||
# 1D to 5D index conversion
|
||||
revenue_value = index // 54
|
||||
remainder = index % 54
|
||||
|
||||
dist_value = remainder // 18
|
||||
remainder = remainder % 18
|
||||
|
||||
partner_value = remainder // 6
|
||||
remainder = remainder % 6
|
||||
|
||||
accept_value = remainder // 2
|
||||
pricezone_value = remainder % 2
|
||||
|
||||
return cls(
|
||||
revenue_range=RevenueRange(revenue_value),
|
||||
distribution=DistributionStructure(dist_value),
|
||||
partner_type=PartnerType(partner_value),
|
||||
acceptance_rate=AcceptanceRate(accept_value),
|
||||
input_price_zone=InputPriceZone(pricezone_value),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_array(cls, arr: List[int]) -> "State":
|
||||
"""배열로부터 State 생성"""
|
||||
if len(arr) != 5:
|
||||
raise ValueError(f"Expected 5 elements, got {len(arr)}")
|
||||
return cls(
|
||||
revenue_range=RevenueRange(arr[0]),
|
||||
distribution=DistributionStructure(arr[1]),
|
||||
partner_type=PartnerType(arr[2]),
|
||||
acceptance_rate=AcceptanceRate(arr[3]),
|
||||
input_price_zone=InputPriceZone(arr[4]),
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"State("
|
||||
f"revenue={self.revenue_range.description}, "
|
||||
f"distribution={self.distribution.description}, "
|
||||
f"partner={self.partner_type.description}, "
|
||||
f"acceptance={self.acceptance_rate.description}, "
|
||||
f"price_zone={self.input_price_zone.description})"
|
||||
)
|
||||
|
|
@ -0,0 +1,148 @@
|
|||
"""Visit (N) table for UCB-based policies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class VisitTable:
|
||||
"""Tracks state-action visit counts for UCB selection."""
|
||||
|
||||
state_space_size: int
|
||||
action_space_size: int
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""
|
||||
Args:
|
||||
state_space_size: 상태 공간 크기
|
||||
action_space_size: 액션 공간 크기
|
||||
"""
|
||||
self._counts = np.zeros(
|
||||
(self.state_space_size, self.action_space_size), dtype=np.int64
|
||||
)
|
||||
|
||||
def increment(self, state_index: int, action_id: int, count: int = 1) -> None:
|
||||
"""
|
||||
특정 (state, action)의 방문 횟수 증가
|
||||
|
||||
Args:
|
||||
state_index: State 인덱스
|
||||
action_id: Action ID
|
||||
count: 증가시킬 횟수 (기본 1)
|
||||
"""
|
||||
if state_index < 0 or state_index >= self.state_space_size:
|
||||
raise ValueError(
|
||||
f"state_index {state_index} out of range [0, {self.state_space_size})"
|
||||
)
|
||||
if action_id < 0 or action_id >= self.action_space_size:
|
||||
raise ValueError(
|
||||
f"action_id {action_id} out of range [0, {self.action_space_size})"
|
||||
)
|
||||
if count < 0:
|
||||
raise ValueError("count must be non-negative")
|
||||
self._counts[state_index, action_id] += count
|
||||
|
||||
def get_action_count(self, state_index: int, action_id: int) -> int:
|
||||
"""
|
||||
특정 (state, action)의 방문 횟수 조회
|
||||
|
||||
Args:
|
||||
state_index: State 인덱스
|
||||
action_id: Action ID
|
||||
|
||||
Returns:
|
||||
방문 횟수
|
||||
"""
|
||||
self._validate_indices(state_index, action_id)
|
||||
return int(self._counts[state_index, action_id])
|
||||
|
||||
def get_state_counts(self, state_index: int) -> np.ndarray:
|
||||
"""
|
||||
특정 state의 모든 action에 대한 방문 횟수 조회
|
||||
|
||||
Args:
|
||||
state_index: State 인덱스
|
||||
|
||||
Returns:
|
||||
shape (action_space_size,)의 방문 횟수 배열
|
||||
"""
|
||||
if state_index < 0 or state_index >= self.state_space_size:
|
||||
raise ValueError(
|
||||
f"state_index {state_index} out of range [0, {self.state_space_size})"
|
||||
)
|
||||
return self._counts[state_index, :].copy()
|
||||
|
||||
def total_visits(self, state_index: int) -> int:
|
||||
"""
|
||||
특정 state의 총 방문 횟수 (모든 action 합계)
|
||||
|
||||
Args:
|
||||
state_index: State 인덱스
|
||||
|
||||
Returns:
|
||||
총 방문 횟수
|
||||
"""
|
||||
if state_index < 0 or state_index >= self.state_space_size:
|
||||
raise ValueError(
|
||||
f"state_index {state_index} out of range [0, {self.state_space_size})"
|
||||
)
|
||||
return int(self._counts[state_index, :].sum())
|
||||
|
||||
def reset_state(self, state_index: int) -> None:
|
||||
"""
|
||||
특정 state의 방문 기록 초기화
|
||||
|
||||
Args:
|
||||
state_index: State 인덱스
|
||||
"""
|
||||
if state_index < 0 or state_index >= self.state_space_size:
|
||||
raise ValueError(
|
||||
f"state_index {state_index} out of range [0, {self.state_space_size})"
|
||||
)
|
||||
self._counts[state_index, :] = 0
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
"""
|
||||
VisitTable을 딕셔너리로 직렬화
|
||||
|
||||
Returns:
|
||||
직렬화된 VisitTable 데이터
|
||||
"""
|
||||
return {
|
||||
"state_space_size": self.state_space_size,
|
||||
"action_space_size": self.action_space_size,
|
||||
"counts": self._counts.tolist(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> "VisitTable":
|
||||
"""
|
||||
딕셔너리에서 VisitTable 복원
|
||||
|
||||
Args:
|
||||
data: 직렬화된 VisitTable 데이터
|
||||
|
||||
Returns:
|
||||
복원된 VisitTable 인스턴스
|
||||
"""
|
||||
table = cls(
|
||||
state_space_size=data["state_space_size"],
|
||||
action_space_size=data["action_space_size"],
|
||||
)
|
||||
table._counts = np.array(data["counts"], dtype=np.int64)
|
||||
return table
|
||||
|
||||
def _validate_indices(self, state_index: int, action_id: int) -> None:
|
||||
"""인덱스 유효성 검증"""
|
||||
if state_index < 0 or state_index >= self.state_space_size:
|
||||
raise ValueError(
|
||||
f"state_index {state_index} out of range [0, {self.state_space_size})"
|
||||
)
|
||||
if action_id < 0 or action_id >= self.action_space_size:
|
||||
raise ValueError(
|
||||
f"action_id {action_id} out of range [0, {self.action_space_size})"
|
||||
)
|
||||
|
|
@ -0,0 +1,200 @@
|
|||
"""
|
||||
Experience Repository
|
||||
Experience 데이터를 JSONL 형식으로 저장/로드
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from ..model.experience import Experience
|
||||
|
||||
|
||||
class ExperienceRepository:
|
||||
"""
|
||||
Experience 데이터를 파일 시스템에 저장/로드하는 Repository
|
||||
|
||||
JSONL (JSON Lines) 형식 사용:
|
||||
- 각 줄이 하나의 JSON 객체 (Experience)
|
||||
- 스트리밍 방식으로 읽기/쓰기 가능
|
||||
- 대용량 데이터 처리에 유리
|
||||
"""
|
||||
|
||||
def __init__(self, data_dir: Optional[str] = None):
|
||||
"""
|
||||
Args:
|
||||
data_dir: Experience 데이터를 저장할 디렉토리
|
||||
(기본: Q_Table/data/experiences/)
|
||||
"""
|
||||
if data_dir is None:
|
||||
current_dir = Path(__file__).parent.parent.parent
|
||||
data_dir = current_dir / "data" / "experiences"
|
||||
|
||||
self.data_dir = Path(data_dir)
|
||||
self.data_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def save(self, experience: Experience, filename: str = "experiences.jsonl"):
|
||||
"""
|
||||
단일 Experience를 파일에 추가 저장
|
||||
|
||||
Args:
|
||||
experience: 저장할 Experience
|
||||
filename: 파일명 (기본: experiences.jsonl)
|
||||
"""
|
||||
file_path = self.data_dir / filename
|
||||
|
||||
with open(file_path, "a", encoding="utf-8") as f:
|
||||
json.dump(experience.to_dict(), f, ensure_ascii=False)
|
||||
f.write("\n")
|
||||
|
||||
def save_batch(
|
||||
self, experiences: List[Experience], filename: str = "experiences.jsonl"
|
||||
):
|
||||
"""
|
||||
여러 Experience를 배치로 저장
|
||||
|
||||
Args:
|
||||
experiences: Experience 리스트
|
||||
filename: 파일명
|
||||
"""
|
||||
file_path = self.data_dir / filename
|
||||
|
||||
with open(file_path, "a", encoding="utf-8") as f:
|
||||
for exp in experiences:
|
||||
json.dump(exp.to_dict(), f, ensure_ascii=False)
|
||||
f.write("\n")
|
||||
|
||||
def load_all(self, filename: str = "experiences.jsonl") -> List[Experience]:
|
||||
"""
|
||||
파일에서 모든 Experience 로드
|
||||
|
||||
Args:
|
||||
filename: 파일명
|
||||
|
||||
Returns:
|
||||
Experience 리스트
|
||||
"""
|
||||
file_path = self.data_dir / filename
|
||||
|
||||
if not file_path.exists():
|
||||
return []
|
||||
|
||||
experiences = []
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.strip():
|
||||
data = json.loads(line)
|
||||
experiences.append(Experience.from_dict(data))
|
||||
|
||||
return experiences
|
||||
|
||||
def load_by_episode(
|
||||
self, episode_id: str, filename: str = "experiences.jsonl"
|
||||
) -> List[Experience]:
|
||||
"""
|
||||
특정 에피소드의 Experience들만 로드
|
||||
|
||||
Args:
|
||||
episode_id: 에피소드 ID
|
||||
filename: 파일명
|
||||
|
||||
Returns:
|
||||
해당 에피소드의 Experience 리스트
|
||||
"""
|
||||
all_experiences = self.load_all(filename)
|
||||
return [exp for exp in all_experiences if exp.episode_id == episode_id]
|
||||
|
||||
def count(self, filename: str = "experiences.jsonl") -> int:
|
||||
"""
|
||||
저장된 Experience 개수
|
||||
|
||||
Args:
|
||||
filename: 파일명
|
||||
|
||||
Returns:
|
||||
Experience 개수
|
||||
"""
|
||||
file_path = self.data_dir / filename
|
||||
|
||||
if not file_path.exists():
|
||||
return 0
|
||||
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return sum(1 for line in f if line.strip())
|
||||
|
||||
def clear(self, filename: str = "experiences.jsonl"):
|
||||
"""
|
||||
파일 내용 삭제
|
||||
|
||||
Args:
|
||||
filename: 파일명
|
||||
"""
|
||||
file_path = self.data_dir / filename
|
||||
|
||||
if file_path.exists():
|
||||
file_path.unlink()
|
||||
|
||||
def create_new_file(self, prefix: str = "exp") -> str:
|
||||
"""
|
||||
타임스탬프가 포함된 새 파일 생성
|
||||
|
||||
Args:
|
||||
prefix: 파일명 접두사
|
||||
|
||||
Returns:
|
||||
생성된 파일명
|
||||
|
||||
Example:
|
||||
>>> repo.create_new_file("train")
|
||||
'train_20251029_163000.jsonl'
|
||||
"""
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"{prefix}_{timestamp}.jsonl"
|
||||
file_path = self.data_dir / filename
|
||||
file_path.touch()
|
||||
return filename
|
||||
|
||||
def list_files(self) -> List[str]:
|
||||
"""
|
||||
데이터 디렉토리의 모든 JSONL 파일 목록
|
||||
|
||||
Returns:
|
||||
파일명 리스트
|
||||
"""
|
||||
return [f.name for f in self.data_dir.glob("*.jsonl")]
|
||||
|
||||
def get_file_info(self, filename: str) -> dict:
|
||||
"""
|
||||
파일 정보 조회
|
||||
|
||||
Args:
|
||||
filename: 파일명
|
||||
|
||||
Returns:
|
||||
{
|
||||
'filename': 'experiences.jsonl',
|
||||
'path': '/path/to/file',
|
||||
'size': 1024, # bytes
|
||||
'count': 100, # experience 개수
|
||||
'created': '2025-10-29 16:30:00'
|
||||
}
|
||||
"""
|
||||
file_path = self.data_dir / filename
|
||||
|
||||
if not file_path.exists():
|
||||
return None
|
||||
|
||||
stat = file_path.stat()
|
||||
|
||||
return {
|
||||
"filename": filename,
|
||||
"path": str(file_path),
|
||||
"size": stat.st_size,
|
||||
"count": self.count(filename),
|
||||
"created": datetime.fromtimestamp(stat.st_ctime).strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
),
|
||||
"modified": datetime.fromtimestamp(stat.st_mtime).strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
),
|
||||
}
|
||||
|
|
@ -0,0 +1,166 @@
|
|||
"""Reward function aligned with the Q-Table design specification (Version 3.0)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from ..model.state import (
|
||||
AcceptanceRate,
|
||||
DistributionStructure,
|
||||
InputPriceZone,
|
||||
PartnerType,
|
||||
RevenueRange,
|
||||
)
|
||||
|
||||
|
||||
class NegotiationOutcome(Enum):
|
||||
ONGOING = "ongoing"
|
||||
SUCCESS = "success"
|
||||
FAILURE = "failure"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RewardConfig:
|
||||
"""보상 함수 설정"""
|
||||
|
||||
beta: float = 0.2
|
||||
success_reward: float = 1.0
|
||||
ongoing_reward: float = 0.0
|
||||
failure_penalty: float = -0.5
|
||||
penalty_lambda: float = 0.02
|
||||
|
||||
# 동적 가중치 계수 (Version 3.0)
|
||||
w1: float = 0.2 # 매출액 가격구간
|
||||
w2: float = 0.25 # 유통 구조
|
||||
w3: float = 0.2 # 파트너사 종류
|
||||
w4: float = 0.25 # 가격 수용률
|
||||
w5: float = 0.1 # 입력 금액 구간
|
||||
|
||||
min_weight: float = 0.2
|
||||
max_weight: float = 0.8
|
||||
|
||||
|
||||
@dataclass
|
||||
class RewardBreakdown:
|
||||
price_reward: float
|
||||
end_reward: float
|
||||
penalty: float
|
||||
weight: float
|
||||
total: float
|
||||
|
||||
|
||||
def calculate_reward(
|
||||
*,
|
||||
revenue_range: RevenueRange,
|
||||
distribution: DistributionStructure,
|
||||
partner_type: PartnerType,
|
||||
acceptance_rate: AcceptanceRate,
|
||||
input_price_zone: InputPriceZone,
|
||||
current_price: float,
|
||||
anchor_price: float,
|
||||
target_price: float,
|
||||
round_number: int,
|
||||
outcome: NegotiationOutcome,
|
||||
config: Optional[RewardConfig] = None,
|
||||
) -> RewardBreakdown:
|
||||
"""
|
||||
보상 계산 (Version 3.0)
|
||||
|
||||
R = W × R_price + (1-W) × R_end - λ × t
|
||||
|
||||
W = clip(W_raw, min_weight, max_weight)
|
||||
W_raw = w1×S_amount + w2×S_dist + w3×S_partner + w4×S_accept + w5×S_pricezone
|
||||
"""
|
||||
cfg = config or RewardConfig()
|
||||
|
||||
price_reward = _calculate_price_reward(
|
||||
current_price=current_price,
|
||||
anchor_price=anchor_price,
|
||||
target_price=target_price,
|
||||
beta=cfg.beta,
|
||||
)
|
||||
end_reward = _calculate_end_reward(outcome, cfg)
|
||||
weight = _calculate_weight(
|
||||
revenue_range=revenue_range,
|
||||
distribution=distribution,
|
||||
partner_type=partner_type,
|
||||
acceptance_rate=acceptance_rate,
|
||||
input_price_zone=input_price_zone,
|
||||
config=cfg,
|
||||
)
|
||||
penalty = _calculate_penalty(round_number, cfg)
|
||||
|
||||
total = weight * price_reward + (1.0 - weight) * end_reward - penalty
|
||||
return RewardBreakdown(
|
||||
price_reward=price_reward,
|
||||
end_reward=end_reward,
|
||||
penalty=penalty,
|
||||
weight=weight,
|
||||
total=total,
|
||||
)
|
||||
|
||||
|
||||
def _calculate_price_reward(
|
||||
*,
|
||||
current_price: float,
|
||||
anchor_price: float,
|
||||
target_price: float,
|
||||
beta: float,
|
||||
) -> float:
|
||||
if anchor_price <= 0 or target_price <= 0:
|
||||
raise ValueError("anchor_price and target_price must be positive")
|
||||
if anchor_price > target_price:
|
||||
raise ValueError("anchor_price must not exceed target_price")
|
||||
|
||||
if current_price < anchor_price:
|
||||
improvement = (anchor_price - current_price) / anchor_price
|
||||
return 1.0 + beta * improvement
|
||||
|
||||
if anchor_price <= current_price <= target_price:
|
||||
span = target_price - anchor_price
|
||||
if span <= 0:
|
||||
return 0.0
|
||||
return (target_price - current_price) / span
|
||||
|
||||
return 0.0
|
||||
|
||||
|
||||
def _calculate_end_reward(outcome: NegotiationOutcome, cfg: RewardConfig) -> float:
|
||||
if outcome is NegotiationOutcome.SUCCESS:
|
||||
return cfg.success_reward
|
||||
if outcome is NegotiationOutcome.FAILURE:
|
||||
return cfg.failure_penalty
|
||||
return cfg.ongoing_reward
|
||||
|
||||
|
||||
def _calculate_weight(
|
||||
*,
|
||||
revenue_range: RevenueRange,
|
||||
distribution: DistributionStructure,
|
||||
partner_type: PartnerType,
|
||||
acceptance_rate: AcceptanceRate,
|
||||
input_price_zone: InputPriceZone,
|
||||
config: RewardConfig,
|
||||
) -> float:
|
||||
"""
|
||||
동적 가중치 계산 (Version 3.0)
|
||||
|
||||
W_raw = w1×S_amount + w2×S_dist + w3×S_partner + w4×S_accept + w5×S_pricezone
|
||||
W = clip(W_raw, min_weight, max_weight)
|
||||
"""
|
||||
w_raw = (
|
||||
config.w1 * revenue_range.weight
|
||||
+ config.w2 * distribution.weight
|
||||
+ config.w3 * partner_type.weight
|
||||
+ config.w4 * acceptance_rate.weight
|
||||
+ config.w5 * input_price_zone.weight
|
||||
)
|
||||
return max(config.min_weight, min(config.max_weight, w_raw))
|
||||
|
||||
|
||||
def _calculate_penalty(round_number: int, cfg: RewardConfig) -> float:
|
||||
if round_number < 0:
|
||||
raise ValueError("round_number must be non-negative")
|
||||
return cfg.penalty_lambda * float(round_number)
|
||||
|
|
@ -0,0 +1,87 @@
|
|||
"""Utility helpers to build Q-Table states from negotiation snapshots (Version 3.0)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from ..model.state import (
|
||||
AcceptanceRate,
|
||||
DistributionStructure,
|
||||
InputPriceZone,
|
||||
PartnerType,
|
||||
RevenueRange,
|
||||
State,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NegotiationSnapshot:
|
||||
"""협상 스냅샷 - 실제 비즈니스 데이터"""
|
||||
|
||||
# 매출액 (만원 단위)
|
||||
revenue_amount: float
|
||||
|
||||
# 유통 구조 코드 ("M": 제조, "W": 총판, "R": 유통)
|
||||
distribution_code: str
|
||||
|
||||
# 파트너사 수
|
||||
partner_count: int
|
||||
|
||||
# 가격 정보
|
||||
anchor_price: float
|
||||
target_price: float
|
||||
input_price: float
|
||||
|
||||
# 가격 수용률 (0~1, 계산되거나 직접 입력)
|
||||
acceptance_ratio: Optional[float] = None
|
||||
|
||||
# 초기 가격 (수용률 계산용, acceptance_ratio가 없을 때)
|
||||
initial_price: Optional[float] = None
|
||||
|
||||
|
||||
def build_state(snapshot: NegotiationSnapshot) -> State:
|
||||
"""
|
||||
협상 스냅샷으로부터 Q-Learning State 생성 (Version 3.0)
|
||||
|
||||
Args:
|
||||
snapshot: 협상 스냅샷
|
||||
|
||||
Returns:
|
||||
162-dimensional state
|
||||
"""
|
||||
# 1. 매출액 가격구간
|
||||
revenue_range = RevenueRange.from_amount(snapshot.revenue_amount)
|
||||
|
||||
# 2. 유통 구조
|
||||
distribution = DistributionStructure.from_code(snapshot.distribution_code)
|
||||
|
||||
# 3. 파트너사 종류
|
||||
partner_type = PartnerType.from_count(snapshot.partner_count)
|
||||
|
||||
# 4. 가격 수용률 구간
|
||||
if snapshot.acceptance_ratio is not None:
|
||||
acceptance_rate = AcceptanceRate.from_ratio(snapshot.acceptance_ratio)
|
||||
elif snapshot.initial_price is not None:
|
||||
# 초기가격으로부터 수용률 계산: (initial - input) / initial
|
||||
if snapshot.initial_price <= 0:
|
||||
raise ValueError("initial_price must be positive")
|
||||
ratio = (snapshot.initial_price - snapshot.input_price) / snapshot.initial_price
|
||||
acceptance_rate = AcceptanceRate.from_ratio(ratio)
|
||||
else:
|
||||
raise ValueError("Either acceptance_ratio or initial_price must be provided")
|
||||
|
||||
# 5. 입력 금액 구간
|
||||
input_price_zone = InputPriceZone.from_prices(
|
||||
input_price=snapshot.input_price,
|
||||
anchor_price=snapshot.anchor_price,
|
||||
target_price=snapshot.target_price,
|
||||
)
|
||||
|
||||
return State(
|
||||
revenue_range=revenue_range,
|
||||
distribution=distribution,
|
||||
partner_type=partner_type,
|
||||
acceptance_rate=acceptance_rate,
|
||||
input_price_zone=input_price_zone,
|
||||
)
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from src.negotiation_agent.Q_Table.domain.model.q_table import QTable
|
||||
from src.negotiation_agent.Q_Table.domain.model.visit_table import VisitTable
|
||||
|
||||
|
||||
class ModelRepository:
|
||||
"""Simple repository to save/load Q-Table and VisitTable models."""
|
||||
|
||||
def __init__(self, model_dir: Optional[str] = None):
|
||||
if model_dir is None:
|
||||
# Current: src/negotiation_agent/Q_Table/infra/repository/model_repository.py
|
||||
# Target Data: src/negotiation_agent/Q_Table/data/model
|
||||
# Path: ../../../data/model
|
||||
current_dir = Path(__file__).parent
|
||||
model_dir = current_dir.parent.parent / "data" / "model"
|
||||
|
||||
self.model_dir = Path(model_dir)
|
||||
self.model_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.q_table_path = self.model_dir / "q_table.json"
|
||||
self.visit_table_path = self.model_dir / "visit_table.json"
|
||||
|
||||
def save(self, q_table: QTable, visit_table: VisitTable):
|
||||
with open(self.q_table_path, "w", encoding="utf-8") as f:
|
||||
json.dump(q_table.to_dict(), f, indent=2)
|
||||
|
||||
with open(self.visit_table_path, "w", encoding="utf-8") as f:
|
||||
json.dump(visit_table.to_dict(), f, indent=2)
|
||||
|
||||
print(f"Models saved to {self.model_dir}")
|
||||
|
||||
def load(self) -> Tuple[Optional[QTable], Optional[VisitTable]]:
|
||||
q_table = None
|
||||
visit_table = None
|
||||
|
||||
if self.q_table_path.exists():
|
||||
with open(self.q_table_path, "r", encoding="utf-8") as f:
|
||||
q_table = QTable.from_dict(json.load(f))
|
||||
|
||||
if self.visit_table_path.exists():
|
||||
with open(self.visit_table_path, "r", encoding="utf-8") as f:
|
||||
visit_table = VisitTable.from_dict(json.load(f))
|
||||
|
||||
return q_table, visit_table
|
||||
Binary file not shown.
|
|
@ -0,0 +1,143 @@
|
|||
"""
|
||||
CollectExperienceUsecase
|
||||
협상 중 Experience 데이터를 수집하여 저장
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from ..domain.model.state import State
|
||||
from ..domain.model.experience import Experience
|
||||
from ..domain.repository.experience_repository import ExperienceRepository
|
||||
|
||||
|
||||
class CollectExperienceUsecase:
|
||||
"""
|
||||
협상 과정에서 발생하는 Experience를 수집하여 저장
|
||||
|
||||
오프라인 학습을 위한 데이터 수집 담당
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
experience_repository: ExperienceRepository,
|
||||
filename: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
experience_repository: Experience 저장소
|
||||
filename: 저장할 파일명 (None이면 기본 파일 사용)
|
||||
"""
|
||||
self.experience_repository = experience_repository
|
||||
self.filename = filename or "experiences.jsonl"
|
||||
self.current_episode_id: Optional[str] = None
|
||||
self.current_step: int = 0
|
||||
|
||||
def start_episode(self, episode_id: Optional[str] = None):
|
||||
"""
|
||||
새로운 에피소드 시작
|
||||
|
||||
Args:
|
||||
episode_id: 에피소드 ID (None이면 자동 생성)
|
||||
"""
|
||||
if episode_id is None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
episode_id = f"ep_{timestamp}"
|
||||
|
||||
self.current_episode_id = episode_id
|
||||
self.current_step = 0
|
||||
|
||||
def collect(
|
||||
self,
|
||||
state: State,
|
||||
action_id: int,
|
||||
reward: float,
|
||||
next_state: Optional[State],
|
||||
done: bool,
|
||||
) -> Experience:
|
||||
"""
|
||||
단일 Experience 수집 및 저장
|
||||
|
||||
Args:
|
||||
state: 현재 상태
|
||||
action_id: 선택한 액션 ID
|
||||
reward: 받은 보상
|
||||
next_state: 다음 상태 (종료 시 None)
|
||||
done: 에피소드 종료 여부
|
||||
|
||||
Returns:
|
||||
수집된 Experience
|
||||
"""
|
||||
if self.current_episode_id is None:
|
||||
self.start_episode()
|
||||
|
||||
# Experience 생성
|
||||
experience = Experience(
|
||||
state=state,
|
||||
action_id=action_id,
|
||||
reward=reward,
|
||||
next_state=next_state,
|
||||
done=done,
|
||||
episode_id=self.current_episode_id,
|
||||
step=self.current_step,
|
||||
timestamp=datetime.now().isoformat(),
|
||||
)
|
||||
|
||||
# 저장
|
||||
self.experience_repository.save(experience, self.filename)
|
||||
|
||||
# 스텝 증가
|
||||
self.current_step += 1
|
||||
|
||||
return experience
|
||||
|
||||
def end_episode(self):
|
||||
"""에피소드 종료"""
|
||||
self.current_episode_id = None
|
||||
self.current_step = 0
|
||||
|
||||
def get_episode_count(self) -> int:
|
||||
"""수집된 에피소드 개수 (근사치)"""
|
||||
# 파일에서 unique episode_id 개수 계산
|
||||
experiences = self.experience_repository.load_all(self.filename)
|
||||
episode_ids = set(exp.episode_id for exp in experiences if exp.episode_id)
|
||||
return len(episode_ids)
|
||||
|
||||
def get_total_count(self) -> int:
|
||||
"""수집된 총 Experience 개수"""
|
||||
return self.experience_repository.count(self.filename)
|
||||
|
||||
def get_collection_info(self) -> dict:
|
||||
"""
|
||||
수집 정보 조회
|
||||
|
||||
Returns:
|
||||
{
|
||||
'filename': 'experiences.jsonl',
|
||||
'total_experiences': 1000,
|
||||
'episodes': 50,
|
||||
'current_episode': 'ep_001' or None,
|
||||
'current_step': 5
|
||||
}
|
||||
"""
|
||||
return {
|
||||
"filename": self.filename,
|
||||
"total_experiences": self.get_total_count(),
|
||||
"episodes": self.get_episode_count(),
|
||||
"current_episode": self.current_episode_id,
|
||||
"current_step": self.current_step,
|
||||
"file_info": self.experience_repository.get_file_info(self.filename),
|
||||
}
|
||||
|
||||
def create_new_collection_file(self, prefix: str = "exp") -> str:
|
||||
"""
|
||||
새로운 수집 파일 생성 및 전환
|
||||
|
||||
Args:
|
||||
prefix: 파일명 접두사
|
||||
|
||||
Returns:
|
||||
생성된 파일명
|
||||
"""
|
||||
new_filename = self.experience_repository.create_new_file(prefix)
|
||||
self.filename = new_filename
|
||||
return new_filename
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
"""
|
||||
EvaluateAgentUsecase
|
||||
학습된 Q-Table의 성능을 평가하는 Usecase
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from ..domain.model.q_table import QTable
|
||||
from ..domain.repository.experience_repository import ExperienceRepository
|
||||
|
||||
|
||||
class EvaluateAgentUsecase:
|
||||
"""
|
||||
저장된 Experience 데이터를 사용하여 현재 에이전트(Q-Table)의 성능을 평가
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
q_table: QTable,
|
||||
experience_repository: ExperienceRepository,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
q_table: 평가할 Q-Table
|
||||
experience_repository: Experience 저장소
|
||||
"""
|
||||
self.q_table = q_table
|
||||
self.experience_repository = experience_repository
|
||||
|
||||
def execute(
|
||||
self, filename: str = "experiences.jsonl", sample_size: Optional[int] = None
|
||||
) -> dict:
|
||||
"""
|
||||
현재 Q-Table의 성능 평가
|
||||
|
||||
Args:
|
||||
filename: 평가용 Experience 파일
|
||||
sample_size: 샘플링 크기 (None이면 전체)
|
||||
|
||||
Returns:
|
||||
{
|
||||
'avg_q_value': 0.5,
|
||||
'avg_reward': 1.0,
|
||||
'total_samples': 100
|
||||
}
|
||||
"""
|
||||
experiences = self.experience_repository.load_all(filename)
|
||||
|
||||
if not experiences:
|
||||
return {"avg_q_value": 0.0, "avg_reward": 0.0, "total_samples": 0}
|
||||
|
||||
# 샘플링
|
||||
if sample_size and sample_size < len(experiences):
|
||||
import random
|
||||
|
||||
experiences = random.sample(experiences, sample_size)
|
||||
|
||||
total_q_value = 0.0
|
||||
total_reward = 0.0
|
||||
|
||||
for exp in experiences:
|
||||
state_index = exp.state.to_index()
|
||||
q_value = self.q_table.get_q_value(state_index, exp.action_id)
|
||||
total_q_value += q_value
|
||||
total_reward += exp.reward
|
||||
|
||||
n = len(experiences)
|
||||
|
||||
return {
|
||||
"avg_q_value": total_q_value / n,
|
||||
"avg_reward": total_reward / n,
|
||||
"total_samples": n,
|
||||
}
|
||||
|
|
@ -0,0 +1,179 @@
|
|||
"""
|
||||
GetBestActionUsecase
|
||||
State를 받아 최적의 협상 카드를 선택하는 Usecase
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional
|
||||
from ..domain.model.state import State
|
||||
from ..domain.model.q_table import QTable
|
||||
from ..domain.agents.policy import UCBPolicy
|
||||
from ...integration.action_card_mapper import ActionCardMapper
|
||||
|
||||
|
||||
class GetBestActionUsecase:
|
||||
"""
|
||||
협상 추론 Usecase: State → action_id → card_id
|
||||
|
||||
Q_Table은 추상적인 action_id만 반환하고,
|
||||
ActionCardMapper를 통해 구체적인 card_id로 변환한다.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
q_table: QTable,
|
||||
policy: UCBPolicy,
|
||||
action_card_mapper: ActionCardMapper,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
q_table: 학습된 Q-Table
|
||||
policy: 액션 선택 정책 (UCB)
|
||||
action_card_mapper: action_id ↔ card_id 매핑
|
||||
"""
|
||||
self.q_table = q_table
|
||||
self.policy = policy
|
||||
self.action_card_mapper = action_card_mapper
|
||||
|
||||
def execute(self, state: State, use_policy: bool = True) -> Dict:
|
||||
"""
|
||||
현재 상태에서 최적의 협상 카드 선택
|
||||
|
||||
Args:
|
||||
state: 현재 협상 상태
|
||||
use_policy: True면 Policy 사용 (UCB + 중복 방지)
|
||||
False면 순수 Q-value 최대값만 사용
|
||||
|
||||
Returns:
|
||||
{
|
||||
'action_id': 5,
|
||||
'card_id': 'no_5',
|
||||
'q_value': 0.85,
|
||||
'state_index': 12,
|
||||
'all_q_values': [0.1, 0.2, ..., 0.85, ...] # optional
|
||||
}
|
||||
|
||||
사용 가능한 액션이 없으면:
|
||||
{
|
||||
'action_id': None,
|
||||
'card_id': None,
|
||||
'q_value': None,
|
||||
'message': 'No available actions'
|
||||
}
|
||||
"""
|
||||
# 1. State → state_index 변환
|
||||
state_index = state.to_index()
|
||||
|
||||
# 2. Q-values 조회
|
||||
q_values = self.q_table.get_q_values_for_state(state_index)
|
||||
|
||||
# 3. Action 선택
|
||||
if use_policy:
|
||||
action_id = self.policy.select_action(state_index, q_values)
|
||||
else:
|
||||
# 순수 Q-value 최대값
|
||||
action_id = self.q_table.get_best_action(state_index)
|
||||
|
||||
if action_id is None:
|
||||
return {
|
||||
"action_id": None,
|
||||
"card_id": None,
|
||||
"q_value": None,
|
||||
"state_index": state_index,
|
||||
"message": "No available actions (all actions used in this episode)",
|
||||
}
|
||||
|
||||
# 4. action_id → card_id 매핑
|
||||
card_id = self.action_card_mapper.get_card_id(action_id)
|
||||
|
||||
if card_id is None:
|
||||
return {
|
||||
"action_id": action_id,
|
||||
"card_id": None,
|
||||
"q_value": float(q_values[action_id]),
|
||||
"state_index": state_index,
|
||||
"message": f"action_id {action_id} has no card mapping",
|
||||
}
|
||||
|
||||
# 5. Q-value 추출
|
||||
q_value = self.q_table.get_q_value(state_index, action_id)
|
||||
|
||||
return {
|
||||
"action_id": action_id,
|
||||
"card_id": card_id,
|
||||
"q_value": float(q_value),
|
||||
"state_index": state_index,
|
||||
"all_q_values": q_values.tolist(), # 디버깅용
|
||||
}
|
||||
|
||||
def get_top_k_actions(self, state: State, k: int = 3) -> list:
|
||||
"""
|
||||
상위 k개의 액션 추천
|
||||
|
||||
Args:
|
||||
state: 현재 협상 상태
|
||||
k: 추천할 액션 개수
|
||||
|
||||
Returns:
|
||||
[
|
||||
{'action_id': 5, 'card_id': 'no_5', 'q_value': 0.85},
|
||||
{'action_id': 3, 'card_id': 'no_3', 'q_value': 0.72},
|
||||
{'action_id': 1, 'card_id': 'no_1', 'q_value': 0.68}
|
||||
]
|
||||
"""
|
||||
state_index = state.to_index()
|
||||
q_values = self.q_table.get_q_values_for_state(state_index)
|
||||
|
||||
# Q-value 내림차순 정렬
|
||||
sorted_actions = sorted(enumerate(q_values), key=lambda x: x[1], reverse=True)
|
||||
|
||||
# 상위 k개 추출
|
||||
top_k = sorted_actions[:k]
|
||||
|
||||
results = []
|
||||
for action_id, q_value in top_k:
|
||||
card_id = self.action_card_mapper.get_card_id(action_id)
|
||||
results.append(
|
||||
{"action_id": action_id, "card_id": card_id, "q_value": float(q_value)}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def reset_episode(self):
|
||||
"""
|
||||
에피소드 종료 시 Policy 초기화
|
||||
(다음 협상 세션 시작 전 호출)
|
||||
"""
|
||||
self.policy.reset_episode()
|
||||
|
||||
def get_available_actions(self) -> list:
|
||||
"""
|
||||
현재 에피소드에서 아직 사용하지 않은 액션들
|
||||
|
||||
Returns:
|
||||
[0, 2, 4, 5, ...] (사용하지 않은 action_id들)
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
mask = self.policy.get_action_mask()
|
||||
available_action_ids = np.where(mask > 0)[0].tolist()
|
||||
return available_action_ids
|
||||
|
||||
def get_available_cards(self) -> list:
|
||||
"""
|
||||
현재 에피소드에서 아직 사용하지 않은 카드들
|
||||
|
||||
Returns:
|
||||
[
|
||||
{'action_id': 0, 'card_id': 'no_0'},
|
||||
{'action_id': 2, 'card_id': 'no_2'},
|
||||
...
|
||||
]
|
||||
"""
|
||||
available_action_ids = self.get_available_actions()
|
||||
|
||||
results = []
|
||||
for action_id in available_action_ids:
|
||||
card_id = self.action_card_mapper.get_card_id(action_id)
|
||||
results.append({"action_id": action_id, "card_id": card_id})
|
||||
|
||||
return results
|
||||
|
|
@ -0,0 +1,203 @@
|
|||
"""
|
||||
TrainOfflineUsecase
|
||||
저장된 Experience 데이터로 Q-Table을 오프라인 학습
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from ..domain.model.q_table import QTable
|
||||
from ..domain.model.experience import Experience
|
||||
from ..domain.model.visit_table import VisitTable
|
||||
from ..domain.repository.experience_repository import ExperienceRepository
|
||||
|
||||
|
||||
class TrainOfflineUsecase:
|
||||
"""
|
||||
수집된 Experience 데이터를 사용하여 Q-Table을 오프라인으로 학습
|
||||
|
||||
배치 학습 방식:
|
||||
- 저장된 Experience를 로드
|
||||
- Q-Learning 알고리즘으로 Q-Table 업데이트
|
||||
- 여러 에포크 반복 가능
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
q_table: QTable,
|
||||
experience_repository: ExperienceRepository,
|
||||
visit_table: Optional[VisitTable] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
q_table: 학습할 Q-Table
|
||||
experience_repository: Experience 저장소
|
||||
visit_table: 방문 횟수를 함께 추적할 N-Table (선택)
|
||||
"""
|
||||
self.q_table = q_table
|
||||
self.experience_repository = experience_repository
|
||||
self.visit_table = visit_table
|
||||
|
||||
def train(
|
||||
self,
|
||||
filename: str = "experiences.jsonl",
|
||||
epochs: int = 1,
|
||||
batch_size: Optional[int] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
저장된 Experience로 Q-Table 학습
|
||||
|
||||
Args:
|
||||
filename: Experience 파일명
|
||||
epochs: 학습 반복 횟수
|
||||
batch_size: 배치 크기 (None이면 전체 데이터 사용)
|
||||
|
||||
Returns:
|
||||
{
|
||||
'total_experiences': 1000,
|
||||
'epochs': 10,
|
||||
'updates': 10000,
|
||||
'avg_loss': 0.05
|
||||
}
|
||||
"""
|
||||
# Experience 로드
|
||||
experiences = self.experience_repository.load_all(filename)
|
||||
|
||||
if not experiences:
|
||||
return {
|
||||
"total_experiences": 0,
|
||||
"epochs": 0,
|
||||
"updates": 0,
|
||||
"avg_loss": 0.0,
|
||||
"message": "No experiences found",
|
||||
}
|
||||
|
||||
total_updates = 0
|
||||
total_loss = 0.0
|
||||
|
||||
# 에포크 반복
|
||||
for epoch in range(epochs):
|
||||
epoch_loss = 0.0
|
||||
|
||||
# 배치 처리
|
||||
if batch_size:
|
||||
for i in range(0, len(experiences), batch_size):
|
||||
batch = experiences[i : i + batch_size]
|
||||
loss = self._train_batch(batch)
|
||||
epoch_loss += loss
|
||||
total_updates += len(batch)
|
||||
else:
|
||||
# 전체 데이터 한번에
|
||||
loss = self._train_batch(experiences)
|
||||
epoch_loss += loss
|
||||
total_updates += len(experiences)
|
||||
|
||||
total_loss += epoch_loss
|
||||
|
||||
avg_loss = total_loss / total_updates if total_updates > 0 else 0.0
|
||||
|
||||
return {
|
||||
"total_experiences": len(experiences),
|
||||
"epochs": epochs,
|
||||
"updates": total_updates,
|
||||
"avg_loss": avg_loss,
|
||||
}
|
||||
|
||||
def _train_batch(self, experiences: List[Experience]) -> float:
|
||||
"""
|
||||
배치 Experience로 Q-Table 업데이트
|
||||
|
||||
Args:
|
||||
experiences: Experience 리스트
|
||||
|
||||
Returns:
|
||||
평균 손실값
|
||||
"""
|
||||
total_loss = 0.0
|
||||
|
||||
for exp in experiences:
|
||||
# State → state_index
|
||||
state_index = exp.state.to_index()
|
||||
|
||||
# Q-value 업데이트 전 값
|
||||
old_q = self.q_table.get_q_value(state_index, exp.action_id)
|
||||
|
||||
# Q-Table 업데이트
|
||||
if exp.next_state:
|
||||
next_state_index = exp.next_state.to_index()
|
||||
self.q_table.update(
|
||||
state_index=state_index,
|
||||
action_id=exp.action_id,
|
||||
reward=exp.reward,
|
||||
next_state_index=next_state_index,
|
||||
done=exp.done,
|
||||
)
|
||||
else:
|
||||
# 종료 상태
|
||||
self.q_table.update(
|
||||
state_index=state_index,
|
||||
action_id=exp.action_id,
|
||||
reward=exp.reward,
|
||||
next_state_index=None,
|
||||
done=True,
|
||||
)
|
||||
|
||||
if self.visit_table is not None:
|
||||
self.visit_table.increment(state_index, exp.action_id)
|
||||
|
||||
# Q-value 업데이트 후 값
|
||||
new_q = self.q_table.get_q_value(state_index, exp.action_id)
|
||||
|
||||
# 손실 계산 (업데이트 크기)
|
||||
loss = abs(new_q - old_q)
|
||||
total_loss += loss
|
||||
|
||||
return total_loss / len(experiences) if experiences else 0.0
|
||||
|
||||
def train_by_episodes(
|
||||
self,
|
||||
episode_ids: List[str],
|
||||
filename: str = "experiences.jsonl",
|
||||
epochs: int = 1,
|
||||
) -> dict:
|
||||
"""
|
||||
특정 에피소드들만 선택하여 학습
|
||||
|
||||
Args:
|
||||
episode_ids: 학습할 에피소드 ID 리스트
|
||||
filename: Experience 파일명
|
||||
epochs: 학습 반복 횟수
|
||||
|
||||
Returns:
|
||||
학습 결과
|
||||
"""
|
||||
# 해당 에피소드의 Experience만 로드
|
||||
all_experiences = self.experience_repository.load_all(filename)
|
||||
filtered_experiences = [
|
||||
exp for exp in all_experiences if exp.episode_id in episode_ids
|
||||
]
|
||||
|
||||
if not filtered_experiences:
|
||||
return {
|
||||
"total_experiences": 0,
|
||||
"episodes": 0,
|
||||
"epochs": 0,
|
||||
"updates": 0,
|
||||
"avg_loss": 0.0,
|
||||
"message": "No matching episodes found",
|
||||
}
|
||||
|
||||
total_updates = 0
|
||||
total_loss = 0.0
|
||||
|
||||
for epoch in range(epochs):
|
||||
loss = self._train_batch(filtered_experiences)
|
||||
total_loss += loss * len(filtered_experiences)
|
||||
total_updates += len(filtered_experiences)
|
||||
|
||||
return {
|
||||
"total_experiences": len(filtered_experiences),
|
||||
"episodes": len(episode_ids),
|
||||
"epochs": epochs,
|
||||
"updates": total_updates,
|
||||
"avg_loss": total_loss / total_updates if total_updates > 0 else 0.0,
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,152 @@
|
|||
"""
|
||||
Action-Card 매핑 레이어
|
||||
Q_Table의 action_id와 Card_Management의 card_id를 연결
|
||||
"""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
class ActionCardMapper:
|
||||
"""
|
||||
action_id (0, 1, 2, ...) ↔ card_id ("no_0", "no_1", ...)
|
||||
|
||||
Q_Table은 추상적인 action_id만 다루고,
|
||||
Card_Management는 구체적인 card_id만 다룬다.
|
||||
이 클래스가 둘을 연결하는 단일 책임을 가진다.
|
||||
|
||||
매핑은 JSON 파일에서 로드되며, 카드 추가/삭제 시 수동으로 업데이트된다.
|
||||
"""
|
||||
|
||||
def __init__(self, mapping_file: Optional[str] = None):
|
||||
"""
|
||||
Args:
|
||||
mapping_file: JSON 매핑 파일 경로
|
||||
(기본: integration/data/action_card_mapping.json)
|
||||
"""
|
||||
if mapping_file is None:
|
||||
current_dir = Path(__file__).parent
|
||||
mapping_file = current_dir / "data" / "action_card_mapping.json"
|
||||
|
||||
self.mapping_file = Path(mapping_file)
|
||||
self.action_to_card: Dict[int, str] = {}
|
||||
self.card_to_action: Dict[str, int] = {}
|
||||
|
||||
self._load_mapping()
|
||||
|
||||
def _load_mapping(self):
|
||||
"""JSON 파일에서 매핑 로드"""
|
||||
if not self.mapping_file.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Mapping file not found: {self.mapping_file}\n"
|
||||
f"Please create action_card_mapping.json first."
|
||||
)
|
||||
|
||||
with open(self.mapping_file, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
# action_to_card: {"0": "no_0", "1": "no_1", ...}
|
||||
# JSON keys are strings, convert to int
|
||||
self.action_to_card = {int(k): v for k, v in data["action_to_card"].items()}
|
||||
|
||||
# card_to_action: reverse mapping for quick lookup
|
||||
self.card_to_action = {v: int(k) for k, v in data["action_to_card"].items()}
|
||||
|
||||
def get_card_id(self, action_id: int) -> Optional[str]:
|
||||
"""
|
||||
action_id → card_id 변환
|
||||
|
||||
Args:
|
||||
action_id: Q_Table에서 사용하는 액션 ID (0, 1, 2, ...)
|
||||
|
||||
Returns:
|
||||
card_id: Card_Management에서 사용하는 카드 ID ("no_0", "no_1", ...)
|
||||
매핑이 없으면 None
|
||||
|
||||
Example:
|
||||
>>> mapper.get_card_id(5)
|
||||
'no_5'
|
||||
"""
|
||||
return self.action_to_card.get(action_id)
|
||||
|
||||
def get_action_id(self, card_id: str) -> Optional[int]:
|
||||
"""
|
||||
card_id → action_id 변환
|
||||
|
||||
Args:
|
||||
card_id: Card_Management에서 사용하는 카드 ID ("no_0", "no_1", ...)
|
||||
|
||||
Returns:
|
||||
action_id: Q_Table에서 사용하는 액션 ID (0, 1, 2, ...)
|
||||
매핑이 없으면 None
|
||||
|
||||
Example:
|
||||
>>> mapper.get_action_id('no_5')
|
||||
5
|
||||
"""
|
||||
return self.card_to_action.get(card_id)
|
||||
|
||||
def get_action_space_size(self) -> int:
|
||||
"""
|
||||
현재 매핑된 액션 개수 (= 카드 개수)
|
||||
|
||||
Returns:
|
||||
액션 공간 크기
|
||||
|
||||
Example:
|
||||
>>> mapper.get_action_space_size()
|
||||
21 # 21개의 카드
|
||||
"""
|
||||
return len(self.action_to_card)
|
||||
|
||||
def get_all_mappings(self) -> Dict[int, str]:
|
||||
"""
|
||||
전체 매핑 반환
|
||||
|
||||
Returns:
|
||||
{action_id: card_id} 딕셔너리
|
||||
|
||||
Example:
|
||||
>>> mapper.get_all_mappings()
|
||||
{0: 'no_0', 1: 'no_1', ..., 20: 'no_20'}
|
||||
"""
|
||||
return self.action_to_card.copy()
|
||||
|
||||
def get_all_action_ids(self) -> list:
|
||||
"""
|
||||
모든 action_id 리스트 반환 (정렬됨)
|
||||
|
||||
Returns:
|
||||
[0, 1, 2, ..., 20]
|
||||
"""
|
||||
return sorted(self.action_to_card.keys())
|
||||
|
||||
def get_all_card_ids(self) -> list:
|
||||
"""
|
||||
모든 card_id 리스트 반환 (action_id 순서)
|
||||
|
||||
Returns:
|
||||
['no_0', 'no_1', ..., 'no_20']
|
||||
"""
|
||||
return [self.action_to_card[i] for i in sorted(self.action_to_card.keys())]
|
||||
|
||||
def is_valid_action_id(self, action_id: int) -> bool:
|
||||
"""action_id가 매핑에 존재하는지 확인"""
|
||||
return action_id in self.action_to_card
|
||||
|
||||
def is_valid_card_id(self, card_id: str) -> bool:
|
||||
"""card_id가 매핑에 존재하는지 확인"""
|
||||
return card_id in self.card_to_action
|
||||
|
||||
def reload(self):
|
||||
"""매핑 파일 다시 로드 (런타임 중 파일이 변경된 경우)"""
|
||||
self._load_mapping()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"ActionCardMapper("
|
||||
f"action_space_size={self.get_action_space_size()}, "
|
||||
f"mapping_file={self.mapping_file}"
|
||||
f")"
|
||||
)
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
{
|
||||
"action_to_card": {
|
||||
"0": "no_1",
|
||||
"1": "no_2",
|
||||
"2": "no_3",
|
||||
"3": "no_5",
|
||||
"4": "no_6",
|
||||
"5": "no_8",
|
||||
"6": "no_11",
|
||||
"7": "no_13",
|
||||
"8": "no_14"
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,90 @@
|
|||
import sys
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path to allow imports from src
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from src.negotiation_agent.Q_Table.domain.model.q_table import QTable
|
||||
from src.negotiation_agent.Q_Table.domain.model.visit_table import VisitTable
|
||||
from src.negotiation_agent.Q_Table.domain.repository.experience_repository import ExperienceRepository
|
||||
from src.negotiation_agent.Q_Table.infra.repository.model_repository import ModelRepository
|
||||
from src.negotiation_agent.Q_Table.usecase.train_offline_usecase import TrainOfflineUsecase
|
||||
from src.negotiation_agent.integration.action_card_mapper import ActionCardMapper
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Train Q-Table Agent")
|
||||
parser.add_argument("--epochs", type=int, default=10, help="Number of epochs")
|
||||
parser.add_argument("--batch-size", type=int, default=32, help="Batch size")
|
||||
parser.add_argument("--lr", type=float, default=0.1, help="Learning rate (only for new tables)")
|
||||
parser.add_argument("--gamma", type=float, default=0.9, help="Discount factor (only for new tables)")
|
||||
parser.add_argument("--data-file", type=str, default="experiences.jsonl", help="Experience file name inside data/experiences/")
|
||||
args = parser.parse_args()
|
||||
|
||||
print("=== KTC V2 Agent Training ===")
|
||||
|
||||
# 1. Config
|
||||
try:
|
||||
mapper = ActionCardMapper()
|
||||
ACTION_SIZE = mapper.get_action_space_size()
|
||||
except Exception as e:
|
||||
# Fallback if specific file error
|
||||
print(f"Warning: Could not load Action mapping ({e}). Defaulting to 21.")
|
||||
ACTION_SIZE = 21
|
||||
|
||||
STATE_SIZE = 162
|
||||
print(f"Configuration: State Size={STATE_SIZE}, Action Size={ACTION_SIZE}")
|
||||
|
||||
# 2. Repository & Models
|
||||
model_repo = ModelRepository()
|
||||
|
||||
print("Loading models...")
|
||||
q_table, visit_table = model_repo.load()
|
||||
|
||||
if q_table is None:
|
||||
print("[Info] No existing Q-Table found. Creating new one.")
|
||||
q_table = QTable(
|
||||
state_space_size=STATE_SIZE,
|
||||
action_space_size=ACTION_SIZE,
|
||||
learning_rate=args.lr,
|
||||
discount_factor=args.gamma
|
||||
)
|
||||
else:
|
||||
print("[Info] Loaded existing Q-Table.")
|
||||
|
||||
if visit_table is None:
|
||||
print("[Info] No existing VisitTable found. Creating new one.")
|
||||
visit_table = VisitTable(STATE_SIZE, ACTION_SIZE)
|
||||
else:
|
||||
print("[Info] Loaded existing VisitTable.")
|
||||
|
||||
# 3. Data Repository
|
||||
exp_repo = ExperienceRepository()
|
||||
|
||||
# Check if data file exists
|
||||
data_path = exp_repo.data_dir / args.data_file
|
||||
if not data_path.exists():
|
||||
print(f"[Warning] Experience file not found at: {data_path}")
|
||||
print("Please ensure the data file is synchronized from the main server.")
|
||||
return
|
||||
|
||||
# 4. Usecase
|
||||
trainer = TrainOfflineUsecase(q_table, exp_repo, visit_table)
|
||||
|
||||
# 5. Train
|
||||
print(f"Starting training for {args.epochs} epochs with batch size {args.batch_size}...")
|
||||
result = trainer.train(filename=args.data_file, epochs=args.epochs, batch_size=args.batch_size)
|
||||
|
||||
print("\nTraining Result:")
|
||||
for k, v in result.items():
|
||||
print(f" {k}: {v}")
|
||||
|
||||
# 6. Save
|
||||
print("Saving models...")
|
||||
model_repo.save(q_table, visit_table)
|
||||
print("Done.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue