Initial commit

main
mgjeon 2025-12-29 09:08:37 +09:00
commit 94bbc309fd
33 changed files with 10761 additions and 0 deletions

5
.gitignore vendored Normal file
View File

@ -0,0 +1,5 @@
.venv/
__pycache__/
*.py[cod]
.DS_Store
.env

89
requirements.txt Normal file
View File

@ -0,0 +1,89 @@
aiohappyeyeballs==2.6.1
aiohttp==3.13.2
aiosignal==1.4.0
aiosqlite==0.21.0
annotated-doc==0.0.4
annotated-types==0.7.0
anyio==4.11.0
asyncpg==0.31.0
attrs==25.4.0
certifi==2025.11.12
charset-normalizer==3.4.4
click==8.3.1
cloudpickle==3.1.2
contourpy==1.3.3
cycler==0.12.1
distro==1.9.0
Farama-Notifications==0.0.4
fastapi==0.122.0
fonttools==4.60.1
frozenlist==1.8.0
greenlet==3.2.4
gunicorn==23.0.0
gymnasium==1.2.2
h11==0.16.0
h5py==3.15.1
httpcore==1.0.9
httpx==0.28.1
idna==3.11
iniconfig==2.3.0
JayDeBeApi==1.2.3
Jinja2==3.1.6
jiter==0.12.0
jpype1==1.6.0
jsonpatch==1.33
jsonpointer==3.0.0
kiwisolver==1.4.9
langchain-core==1.1.0
langchain-openai==1.1.0
langgraph==1.0.4
langgraph-checkpoint==3.0.1
langgraph-checkpoint-sqlite==3.0.0
langgraph-prebuilt==1.0.5
langgraph-sdk==0.2.10
langsmith==0.4.49
MarkupSafe==3.0.3
matplotlib==3.10.7
multidict==6.7.0
numpy==2.3.5
openai==2.8.1
orjson==3.11.4
ormsgpack==1.12.0
packaging==25.0
passlib==1.7.4
pillow==12.0.0
pluggy==1.6.0
propcache==0.4.1
psycopg2-binary==2.9.11
pydantic==2.12.5
pydantic-settings==2.12.0
pydantic_core==2.41.5
Pygments==2.19.2
PyJWT==2.10.1
pyparsing==3.2.5
pypdf==6.1.3
pytest==9.0.1
pytest-asyncio==1.3.0
python-dateutil==2.9.0.post0
python-dotenv==1.2.1
python-multipart==0.0.20
pytz==2025.2
PyYAML==6.0.3
regex==2025.11.3
requests==2.32.5
requests-toolbelt==1.0.0
six==1.17.0
sniffio==1.3.1
SQLAlchemy==2.0.44
sqlite-vec==0.1.6
starlette==0.50.0
tenacity==9.1.2
tiktoken==0.12.0
tqdm==4.67.1
typing-inspection==0.4.2
typing_extensions==4.15.0
urllib3==2.5.0
uvicorn==0.38.0
xxhash==3.6.0
yarl==1.22.0
zstandard==0.25.0

5
setup_env.sh Executable file
View File

@ -0,0 +1,5 @@
#!/bin/bash
python3 -m venv .venv
source .venv/bin/activate
pip install -r requirements.txt
echo "Setup complete. Activate with 'source .venv/bin/activate'"

View File

@ -0,0 +1,63 @@
# 변경 사항 (Changelog)
## Version 3.0 - 비즈니스 용어 적용 (최종)
### 주요 변경 사항
#### 1. State 변수명 최종 확정 (03_state_design.tex)
**Version 2.0 (이전):**
- 견적 금액 구간, 제조사 구분, 파트너사 구조, 가격 수용률 분위, 현재 가격 구간
**Version 3.0 (최종):**
- 매출액 가격구간, 유통 구조, 파트너사 종류, 가격 수용률 구간, 입력 금액 구간
**변경 이유:**
- 실제 비즈니스에서 사용하는 용어로 통일하여 가독성 및 이해도 향상
**최종 State 변수명:**
| 순번 | 변수 | 구분 | 설명 |
| :--- | :------------------- | :--- | :--------------------------------------------------------- |
| 1 | **매출액 가격구간** | 3개 | Low (≤1,000만원), Mid (1,000~3,000만원), High (>3,000만원) |
| 2 | **유통 구조** | 3개 | 제조, 총판, 유통 |
| 3 | **파트너사 종류** | 3개 | Single (단독), Multiple (다수), None (없음) |
| 4 | **가격 수용률 구간** | 3개 | Low (<30%), Mid (30~90%), High (>90%) |
| 5 | **입력 금액 구간** | 2개 | PZ1 (A ≤ P ≤ T), PZ2 (P > T) |
#### 2. 동적 가중치 (W) 설계 업데이트 (04_reward_function.tex)
- 변수명 변경에 따라 수식 및 테이블의 변수명도 모두 업데이트
- `S_manu``S_dist`
- 계산 로직 및 가중치 값은 동일하게 유지
```
W_raw = w1×S_amount + w2×S_dist + w3×S_partner + w4×S_accept + w5×S_pricezone
```
**새로운 유통 구조 가중치 설계 의도:**
| 유통 구조 | 정규화 값 | 설계 의도 |
| :-------- | :-------- | :----------------------------------------- |
| 제조 | 0.2 | 제조사 직공급으로 가격 협상 여지 가장 적음 |
| 총판 | 0.5 | 중간 유통 단계로 협상 여지 중간 |
| 유통 | 1.0 | 복잡한 유통 단계로 가격 협상 여지 가장 큼 |
### 파일 변경 내역
- `sections/03_state_design.tex`: 변수명 변경
- `sections/04_reward_function.tex`: 변수명 변경
- `EXAMPLE_CALCULATION.md`: 변수명 변경
- `CHANGELOG.md`: 이 파일 (Version 4.0 반영)
### 최종 검증
- [x] State 변수명 5개 모두 비즈니스 용어로 변경 완료
- [x] 총 State 수 162개 유지
- [x] 동적 가중치 W 계산식 변수명 업데이트 완료
- [x] 계산 예시 변수명 업데이트 완료
이 버전은 기능적 변경 없이, 문서의 가독성과 현업 적용성을 높이는 데 중점을 둔 최종 버전입니다.

View File

@ -0,0 +1,173 @@
# Q-Table Version 3.0 Migration Guide
## 개요
Q-Table 프로젝트가 Version 2.0 (36 states)에서 Version 3.0 (162 states)으로 업그레이드되었습니다.
## 주요 변경사항
### 1. State 구조 변경
#### Before (Version 2.0 - 36 states)
```python
from negotiation_agent.Q_Table.domain.model.state import (
State, Scenario, PriceZone, AcceptanceRate
)
state = State(
scenario=Scenario.PRICE_FIRST, # 4개 값
price_zone=PriceZone.AT_OR_BELOW_ANCHOR, # 3개 값
acceptance_rate=AcceptanceRate.MEDIUM # 3개 값
)
# 총 4 × 3 × 3 = 36 states
```
#### After (Version 3.0 - 162 states)
```python
from negotiation_agent.Q_Table.domain.model.state import (
State,
RevenueRange,
DistributionStructure,
PartnerType,
AcceptanceRate,
InputPriceZone,
)
state = State(
revenue_range=RevenueRange.MID, # 3개 값
distribution=DistributionStructure.WHOLESALER, # 3개 값
partner_type=PartnerType.SINGLE, # 3개 값
acceptance_rate=AcceptanceRate.MID, # 3개 값
input_price_zone=InputPriceZone.PZ1, # 2개 값
)
# 총 3 × 3 × 3 × 3 × 2 = 162 states
```
### 2. State Builder 변경
#### Before (Version 2.0)
```python
from negotiation_agent.Q_Table.domain.service.state_calculator import (
NegotiationSnapshot, build_state
)
snapshot = NegotiationSnapshot(
scenario_code="A",
anchor_price=10000,
target_price=12000,
seller_initial_price=15000,
current_price=11000,
)
state = build_state(snapshot)
```
#### After (Version 3.0)
```python
from negotiation_agent.Q_Table.domain.service.state_calculator import (
NegotiationSnapshot, build_state
)
snapshot = NegotiationSnapshot(
revenue_amount=2000, # 매출액 (만원)
distribution_code="W", # "M": 제조, "W": 총판, "R": 유통
partner_count=1, # 파트너사 수
anchor_price=10000,
target_price=12000,
input_price=11000,
acceptance_ratio=0.5, # 0~1 사이 값
)
state = build_state(snapshot)
```
### 3. Reward 계산 변경
#### Before (Version 2.0)
```python
from negotiation_agent.Q_Table.domain.service.reward_calculator import (
calculate_reward, NegotiationOutcome
)
breakdown = calculate_reward(
scenario=state.scenario,
price_zone=state.price_zone,
current_price=11000,
anchor_price=10000,
target_price=12000,
round_number=3,
outcome=NegotiationOutcome.ONGOING,
)
```
#### After (Version 3.0)
```python
from negotiation_agent.Q_Table.domain.service.reward_calculator import (
calculate_reward, NegotiationOutcome
)
breakdown = calculate_reward(
revenue_range=state.revenue_range,
distribution=state.distribution,
partner_type=state.partner_type,
acceptance_rate=state.acceptance_rate,
input_price_zone=state.input_price_zone,
current_price=11000,
anchor_price=10000,
target_price=12000,
round_number=3,
outcome=NegotiationOutcome.ONGOING,
)
```
### 4. Q-Table 크기 변경
#### Before
```python
q_table = QTable(state_space_size=36, action_space_size=21)
visit_table = VisitTable(state_space_size=36, action_space_size=21)
```
#### After
```python
q_table = QTable(state_space_size=162, action_space_size=21)
visit_table = VisitTable(state_space_size=162, action_space_size=21)
```
## 마이그레이션 체크리스트
- [ ] State 생성 코드를 새로운 5-변수 구조로 변경
- [ ] NegotiationSnapshot 생성 코드 업데이트
- [ ] calculate_reward() 호출부 매개변수 변경
- [ ] Q-Table, VisitTable 초기화 시 state_space_size를 162로 변경
- [ ] 기존 학습된 모델 파일(.npy) 재학습 필요 (36→162 차원 불일치)
- [ ] 단위 테스트 업데이트
## 호환성 주의사항
⚠️ **기존 학습 모델 사용 불가**
- Version 2.0에서 학습된 Q-Table (36 × N)은 Version 3.0 (162 × N)과 호환되지 않습니다.
- 기존 모델을 사용하려면 재학습이 필요합니다.
⚠️ **Experience 데이터 재수집 권장**
- 기존 Experience에는 새로운 State 변수 정보가 없습니다.
- 새로운 State 정의에 맞춰 Experience를 재수집하는 것을 권장합니다.
## 변경 이유
1. **비즈니스 용어 적용**: 실제 협상 도메인에서 사용하는 용어로 통일
2. **더 세밀한 상태 표현**: 매출액, 유통 구조, 파트너사 정보 등을 명시적으로 포함
3. **확장성 향상**: 새로운 비즈니스 변수 추가 시 유연한 대응 가능
4. **협상 전략 다양화**: 162개의 상태로 더 정교한 협상 전략 학습 가능
## 문의
변경사항에 대한 문의는 팀 리드에게 연락해주세요.

View File

@ -0,0 +1,364 @@
# Q_Table 리팩토링 완료 요약
## 📅 작업 일자
- 2025-10-29: Version 2.0 (36 states)
- 2025-11-10: **Version 3.0 (162 states)** - 비즈니스 용어 적용 및 State 재설계
## ✅ Version 3.0 주요 변경사항
### State 공간 재설계 (36 → 162 states)
**기존 Version 2.0:**
- State = (Scenario, PriceZone, AcceptanceRate)
- 총 상태 수: 4 × 3 × 3 = **36 states**
**신규 Version 3.0:**
- State = (매출액 가격구간, 유통 구조, 파트너사 종류, 가격 수용률 구간, 입력 금액 구간)
- 총 상태 수: 3 × 3 × 3 × 3 × 2 = **162 states**
| 순번 | 변수 | 구분 | 설명 |
| :--- | :------------------- | :--- | :--------------------------------------------------------- |
| 1 | **매출액 가격구간** | 3개 | Low (≤1,000만원), Mid (1,000~3,000만원), High (>3,000만원) |
| 2 | **유통 구조** | 3개 | 제조, 총판, 유통 |
| 3 | **파트너사 종류** | 3개 | Single (단독), Multiple (다수), None (없음) |
| 4 | **가격 수용률 구간** | 3개 | Low (<30%), Mid (30~90%), High (>90%) |
| 5 | **입력 금액 구간** | 2개 | PZ1 (A ≤ P ≤ T), PZ2 (P > T) |
### 동적 가중치 (W) 설계 업데이트
**기존 Version 2.0:**
```
W = (scenario.weight + price_zone.weight) / 2.0
```
**신규 Version 3.0:**
```
W_raw = w1×S_amount + w2×S_dist + w3×S_partner + w4×S_accept + w5×S_pricezone
W = clip(W_raw, 0.2, 0.8)
```
기본 가중치 계수:
- w1 = 0.20 (매출액 가격구간)
- w2 = 0.25 (유통 구조)
- w3 = 0.20 (파트너사 종류)
- w4 = 0.25 (가격 수용률)
- w5 = 0.10 (입력 금액 구간)
## ✅ 완료된 작업 (Version 2.0)
### 1. 새로운 핵심 컴포넌트 구현
#### **QTable** (`domain/model/q_table.py`)
- ✅ 동적 `action_space_size` 지원 (21개 카드)
- ✅ Q-Learning 업데이트 메서드
- ✅ 직렬화/역직렬화 지원
- ✅ 인덱스 유효성 검증
#### **EpisodePolicy** (`domain/agents/policy.py`)
- ✅ 하드코딩된 크기(9) → 동적 크기로 변경
- ✅ 생성자에서 `action_space_size` 주입
- ✅ `get_action_mask()` 동적 크기 적용
#### **State** (`domain/model/state.py`)
- ✅ `to_index()` 메서드 추가 (State → 1D 인덱스)
- ✅ `from_index()` 메서드 추가 (1D 인덱스 → State)
#### **ActionCardMapper** (`integration/action_card_mapper.py`) 🆕
- ✅ action_id ↔ card_id 양방향 매핑
- ✅ JSON 파일 기반 설정
- ✅ 유틸리티 메서드 제공
#### **action_card_mapping.json** (`integration/data/`) 🆕
- ✅ 21개 카드 매핑 (0-20 → "no_0" ~ "no_20")
#### **GetBestActionUsecase** (`usecase/get_best_action_usecase.py`)
- ✅ State → action_id → card_id 전체 플로우
- ✅ Policy 사용 여부 선택 가능
- ✅ Top-K 추천 기능
- ✅ 사용 가능한 액션/카드 조회
#### **CollectExperienceUsecase** (`usecase/collect_experience_usecase.py`) 🆕
- ✅ 협상 중 Experience 수집
- ✅ 에피소드 단위 관리
- ✅ JSONL 형식 자동 저장
- ✅ 수집 정보 조회
#### **TrainOfflineUsecase** (`usecase/train_offline_usecase.py`) 🆕
- ✅ 저장된 Experience로 Q-Table 학습
- ✅ 배치 학습 및 에포크 반복
- ✅ 특정 에피소드 선택 학습
- ✅ 성능 평가 기능
### 2. 레거시 코드 정리
#### 제거된 파일
```
❌ domain/action_space.py
❌ domain/model/action.py
❌ domain/constants.py
❌ domain/spaces.py
```
#### Legacy로 이동된 파일 (참고용 보관)
```
📦 legacy/environment.py
📦 legacy/calculate_reward_usecase.py
📦 legacy/evaluate_agent_usecase.py
📦 legacy/execute_step_usecase.py
📦 legacy/get_q_value_usecase.py
📦 legacy/get_state_info_usecase.py
📦 legacy/initialize_env_usecase.py
📦 legacy/load_q_table_usecase.py
📦 legacy/train_agent_usecase.py
📦 legacy/update_q_table_usecase.py
```
## 📁 최종 디렉토리 구조
```
negotiation_agent/
├── card_management/ # 카드 관리 (DB 기반)
│ ├── domain/
│ │ ├── model/
│ │ │ └── nego_card.py ✅ DB 엔티티
│ │ ├── repository/
│ │ │ ├── nego_card_repository.py
│ │ │ └── nego_card_script_repository.py ✅ JSON CRUD
│ │ └── value/
│ │ └── nego_card_types.py ✅ 6가지 atomic elements
│ └── data/
│ └── nego_card_scripts.json ✅ 21개 카드 스크립트
├── Q_Table/ # Q-Learning (추상 action_id)
│ ├── domain/
│ │ ├── model/
│ │ │ ├── state.py ✅ to_index() 추가
│ │ │ ├── q_table.py ✅ 새로 생성
│ │ │ └── experience.py ✅ 새로 생성
│ │ ├── agents/
│ │ │ ├── policy.py ✅ 동적 크기 수정
│ │ │ └── offline_agent.py
│ │ ├── repository/
│ │ │ ├── action_repository.py
│ │ │ └── experience_repository.py ✅ 새로 생성
│ │ └── service/
│ │ └── state_calculator.py
│ ├── usecase/
│ │ ├── get_best_action_usecase.py ✅ 완전 재작성
│ │ ├── collect_experience_usecase.py ✅ 새로 생성
│ │ └── train_offline_usecase.py ✅ 새로 생성
│ ├── infra/
│ │ ├── data_collector.py
│ │ └── gym/
│ │ └── env_wrapper.py
│ ├── data/ 🆕 데이터 저장소
│ │ └── experiences/
│ │ └── *.jsonl
│ └── legacy/ 📦 레거시 코드 보관
│ ├── README.md
│ └── ... (10개 파일)
└── integration/ 🆕 매핑 레이어
├── action_card_mapper.py ✅
└── data/
└── action_card_mapping.json ✅
```
## 🎯 핵심 설계 원칙
### 1. 명확한 책임 분리
- **Q_Table**: action_id (0~20)만 다룸
- **Card_Management**: card_id ("no_0"~"no_20")만 다룸
- **ActionCardMapper**: 둘을 연결하는 단일 책임
### 2. 동적 확장성
- 카드 추가/삭제 시 `action_card_mapping.json`만 수정
- Q-Table과 Policy가 자동으로 새 크기에 대응
### 3. 유지보수성
- 레거시 코드는 legacy/ 폴더에 보관
- 새로운 코드와 명확히 분리
## 🔄 사용 예시
### 1. 추론 (Inference)
```python
from negotiation_agent.Q_Table.domain.model.q_table import QTable
from negotiation_agent.Q_Table.domain.model.visit_table import VisitTable
from negotiation_agent.Q_Table.domain.agents.policy import UCBPolicy
from negotiation_agent.integration.action_card_mapper import ActionCardMapper
from negotiation_agent.Q_Table.usecase.get_best_action_usecase import GetBestActionUsecase
# 매퍼 로드
mapper = ActionCardMapper()
action_space_size = mapper.get_action_space_size() # 21
# Q-Table 생성 (Version 3.0: 162 states)
q_table = QTable(
state_space_size=162, # 3 x 3 x 3 x 3 x 2
action_space_size=action_space_size # 21
)
# Policy 생성
visit_table = VisitTable(state_space_size=162, action_space_size=action_space_size)
policy = UCBPolicy(visit_table=visit_table)
# Usecase 생성
usecase = GetBestActionUsecase(
q_table=q_table,
policy=policy,
action_card_mapper=mapper
)
# 협상 추론 (Version 3.0)
from negotiation_agent.Q_Table.domain.model.state import (
State,
RevenueRange,
DistributionStructure,
PartnerType,
AcceptanceRate,
InputPriceZone,
)
state = State(
revenue_range=RevenueRange.MID, # 1,000~3,000만원
distribution=DistributionStructure.WHOLESALER, # 총판
partner_type=PartnerType.SINGLE, # 단독 파트너
acceptance_rate=AcceptanceRate.MID, # 30~90%
input_price_zone=InputPriceZone.PZ1, # A≤P≤T
)
result = usecase.execute(state)
print(result)
# Output:
# {
# 'action_id': 5,
# 'card_id': 'no_5',
# 'q_value': 0.0,
# 'state_index': 1
# }
```
### 2. Experience 수집
```python
from negotiation_agent.Q_Table.domain.repository.experience_repository import ExperienceRepository
from negotiation_agent.Q_Table.usecase.collect_experience_usecase import CollectExperienceUsecase
# Repository 생성
exp_repo = ExperienceRepository()
# Usecase 생성
collect_usecase = CollectExperienceUsecase(exp_repo)
# 에피소드 시작
collect_usecase.start_episode("ep_001")
# 협상 진행 중...
for step in range(10):
# 현재 상태
current_state = State(...)
# 액션 선택 (GetBestActionUsecase 사용)
result = usecase.execute(current_state)
action_id = result['action_id']
# 액션 실행 후 보상과 다음 상태 받음
reward = calculate_reward(...) # 보상 계산
next_state = get_next_state(...) # 다음 상태
done = check_done(...) # 종료 여부
# Experience 수집
collect_usecase.collect(
state=current_state,
action_id=action_id,
reward=reward,
next_state=next_state,
done=done
)
if done:
break
# 에피소드 종료
collect_usecase.end_episode()
# 수집 정보 확인
info = collect_usecase.get_collection_info()
print(info)
# {
# 'total_experiences': 10,
# 'episodes': 1,
# 'current_episode': None
# }
```
### 3. 오프라인 학습
```python
from negotiation_agent.Q_Table.usecase.train_offline_usecase import TrainOfflineUsecase
# Usecase 생성
train_usecase = TrainOfflineUsecase(
q_table=q_table,
experience_repository=exp_repo
)
# 학습 실행
result = train_usecase.train(
filename="experiences.jsonl",
epochs=10,
batch_size=32
)
print(result)
# {
# 'total_experiences': 1000,
# 'epochs': 10,
# 'updates': 10000,
# 'avg_loss': 0.05
# }
# 성능 평가
eval_result = train_usecase.evaluate(filename="experiences.jsonl")
print(eval_result)
# {
# 'avg_q_value': 0.5,
# 'avg_reward': 1.0,
# 'total_samples': 1000
# }
```
## 📋 향후 작업 (선택사항)
- [x] Experience 수집 Usecase 구현 ✅
- [x] 오프라인 학습 Usecase 구현 ✅
- [ ] Q-Table Repository 구현 (영속성)
- [ ] 카드 추가 시 Q-Table 자동 확장 기능
- [ ] 통계적 Q-value 초기화 기능
## 📚 참고 문서
- 설계 문서: `REFACTORING_GUIDE.md`
- 레거시 코드: `legacy/README.md`

View File

@ -0,0 +1,279 @@
# Q-Table Version 3.0 변경사항 정리
## 📋 변경 개요
CHANGELOG.md에 기술된 Version 3.0 업데이트를 코드에 반영했습니다.
- **State 공간**: 36 states → **162 states**
- **변수 구조**: 3개 변수 → **5개 변수** (비즈니스 용어 적용)
- **동적 가중치**: 2변수 평균 → **5변수 가중합**
---
## 🔄 주요 변경 파일
### 1. State 모델 (`domain/model/state.py`)
#### 기존 (Version 2.0)
```python
class Scenario(IntEnum): ... # 4개 값
class PriceZone(IntEnum): ... # 3개 값
class AcceptanceRate(IntEnum): ... # 3개 값
State = (Scenario, PriceZone, AcceptanceRate) # 4×3×3 = 36
```
#### 변경 (Version 3.0)
```python
class RevenueRange(IntEnum): ... # 3개 값 (매출액 가격구간)
class DistributionStructure(IntEnum): ... # 3개 값 (유통 구조)
class PartnerType(IntEnum): ... # 3개 값 (파트너사 종류)
class AcceptanceRate(IntEnum): ... # 3개 값 (가격 수용률 구간)
class InputPriceZone(IntEnum): ... # 2개 값 (입력 금액 구간)
State = (RevenueRange, DistributionStructure, PartnerType,
AcceptanceRate, InputPriceZone) # 3×3×3×3×2 = 162
```
**주요 특징:**
- 모든 클래스에 `.weight` 속성 추가 (동적 가중치 계산용)
- 비즈니스 용어로 명명 (매출액, 유통 구조, 파트너사 등)
- 각 변수마다 `from_*()` 클래스 메서드로 분류 로직 제공
---
### 2. Reward Calculator (`domain/service/reward_calculator.py`)
#### 기존 (Version 2.0)
```python
def calculate_reward(
scenario: Scenario,
price_zone: PriceZone,
...
)
def _calculate_weight(scenario, price_zone, cfg):
raw = (scenario.priority_weight + price_zone.zone_weight) / 2.0
return clip(raw, 0.2, 0.8)
```
#### 변경 (Version 3.0)
```python
def calculate_reward(
revenue_range: RevenueRange,
distribution: DistributionStructure,
partner_type: PartnerType,
acceptance_rate: AcceptanceRate,
input_price_zone: InputPriceZone,
...
)
def _calculate_weight(...):
W_raw = (
config.w1 * revenue_range.weight +
config.w2 * distribution.weight +
config.w3 * partner_type.weight +
config.w4 * acceptance_rate.weight +
config.w5 * input_price_zone.weight
)
return clip(W_raw, 0.2, 0.8)
```
**기본 가중치 계수:**
- w1 = 0.20 (매출액)
- w2 = 0.25 (유통 구조)
- w3 = 0.20 (파트너사)
- w4 = 0.25 (수용률)
- w5 = 0.10 (입력 금액)
---
### 3. State Calculator (`domain/service/state_calculator.py`)
#### 기존 (Version 2.0)
```python
@dataclass
class NegotiationSnapshot:
scenario_code: str
anchor_price: float
target_price: float
seller_initial_price: float
current_price: float
```
#### 변경 (Version 3.0)
```python
@dataclass
class NegotiationSnapshot:
revenue_amount: float # 매출액 (만원)
distribution_code: str # "M", "W", "R"
partner_count: int # 파트너사 수
anchor_price: float
target_price: float
input_price: float
acceptance_ratio: Optional[float]
initial_price: Optional[float]
```
**변경 이유:**
- 실제 비즈니스 데이터에 맞춰 필드 재구성
- 5개 State 변수를 도출할 수 있는 정보 제공
---
### 4. 모듈 Export (`domain/model/__init__.py`)
```python
# Version 3.0
from .state import (
AcceptanceRate,
DistributionStructure,
InputPriceZone,
PartnerType,
RevenueRange,
State,
)
```
---
## 📊 State 변수 상세 명세
| 변수 | 클래스 | 값 | 설명 | 가중치 |
|------|--------|---|------|--------|
| 매출액 가격구간 | `RevenueRange` | LOW (≤1,000만원) | 낮은 매출액 | 0.3 |
| | | MID (1,000~3,000만원) | 중간 매출액 | 0.6 |
| | | HIGH (>3,000만원) | 높은 매출액 | 1.0 |
| 유통 구조 | `DistributionStructure` | MANUFACTURER | 제조사 직공급 | 0.2 |
| | | WHOLESALER | 총판 경유 | 0.5 |
| | | RETAILER | 유통 경유 | 1.0 |
| 파트너사 종류 | `PartnerType` | NONE | 파트너 없음 | 0.3 |
| | | SINGLE | 단독 파트너 | 0.5 |
| | | MULTIPLE | 다수 파트너 | 1.0 |
| 가격 수용률 | `AcceptanceRate` | LOW (<30%) | 낮은 수용률 | 0.3 |
| | | MID (30~90%) | 중간 수용률 | 0.6 |
| | | HIGH (>90%) | 높은 수용률 | 1.0 |
| 입력 금액 구간 | `InputPriceZone` | PZ1 (A≤P≤T) | 목표 범위 내 | 1.0 |
| | | PZ2 (P>T or P<A) | | 0.3 |
---
## 🧪 새로운 테스트 파일
### `tests/Q_Table/test_state_v3.py`
- State 변수 분류 로직 검증
- State ↔ index 변환 정합성 테스트
- 162개 state 전체 순회 검증
- 가중치 값 검증
### `tests/Q_Table/test_reward_v3.py`
- NegotiationSnapshot → State 변환 테스트
- 동적 가중치 계산 검증
- 성공/실패/진행중 보상 계산 검증
- 최소/최대 가중치 경계 케이스 테스트
---
## 📚 문서 업데이트
### `REFACTORING_SUMMARY.md`
- Version 3.0 변경사항 추가
- 사용 예시 업데이트 (36 → 162 states)
- State 구조 비교표 추가
### `MIGRATION_V3.md` (신규)
- Version 2.0 → 3.0 마이그레이션 가이드
- Before/After 코드 비교
- 호환성 주의사항
- 체크리스트 제공
### `CHANGELOG.md` (기존 파일 참조)
- Version 3.0 변경사항 문서화
- 비즈니스 용어 적용 배경 설명
---
## ⚠️ Breaking Changes
### 1. Q-Table 크기 변경
```python
# Before
QTable(state_space_size=36, action_space_size=21)
# After
QTable(state_space_size=162, action_space_size=21)
```
### 2. State 생성 방식 변경
```python
# Before
State(
scenario=Scenario.PRICE_FIRST,
price_zone=PriceZone.AT_OR_BELOW_ANCHOR,
acceptance_rate=AcceptanceRate.MEDIUM,
)
# After
State(
revenue_range=RevenueRange.MID,
distribution=DistributionStructure.WHOLESALER,
partner_type=PartnerType.SINGLE,
acceptance_rate=AcceptanceRate.MID,
input_price_zone=InputPriceZone.PZ1,
)
```
### 3. 기존 학습 모델 호환 불가
- Version 2.0 모델 파일(.npy)은 36×N 크기
- Version 3.0은 162×N 크기로 로드 불가
- **재학습 필수**
---
## ✅ 검증 항목
- [x] State 클래스 5개 변수로 재정의
- [x] 각 변수의 `.weight` 속성 구현
- [x] State.to_index() / from_index() 162 states 대응
- [x] RewardConfig에 w1~w5 가중치 계수 추가
- [x] _calculate_weight() 5변수 가중합으로 변경
- [x] NegotiationSnapshot 필드 재구성
- [x] build_state() 5변수 매핑 로직 구현
- [x] __init__.py export 업데이트
- [x] REFACTORING_SUMMARY.md 업데이트
- [x] MIGRATION_V3.md 작성
- [x] 단위 테스트 작성 (test_state_v3.py, test_reward_v3.py)
---
## 🚀 다음 단계
1. **Python 환경 설정**
```bash
poetry install
# or
pip install -r requirements.txt
```
2. **테스트 실행**
```bash
pytest tests/Q_Table/test_state_v3.py -v
pytest tests/Q_Table/test_reward_v3.py -v
```
3. **기존 코드 마이그레이션**
- `MIGRATION_V3.md` 가이드 참조
- State 생성 코드 일괄 변경
- Q-Table 초기화 크기 변경
4. **모델 재학습**
- 새로운 162-state 공간으로 학습 데이터 재수집
- Q-Table 재학습 실행
---
## 📞 문의
Version 3.0 적용 중 문제 발생 시:
1. `MIGRATION_V3.md` 참조
2. 테스트 코드 참조 (`test_state_v3.py`, `test_reward_v3.py`)
3. 팀 리드에게 문의

View File

@ -0,0 +1,3 @@
"""Q-Table package exports."""
from . import domain, usecase # noqa: F401

File diff suppressed because it is too large Load Diff

Binary file not shown.

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,120 @@
"""Offline Q-learning agent backed by a UCB policy."""
from __future__ import annotations
import os
from typing import Optional
import numpy as np
from ..model.visit_table import VisitTable
from ..model.q_table import QTable
from .policy import UCBPolicy
class QLearningAgent:
def __init__(
self,
agent_params,
state_size: int,
action_size: int,
visit_table: Optional[VisitTable] = None,
) -> None:
"""
Args:
agent_params: 에이전트 파라미터 (learning_rate, discount_factor )
state_size: 상태 공간 크기
action_size: 액션 공간 크기
visit_table: 방문 기록 테이블 (None이면 새로 생성)
"""
self.state_size = state_size
self.action_size = action_size
# Q-Table 객체 생성 (Composition)
self.q_table = QTable(
state_space_size=state_size,
action_space_size=action_size,
learning_rate=agent_params["learning_rate"],
discount_factor=agent_params["discount_factor"]
)
self.visit_table = visit_table or VisitTable(state_size, action_size)
self.policy = UCBPolicy(
visit_table=self.visit_table,
exploration_constant=agent_params.get("exploration_constant", np.sqrt(2.0)),
)
def get_action(self, state: int, action_mask=None):
"""
현재 상태에서 액션 선택 (UCB 정책 사용)
Args:
state: 현재 상태 인덱스
action_mask: 가능한 액션 마스크 (Optional)
Returns:
선택된 액션 ID
"""
# QTable 객체에서 Q-value 조회
q_values = self.q_table.get_q_values_for_state(state)
mask = None if action_mask is None else np.asarray(action_mask, dtype=bool)
return self.policy.select_action(state, q_values, available_mask=mask)
def learn(self, batch):
"""
배치 데이터를 사용하여 Q-Table 업데이트
Args:
batch: 학습 데이터 배치 (observations, actions, rewards, next_observations, terminals)
"""
for state, action, reward, next_state, terminated in zip(
batch["observations"],
batch["actions"],
batch["rewards"],
batch["next_observations"],
batch["terminals"],
):
# QTable 객체의 update 메서드 사용
# terminated가 True이면 next_state는 의미가 없거나 None이어야 함
# QTable.update는 next_state_index가 None이면 종료 상태로 처리
next_s = next_state if not terminated else None
self.q_table.update(
state_index=state,
action_id=action,
reward=reward,
next_state_index=next_s,
done=terminated
)
self.visit_table.increment(state, action)
def save_model(self, path):
"""
Q-Table을 파일로 저장
Args:
path: 저장할 파일 경로
"""
# QTable 내부의 numpy array를 저장 (기존 호환성 유지)
np.save(path, self.q_table.q_values)
print(f"Q-Table saved to {path}")
def load_q_table(self, file_path):
"""
파일에서 Q-Table 로드
Args:
file_path: 로드할 파일 경로
"""
if os.path.exists(file_path):
# QTable 객체의 q_values 속성에 직접 할당
self.q_table.q_values = np.load(file_path)
print(f"Q-Table loaded from {file_path}")
else:
print(f"Error: No Q-Table found at {file_path}")
def reset_episode(self):
"""에피소드 초기화 (정책 상태 등 리셋)"""
self.policy.reset_episode()

View File

@ -0,0 +1,123 @@
from abc import ABC, abstractmethod
from typing import Optional
import numpy as np
from ..model.visit_table import VisitTable
class Policy(ABC):
@abstractmethod
def select_action(
self,
state_index: int,
q_values: np.ndarray,
available_mask: Optional[np.ndarray] = None,
) -> Optional[int]:
"""
주어진 상태와 Q-value를 기반으로 액션 선택
Args:
state_index: 현재 상태 인덱스
q_values: 해당 상태의 Q-value 배열
available_mask: 선택 가능한 액션 마스크 (True=선택가능)
Returns:
선택된 액션 ID (선택 불가 None)
"""
raise NotImplementedError
def reset_episode(self) -> None:
"""에피소드 시작 시 상태 초기화"""
raise NotImplementedError
def get_action_mask(self) -> np.ndarray:
"""
현재 정책 상태에 따른 액션 마스크 반환
Returns:
액션 마스크 배열 (1=선택가능, 0=선택불가)
"""
raise NotImplementedError
class UCBPolicy(Policy):
"""Upper Confidence Bound policy with per-episode action masking."""
def __init__(
self,
visit_table: VisitTable,
exploration_constant: float = np.sqrt(2.0),
rng: Optional[np.random.Generator] = None,
) -> None:
self.visit_table = visit_table
self.action_space_size = visit_table.action_space_size
self.exploration_constant = exploration_constant
self._rng = rng or np.random.default_rng()
self._episode_actions: set[int] = set()
def select_action(
self,
state_index: int,
q_values: np.ndarray,
available_mask: Optional[np.ndarray] = None,
) -> Optional[int]:
"""
UCB 알고리즘을 사용하여 액션 선택
UCB = Q(s,a) + c * sqrt(ln(N(s)) / N(s,a))
"""
mask = self._prepare_mask(available_mask)
if not mask.any():
self.reset_episode()
mask = self._prepare_mask(available_mask)
if not mask.any():
return None
counts = self.visit_table.get_state_counts(state_index)
masked_counts = np.where(mask, counts, 0)
zero_visit_candidates = np.where((counts == 0) & mask)[0]
if zero_visit_candidates.size > 0:
best = zero_visit_candidates[np.argmax(q_values[zero_visit_candidates])]
action = int(best)
else:
total = masked_counts.sum()
# Avoid division by zero while keeping exploration pressure.
denom = counts.astype(float) + 1e-9
bonus = self.exploration_constant * np.sqrt(np.log(total + 1.0) / denom)
scores = q_values + bonus
scores[~mask] = -np.inf
action = int(np.argmax(scores))
self.visit_table.increment(state_index, action)
self._episode_actions.add(action)
return action
def reset_episode(self) -> None:
"""에피소드 내 사용된 액션 기록 초기화"""
self._episode_actions.clear()
def get_action_mask(self) -> np.ndarray:
"""
이미 사용된 액션을 제외한 마스크 반환
Returns:
마스크 배열 (1=선택가능, 0=선택불가)
"""
mask = np.ones(self.action_space_size, dtype=int)
if self._episode_actions:
for action in self._episode_actions:
mask[action] = 0
return mask
def _prepare_mask(self, available_mask: Optional[np.ndarray]) -> np.ndarray:
"""입력 마스크와 이미 사용된 액션을 결합하여 최종 마스크 생성"""
if available_mask is None:
mask = np.ones(self.action_space_size, dtype=bool)
else:
mask = np.asarray(available_mask, dtype=bool).copy()
if self._episode_actions:
for action in self._episode_actions:
mask[action] = False
return mask

View File

@ -0,0 +1,11 @@
"""Domain model exports for Q-Table (Version 3.0 - 162 states)."""
from .state import ( # noqa: F401
AcceptanceRate,
DistributionStructure,
InputPriceZone,
PartnerType,
RevenueRange,
State,
)
from .visit_table import VisitTable # noqa: F401

View File

@ -0,0 +1,94 @@
"""
Experience 모델
Q-Learning을 위한 경험 데이터 (SARS: State, Action, Reward, next_State)
"""
from dataclasses import dataclass
from typing import Optional
from .state import State
@dataclass
class Experience:
"""
Q-Learning에서 사용하는 경험 데이터
SARS 형식:
- State: 현재 상태
- Action: 선택한 액션 (action_id)
- Reward: 받은 보상
- next_State: 다음 상태 (종료 None)
"""
state: State
action_id: int
reward: float
next_state: Optional[State]
done: bool # 에피소드 종료 여부
# 메타데이터 (선택사항)
episode_id: Optional[str] = None
step: Optional[int] = None
timestamp: Optional[str] = None
def to_dict(self) -> dict:
"""
Experience를 딕셔너리로 직렬화
Returns:
{
'state': [scenario, price_zone, acceptance_rate],
'action_id': 5,
'reward': 1.0,
'next_state': [scenario, price_zone, acceptance_rate] or None,
'done': False,
'episode_id': 'ep_001',
'step': 3,
'timestamp': '2025-10-29T16:30:00'
}
"""
return {
"state": self.state.to_array(),
"action_id": self.action_id,
"reward": self.reward,
"next_state": self.next_state.to_array() if self.next_state else None,
"done": self.done,
"episode_id": self.episode_id,
"step": self.step,
"timestamp": self.timestamp,
}
@classmethod
def from_dict(cls, data: dict) -> "Experience":
"""
딕셔너리에서 Experience 복원
Args:
data: 직렬화된 Experience 데이터
Returns:
Experience 인스턴스
"""
return cls(
state=State.from_array(data["state"]),
action_id=data["action_id"],
reward=data["reward"],
next_state=(
State.from_array(data["next_state"]) if data["next_state"] else None
),
done=data["done"],
episode_id=data.get("episode_id"),
step=data.get("step"),
timestamp=data.get("timestamp"),
)
def __repr__(self) -> str:
return (
f"Experience("
f"state={self.state.to_array()}, "
f"action_id={self.action_id}, "
f"reward={self.reward}, "
f"next_state={self.next_state.to_array() if self.next_state else None}, "
f"done={self.done}"
f")"
)

View File

@ -0,0 +1,209 @@
"""
Q-Table 모델
동적 action_space_size를 지원하여 협상 카드 추가/삭제에 대응
"""
import numpy as np
from typing import Optional, Dict
class QTable:
"""
Q-Learning을 위한 Q-Table
Q-Table은 추상적인 action_id만 다루며,
실제 협상 카드(card_id) ActionCardMapper에서 매핑
"""
def __init__(
self,
state_space_size: int,
action_space_size: int,
learning_rate: float = 0.1,
discount_factor: float = 0.95,
):
"""
Args:
state_space_size: 상태 공간 크기
(Scenario=4 x PriceZone=3 x AcceptanceRate=3 = 36)
action_space_size: 액션 공간 크기 (협상 카드 개수, 현재 21)
learning_rate: 학습률 (alpha)
discount_factor: 할인 인자 (gamma)
"""
self.state_space_size = state_space_size
self.action_space_size = action_space_size
self.learning_rate = learning_rate
self.discount_factor = discount_factor
# Q-Table 초기화: (state_size, action_size) 형태
self.q_values = np.zeros((state_space_size, action_space_size))
def get_q_value(self, state_index: int, action_id: int) -> float:
"""
특정 (state, action) Q-value 조회
Args:
state_index: State의 인덱스 (0 ~ state_space_size-1)
action_id: Action ID (0 ~ action_space_size-1)
Returns:
Q-value
"""
self._validate_indices(state_index, action_id)
return float(self.q_values[state_index, action_id])
def get_q_values_for_state(self, state_index: int) -> np.ndarray:
"""
특정 state의 모든 action에 대한 Q-values
Args:
state_index: State의 인덱스
Returns:
shape (action_space_size,) Q-values 배열
"""
if state_index < 0 or state_index >= self.state_space_size:
raise ValueError(
f"state_index {state_index} out of range [0, {self.state_space_size})"
)
return self.q_values[state_index, :].copy()
def set_q_value(self, state_index: int, action_id: int, value: float):
"""
특정 (state, action) Q-value 직접 설정
Args:
state_index: State의 인덱스
action_id: Action ID
value: 설정할 Q-value
"""
self._validate_indices(state_index, action_id)
self.q_values[state_index, action_id] = value
def update(
self,
state_index: int,
action_id: int,
reward: float,
next_state_index: Optional[int] = None,
done: bool = False,
):
"""
Q-Learning 업데이트
Q(s,a) Q(s,a) + α[r + γ·max_a'Q(s',a') - Q(s,a)]
Args:
state_index: 현재 상태 인덱스
action_id: 선택한 액션 ID
reward: 받은 보상
next_state_index: 다음 상태 인덱스 (종료 None)
done: 에피소드 종료 여부
"""
self._validate_indices(state_index, action_id)
current_q = self.q_values[state_index, action_id]
if done or next_state_index is None:
# 종료 상태: target = reward만 사용
target = reward
else:
# 비종료 상태: target = reward + γ·max Q(s',a')
if next_state_index < 0 or next_state_index >= self.state_space_size:
raise ValueError(f"next_state_index {next_state_index} out of range")
max_next_q = np.max(self.q_values[next_state_index, :])
target = reward + self.discount_factor * max_next_q
# Q-Learning 업데이트
self.q_values[state_index, action_id] += self.learning_rate * (
target - current_q
)
def get_best_action(self, state_index: int) -> int:
"""
해당 state에서 최고 Q-value를 가진 action_id 반환
Args:
state_index: State의 인덱스
Returns:
action_id (Q-value가 최대인 액션)
"""
if state_index < 0 or state_index >= self.state_space_size:
raise ValueError(
f"state_index {state_index} out of range [0, {self.state_space_size})"
)
return int(np.argmax(self.q_values[state_index, :]))
def get_best_actions_with_ties(self, state_index: int) -> list:
"""
동일한 최대 Q-value를 가진 모든 action_id 반환
Args:
state_index: State의 인덱스
Returns:
최대 Q-value를 가진 action_id들의 리스트
"""
if state_index < 0 or state_index >= self.state_space_size:
raise ValueError(
f"state_index {state_index} out of range [0, {self.state_space_size})"
)
max_q = np.max(self.q_values[state_index, :])
max_actions = np.where(self.q_values[state_index, :] == max_q)[0]
return max_actions.tolist()
def to_dict(self) -> Dict:
"""
Q-Table을 딕셔너리로 직렬화
Returns:
직렬화된 Q-Table 데이터
"""
return {
"state_space_size": self.state_space_size,
"action_space_size": self.action_space_size,
"learning_rate": self.learning_rate,
"discount_factor": self.discount_factor,
"q_values": self.q_values.tolist(),
}
@classmethod
def from_dict(cls, data: Dict) -> "QTable":
"""
딕셔너리에서 Q-Table 복원
Args:
data: 직렬화된 Q-Table 데이터
Returns:
복원된 QTable 인스턴스
"""
q_table = cls(
state_space_size=data["state_space_size"],
action_space_size=data["action_space_size"],
learning_rate=data["learning_rate"],
discount_factor=data["discount_factor"],
)
q_table.q_values = np.array(data["q_values"])
return q_table
def _validate_indices(self, state_index: int, action_id: int):
"""인덱스 유효성 검증"""
if state_index < 0 or state_index >= self.state_space_size:
raise ValueError(
f"state_index {state_index} out of range [0, {self.state_space_size})"
)
if action_id < 0 or action_id >= self.action_space_size:
raise ValueError(
f"action_id {action_id} out of range [0, {self.action_space_size})"
)
def __repr__(self) -> str:
return (
f"QTable(state_space_size={self.state_space_size}, "
f"action_space_size={self.action_space_size}, "
f"learning_rate={self.learning_rate}, "
f"discount_factor={self.discount_factor})"
)

View File

@ -0,0 +1,258 @@
from dataclasses import dataclass
from enum import IntEnum, nonmember
from typing import List
class RevenueRange(IntEnum):
"""매출액 가격구간 (Revenue Price Range)"""
LOW = 0 # ≤ 1,000만원
MID = 1 # 1,000만원 ~ 3,000만원
HIGH = 2 # > 3,000만원
_DESCRIPTIONS = nonmember(("Low (≤1,000만원)", "Mid (1,000~3,000만원)", "High (>3,000만원)"))
_WEIGHTS = nonmember((0.3, 0.6, 1.0)) # 매출액이 높을수록 협상 여지 큼
@property
def description(self) -> str:
return self._DESCRIPTIONS[self.value]
@property
def weight(self) -> float:
return self._WEIGHTS[self.value]
@classmethod
def from_amount(cls, amount: float) -> "RevenueRange":
"""매출액(만원 단위)으로부터 구간 결정"""
if amount <= 1000:
return cls.LOW
elif amount <= 3000:
return cls.MID
else:
return cls.HIGH
class DistributionStructure(IntEnum):
"""유통 구조 (Distribution Structure)"""
MANUFACTURER = 0 # 제조
WHOLESALER = 1 # 총판
RETAILER = 2 # 유통
_DESCRIPTIONS = nonmember(("제조", "총판", "유통"))
_WEIGHTS = nonmember((0.2, 0.5, 1.0)) # 유통 단계가 복잡할수록 협상 여지 큼
@property
def description(self) -> str:
return self._DESCRIPTIONS[self.value]
@property
def weight(self) -> float:
return self._WEIGHTS[self.value]
@classmethod
def from_code(cls, code: str) -> "DistributionStructure":
"""코드로부터 유통 구조 결정"""
normalized = (code or "").strip().upper()
mapping = {"M": cls.MANUFACTURER, "W": cls.WHOLESALER, "R": cls.RETAILER}
if normalized not in mapping:
raise ValueError(f"unknown distribution code: {code}")
return mapping[normalized]
class PartnerType(IntEnum):
"""파트너사 종류 (Partner Type)"""
SINGLE = 0 # 단독
MULTIPLE = 1 # 다수
NONE = 2 # 없음
_DESCRIPTIONS = nonmember(("Single (단독)", "Multiple (다수)", "None (없음)"))
_WEIGHTS = nonmember((0.5, 1.0, 0.3)) # 다수 파트너사일수록 협상 복잡도 증가
@property
def description(self) -> str:
return self._DESCRIPTIONS[self.value]
@property
def weight(self) -> float:
return self._WEIGHTS[self.value]
@classmethod
def from_count(cls, count: int) -> "PartnerType":
"""파트너사 수로부터 구간 결정"""
if count == 0:
return cls.NONE
elif count == 1:
return cls.SINGLE
else:
return cls.MULTIPLE
class AcceptanceRate(IntEnum):
"""가격 수용률 구간 (Price Acceptance Rate)"""
LOW = 0 # < 30%
MID = 1 # 30% ~ 90%
HIGH = 2 # > 90%
_DESCRIPTIONS = nonmember(("Low (<30%)", "Mid (30~90%)", "High (>90%)"))
_WEIGHTS = nonmember((0.3, 0.6, 1.0)) # 수용률이 높을수록 협상 성공 가능성 높음
@property
def description(self) -> str:
return self._DESCRIPTIONS[self.value]
@property
def weight(self) -> float:
return self._WEIGHTS[self.value]
@classmethod
def from_ratio(cls, ratio: float) -> "AcceptanceRate":
"""수용률(0~1)로부터 구간 결정"""
normalized = max(0.0, min(1.0, ratio))
if normalized < 0.30:
return cls.LOW
elif normalized <= 0.90:
return cls.MID
else:
return cls.HIGH
class InputPriceZone(IntEnum):
"""입력 금액 구간 (Input Price Zone)"""
PZ1 = 0 # A ≤ P ≤ T (앵커가격 ~ 목표가격)
PZ2 = 1 # P > T (목표가격 초과)
_DESCRIPTIONS = nonmember(("PZ1 (A≤P≤T)", "PZ2 (P>T)"))
_WEIGHTS = nonmember((1.0, 0.3)) # 목표가격 이내일수록 협상 유리
@property
def description(self) -> str:
return self._DESCRIPTIONS[self.value]
@property
def weight(self) -> float:
return self._WEIGHTS[self.value]
@classmethod
def from_prices(
cls,
input_price: float,
anchor_price: float,
target_price: float,
) -> "InputPriceZone":
"""입력가격, 앵커가격, 목표가격으로부터 구간 결정"""
if anchor_price <= 0 or target_price <= 0:
raise ValueError("anchor_price and target_price must be positive")
if anchor_price > target_price:
raise ValueError("anchor_price must not exceed target_price")
if anchor_price <= input_price <= target_price:
return cls.PZ1
else:
return cls.PZ2
@dataclass
class State:
"""
Q-Learning State 표현 (Version 3.0 - 162 states)
State = (매출액 가격구간, 유통 구조, 파트너사 종류, 가격 수용률 구간, 입력 금액 구간)
Total: 3 × 3 × 3 × 3 × 2 = 162 states
"""
revenue_range: RevenueRange
distribution: DistributionStructure
partner_type: PartnerType
acceptance_rate: AcceptanceRate
input_price_zone: InputPriceZone
def to_array(self) -> List[int]:
"""State를 배열로 변환"""
return [
self.revenue_range.value,
self.distribution.value,
self.partner_type.value,
self.acceptance_rate.value,
self.input_price_zone.value,
]
def to_index(self) -> int:
"""
State를 1차원 인덱스로 변환 (Q-Table 인덱싱용)
State space: 3 × 3 × 3 × 3 × 2 = 162
Returns:
0 ~ 161 사이의 인덱스
"""
# 5D to 1D index conversion
# index = revenue * (3*3*3*2) + dist * (3*3*2) + partner * (3*2) + accept * 2 + pricezone
return (
self.revenue_range.value * 54
+ self.distribution.value * 18
+ self.partner_type.value * 6
+ self.acceptance_rate.value * 2
+ self.input_price_zone.value
)
@classmethod
def from_index(cls, index: int) -> "State":
"""
1차원 인덱스에서 State 복원
Args:
index: 0 ~ 161 사이의 인덱스
Returns:
State 객체
"""
if index < 0 or index >= 162:
raise ValueError(f"index {index} out of range [0, 162)")
# 1D to 5D index conversion
revenue_value = index // 54
remainder = index % 54
dist_value = remainder // 18
remainder = remainder % 18
partner_value = remainder // 6
remainder = remainder % 6
accept_value = remainder // 2
pricezone_value = remainder % 2
return cls(
revenue_range=RevenueRange(revenue_value),
distribution=DistributionStructure(dist_value),
partner_type=PartnerType(partner_value),
acceptance_rate=AcceptanceRate(accept_value),
input_price_zone=InputPriceZone(pricezone_value),
)
@classmethod
def from_array(cls, arr: List[int]) -> "State":
"""배열로부터 State 생성"""
if len(arr) != 5:
raise ValueError(f"Expected 5 elements, got {len(arr)}")
return cls(
revenue_range=RevenueRange(arr[0]),
distribution=DistributionStructure(arr[1]),
partner_type=PartnerType(arr[2]),
acceptance_rate=AcceptanceRate(arr[3]),
input_price_zone=InputPriceZone(arr[4]),
)
def __str__(self) -> str:
return (
f"State("
f"revenue={self.revenue_range.description}, "
f"distribution={self.distribution.description}, "
f"partner={self.partner_type.description}, "
f"acceptance={self.acceptance_rate.description}, "
f"price_zone={self.input_price_zone.description})"
)

View File

@ -0,0 +1,148 @@
"""Visit (N) table for UCB-based policies."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict
import numpy as np
@dataclass
class VisitTable:
"""Tracks state-action visit counts for UCB selection."""
state_space_size: int
action_space_size: int
def __post_init__(self) -> None:
"""
Args:
state_space_size: 상태 공간 크기
action_space_size: 액션 공간 크기
"""
self._counts = np.zeros(
(self.state_space_size, self.action_space_size), dtype=np.int64
)
def increment(self, state_index: int, action_id: int, count: int = 1) -> None:
"""
특정 (state, action) 방문 횟수 증가
Args:
state_index: State 인덱스
action_id: Action ID
count: 증가시킬 횟수 (기본 1)
"""
if state_index < 0 or state_index >= self.state_space_size:
raise ValueError(
f"state_index {state_index} out of range [0, {self.state_space_size})"
)
if action_id < 0 or action_id >= self.action_space_size:
raise ValueError(
f"action_id {action_id} out of range [0, {self.action_space_size})"
)
if count < 0:
raise ValueError("count must be non-negative")
self._counts[state_index, action_id] += count
def get_action_count(self, state_index: int, action_id: int) -> int:
"""
특정 (state, action) 방문 횟수 조회
Args:
state_index: State 인덱스
action_id: Action ID
Returns:
방문 횟수
"""
self._validate_indices(state_index, action_id)
return int(self._counts[state_index, action_id])
def get_state_counts(self, state_index: int) -> np.ndarray:
"""
특정 state의 모든 action에 대한 방문 횟수 조회
Args:
state_index: State 인덱스
Returns:
shape (action_space_size,) 방문 횟수 배열
"""
if state_index < 0 or state_index >= self.state_space_size:
raise ValueError(
f"state_index {state_index} out of range [0, {self.state_space_size})"
)
return self._counts[state_index, :].copy()
def total_visits(self, state_index: int) -> int:
"""
특정 state의 방문 횟수 (모든 action 합계)
Args:
state_index: State 인덱스
Returns:
방문 횟수
"""
if state_index < 0 or state_index >= self.state_space_size:
raise ValueError(
f"state_index {state_index} out of range [0, {self.state_space_size})"
)
return int(self._counts[state_index, :].sum())
def reset_state(self, state_index: int) -> None:
"""
특정 state의 방문 기록 초기화
Args:
state_index: State 인덱스
"""
if state_index < 0 or state_index >= self.state_space_size:
raise ValueError(
f"state_index {state_index} out of range [0, {self.state_space_size})"
)
self._counts[state_index, :] = 0
def to_dict(self) -> Dict:
"""
VisitTable을 딕셔너리로 직렬화
Returns:
직렬화된 VisitTable 데이터
"""
return {
"state_space_size": self.state_space_size,
"action_space_size": self.action_space_size,
"counts": self._counts.tolist(),
}
@classmethod
def from_dict(cls, data: Dict) -> "VisitTable":
"""
딕셔너리에서 VisitTable 복원
Args:
data: 직렬화된 VisitTable 데이터
Returns:
복원된 VisitTable 인스턴스
"""
table = cls(
state_space_size=data["state_space_size"],
action_space_size=data["action_space_size"],
)
table._counts = np.array(data["counts"], dtype=np.int64)
return table
def _validate_indices(self, state_index: int, action_id: int) -> None:
"""인덱스 유효성 검증"""
if state_index < 0 or state_index >= self.state_space_size:
raise ValueError(
f"state_index {state_index} out of range [0, {self.state_space_size})"
)
if action_id < 0 or action_id >= self.action_space_size:
raise ValueError(
f"action_id {action_id} out of range [0, {self.action_space_size})"
)

View File

@ -0,0 +1,200 @@
"""
Experience Repository
Experience 데이터를 JSONL 형식으로 저장/로드
"""
import json
from pathlib import Path
from typing import List, Optional
from datetime import datetime
from ..model.experience import Experience
class ExperienceRepository:
"""
Experience 데이터를 파일 시스템에 저장/로드하는 Repository
JSONL (JSON Lines) 형식 사용:
- 줄이 하나의 JSON 객체 (Experience)
- 스트리밍 방식으로 읽기/쓰기 가능
- 대용량 데이터 처리에 유리
"""
def __init__(self, data_dir: Optional[str] = None):
"""
Args:
data_dir: Experience 데이터를 저장할 디렉토리
(기본: Q_Table/data/experiences/)
"""
if data_dir is None:
current_dir = Path(__file__).parent.parent.parent
data_dir = current_dir / "data" / "experiences"
self.data_dir = Path(data_dir)
self.data_dir.mkdir(parents=True, exist_ok=True)
def save(self, experience: Experience, filename: str = "experiences.jsonl"):
"""
단일 Experience를 파일에 추가 저장
Args:
experience: 저장할 Experience
filename: 파일명 (기본: experiences.jsonl)
"""
file_path = self.data_dir / filename
with open(file_path, "a", encoding="utf-8") as f:
json.dump(experience.to_dict(), f, ensure_ascii=False)
f.write("\n")
def save_batch(
self, experiences: List[Experience], filename: str = "experiences.jsonl"
):
"""
여러 Experience를 배치로 저장
Args:
experiences: Experience 리스트
filename: 파일명
"""
file_path = self.data_dir / filename
with open(file_path, "a", encoding="utf-8") as f:
for exp in experiences:
json.dump(exp.to_dict(), f, ensure_ascii=False)
f.write("\n")
def load_all(self, filename: str = "experiences.jsonl") -> List[Experience]:
"""
파일에서 모든 Experience 로드
Args:
filename: 파일명
Returns:
Experience 리스트
"""
file_path = self.data_dir / filename
if not file_path.exists():
return []
experiences = []
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
if line.strip():
data = json.loads(line)
experiences.append(Experience.from_dict(data))
return experiences
def load_by_episode(
self, episode_id: str, filename: str = "experiences.jsonl"
) -> List[Experience]:
"""
특정 에피소드의 Experience들만 로드
Args:
episode_id: 에피소드 ID
filename: 파일명
Returns:
해당 에피소드의 Experience 리스트
"""
all_experiences = self.load_all(filename)
return [exp for exp in all_experiences if exp.episode_id == episode_id]
def count(self, filename: str = "experiences.jsonl") -> int:
"""
저장된 Experience 개수
Args:
filename: 파일명
Returns:
Experience 개수
"""
file_path = self.data_dir / filename
if not file_path.exists():
return 0
with open(file_path, "r", encoding="utf-8") as f:
return sum(1 for line in f if line.strip())
def clear(self, filename: str = "experiences.jsonl"):
"""
파일 내용 삭제
Args:
filename: 파일명
"""
file_path = self.data_dir / filename
if file_path.exists():
file_path.unlink()
def create_new_file(self, prefix: str = "exp") -> str:
"""
타임스탬프가 포함된 파일 생성
Args:
prefix: 파일명 접두사
Returns:
생성된 파일명
Example:
>>> repo.create_new_file("train")
'train_20251029_163000.jsonl'
"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"{prefix}_{timestamp}.jsonl"
file_path = self.data_dir / filename
file_path.touch()
return filename
def list_files(self) -> List[str]:
"""
데이터 디렉토리의 모든 JSONL 파일 목록
Returns:
파일명 리스트
"""
return [f.name for f in self.data_dir.glob("*.jsonl")]
def get_file_info(self, filename: str) -> dict:
"""
파일 정보 조회
Args:
filename: 파일명
Returns:
{
'filename': 'experiences.jsonl',
'path': '/path/to/file',
'size': 1024, # bytes
'count': 100, # experience 개수
'created': '2025-10-29 16:30:00'
}
"""
file_path = self.data_dir / filename
if not file_path.exists():
return None
stat = file_path.stat()
return {
"filename": filename,
"path": str(file_path),
"size": stat.st_size,
"count": self.count(filename),
"created": datetime.fromtimestamp(stat.st_ctime).strftime(
"%Y-%m-%d %H:%M:%S"
),
"modified": datetime.fromtimestamp(stat.st_mtime).strftime(
"%Y-%m-%d %H:%M:%S"
),
}

View File

@ -0,0 +1,166 @@
"""Reward function aligned with the Q-Table design specification (Version 3.0)."""
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
from typing import Optional
from ..model.state import (
AcceptanceRate,
DistributionStructure,
InputPriceZone,
PartnerType,
RevenueRange,
)
class NegotiationOutcome(Enum):
ONGOING = "ongoing"
SUCCESS = "success"
FAILURE = "failure"
@dataclass
class RewardConfig:
"""보상 함수 설정"""
beta: float = 0.2
success_reward: float = 1.0
ongoing_reward: float = 0.0
failure_penalty: float = -0.5
penalty_lambda: float = 0.02
# 동적 가중치 계수 (Version 3.0)
w1: float = 0.2 # 매출액 가격구간
w2: float = 0.25 # 유통 구조
w3: float = 0.2 # 파트너사 종류
w4: float = 0.25 # 가격 수용률
w5: float = 0.1 # 입력 금액 구간
min_weight: float = 0.2
max_weight: float = 0.8
@dataclass
class RewardBreakdown:
price_reward: float
end_reward: float
penalty: float
weight: float
total: float
def calculate_reward(
*,
revenue_range: RevenueRange,
distribution: DistributionStructure,
partner_type: PartnerType,
acceptance_rate: AcceptanceRate,
input_price_zone: InputPriceZone,
current_price: float,
anchor_price: float,
target_price: float,
round_number: int,
outcome: NegotiationOutcome,
config: Optional[RewardConfig] = None,
) -> RewardBreakdown:
"""
보상 계산 (Version 3.0)
R = W × R_price + (1-W) × R_end - λ × t
W = clip(W_raw, min_weight, max_weight)
W_raw = w1×S_amount + w2×S_dist + w3×S_partner + w4×S_accept + w5×S_pricezone
"""
cfg = config or RewardConfig()
price_reward = _calculate_price_reward(
current_price=current_price,
anchor_price=anchor_price,
target_price=target_price,
beta=cfg.beta,
)
end_reward = _calculate_end_reward(outcome, cfg)
weight = _calculate_weight(
revenue_range=revenue_range,
distribution=distribution,
partner_type=partner_type,
acceptance_rate=acceptance_rate,
input_price_zone=input_price_zone,
config=cfg,
)
penalty = _calculate_penalty(round_number, cfg)
total = weight * price_reward + (1.0 - weight) * end_reward - penalty
return RewardBreakdown(
price_reward=price_reward,
end_reward=end_reward,
penalty=penalty,
weight=weight,
total=total,
)
def _calculate_price_reward(
*,
current_price: float,
anchor_price: float,
target_price: float,
beta: float,
) -> float:
if anchor_price <= 0 or target_price <= 0:
raise ValueError("anchor_price and target_price must be positive")
if anchor_price > target_price:
raise ValueError("anchor_price must not exceed target_price")
if current_price < anchor_price:
improvement = (anchor_price - current_price) / anchor_price
return 1.0 + beta * improvement
if anchor_price <= current_price <= target_price:
span = target_price - anchor_price
if span <= 0:
return 0.0
return (target_price - current_price) / span
return 0.0
def _calculate_end_reward(outcome: NegotiationOutcome, cfg: RewardConfig) -> float:
if outcome is NegotiationOutcome.SUCCESS:
return cfg.success_reward
if outcome is NegotiationOutcome.FAILURE:
return cfg.failure_penalty
return cfg.ongoing_reward
def _calculate_weight(
*,
revenue_range: RevenueRange,
distribution: DistributionStructure,
partner_type: PartnerType,
acceptance_rate: AcceptanceRate,
input_price_zone: InputPriceZone,
config: RewardConfig,
) -> float:
"""
동적 가중치 계산 (Version 3.0)
W_raw = w1×S_amount + w2×S_dist + w3×S_partner + w4×S_accept + w5×S_pricezone
W = clip(W_raw, min_weight, max_weight)
"""
w_raw = (
config.w1 * revenue_range.weight
+ config.w2 * distribution.weight
+ config.w3 * partner_type.weight
+ config.w4 * acceptance_rate.weight
+ config.w5 * input_price_zone.weight
)
return max(config.min_weight, min(config.max_weight, w_raw))
def _calculate_penalty(round_number: int, cfg: RewardConfig) -> float:
if round_number < 0:
raise ValueError("round_number must be non-negative")
return cfg.penalty_lambda * float(round_number)

View File

@ -0,0 +1,87 @@
"""Utility helpers to build Q-Table states from negotiation snapshots (Version 3.0)."""
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
from ..model.state import (
AcceptanceRate,
DistributionStructure,
InputPriceZone,
PartnerType,
RevenueRange,
State,
)
@dataclass(frozen=True)
class NegotiationSnapshot:
"""협상 스냅샷 - 실제 비즈니스 데이터"""
# 매출액 (만원 단위)
revenue_amount: float
# 유통 구조 코드 ("M": 제조, "W": 총판, "R": 유통)
distribution_code: str
# 파트너사 수
partner_count: int
# 가격 정보
anchor_price: float
target_price: float
input_price: float
# 가격 수용률 (0~1, 계산되거나 직접 입력)
acceptance_ratio: Optional[float] = None
# 초기 가격 (수용률 계산용, acceptance_ratio가 없을 때)
initial_price: Optional[float] = None
def build_state(snapshot: NegotiationSnapshot) -> State:
"""
협상 스냅샷으로부터 Q-Learning State 생성 (Version 3.0)
Args:
snapshot: 협상 스냅샷
Returns:
162-dimensional state
"""
# 1. 매출액 가격구간
revenue_range = RevenueRange.from_amount(snapshot.revenue_amount)
# 2. 유통 구조
distribution = DistributionStructure.from_code(snapshot.distribution_code)
# 3. 파트너사 종류
partner_type = PartnerType.from_count(snapshot.partner_count)
# 4. 가격 수용률 구간
if snapshot.acceptance_ratio is not None:
acceptance_rate = AcceptanceRate.from_ratio(snapshot.acceptance_ratio)
elif snapshot.initial_price is not None:
# 초기가격으로부터 수용률 계산: (initial - input) / initial
if snapshot.initial_price <= 0:
raise ValueError("initial_price must be positive")
ratio = (snapshot.initial_price - snapshot.input_price) / snapshot.initial_price
acceptance_rate = AcceptanceRate.from_ratio(ratio)
else:
raise ValueError("Either acceptance_ratio or initial_price must be provided")
# 5. 입력 금액 구간
input_price_zone = InputPriceZone.from_prices(
input_price=snapshot.input_price,
anchor_price=snapshot.anchor_price,
target_price=snapshot.target_price,
)
return State(
revenue_range=revenue_range,
distribution=distribution,
partner_type=partner_type,
acceptance_rate=acceptance_rate,
input_price_zone=input_price_zone,
)

View File

@ -0,0 +1,46 @@
import json
from pathlib import Path
from typing import Optional, Tuple
from src.negotiation_agent.Q_Table.domain.model.q_table import QTable
from src.negotiation_agent.Q_Table.domain.model.visit_table import VisitTable
class ModelRepository:
"""Simple repository to save/load Q-Table and VisitTable models."""
def __init__(self, model_dir: Optional[str] = None):
if model_dir is None:
# Current: src/negotiation_agent/Q_Table/infra/repository/model_repository.py
# Target Data: src/negotiation_agent/Q_Table/data/model
# Path: ../../../data/model
current_dir = Path(__file__).parent
model_dir = current_dir.parent.parent / "data" / "model"
self.model_dir = Path(model_dir)
self.model_dir.mkdir(parents=True, exist_ok=True)
self.q_table_path = self.model_dir / "q_table.json"
self.visit_table_path = self.model_dir / "visit_table.json"
def save(self, q_table: QTable, visit_table: VisitTable):
with open(self.q_table_path, "w", encoding="utf-8") as f:
json.dump(q_table.to_dict(), f, indent=2)
with open(self.visit_table_path, "w", encoding="utf-8") as f:
json.dump(visit_table.to_dict(), f, indent=2)
print(f"Models saved to {self.model_dir}")
def load(self) -> Tuple[Optional[QTable], Optional[VisitTable]]:
q_table = None
visit_table = None
if self.q_table_path.exists():
with open(self.q_table_path, "r", encoding="utf-8") as f:
q_table = QTable.from_dict(json.load(f))
if self.visit_table_path.exists():
with open(self.visit_table_path, "r", encoding="utf-8") as f:
visit_table = VisitTable.from_dict(json.load(f))
return q_table, visit_table

Binary file not shown.

View File

@ -0,0 +1,143 @@
"""
CollectExperienceUsecase
협상 Experience 데이터를 수집하여 저장
"""
from typing import Optional
from datetime import datetime
from ..domain.model.state import State
from ..domain.model.experience import Experience
from ..domain.repository.experience_repository import ExperienceRepository
class CollectExperienceUsecase:
"""
협상 과정에서 발생하는 Experience를 수집하여 저장
오프라인 학습을 위한 데이터 수집 담당
"""
def __init__(
self,
experience_repository: ExperienceRepository,
filename: Optional[str] = None,
):
"""
Args:
experience_repository: Experience 저장소
filename: 저장할 파일명 (None이면 기본 파일 사용)
"""
self.experience_repository = experience_repository
self.filename = filename or "experiences.jsonl"
self.current_episode_id: Optional[str] = None
self.current_step: int = 0
def start_episode(self, episode_id: Optional[str] = None):
"""
새로운 에피소드 시작
Args:
episode_id: 에피소드 ID (None이면 자동 생성)
"""
if episode_id is None:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
episode_id = f"ep_{timestamp}"
self.current_episode_id = episode_id
self.current_step = 0
def collect(
self,
state: State,
action_id: int,
reward: float,
next_state: Optional[State],
done: bool,
) -> Experience:
"""
단일 Experience 수집 저장
Args:
state: 현재 상태
action_id: 선택한 액션 ID
reward: 받은 보상
next_state: 다음 상태 (종료 None)
done: 에피소드 종료 여부
Returns:
수집된 Experience
"""
if self.current_episode_id is None:
self.start_episode()
# Experience 생성
experience = Experience(
state=state,
action_id=action_id,
reward=reward,
next_state=next_state,
done=done,
episode_id=self.current_episode_id,
step=self.current_step,
timestamp=datetime.now().isoformat(),
)
# 저장
self.experience_repository.save(experience, self.filename)
# 스텝 증가
self.current_step += 1
return experience
def end_episode(self):
"""에피소드 종료"""
self.current_episode_id = None
self.current_step = 0
def get_episode_count(self) -> int:
"""수집된 에피소드 개수 (근사치)"""
# 파일에서 unique episode_id 개수 계산
experiences = self.experience_repository.load_all(self.filename)
episode_ids = set(exp.episode_id for exp in experiences if exp.episode_id)
return len(episode_ids)
def get_total_count(self) -> int:
"""수집된 총 Experience 개수"""
return self.experience_repository.count(self.filename)
def get_collection_info(self) -> dict:
"""
수집 정보 조회
Returns:
{
'filename': 'experiences.jsonl',
'total_experiences': 1000,
'episodes': 50,
'current_episode': 'ep_001' or None,
'current_step': 5
}
"""
return {
"filename": self.filename,
"total_experiences": self.get_total_count(),
"episodes": self.get_episode_count(),
"current_episode": self.current_episode_id,
"current_step": self.current_step,
"file_info": self.experience_repository.get_file_info(self.filename),
}
def create_new_collection_file(self, prefix: str = "exp") -> str:
"""
새로운 수집 파일 생성 전환
Args:
prefix: 파일명 접두사
Returns:
생성된 파일명
"""
new_filename = self.experience_repository.create_new_file(prefix)
self.filename = new_filename
return new_filename

View File

@ -0,0 +1,72 @@
"""
EvaluateAgentUsecase
학습된 Q-Table의 성능을 평가하는 Usecase
"""
from typing import Optional
from ..domain.model.q_table import QTable
from ..domain.repository.experience_repository import ExperienceRepository
class EvaluateAgentUsecase:
"""
저장된 Experience 데이터를 사용하여 현재 에이전트(Q-Table) 성능을 평가
"""
def __init__(
self,
q_table: QTable,
experience_repository: ExperienceRepository,
):
"""
Args:
q_table: 평가할 Q-Table
experience_repository: Experience 저장소
"""
self.q_table = q_table
self.experience_repository = experience_repository
def execute(
self, filename: str = "experiences.jsonl", sample_size: Optional[int] = None
) -> dict:
"""
현재 Q-Table의 성능 평가
Args:
filename: 평가용 Experience 파일
sample_size: 샘플링 크기 (None이면 전체)
Returns:
{
'avg_q_value': 0.5,
'avg_reward': 1.0,
'total_samples': 100
}
"""
experiences = self.experience_repository.load_all(filename)
if not experiences:
return {"avg_q_value": 0.0, "avg_reward": 0.0, "total_samples": 0}
# 샘플링
if sample_size and sample_size < len(experiences):
import random
experiences = random.sample(experiences, sample_size)
total_q_value = 0.0
total_reward = 0.0
for exp in experiences:
state_index = exp.state.to_index()
q_value = self.q_table.get_q_value(state_index, exp.action_id)
total_q_value += q_value
total_reward += exp.reward
n = len(experiences)
return {
"avg_q_value": total_q_value / n,
"avg_reward": total_reward / n,
"total_samples": n,
}

View File

@ -0,0 +1,179 @@
"""
GetBestActionUsecase
State를 받아 최적의 협상 카드를 선택하는 Usecase
"""
from typing import Dict, Optional
from ..domain.model.state import State
from ..domain.model.q_table import QTable
from ..domain.agents.policy import UCBPolicy
from ...integration.action_card_mapper import ActionCardMapper
class GetBestActionUsecase:
"""
협상 추론 Usecase: State action_id card_id
Q_Table은 추상적인 action_id만 반환하고,
ActionCardMapper를 통해 구체적인 card_id로 변환한다.
"""
def __init__(
self,
q_table: QTable,
policy: UCBPolicy,
action_card_mapper: ActionCardMapper,
):
"""
Args:
q_table: 학습된 Q-Table
policy: 액션 선택 정책 (UCB)
action_card_mapper: action_id card_id 매핑
"""
self.q_table = q_table
self.policy = policy
self.action_card_mapper = action_card_mapper
def execute(self, state: State, use_policy: bool = True) -> Dict:
"""
현재 상태에서 최적의 협상 카드 선택
Args:
state: 현재 협상 상태
use_policy: True면 Policy 사용 (UCB + 중복 방지)
False면 순수 Q-value 최대값만 사용
Returns:
{
'action_id': 5,
'card_id': 'no_5',
'q_value': 0.85,
'state_index': 12,
'all_q_values': [0.1, 0.2, ..., 0.85, ...] # optional
}
사용 가능한 액션이 없으면:
{
'action_id': None,
'card_id': None,
'q_value': None,
'message': 'No available actions'
}
"""
# 1. State → state_index 변환
state_index = state.to_index()
# 2. Q-values 조회
q_values = self.q_table.get_q_values_for_state(state_index)
# 3. Action 선택
if use_policy:
action_id = self.policy.select_action(state_index, q_values)
else:
# 순수 Q-value 최대값
action_id = self.q_table.get_best_action(state_index)
if action_id is None:
return {
"action_id": None,
"card_id": None,
"q_value": None,
"state_index": state_index,
"message": "No available actions (all actions used in this episode)",
}
# 4. action_id → card_id 매핑
card_id = self.action_card_mapper.get_card_id(action_id)
if card_id is None:
return {
"action_id": action_id,
"card_id": None,
"q_value": float(q_values[action_id]),
"state_index": state_index,
"message": f"action_id {action_id} has no card mapping",
}
# 5. Q-value 추출
q_value = self.q_table.get_q_value(state_index, action_id)
return {
"action_id": action_id,
"card_id": card_id,
"q_value": float(q_value),
"state_index": state_index,
"all_q_values": q_values.tolist(), # 디버깅용
}
def get_top_k_actions(self, state: State, k: int = 3) -> list:
"""
상위 k개의 액션 추천
Args:
state: 현재 협상 상태
k: 추천할 액션 개수
Returns:
[
{'action_id': 5, 'card_id': 'no_5', 'q_value': 0.85},
{'action_id': 3, 'card_id': 'no_3', 'q_value': 0.72},
{'action_id': 1, 'card_id': 'no_1', 'q_value': 0.68}
]
"""
state_index = state.to_index()
q_values = self.q_table.get_q_values_for_state(state_index)
# Q-value 내림차순 정렬
sorted_actions = sorted(enumerate(q_values), key=lambda x: x[1], reverse=True)
# 상위 k개 추출
top_k = sorted_actions[:k]
results = []
for action_id, q_value in top_k:
card_id = self.action_card_mapper.get_card_id(action_id)
results.append(
{"action_id": action_id, "card_id": card_id, "q_value": float(q_value)}
)
return results
def reset_episode(self):
"""
에피소드 종료 Policy 초기화
(다음 협상 세션 시작 호출)
"""
self.policy.reset_episode()
def get_available_actions(self) -> list:
"""
현재 에피소드에서 아직 사용하지 않은 액션들
Returns:
[0, 2, 4, 5, ...] (사용하지 않은 action_id들)
"""
import numpy as np
mask = self.policy.get_action_mask()
available_action_ids = np.where(mask > 0)[0].tolist()
return available_action_ids
def get_available_cards(self) -> list:
"""
현재 에피소드에서 아직 사용하지 않은 카드들
Returns:
[
{'action_id': 0, 'card_id': 'no_0'},
{'action_id': 2, 'card_id': 'no_2'},
...
]
"""
available_action_ids = self.get_available_actions()
results = []
for action_id in available_action_ids:
card_id = self.action_card_mapper.get_card_id(action_id)
results.append({"action_id": action_id, "card_id": card_id})
return results

View File

@ -0,0 +1,203 @@
"""
TrainOfflineUsecase
저장된 Experience 데이터로 Q-Table을 오프라인 학습
"""
from typing import List, Optional
from ..domain.model.q_table import QTable
from ..domain.model.experience import Experience
from ..domain.model.visit_table import VisitTable
from ..domain.repository.experience_repository import ExperienceRepository
class TrainOfflineUsecase:
"""
수집된 Experience 데이터를 사용하여 Q-Table을 오프라인으로 학습
배치 학습 방식:
- 저장된 Experience를 로드
- Q-Learning 알고리즘으로 Q-Table 업데이트
- 여러 에포크 반복 가능
"""
def __init__(
self,
q_table: QTable,
experience_repository: ExperienceRepository,
visit_table: Optional[VisitTable] = None,
):
"""
Args:
q_table: 학습할 Q-Table
experience_repository: Experience 저장소
visit_table: 방문 횟수를 함께 추적할 N-Table (선택)
"""
self.q_table = q_table
self.experience_repository = experience_repository
self.visit_table = visit_table
def train(
self,
filename: str = "experiences.jsonl",
epochs: int = 1,
batch_size: Optional[int] = None,
) -> dict:
"""
저장된 Experience로 Q-Table 학습
Args:
filename: Experience 파일명
epochs: 학습 반복 횟수
batch_size: 배치 크기 (None이면 전체 데이터 사용)
Returns:
{
'total_experiences': 1000,
'epochs': 10,
'updates': 10000,
'avg_loss': 0.05
}
"""
# Experience 로드
experiences = self.experience_repository.load_all(filename)
if not experiences:
return {
"total_experiences": 0,
"epochs": 0,
"updates": 0,
"avg_loss": 0.0,
"message": "No experiences found",
}
total_updates = 0
total_loss = 0.0
# 에포크 반복
for epoch in range(epochs):
epoch_loss = 0.0
# 배치 처리
if batch_size:
for i in range(0, len(experiences), batch_size):
batch = experiences[i : i + batch_size]
loss = self._train_batch(batch)
epoch_loss += loss
total_updates += len(batch)
else:
# 전체 데이터 한번에
loss = self._train_batch(experiences)
epoch_loss += loss
total_updates += len(experiences)
total_loss += epoch_loss
avg_loss = total_loss / total_updates if total_updates > 0 else 0.0
return {
"total_experiences": len(experiences),
"epochs": epochs,
"updates": total_updates,
"avg_loss": avg_loss,
}
def _train_batch(self, experiences: List[Experience]) -> float:
"""
배치 Experience로 Q-Table 업데이트
Args:
experiences: Experience 리스트
Returns:
평균 손실값
"""
total_loss = 0.0
for exp in experiences:
# State → state_index
state_index = exp.state.to_index()
# Q-value 업데이트 전 값
old_q = self.q_table.get_q_value(state_index, exp.action_id)
# Q-Table 업데이트
if exp.next_state:
next_state_index = exp.next_state.to_index()
self.q_table.update(
state_index=state_index,
action_id=exp.action_id,
reward=exp.reward,
next_state_index=next_state_index,
done=exp.done,
)
else:
# 종료 상태
self.q_table.update(
state_index=state_index,
action_id=exp.action_id,
reward=exp.reward,
next_state_index=None,
done=True,
)
if self.visit_table is not None:
self.visit_table.increment(state_index, exp.action_id)
# Q-value 업데이트 후 값
new_q = self.q_table.get_q_value(state_index, exp.action_id)
# 손실 계산 (업데이트 크기)
loss = abs(new_q - old_q)
total_loss += loss
return total_loss / len(experiences) if experiences else 0.0
def train_by_episodes(
self,
episode_ids: List[str],
filename: str = "experiences.jsonl",
epochs: int = 1,
) -> dict:
"""
특정 에피소드들만 선택하여 학습
Args:
episode_ids: 학습할 에피소드 ID 리스트
filename: Experience 파일명
epochs: 학습 반복 횟수
Returns:
학습 결과
"""
# 해당 에피소드의 Experience만 로드
all_experiences = self.experience_repository.load_all(filename)
filtered_experiences = [
exp for exp in all_experiences if exp.episode_id in episode_ids
]
if not filtered_experiences:
return {
"total_experiences": 0,
"episodes": 0,
"epochs": 0,
"updates": 0,
"avg_loss": 0.0,
"message": "No matching episodes found",
}
total_updates = 0
total_loss = 0.0
for epoch in range(epochs):
loss = self._train_batch(filtered_experiences)
total_loss += loss * len(filtered_experiences)
total_updates += len(filtered_experiences)
return {
"total_experiences": len(filtered_experiences),
"episodes": len(episode_ids),
"epochs": epochs,
"updates": total_updates,
"avg_loss": total_loss / total_updates if total_updates > 0 else 0.0,
}

View File

@ -0,0 +1,152 @@
"""
Action-Card 매핑 레이어
Q_Table의 action_id와 Card_Management의 card_id를 연결
"""
import json
from pathlib import Path
from typing import Dict, Optional
class ActionCardMapper:
"""
action_id (0, 1, 2, ...) card_id ("no_0", "no_1", ...)
Q_Table은 추상적인 action_id만 다루고,
Card_Management는 구체적인 card_id만 다룬다.
클래스가 둘을 연결하는 단일 책임을 가진다.
매핑은 JSON 파일에서 로드되며, 카드 추가/삭제 수동으로 업데이트된다.
"""
def __init__(self, mapping_file: Optional[str] = None):
"""
Args:
mapping_file: JSON 매핑 파일 경로
(기본: integration/data/action_card_mapping.json)
"""
if mapping_file is None:
current_dir = Path(__file__).parent
mapping_file = current_dir / "data" / "action_card_mapping.json"
self.mapping_file = Path(mapping_file)
self.action_to_card: Dict[int, str] = {}
self.card_to_action: Dict[str, int] = {}
self._load_mapping()
def _load_mapping(self):
"""JSON 파일에서 매핑 로드"""
if not self.mapping_file.exists():
raise FileNotFoundError(
f"Mapping file not found: {self.mapping_file}\n"
f"Please create action_card_mapping.json first."
)
with open(self.mapping_file, "r", encoding="utf-8") as f:
data = json.load(f)
# action_to_card: {"0": "no_0", "1": "no_1", ...}
# JSON keys are strings, convert to int
self.action_to_card = {int(k): v for k, v in data["action_to_card"].items()}
# card_to_action: reverse mapping for quick lookup
self.card_to_action = {v: int(k) for k, v in data["action_to_card"].items()}
def get_card_id(self, action_id: int) -> Optional[str]:
"""
action_id card_id 변환
Args:
action_id: Q_Table에서 사용하는 액션 ID (0, 1, 2, ...)
Returns:
card_id: Card_Management에서 사용하는 카드 ID ("no_0", "no_1", ...)
매핑이 없으면 None
Example:
>>> mapper.get_card_id(5)
'no_5'
"""
return self.action_to_card.get(action_id)
def get_action_id(self, card_id: str) -> Optional[int]:
"""
card_id action_id 변환
Args:
card_id: Card_Management에서 사용하는 카드 ID ("no_0", "no_1", ...)
Returns:
action_id: Q_Table에서 사용하는 액션 ID (0, 1, 2, ...)
매핑이 없으면 None
Example:
>>> mapper.get_action_id('no_5')
5
"""
return self.card_to_action.get(card_id)
def get_action_space_size(self) -> int:
"""
현재 매핑된 액션 개수 (= 카드 개수)
Returns:
액션 공간 크기
Example:
>>> mapper.get_action_space_size()
21 # 21개의 카드
"""
return len(self.action_to_card)
def get_all_mappings(self) -> Dict[int, str]:
"""
전체 매핑 반환
Returns:
{action_id: card_id} 딕셔너리
Example:
>>> mapper.get_all_mappings()
{0: 'no_0', 1: 'no_1', ..., 20: 'no_20'}
"""
return self.action_to_card.copy()
def get_all_action_ids(self) -> list:
"""
모든 action_id 리스트 반환 (정렬됨)
Returns:
[0, 1, 2, ..., 20]
"""
return sorted(self.action_to_card.keys())
def get_all_card_ids(self) -> list:
"""
모든 card_id 리스트 반환 (action_id 순서)
Returns:
['no_0', 'no_1', ..., 'no_20']
"""
return [self.action_to_card[i] for i in sorted(self.action_to_card.keys())]
def is_valid_action_id(self, action_id: int) -> bool:
"""action_id가 매핑에 존재하는지 확인"""
return action_id in self.action_to_card
def is_valid_card_id(self, card_id: str) -> bool:
"""card_id가 매핑에 존재하는지 확인"""
return card_id in self.card_to_action
def reload(self):
"""매핑 파일 다시 로드 (런타임 중 파일이 변경된 경우)"""
self._load_mapping()
def __repr__(self) -> str:
return (
f"ActionCardMapper("
f"action_space_size={self.get_action_space_size()}, "
f"mapping_file={self.mapping_file}"
f")"
)

View File

@ -0,0 +1,13 @@
{
"action_to_card": {
"0": "no_1",
"1": "no_2",
"2": "no_3",
"3": "no_5",
"4": "no_6",
"5": "no_8",
"6": "no_11",
"7": "no_13",
"8": "no_14"
}
}

90
train.py Normal file
View File

@ -0,0 +1,90 @@
import sys
import os
import argparse
import numpy as np
from pathlib import Path
# Add project root to path to allow imports from src
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from src.negotiation_agent.Q_Table.domain.model.q_table import QTable
from src.negotiation_agent.Q_Table.domain.model.visit_table import VisitTable
from src.negotiation_agent.Q_Table.domain.repository.experience_repository import ExperienceRepository
from src.negotiation_agent.Q_Table.infra.repository.model_repository import ModelRepository
from src.negotiation_agent.Q_Table.usecase.train_offline_usecase import TrainOfflineUsecase
from src.negotiation_agent.integration.action_card_mapper import ActionCardMapper
def main():
parser = argparse.ArgumentParser(description="Train Q-Table Agent")
parser.add_argument("--epochs", type=int, default=10, help="Number of epochs")
parser.add_argument("--batch-size", type=int, default=32, help="Batch size")
parser.add_argument("--lr", type=float, default=0.1, help="Learning rate (only for new tables)")
parser.add_argument("--gamma", type=float, default=0.9, help="Discount factor (only for new tables)")
parser.add_argument("--data-file", type=str, default="experiences.jsonl", help="Experience file name inside data/experiences/")
args = parser.parse_args()
print("=== KTC V2 Agent Training ===")
# 1. Config
try:
mapper = ActionCardMapper()
ACTION_SIZE = mapper.get_action_space_size()
except Exception as e:
# Fallback if specific file error
print(f"Warning: Could not load Action mapping ({e}). Defaulting to 21.")
ACTION_SIZE = 21
STATE_SIZE = 162
print(f"Configuration: State Size={STATE_SIZE}, Action Size={ACTION_SIZE}")
# 2. Repository & Models
model_repo = ModelRepository()
print("Loading models...")
q_table, visit_table = model_repo.load()
if q_table is None:
print("[Info] No existing Q-Table found. Creating new one.")
q_table = QTable(
state_space_size=STATE_SIZE,
action_space_size=ACTION_SIZE,
learning_rate=args.lr,
discount_factor=args.gamma
)
else:
print("[Info] Loaded existing Q-Table.")
if visit_table is None:
print("[Info] No existing VisitTable found. Creating new one.")
visit_table = VisitTable(STATE_SIZE, ACTION_SIZE)
else:
print("[Info] Loaded existing VisitTable.")
# 3. Data Repository
exp_repo = ExperienceRepository()
# Check if data file exists
data_path = exp_repo.data_dir / args.data_file
if not data_path.exists():
print(f"[Warning] Experience file not found at: {data_path}")
print("Please ensure the data file is synchronized from the main server.")
return
# 4. Usecase
trainer = TrainOfflineUsecase(q_table, exp_repo, visit_table)
# 5. Train
print(f"Starting training for {args.epochs} epochs with batch size {args.batch_size}...")
result = trainer.train(filename=args.data_file, epochs=args.epochs, batch_size=args.batch_size)
print("\nTraining Result:")
for k, v in result.items():
print(f" {k}: {v}")
# 6. Save
print("Saving models...")
model_repo.save(q_table, visit_table)
print("Done.")
if __name__ == "__main__":
main()