commit 94bbc309fd1a4763afc6a3819774c4d0da135568 Author: mgjeon Date: Mon Dec 29 09:08:37 2025 +0900 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2747813 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +.venv/ +__pycache__/ +*.py[cod] +.DS_Store +.env diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..807b21c --- /dev/null +++ b/requirements.txt @@ -0,0 +1,89 @@ +aiohappyeyeballs==2.6.1 +aiohttp==3.13.2 +aiosignal==1.4.0 +aiosqlite==0.21.0 +annotated-doc==0.0.4 +annotated-types==0.7.0 +anyio==4.11.0 +asyncpg==0.31.0 +attrs==25.4.0 +certifi==2025.11.12 +charset-normalizer==3.4.4 +click==8.3.1 +cloudpickle==3.1.2 +contourpy==1.3.3 +cycler==0.12.1 +distro==1.9.0 +Farama-Notifications==0.0.4 +fastapi==0.122.0 +fonttools==4.60.1 +frozenlist==1.8.0 +greenlet==3.2.4 +gunicorn==23.0.0 +gymnasium==1.2.2 +h11==0.16.0 +h5py==3.15.1 +httpcore==1.0.9 +httpx==0.28.1 +idna==3.11 +iniconfig==2.3.0 +JayDeBeApi==1.2.3 +Jinja2==3.1.6 +jiter==0.12.0 +jpype1==1.6.0 +jsonpatch==1.33 +jsonpointer==3.0.0 +kiwisolver==1.4.9 +langchain-core==1.1.0 +langchain-openai==1.1.0 +langgraph==1.0.4 +langgraph-checkpoint==3.0.1 +langgraph-checkpoint-sqlite==3.0.0 +langgraph-prebuilt==1.0.5 +langgraph-sdk==0.2.10 +langsmith==0.4.49 +MarkupSafe==3.0.3 +matplotlib==3.10.7 +multidict==6.7.0 +numpy==2.3.5 +openai==2.8.1 +orjson==3.11.4 +ormsgpack==1.12.0 +packaging==25.0 +passlib==1.7.4 +pillow==12.0.0 +pluggy==1.6.0 +propcache==0.4.1 +psycopg2-binary==2.9.11 +pydantic==2.12.5 +pydantic-settings==2.12.0 +pydantic_core==2.41.5 +Pygments==2.19.2 +PyJWT==2.10.1 +pyparsing==3.2.5 +pypdf==6.1.3 +pytest==9.0.1 +pytest-asyncio==1.3.0 +python-dateutil==2.9.0.post0 +python-dotenv==1.2.1 +python-multipart==0.0.20 +pytz==2025.2 +PyYAML==6.0.3 +regex==2025.11.3 +requests==2.32.5 +requests-toolbelt==1.0.0 +six==1.17.0 +sniffio==1.3.1 +SQLAlchemy==2.0.44 +sqlite-vec==0.1.6 +starlette==0.50.0 +tenacity==9.1.2 +tiktoken==0.12.0 +tqdm==4.67.1 +typing-inspection==0.4.2 +typing_extensions==4.15.0 +urllib3==2.5.0 +uvicorn==0.38.0 +xxhash==3.6.0 +yarl==1.22.0 +zstandard==0.25.0 diff --git a/setup_env.sh b/setup_env.sh new file mode 100755 index 0000000..bf52ec5 --- /dev/null +++ b/setup_env.sh @@ -0,0 +1,5 @@ +#!/bin/bash +python3 -m venv .venv +source .venv/bin/activate +pip install -r requirements.txt +echo "Setup complete. Activate with 'source .venv/bin/activate'" diff --git a/src/negotiation_agent/Q_Table/CHANGELOG.md b/src/negotiation_agent/Q_Table/CHANGELOG.md new file mode 100644 index 0000000..d05ba3e --- /dev/null +++ b/src/negotiation_agent/Q_Table/CHANGELOG.md @@ -0,0 +1,63 @@ +# 변경 사항 (Changelog) + +## Version 3.0 - 비즈니스 용어 적용 (최종) + +### 주요 변경 사항 + +#### 1. State 변수명 최종 확정 (03_state_design.tex) + +**Version 2.0 (이전):** + +- 견적 금액 구간, 제조사 구분, 파트너사 구조, 가격 수용률 분위, 현재 가격 구간 + +**Version 3.0 (최종):** + +- 매출액 가격구간, 유통 구조, 파트너사 종류, 가격 수용률 구간, 입력 금액 구간 + +**변경 이유:** + +- 실제 비즈니스에서 사용하는 용어로 통일하여 가독성 및 이해도 향상 + +**최종 State 변수명:** + +| 순번 | 변수 | 구분 | 설명 | +| :--- | :------------------- | :--- | :--------------------------------------------------------- | +| 1 | **매출액 가격구간** | 3개 | Low (≤1,000만원), Mid (1,000~3,000만원), High (>3,000만원) | +| 2 | **유통 구조** | 3개 | 제조, 총판, 유통 | +| 3 | **파트너사 종류** | 3개 | Single (단독), Multiple (다수), None (없음) | +| 4 | **가격 수용률 구간** | 3개 | Low (<30%), Mid (30~90%), High (>90%) | +| 5 | **입력 금액 구간** | 2개 | PZ1 (A ≤ P ≤ T), PZ2 (P > T) | + +#### 2. 동적 가중치 (W) 설계 업데이트 (04_reward_function.tex) + +- 변수명 변경에 따라 수식 및 테이블의 변수명도 모두 업데이트 +- `S_manu` → `S_dist` +- 계산 로직 및 가중치 값은 동일하게 유지 + +``` +W_raw = w1×S_amount + w2×S_dist + w3×S_partner + w4×S_accept + w5×S_pricezone +``` + +**새로운 유통 구조 가중치 설계 의도:** + +| 유통 구조 | 정규화 값 | 설계 의도 | +| :-------- | :-------- | :----------------------------------------- | +| 제조 | 0.2 | 제조사 직공급으로 가격 협상 여지 가장 적음 | +| 총판 | 0.5 | 중간 유통 단계로 협상 여지 중간 | +| 유통 | 1.0 | 복잡한 유통 단계로 가격 협상 여지 가장 큼 | + +### 파일 변경 내역 + +- `sections/03_state_design.tex`: 변수명 변경 +- `sections/04_reward_function.tex`: 변수명 변경 +- `EXAMPLE_CALCULATION.md`: 변수명 변경 +- `CHANGELOG.md`: 이 파일 (Version 4.0 반영) + +### 최종 검증 + +- [x] State 변수명 5개 모두 비즈니스 용어로 변경 완료 +- [x] 총 State 수 162개 유지 +- [x] 동적 가중치 W 계산식 변수명 업데이트 완료 +- [x] 계산 예시 변수명 업데이트 완료 + +이 버전은 기능적 변경 없이, 문서의 가독성과 현업 적용성을 높이는 데 중점을 둔 최종 버전입니다. diff --git a/src/negotiation_agent/Q_Table/MIGRATION_V3.md b/src/negotiation_agent/Q_Table/MIGRATION_V3.md new file mode 100644 index 0000000..4725b4d --- /dev/null +++ b/src/negotiation_agent/Q_Table/MIGRATION_V3.md @@ -0,0 +1,173 @@ +# Q-Table Version 3.0 Migration Guide + +## 개요 + +Q-Table 프로젝트가 Version 2.0 (36 states)에서 Version 3.0 (162 states)으로 업그레이드되었습니다. + +## 주요 변경사항 + +### 1. State 구조 변경 + +#### Before (Version 2.0 - 36 states) + +```python +from negotiation_agent.Q_Table.domain.model.state import ( + State, Scenario, PriceZone, AcceptanceRate +) + +state = State( + scenario=Scenario.PRICE_FIRST, # 4개 값 + price_zone=PriceZone.AT_OR_BELOW_ANCHOR, # 3개 값 + acceptance_rate=AcceptanceRate.MEDIUM # 3개 값 +) +# 총 4 × 3 × 3 = 36 states +``` + +#### After (Version 3.0 - 162 states) + +```python +from negotiation_agent.Q_Table.domain.model.state import ( + State, + RevenueRange, + DistributionStructure, + PartnerType, + AcceptanceRate, + InputPriceZone, +) + +state = State( + revenue_range=RevenueRange.MID, # 3개 값 + distribution=DistributionStructure.WHOLESALER, # 3개 값 + partner_type=PartnerType.SINGLE, # 3개 값 + acceptance_rate=AcceptanceRate.MID, # 3개 값 + input_price_zone=InputPriceZone.PZ1, # 2개 값 +) +# 총 3 × 3 × 3 × 3 × 2 = 162 states +``` + +### 2. State Builder 변경 + +#### Before (Version 2.0) + +```python +from negotiation_agent.Q_Table.domain.service.state_calculator import ( + NegotiationSnapshot, build_state +) + +snapshot = NegotiationSnapshot( + scenario_code="A", + anchor_price=10000, + target_price=12000, + seller_initial_price=15000, + current_price=11000, +) +state = build_state(snapshot) +``` + +#### After (Version 3.0) + +```python +from negotiation_agent.Q_Table.domain.service.state_calculator import ( + NegotiationSnapshot, build_state +) + +snapshot = NegotiationSnapshot( + revenue_amount=2000, # 매출액 (만원) + distribution_code="W", # "M": 제조, "W": 총판, "R": 유통 + partner_count=1, # 파트너사 수 + anchor_price=10000, + target_price=12000, + input_price=11000, + acceptance_ratio=0.5, # 0~1 사이 값 +) +state = build_state(snapshot) +``` + +### 3. Reward 계산 변경 + +#### Before (Version 2.0) + +```python +from negotiation_agent.Q_Table.domain.service.reward_calculator import ( + calculate_reward, NegotiationOutcome +) + +breakdown = calculate_reward( + scenario=state.scenario, + price_zone=state.price_zone, + current_price=11000, + anchor_price=10000, + target_price=12000, + round_number=3, + outcome=NegotiationOutcome.ONGOING, +) +``` + +#### After (Version 3.0) + +```python +from negotiation_agent.Q_Table.domain.service.reward_calculator import ( + calculate_reward, NegotiationOutcome +) + +breakdown = calculate_reward( + revenue_range=state.revenue_range, + distribution=state.distribution, + partner_type=state.partner_type, + acceptance_rate=state.acceptance_rate, + input_price_zone=state.input_price_zone, + current_price=11000, + anchor_price=10000, + target_price=12000, + round_number=3, + outcome=NegotiationOutcome.ONGOING, +) +``` + +### 4. Q-Table 크기 변경 + +#### Before + +```python +q_table = QTable(state_space_size=36, action_space_size=21) +visit_table = VisitTable(state_space_size=36, action_space_size=21) +``` + +#### After + +```python +q_table = QTable(state_space_size=162, action_space_size=21) +visit_table = VisitTable(state_space_size=162, action_space_size=21) +``` + +## 마이그레이션 체크리스트 + +- [ ] State 생성 코드를 새로운 5-변수 구조로 변경 +- [ ] NegotiationSnapshot 생성 코드 업데이트 +- [ ] calculate_reward() 호출부 매개변수 변경 +- [ ] Q-Table, VisitTable 초기화 시 state_space_size를 162로 변경 +- [ ] 기존 학습된 모델 파일(.npy) 재학습 필요 (36→162 차원 불일치) +- [ ] 단위 테스트 업데이트 + +## 호환성 주의사항 + +⚠️ **기존 학습 모델 사용 불가** + +- Version 2.0에서 학습된 Q-Table (36 × N)은 Version 3.0 (162 × N)과 호환되지 않습니다. +- 기존 모델을 사용하려면 재학습이 필요합니다. + +⚠️ **Experience 데이터 재수집 권장** + +- 기존 Experience에는 새로운 State 변수 정보가 없습니다. +- 새로운 State 정의에 맞춰 Experience를 재수집하는 것을 권장합니다. + +## 변경 이유 + +1. **비즈니스 용어 적용**: 실제 협상 도메인에서 사용하는 용어로 통일 +2. **더 세밀한 상태 표현**: 매출액, 유통 구조, 파트너사 정보 등을 명시적으로 포함 +3. **확장성 향상**: 새로운 비즈니스 변수 추가 시 유연한 대응 가능 +4. **협상 전략 다양화**: 162개의 상태로 더 정교한 협상 전략 학습 가능 + +## 문의 + +변경사항에 대한 문의는 팀 리드에게 연락해주세요. diff --git a/src/negotiation_agent/Q_Table/REFACTORING_SUMMARY.md b/src/negotiation_agent/Q_Table/REFACTORING_SUMMARY.md new file mode 100644 index 0000000..0b5034e --- /dev/null +++ b/src/negotiation_agent/Q_Table/REFACTORING_SUMMARY.md @@ -0,0 +1,364 @@ +# Q_Table 리팩토링 완료 요약 + +## 📅 작업 일자 + +- 2025-10-29: Version 2.0 (36 states) +- 2025-11-10: **Version 3.0 (162 states)** - 비즈니스 용어 적용 및 State 재설계 + +## ✅ Version 3.0 주요 변경사항 + +### State 공간 재설계 (36 → 162 states) + +**기존 Version 2.0:** + +- State = (Scenario, PriceZone, AcceptanceRate) +- 총 상태 수: 4 × 3 × 3 = **36 states** + +**신규 Version 3.0:** + +- State = (매출액 가격구간, 유통 구조, 파트너사 종류, 가격 수용률 구간, 입력 금액 구간) +- 총 상태 수: 3 × 3 × 3 × 3 × 2 = **162 states** + +| 순번 | 변수 | 구분 | 설명 | +| :--- | :------------------- | :--- | :--------------------------------------------------------- | +| 1 | **매출액 가격구간** | 3개 | Low (≤1,000만원), Mid (1,000~3,000만원), High (>3,000만원) | +| 2 | **유통 구조** | 3개 | 제조, 총판, 유통 | +| 3 | **파트너사 종류** | 3개 | Single (단독), Multiple (다수), None (없음) | +| 4 | **가격 수용률 구간** | 3개 | Low (<30%), Mid (30~90%), High (>90%) | +| 5 | **입력 금액 구간** | 2개 | PZ1 (A ≤ P ≤ T), PZ2 (P > T) | + +### 동적 가중치 (W) 설계 업데이트 + +**기존 Version 2.0:** + +``` +W = (scenario.weight + price_zone.weight) / 2.0 +``` + +**신규 Version 3.0:** + +``` +W_raw = w1×S_amount + w2×S_dist + w3×S_partner + w4×S_accept + w5×S_pricezone +W = clip(W_raw, 0.2, 0.8) +``` + +기본 가중치 계수: + +- w1 = 0.20 (매출액 가격구간) +- w2 = 0.25 (유통 구조) +- w3 = 0.20 (파트너사 종류) +- w4 = 0.25 (가격 수용률) +- w5 = 0.10 (입력 금액 구간) + +## ✅ 완료된 작업 (Version 2.0) + +### 1. 새로운 핵심 컴포넌트 구현 + +#### **QTable** (`domain/model/q_table.py`) + +- ✅ 동적 `action_space_size` 지원 (21개 카드) +- ✅ Q-Learning 업데이트 메서드 +- ✅ 직렬화/역직렬화 지원 +- ✅ 인덱스 유효성 검증 + +#### **EpisodePolicy** (`domain/agents/policy.py`) + +- ✅ 하드코딩된 크기(9) → 동적 크기로 변경 +- ✅ 생성자에서 `action_space_size` 주입 +- ✅ `get_action_mask()` 동적 크기 적용 + +#### **State** (`domain/model/state.py`) + +- ✅ `to_index()` 메서드 추가 (State → 1D 인덱스) +- ✅ `from_index()` 메서드 추가 (1D 인덱스 → State) + +#### **ActionCardMapper** (`integration/action_card_mapper.py`) 🆕 + +- ✅ action_id ↔ card_id 양방향 매핑 +- ✅ JSON 파일 기반 설정 +- ✅ 유틸리티 메서드 제공 + +#### **action_card_mapping.json** (`integration/data/`) 🆕 + +- ✅ 21개 카드 매핑 (0-20 → "no_0" ~ "no_20") + +#### **GetBestActionUsecase** (`usecase/get_best_action_usecase.py`) + +- ✅ State → action_id → card_id 전체 플로우 +- ✅ Policy 사용 여부 선택 가능 +- ✅ Top-K 추천 기능 +- ✅ 사용 가능한 액션/카드 조회 + +#### **CollectExperienceUsecase** (`usecase/collect_experience_usecase.py`) 🆕 + +- ✅ 협상 중 Experience 수집 +- ✅ 에피소드 단위 관리 +- ✅ JSONL 형식 자동 저장 +- ✅ 수집 정보 조회 + +#### **TrainOfflineUsecase** (`usecase/train_offline_usecase.py`) 🆕 + +- ✅ 저장된 Experience로 Q-Table 학습 +- ✅ 배치 학습 및 에포크 반복 +- ✅ 특정 에피소드 선택 학습 +- ✅ 성능 평가 기능 + +### 2. 레거시 코드 정리 + +#### 제거된 파일 + +``` +❌ domain/action_space.py +❌ domain/model/action.py +❌ domain/constants.py +❌ domain/spaces.py +``` + +#### Legacy로 이동된 파일 (참고용 보관) + +``` +📦 legacy/environment.py +📦 legacy/calculate_reward_usecase.py +📦 legacy/evaluate_agent_usecase.py +📦 legacy/execute_step_usecase.py +📦 legacy/get_q_value_usecase.py +📦 legacy/get_state_info_usecase.py +📦 legacy/initialize_env_usecase.py +📦 legacy/load_q_table_usecase.py +📦 legacy/train_agent_usecase.py +📦 legacy/update_q_table_usecase.py +``` + +## 📁 최종 디렉토리 구조 + +``` +negotiation_agent/ +│ +├── card_management/ # 카드 관리 (DB 기반) +│ ├── domain/ +│ │ ├── model/ +│ │ │ └── nego_card.py ✅ DB 엔티티 +│ │ ├── repository/ +│ │ │ ├── nego_card_repository.py +│ │ │ └── nego_card_script_repository.py ✅ JSON CRUD +│ │ └── value/ +│ │ └── nego_card_types.py ✅ 6가지 atomic elements +│ └── data/ +│ └── nego_card_scripts.json ✅ 21개 카드 스크립트 +│ +├── Q_Table/ # Q-Learning (추상 action_id) +│ ├── domain/ +│ │ ├── model/ +│ │ │ ├── state.py ✅ to_index() 추가 +│ │ │ ├── q_table.py ✅ 새로 생성 +│ │ │ └── experience.py ✅ 새로 생성 +│ │ ├── agents/ +│ │ │ ├── policy.py ✅ 동적 크기 수정 +│ │ │ └── offline_agent.py +│ │ ├── repository/ +│ │ │ ├── action_repository.py +│ │ │ └── experience_repository.py ✅ 새로 생성 +│ │ └── service/ +│ │ └── state_calculator.py +│ ├── usecase/ +│ │ ├── get_best_action_usecase.py ✅ 완전 재작성 +│ │ ├── collect_experience_usecase.py ✅ 새로 생성 +│ │ └── train_offline_usecase.py ✅ 새로 생성 +│ ├── infra/ +│ │ ├── data_collector.py +│ │ └── gym/ +│ │ └── env_wrapper.py +│ ├── data/ 🆕 데이터 저장소 +│ │ └── experiences/ +│ │ └── *.jsonl +│ └── legacy/ 📦 레거시 코드 보관 +│ ├── README.md +│ └── ... (10개 파일) +│ +└── integration/ 🆕 매핑 레이어 + ├── action_card_mapper.py ✅ + └── data/ + └── action_card_mapping.json ✅ +``` + +## 🎯 핵심 설계 원칙 + +### 1. 명확한 책임 분리 + +- **Q_Table**: action_id (0~20)만 다룸 +- **Card_Management**: card_id ("no_0"~"no_20")만 다룸 +- **ActionCardMapper**: 둘을 연결하는 단일 책임 + +### 2. 동적 확장성 + +- 카드 추가/삭제 시 `action_card_mapping.json`만 수정 +- Q-Table과 Policy가 자동으로 새 크기에 대응 + +### 3. 유지보수성 + +- 레거시 코드는 legacy/ 폴더에 보관 +- 새로운 코드와 명확히 분리 + +## 🔄 사용 예시 + +### 1. 추론 (Inference) + +```python +from negotiation_agent.Q_Table.domain.model.q_table import QTable +from negotiation_agent.Q_Table.domain.model.visit_table import VisitTable +from negotiation_agent.Q_Table.domain.agents.policy import UCBPolicy +from negotiation_agent.integration.action_card_mapper import ActionCardMapper +from negotiation_agent.Q_Table.usecase.get_best_action_usecase import GetBestActionUsecase + +# 매퍼 로드 +mapper = ActionCardMapper() +action_space_size = mapper.get_action_space_size() # 21 + +# Q-Table 생성 (Version 3.0: 162 states) +q_table = QTable( + state_space_size=162, # 3 x 3 x 3 x 3 x 2 + action_space_size=action_space_size # 21 +) + +# Policy 생성 +visit_table = VisitTable(state_space_size=162, action_space_size=action_space_size) +policy = UCBPolicy(visit_table=visit_table) + +# Usecase 생성 +usecase = GetBestActionUsecase( + q_table=q_table, + policy=policy, + action_card_mapper=mapper +) + +# 협상 추론 (Version 3.0) +from negotiation_agent.Q_Table.domain.model.state import ( + State, + RevenueRange, + DistributionStructure, + PartnerType, + AcceptanceRate, + InputPriceZone, +) + +state = State( + revenue_range=RevenueRange.MID, # 1,000~3,000만원 + distribution=DistributionStructure.WHOLESALER, # 총판 + partner_type=PartnerType.SINGLE, # 단독 파트너 + acceptance_rate=AcceptanceRate.MID, # 30~90% + input_price_zone=InputPriceZone.PZ1, # A≤P≤T +) + +result = usecase.execute(state) +print(result) +# Output: +# { +# 'action_id': 5, +# 'card_id': 'no_5', +# 'q_value': 0.0, +# 'state_index': 1 +# } +``` + +### 2. Experience 수집 + +```python +from negotiation_agent.Q_Table.domain.repository.experience_repository import ExperienceRepository +from negotiation_agent.Q_Table.usecase.collect_experience_usecase import CollectExperienceUsecase + +# Repository 생성 +exp_repo = ExperienceRepository() + +# Usecase 생성 +collect_usecase = CollectExperienceUsecase(exp_repo) + +# 에피소드 시작 +collect_usecase.start_episode("ep_001") + +# 협상 진행 중... +for step in range(10): + # 현재 상태 + current_state = State(...) + + # 액션 선택 (GetBestActionUsecase 사용) + result = usecase.execute(current_state) + action_id = result['action_id'] + + # 액션 실행 후 보상과 다음 상태 받음 + reward = calculate_reward(...) # 보상 계산 + next_state = get_next_state(...) # 다음 상태 + done = check_done(...) # 종료 여부 + + # Experience 수집 + collect_usecase.collect( + state=current_state, + action_id=action_id, + reward=reward, + next_state=next_state, + done=done + ) + + if done: + break + +# 에피소드 종료 +collect_usecase.end_episode() + +# 수집 정보 확인 +info = collect_usecase.get_collection_info() +print(info) +# { +# 'total_experiences': 10, +# 'episodes': 1, +# 'current_episode': None +# } +``` + +### 3. 오프라인 학습 + +```python +from negotiation_agent.Q_Table.usecase.train_offline_usecase import TrainOfflineUsecase + +# Usecase 생성 +train_usecase = TrainOfflineUsecase( + q_table=q_table, + experience_repository=exp_repo +) + +# 학습 실행 +result = train_usecase.train( + filename="experiences.jsonl", + epochs=10, + batch_size=32 +) + +print(result) +# { +# 'total_experiences': 1000, +# 'epochs': 10, +# 'updates': 10000, +# 'avg_loss': 0.05 +# } + +# 성능 평가 +eval_result = train_usecase.evaluate(filename="experiences.jsonl") +print(eval_result) +# { +# 'avg_q_value': 0.5, +# 'avg_reward': 1.0, +# 'total_samples': 1000 +# } +``` + +## 📋 향후 작업 (선택사항) + +- [x] Experience 수집 Usecase 구현 ✅ +- [x] 오프라인 학습 Usecase 구현 ✅ +- [ ] Q-Table Repository 구현 (영속성) +- [ ] 카드 추가 시 Q-Table 자동 확장 기능 +- [ ] 통계적 Q-value 초기화 기능 + +## 📚 참고 문서 + +- 설계 문서: `REFACTORING_GUIDE.md` +- 레거시 코드: `legacy/README.md` diff --git a/src/negotiation_agent/Q_Table/VERSION_3_SUMMARY.md b/src/negotiation_agent/Q_Table/VERSION_3_SUMMARY.md new file mode 100644 index 0000000..281fd05 --- /dev/null +++ b/src/negotiation_agent/Q_Table/VERSION_3_SUMMARY.md @@ -0,0 +1,279 @@ +# Q-Table Version 3.0 변경사항 정리 + +## 📋 변경 개요 + +CHANGELOG.md에 기술된 Version 3.0 업데이트를 코드에 반영했습니다. +- **State 공간**: 36 states → **162 states** +- **변수 구조**: 3개 변수 → **5개 변수** (비즈니스 용어 적용) +- **동적 가중치**: 2변수 평균 → **5변수 가중합** + +--- + +## 🔄 주요 변경 파일 + +### 1. State 모델 (`domain/model/state.py`) + +#### 기존 (Version 2.0) +```python +class Scenario(IntEnum): ... # 4개 값 +class PriceZone(IntEnum): ... # 3개 값 +class AcceptanceRate(IntEnum): ... # 3개 값 + +State = (Scenario, PriceZone, AcceptanceRate) # 4×3×3 = 36 +``` + +#### 변경 (Version 3.0) +```python +class RevenueRange(IntEnum): ... # 3개 값 (매출액 가격구간) +class DistributionStructure(IntEnum): ... # 3개 값 (유통 구조) +class PartnerType(IntEnum): ... # 3개 값 (파트너사 종류) +class AcceptanceRate(IntEnum): ... # 3개 값 (가격 수용률 구간) +class InputPriceZone(IntEnum): ... # 2개 값 (입력 금액 구간) + +State = (RevenueRange, DistributionStructure, PartnerType, + AcceptanceRate, InputPriceZone) # 3×3×3×3×2 = 162 +``` + +**주요 특징:** +- 모든 클래스에 `.weight` 속성 추가 (동적 가중치 계산용) +- 비즈니스 용어로 명명 (매출액, 유통 구조, 파트너사 등) +- 각 변수마다 `from_*()` 클래스 메서드로 분류 로직 제공 + +--- + +### 2. Reward Calculator (`domain/service/reward_calculator.py`) + +#### 기존 (Version 2.0) +```python +def calculate_reward( + scenario: Scenario, + price_zone: PriceZone, + ... +) + +def _calculate_weight(scenario, price_zone, cfg): + raw = (scenario.priority_weight + price_zone.zone_weight) / 2.0 + return clip(raw, 0.2, 0.8) +``` + +#### 변경 (Version 3.0) +```python +def calculate_reward( + revenue_range: RevenueRange, + distribution: DistributionStructure, + partner_type: PartnerType, + acceptance_rate: AcceptanceRate, + input_price_zone: InputPriceZone, + ... +) + +def _calculate_weight(...): + W_raw = ( + config.w1 * revenue_range.weight + + config.w2 * distribution.weight + + config.w3 * partner_type.weight + + config.w4 * acceptance_rate.weight + + config.w5 * input_price_zone.weight + ) + return clip(W_raw, 0.2, 0.8) +``` + +**기본 가중치 계수:** +- w1 = 0.20 (매출액) +- w2 = 0.25 (유통 구조) +- w3 = 0.20 (파트너사) +- w4 = 0.25 (수용률) +- w5 = 0.10 (입력 금액) + +--- + +### 3. State Calculator (`domain/service/state_calculator.py`) + +#### 기존 (Version 2.0) +```python +@dataclass +class NegotiationSnapshot: + scenario_code: str + anchor_price: float + target_price: float + seller_initial_price: float + current_price: float +``` + +#### 변경 (Version 3.0) +```python +@dataclass +class NegotiationSnapshot: + revenue_amount: float # 매출액 (만원) + distribution_code: str # "M", "W", "R" + partner_count: int # 파트너사 수 + anchor_price: float + target_price: float + input_price: float + acceptance_ratio: Optional[float] + initial_price: Optional[float] +``` + +**변경 이유:** +- 실제 비즈니스 데이터에 맞춰 필드 재구성 +- 5개 State 변수를 도출할 수 있는 정보 제공 + +--- + +### 4. 모듈 Export (`domain/model/__init__.py`) + +```python +# Version 3.0 +from .state import ( + AcceptanceRate, + DistributionStructure, + InputPriceZone, + PartnerType, + RevenueRange, + State, +) +``` + +--- + +## 📊 State 변수 상세 명세 + +| 변수 | 클래스 | 값 | 설명 | 가중치 | +|------|--------|---|------|--------| +| 매출액 가격구간 | `RevenueRange` | LOW (≤1,000만원) | 낮은 매출액 | 0.3 | +| | | MID (1,000~3,000만원) | 중간 매출액 | 0.6 | +| | | HIGH (>3,000만원) | 높은 매출액 | 1.0 | +| 유통 구조 | `DistributionStructure` | MANUFACTURER | 제조사 직공급 | 0.2 | +| | | WHOLESALER | 총판 경유 | 0.5 | +| | | RETAILER | 유통 경유 | 1.0 | +| 파트너사 종류 | `PartnerType` | NONE | 파트너 없음 | 0.3 | +| | | SINGLE | 단독 파트너 | 0.5 | +| | | MULTIPLE | 다수 파트너 | 1.0 | +| 가격 수용률 | `AcceptanceRate` | LOW (<30%) | 낮은 수용률 | 0.3 | +| | | MID (30~90%) | 중간 수용률 | 0.6 | +| | | HIGH (>90%) | 높은 수용률 | 1.0 | +| 입력 금액 구간 | `InputPriceZone` | PZ1 (A≤P≤T) | 목표 범위 내 | 1.0 | +| | | PZ2 (P>T or P None: + """ + Args: + agent_params: 에이전트 파라미터 (learning_rate, discount_factor 등) + state_size: 상태 공간 크기 + action_size: 액션 공간 크기 + visit_table: 방문 기록 테이블 (None이면 새로 생성) + """ + self.state_size = state_size + self.action_size = action_size + + # Q-Table 객체 생성 (Composition) + self.q_table = QTable( + state_space_size=state_size, + action_space_size=action_size, + learning_rate=agent_params["learning_rate"], + discount_factor=agent_params["discount_factor"] + ) + + self.visit_table = visit_table or VisitTable(state_size, action_size) + self.policy = UCBPolicy( + visit_table=self.visit_table, + exploration_constant=agent_params.get("exploration_constant", np.sqrt(2.0)), + ) + + def get_action(self, state: int, action_mask=None): + """ + 현재 상태에서 액션 선택 (UCB 정책 사용) + + Args: + state: 현재 상태 인덱스 + action_mask: 가능한 액션 마스크 (Optional) + + Returns: + 선택된 액션 ID + """ + # QTable 객체에서 Q-value 조회 + q_values = self.q_table.get_q_values_for_state(state) + mask = None if action_mask is None else np.asarray(action_mask, dtype=bool) + return self.policy.select_action(state, q_values, available_mask=mask) + + def learn(self, batch): + """ + 배치 데이터를 사용하여 Q-Table 업데이트 + + Args: + batch: 학습 데이터 배치 (observations, actions, rewards, next_observations, terminals) + """ + for state, action, reward, next_state, terminated in zip( + batch["observations"], + batch["actions"], + batch["rewards"], + batch["next_observations"], + batch["terminals"], + ): + # QTable 객체의 update 메서드 사용 + # terminated가 True이면 next_state는 의미가 없거나 None이어야 함 + # QTable.update는 next_state_index가 None이면 종료 상태로 처리 + + next_s = next_state if not terminated else None + + self.q_table.update( + state_index=state, + action_id=action, + reward=reward, + next_state_index=next_s, + done=terminated + ) + + self.visit_table.increment(state, action) + + def save_model(self, path): + """ + Q-Table을 파일로 저장 + + Args: + path: 저장할 파일 경로 + """ + # QTable 내부의 numpy array를 저장 (기존 호환성 유지) + np.save(path, self.q_table.q_values) + print(f"Q-Table saved to {path}") + + def load_q_table(self, file_path): + """ + 파일에서 Q-Table 로드 + + Args: + file_path: 로드할 파일 경로 + """ + if os.path.exists(file_path): + # QTable 객체의 q_values 속성에 직접 할당 + self.q_table.q_values = np.load(file_path) + print(f"Q-Table loaded from {file_path}") + else: + print(f"Error: No Q-Table found at {file_path}") + + def reset_episode(self): + """에피소드 초기화 (정책 상태 등 리셋)""" + self.policy.reset_episode() diff --git a/src/negotiation_agent/Q_Table/domain/agents/policy.py b/src/negotiation_agent/Q_Table/domain/agents/policy.py new file mode 100644 index 0000000..52895b3 --- /dev/null +++ b/src/negotiation_agent/Q_Table/domain/agents/policy.py @@ -0,0 +1,123 @@ +from abc import ABC, abstractmethod +from typing import Optional + +import numpy as np + +from ..model.visit_table import VisitTable + + +class Policy(ABC): + @abstractmethod + def select_action( + self, + state_index: int, + q_values: np.ndarray, + available_mask: Optional[np.ndarray] = None, + ) -> Optional[int]: + """ + 주어진 상태와 Q-value를 기반으로 액션 선택 + + Args: + state_index: 현재 상태 인덱스 + q_values: 해당 상태의 Q-value 배열 + available_mask: 선택 가능한 액션 마스크 (True=선택가능) + + Returns: + 선택된 액션 ID (선택 불가 시 None) + """ + raise NotImplementedError + + def reset_episode(self) -> None: + """에피소드 시작 시 상태 초기화""" + raise NotImplementedError + + def get_action_mask(self) -> np.ndarray: + """ + 현재 정책 상태에 따른 액션 마스크 반환 + + Returns: + 액션 마스크 배열 (1=선택가능, 0=선택불가) + """ + raise NotImplementedError + + +class UCBPolicy(Policy): + """Upper Confidence Bound policy with per-episode action masking.""" + + def __init__( + self, + visit_table: VisitTable, + exploration_constant: float = np.sqrt(2.0), + rng: Optional[np.random.Generator] = None, + ) -> None: + self.visit_table = visit_table + self.action_space_size = visit_table.action_space_size + self.exploration_constant = exploration_constant + self._rng = rng or np.random.default_rng() + self._episode_actions: set[int] = set() + + def select_action( + self, + state_index: int, + q_values: np.ndarray, + available_mask: Optional[np.ndarray] = None, + ) -> Optional[int]: + """ + UCB 알고리즘을 사용하여 액션 선택 + + UCB = Q(s,a) + c * sqrt(ln(N(s)) / N(s,a)) + """ + mask = self._prepare_mask(available_mask) + if not mask.any(): + self.reset_episode() + mask = self._prepare_mask(available_mask) + if not mask.any(): + return None + + counts = self.visit_table.get_state_counts(state_index) + masked_counts = np.where(mask, counts, 0) + + zero_visit_candidates = np.where((counts == 0) & mask)[0] + if zero_visit_candidates.size > 0: + best = zero_visit_candidates[np.argmax(q_values[zero_visit_candidates])] + action = int(best) + else: + total = masked_counts.sum() + # Avoid division by zero while keeping exploration pressure. + denom = counts.astype(float) + 1e-9 + bonus = self.exploration_constant * np.sqrt(np.log(total + 1.0) / denom) + scores = q_values + bonus + scores[~mask] = -np.inf + action = int(np.argmax(scores)) + + self.visit_table.increment(state_index, action) + self._episode_actions.add(action) + return action + + def reset_episode(self) -> None: + """에피소드 내 사용된 액션 기록 초기화""" + self._episode_actions.clear() + + def get_action_mask(self) -> np.ndarray: + """ + 이미 사용된 액션을 제외한 마스크 반환 + + Returns: + 마스크 배열 (1=선택가능, 0=선택불가) + """ + mask = np.ones(self.action_space_size, dtype=int) + if self._episode_actions: + for action in self._episode_actions: + mask[action] = 0 + return mask + + def _prepare_mask(self, available_mask: Optional[np.ndarray]) -> np.ndarray: + """입력 마스크와 이미 사용된 액션을 결합하여 최종 마스크 생성""" + if available_mask is None: + mask = np.ones(self.action_space_size, dtype=bool) + else: + mask = np.asarray(available_mask, dtype=bool).copy() + if self._episode_actions: + for action in self._episode_actions: + mask[action] = False + return mask diff --git a/src/negotiation_agent/Q_Table/domain/model/__init__.py b/src/negotiation_agent/Q_Table/domain/model/__init__.py new file mode 100644 index 0000000..f2fd5d8 --- /dev/null +++ b/src/negotiation_agent/Q_Table/domain/model/__init__.py @@ -0,0 +1,11 @@ +"""Domain model exports for Q-Table (Version 3.0 - 162 states).""" + +from .state import ( # noqa: F401 + AcceptanceRate, + DistributionStructure, + InputPriceZone, + PartnerType, + RevenueRange, + State, +) +from .visit_table import VisitTable # noqa: F401 diff --git a/src/negotiation_agent/Q_Table/domain/model/experience.py b/src/negotiation_agent/Q_Table/domain/model/experience.py new file mode 100644 index 0000000..1e7c3db --- /dev/null +++ b/src/negotiation_agent/Q_Table/domain/model/experience.py @@ -0,0 +1,94 @@ +""" +Experience 모델 +Q-Learning을 위한 경험 데이터 (SARS: State, Action, Reward, next_State) +""" + +from dataclasses import dataclass +from typing import Optional +from .state import State + + +@dataclass +class Experience: + """ + Q-Learning에서 사용하는 경험 데이터 + + SARS 형식: + - State: 현재 상태 + - Action: 선택한 액션 (action_id) + - Reward: 받은 보상 + - next_State: 다음 상태 (종료 시 None) + """ + + state: State + action_id: int + reward: float + next_state: Optional[State] + done: bool # 에피소드 종료 여부 + + # 메타데이터 (선택사항) + episode_id: Optional[str] = None + step: Optional[int] = None + timestamp: Optional[str] = None + + def to_dict(self) -> dict: + """ + Experience를 딕셔너리로 직렬화 + + Returns: + { + 'state': [scenario, price_zone, acceptance_rate], + 'action_id': 5, + 'reward': 1.0, + 'next_state': [scenario, price_zone, acceptance_rate] or None, + 'done': False, + 'episode_id': 'ep_001', + 'step': 3, + 'timestamp': '2025-10-29T16:30:00' + } + """ + return { + "state": self.state.to_array(), + "action_id": self.action_id, + "reward": self.reward, + "next_state": self.next_state.to_array() if self.next_state else None, + "done": self.done, + "episode_id": self.episode_id, + "step": self.step, + "timestamp": self.timestamp, + } + + @classmethod + def from_dict(cls, data: dict) -> "Experience": + """ + 딕셔너리에서 Experience 복원 + + Args: + data: 직렬화된 Experience 데이터 + + Returns: + Experience 인스턴스 + """ + return cls( + state=State.from_array(data["state"]), + action_id=data["action_id"], + reward=data["reward"], + next_state=( + State.from_array(data["next_state"]) if data["next_state"] else None + ), + done=data["done"], + episode_id=data.get("episode_id"), + step=data.get("step"), + timestamp=data.get("timestamp"), + ) + + def __repr__(self) -> str: + return ( + f"Experience(" + f"state={self.state.to_array()}, " + f"action_id={self.action_id}, " + f"reward={self.reward}, " + f"next_state={self.next_state.to_array() if self.next_state else None}, " + f"done={self.done}" + f")" + ) diff --git a/src/negotiation_agent/Q_Table/domain/model/q_table.py b/src/negotiation_agent/Q_Table/domain/model/q_table.py new file mode 100644 index 0000000..120c6c8 --- /dev/null +++ b/src/negotiation_agent/Q_Table/domain/model/q_table.py @@ -0,0 +1,209 @@ +""" +Q-Table 모델 +동적 action_space_size를 지원하여 협상 카드 추가/삭제에 대응 +""" + +import numpy as np +from typing import Optional, Dict + + +class QTable: + """ + Q-Learning을 위한 Q-Table + + Q-Table은 추상적인 action_id만 다루며, + 실제 협상 카드(card_id)는 ActionCardMapper에서 매핑 + """ + + def __init__( + self, + state_space_size: int, + action_space_size: int, + learning_rate: float = 0.1, + discount_factor: float = 0.95, + ): + """ + Args: + state_space_size: 상태 공간 크기 + (Scenario=4 x PriceZone=3 x AcceptanceRate=3 = 36) + action_space_size: 액션 공간 크기 (협상 카드 개수, 현재 21개) + learning_rate: 학습률 (alpha) + discount_factor: 할인 인자 (gamma) + """ + self.state_space_size = state_space_size + self.action_space_size = action_space_size + self.learning_rate = learning_rate + self.discount_factor = discount_factor + + # Q-Table 초기화: (state_size, action_size) 형태 + self.q_values = np.zeros((state_space_size, action_space_size)) + + def get_q_value(self, state_index: int, action_id: int) -> float: + """ + 특정 (state, action)의 Q-value 조회 + + Args: + state_index: State의 인덱스 (0 ~ state_space_size-1) + action_id: Action ID (0 ~ action_space_size-1) + + Returns: + Q-value + """ + self._validate_indices(state_index, action_id) + return float(self.q_values[state_index, action_id]) + + def get_q_values_for_state(self, state_index: int) -> np.ndarray: + """ + 특정 state의 모든 action에 대한 Q-values + + Args: + state_index: State의 인덱스 + + Returns: + shape (action_space_size,)의 Q-values 배열 + """ + if state_index < 0 or state_index >= self.state_space_size: + raise ValueError( + f"state_index {state_index} out of range [0, {self.state_space_size})" + ) + return self.q_values[state_index, :].copy() + + def set_q_value(self, state_index: int, action_id: int, value: float): + """ + 특정 (state, action)의 Q-value 직접 설정 + + Args: + state_index: State의 인덱스 + action_id: Action ID + value: 설정할 Q-value + """ + self._validate_indices(state_index, action_id) + self.q_values[state_index, action_id] = value + + def update( + self, + state_index: int, + action_id: int, + reward: float, + next_state_index: Optional[int] = None, + done: bool = False, + ): + """ + Q-Learning 업데이트 + + Q(s,a) ← Q(s,a) + α[r + γ·max_a'Q(s',a') - Q(s,a)] + + Args: + state_index: 현재 상태 인덱스 + action_id: 선택한 액션 ID + reward: 받은 보상 + next_state_index: 다음 상태 인덱스 (종료 시 None) + done: 에피소드 종료 여부 + """ + self._validate_indices(state_index, action_id) + + current_q = self.q_values[state_index, action_id] + + if done or next_state_index is None: + # 종료 상태: target = reward만 사용 + target = reward + else: + # 비종료 상태: target = reward + γ·max Q(s',a') + if next_state_index < 0 or next_state_index >= self.state_space_size: + raise ValueError(f"next_state_index {next_state_index} out of range") + max_next_q = np.max(self.q_values[next_state_index, :]) + target = reward + self.discount_factor * max_next_q + + # Q-Learning 업데이트 + self.q_values[state_index, action_id] += self.learning_rate * ( + target - current_q + ) + + def get_best_action(self, state_index: int) -> int: + """ + 해당 state에서 최고 Q-value를 가진 action_id 반환 + + Args: + state_index: State의 인덱스 + + Returns: + action_id (Q-value가 최대인 액션) + """ + if state_index < 0 or state_index >= self.state_space_size: + raise ValueError( + f"state_index {state_index} out of range [0, {self.state_space_size})" + ) + return int(np.argmax(self.q_values[state_index, :])) + + def get_best_actions_with_ties(self, state_index: int) -> list: + """ + 동일한 최대 Q-value를 가진 모든 action_id 반환 + + Args: + state_index: State의 인덱스 + + Returns: + 최대 Q-value를 가진 action_id들의 리스트 + """ + if state_index < 0 or state_index >= self.state_space_size: + raise ValueError( + f"state_index {state_index} out of range [0, {self.state_space_size})" + ) + + max_q = np.max(self.q_values[state_index, :]) + max_actions = np.where(self.q_values[state_index, :] == max_q)[0] + return max_actions.tolist() + + def to_dict(self) -> Dict: + """ + Q-Table을 딕셔너리로 직렬화 + + Returns: + 직렬화된 Q-Table 데이터 + """ + return { + "state_space_size": self.state_space_size, + "action_space_size": self.action_space_size, + "learning_rate": self.learning_rate, + "discount_factor": self.discount_factor, + "q_values": self.q_values.tolist(), + } + + @classmethod + def from_dict(cls, data: Dict) -> "QTable": + """ + 딕셔너리에서 Q-Table 복원 + + Args: + data: 직렬화된 Q-Table 데이터 + + Returns: + 복원된 QTable 인스턴스 + """ + q_table = cls( + state_space_size=data["state_space_size"], + action_space_size=data["action_space_size"], + learning_rate=data["learning_rate"], + discount_factor=data["discount_factor"], + ) + q_table.q_values = np.array(data["q_values"]) + return q_table + + def _validate_indices(self, state_index: int, action_id: int): + """인덱스 유효성 검증""" + if state_index < 0 or state_index >= self.state_space_size: + raise ValueError( + f"state_index {state_index} out of range [0, {self.state_space_size})" + ) + if action_id < 0 or action_id >= self.action_space_size: + raise ValueError( + f"action_id {action_id} out of range [0, {self.action_space_size})" + ) + + def __repr__(self) -> str: + return ( + f"QTable(state_space_size={self.state_space_size}, " + f"action_space_size={self.action_space_size}, " + f"learning_rate={self.learning_rate}, " + f"discount_factor={self.discount_factor})" + ) diff --git a/src/negotiation_agent/Q_Table/domain/model/state.py b/src/negotiation_agent/Q_Table/domain/model/state.py new file mode 100644 index 0000000..cc05e59 --- /dev/null +++ b/src/negotiation_agent/Q_Table/domain/model/state.py @@ -0,0 +1,258 @@ +from dataclasses import dataclass +from enum import IntEnum, nonmember +from typing import List + + +class RevenueRange(IntEnum): + """매출액 가격구간 (Revenue Price Range)""" + + LOW = 0 # ≤ 1,000만원 + MID = 1 # 1,000만원 ~ 3,000만원 + HIGH = 2 # > 3,000만원 + + _DESCRIPTIONS = nonmember(("Low (≤1,000만원)", "Mid (1,000~3,000만원)", "High (>3,000만원)")) + _WEIGHTS = nonmember((0.3, 0.6, 1.0)) # 매출액이 높을수록 협상 여지 큼 + + @property + def description(self) -> str: + return self._DESCRIPTIONS[self.value] + + @property + def weight(self) -> float: + return self._WEIGHTS[self.value] + + @classmethod + def from_amount(cls, amount: float) -> "RevenueRange": + """매출액(만원 단위)으로부터 구간 결정""" + if amount <= 1000: + return cls.LOW + elif amount <= 3000: + return cls.MID + else: + return cls.HIGH + + +class DistributionStructure(IntEnum): + """유통 구조 (Distribution Structure)""" + + MANUFACTURER = 0 # 제조 + WHOLESALER = 1 # 총판 + RETAILER = 2 # 유통 + + _DESCRIPTIONS = nonmember(("제조", "총판", "유통")) + _WEIGHTS = nonmember((0.2, 0.5, 1.0)) # 유통 단계가 복잡할수록 협상 여지 큼 + + @property + def description(self) -> str: + return self._DESCRIPTIONS[self.value] + + @property + def weight(self) -> float: + return self._WEIGHTS[self.value] + + @classmethod + def from_code(cls, code: str) -> "DistributionStructure": + """코드로부터 유통 구조 결정""" + normalized = (code or "").strip().upper() + mapping = {"M": cls.MANUFACTURER, "W": cls.WHOLESALER, "R": cls.RETAILER} + if normalized not in mapping: + raise ValueError(f"unknown distribution code: {code}") + return mapping[normalized] + + +class PartnerType(IntEnum): + """파트너사 종류 (Partner Type)""" + + SINGLE = 0 # 단독 + MULTIPLE = 1 # 다수 + NONE = 2 # 없음 + + _DESCRIPTIONS = nonmember(("Single (단독)", "Multiple (다수)", "None (없음)")) + _WEIGHTS = nonmember((0.5, 1.0, 0.3)) # 다수 파트너사일수록 협상 복잡도 증가 + + @property + def description(self) -> str: + return self._DESCRIPTIONS[self.value] + + @property + def weight(self) -> float: + return self._WEIGHTS[self.value] + + @classmethod + def from_count(cls, count: int) -> "PartnerType": + """파트너사 수로부터 구간 결정""" + if count == 0: + return cls.NONE + elif count == 1: + return cls.SINGLE + else: + return cls.MULTIPLE + + +class AcceptanceRate(IntEnum): + """가격 수용률 구간 (Price Acceptance Rate)""" + + LOW = 0 # < 30% + MID = 1 # 30% ~ 90% + HIGH = 2 # > 90% + + _DESCRIPTIONS = nonmember(("Low (<30%)", "Mid (30~90%)", "High (>90%)")) + _WEIGHTS = nonmember((0.3, 0.6, 1.0)) # 수용률이 높을수록 협상 성공 가능성 높음 + + @property + def description(self) -> str: + return self._DESCRIPTIONS[self.value] + + @property + def weight(self) -> float: + return self._WEIGHTS[self.value] + + @classmethod + def from_ratio(cls, ratio: float) -> "AcceptanceRate": + """수용률(0~1)로부터 구간 결정""" + normalized = max(0.0, min(1.0, ratio)) + if normalized < 0.30: + return cls.LOW + elif normalized <= 0.90: + return cls.MID + else: + return cls.HIGH + + +class InputPriceZone(IntEnum): + """입력 금액 구간 (Input Price Zone)""" + + PZ1 = 0 # A ≤ P ≤ T (앵커가격 ~ 목표가격) + PZ2 = 1 # P > T (목표가격 초과) + + _DESCRIPTIONS = nonmember(("PZ1 (A≤P≤T)", "PZ2 (P>T)")) + _WEIGHTS = nonmember((1.0, 0.3)) # 목표가격 이내일수록 협상 유리 + + @property + def description(self) -> str: + return self._DESCRIPTIONS[self.value] + + @property + def weight(self) -> float: + return self._WEIGHTS[self.value] + + @classmethod + def from_prices( + cls, + input_price: float, + anchor_price: float, + target_price: float, + ) -> "InputPriceZone": + """입력가격, 앵커가격, 목표가격으로부터 구간 결정""" + if anchor_price <= 0 or target_price <= 0: + raise ValueError("anchor_price and target_price must be positive") + if anchor_price > target_price: + raise ValueError("anchor_price must not exceed target_price") + + if anchor_price <= input_price <= target_price: + return cls.PZ1 + else: + return cls.PZ2 + + +@dataclass +class State: + """ + Q-Learning State 표현 (Version 3.0 - 162 states) + + State = (매출액 가격구간, 유통 구조, 파트너사 종류, 가격 수용률 구간, 입력 금액 구간) + Total: 3 × 3 × 3 × 3 × 2 = 162 states + """ + + revenue_range: RevenueRange + distribution: DistributionStructure + partner_type: PartnerType + acceptance_rate: AcceptanceRate + input_price_zone: InputPriceZone + + def to_array(self) -> List[int]: + """State를 배열로 변환""" + return [ + self.revenue_range.value, + self.distribution.value, + self.partner_type.value, + self.acceptance_rate.value, + self.input_price_zone.value, + ] + + def to_index(self) -> int: + """ + State를 1차원 인덱스로 변환 (Q-Table 인덱싱용) + + State space: 3 × 3 × 3 × 3 × 2 = 162 + + Returns: + 0 ~ 161 사이의 인덱스 + """ + # 5D to 1D index conversion + # index = revenue * (3*3*3*2) + dist * (3*3*2) + partner * (3*2) + accept * 2 + pricezone + return ( + self.revenue_range.value * 54 + + self.distribution.value * 18 + + self.partner_type.value * 6 + + self.acceptance_rate.value * 2 + + self.input_price_zone.value + ) + + @classmethod + def from_index(cls, index: int) -> "State": + """ + 1차원 인덱스에서 State 복원 + + Args: + index: 0 ~ 161 사이의 인덱스 + + Returns: + State 객체 + """ + if index < 0 or index >= 162: + raise ValueError(f"index {index} out of range [0, 162)") + + # 1D to 5D index conversion + revenue_value = index // 54 + remainder = index % 54 + + dist_value = remainder // 18 + remainder = remainder % 18 + + partner_value = remainder // 6 + remainder = remainder % 6 + + accept_value = remainder // 2 + pricezone_value = remainder % 2 + + return cls( + revenue_range=RevenueRange(revenue_value), + distribution=DistributionStructure(dist_value), + partner_type=PartnerType(partner_value), + acceptance_rate=AcceptanceRate(accept_value), + input_price_zone=InputPriceZone(pricezone_value), + ) + + @classmethod + def from_array(cls, arr: List[int]) -> "State": + """배열로부터 State 생성""" + if len(arr) != 5: + raise ValueError(f"Expected 5 elements, got {len(arr)}") + return cls( + revenue_range=RevenueRange(arr[0]), + distribution=DistributionStructure(arr[1]), + partner_type=PartnerType(arr[2]), + acceptance_rate=AcceptanceRate(arr[3]), + input_price_zone=InputPriceZone(arr[4]), + ) + + def __str__(self) -> str: + return ( + f"State(" + f"revenue={self.revenue_range.description}, " + f"distribution={self.distribution.description}, " + f"partner={self.partner_type.description}, " + f"acceptance={self.acceptance_rate.description}, " + f"price_zone={self.input_price_zone.description})" + ) diff --git a/src/negotiation_agent/Q_Table/domain/model/visit_table.py b/src/negotiation_agent/Q_Table/domain/model/visit_table.py new file mode 100644 index 0000000..b1608c9 --- /dev/null +++ b/src/negotiation_agent/Q_Table/domain/model/visit_table.py @@ -0,0 +1,148 @@ +"""Visit (N) table for UCB-based policies.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict + +import numpy as np + + +@dataclass +class VisitTable: + """Tracks state-action visit counts for UCB selection.""" + + state_space_size: int + action_space_size: int + + def __post_init__(self) -> None: + """ + Args: + state_space_size: 상태 공간 크기 + action_space_size: 액션 공간 크기 + """ + self._counts = np.zeros( + (self.state_space_size, self.action_space_size), dtype=np.int64 + ) + + def increment(self, state_index: int, action_id: int, count: int = 1) -> None: + """ + 특정 (state, action)의 방문 횟수 증가 + + Args: + state_index: State 인덱스 + action_id: Action ID + count: 증가시킬 횟수 (기본 1) + """ + if state_index < 0 or state_index >= self.state_space_size: + raise ValueError( + f"state_index {state_index} out of range [0, {self.state_space_size})" + ) + if action_id < 0 or action_id >= self.action_space_size: + raise ValueError( + f"action_id {action_id} out of range [0, {self.action_space_size})" + ) + if count < 0: + raise ValueError("count must be non-negative") + self._counts[state_index, action_id] += count + + def get_action_count(self, state_index: int, action_id: int) -> int: + """ + 특정 (state, action)의 방문 횟수 조회 + + Args: + state_index: State 인덱스 + action_id: Action ID + + Returns: + 방문 횟수 + """ + self._validate_indices(state_index, action_id) + return int(self._counts[state_index, action_id]) + + def get_state_counts(self, state_index: int) -> np.ndarray: + """ + 특정 state의 모든 action에 대한 방문 횟수 조회 + + Args: + state_index: State 인덱스 + + Returns: + shape (action_space_size,)의 방문 횟수 배열 + """ + if state_index < 0 or state_index >= self.state_space_size: + raise ValueError( + f"state_index {state_index} out of range [0, {self.state_space_size})" + ) + return self._counts[state_index, :].copy() + + def total_visits(self, state_index: int) -> int: + """ + 특정 state의 총 방문 횟수 (모든 action 합계) + + Args: + state_index: State 인덱스 + + Returns: + 총 방문 횟수 + """ + if state_index < 0 or state_index >= self.state_space_size: + raise ValueError( + f"state_index {state_index} out of range [0, {self.state_space_size})" + ) + return int(self._counts[state_index, :].sum()) + + def reset_state(self, state_index: int) -> None: + """ + 특정 state의 방문 기록 초기화 + + Args: + state_index: State 인덱스 + """ + if state_index < 0 or state_index >= self.state_space_size: + raise ValueError( + f"state_index {state_index} out of range [0, {self.state_space_size})" + ) + self._counts[state_index, :] = 0 + + def to_dict(self) -> Dict: + """ + VisitTable을 딕셔너리로 직렬화 + + Returns: + 직렬화된 VisitTable 데이터 + """ + return { + "state_space_size": self.state_space_size, + "action_space_size": self.action_space_size, + "counts": self._counts.tolist(), + } + + @classmethod + def from_dict(cls, data: Dict) -> "VisitTable": + """ + 딕셔너리에서 VisitTable 복원 + + Args: + data: 직렬화된 VisitTable 데이터 + + Returns: + 복원된 VisitTable 인스턴스 + """ + table = cls( + state_space_size=data["state_space_size"], + action_space_size=data["action_space_size"], + ) + table._counts = np.array(data["counts"], dtype=np.int64) + return table + + def _validate_indices(self, state_index: int, action_id: int) -> None: + """인덱스 유효성 검증""" + if state_index < 0 or state_index >= self.state_space_size: + raise ValueError( + f"state_index {state_index} out of range [0, {self.state_space_size})" + ) + if action_id < 0 or action_id >= self.action_space_size: + raise ValueError( + f"action_id {action_id} out of range [0, {self.action_space_size})" + ) diff --git a/src/negotiation_agent/Q_Table/domain/repository/experience_repository.py b/src/negotiation_agent/Q_Table/domain/repository/experience_repository.py new file mode 100644 index 0000000..17e93e6 --- /dev/null +++ b/src/negotiation_agent/Q_Table/domain/repository/experience_repository.py @@ -0,0 +1,200 @@ +""" +Experience Repository +Experience 데이터를 JSONL 형식으로 저장/로드 +""" + +import json +from pathlib import Path +from typing import List, Optional +from datetime import datetime +from ..model.experience import Experience + + +class ExperienceRepository: + """ + Experience 데이터를 파일 시스템에 저장/로드하는 Repository + + JSONL (JSON Lines) 형식 사용: + - 각 줄이 하나의 JSON 객체 (Experience) + - 스트리밍 방식으로 읽기/쓰기 가능 + - 대용량 데이터 처리에 유리 + """ + + def __init__(self, data_dir: Optional[str] = None): + """ + Args: + data_dir: Experience 데이터를 저장할 디렉토리 + (기본: Q_Table/data/experiences/) + """ + if data_dir is None: + current_dir = Path(__file__).parent.parent.parent + data_dir = current_dir / "data" / "experiences" + + self.data_dir = Path(data_dir) + self.data_dir.mkdir(parents=True, exist_ok=True) + + def save(self, experience: Experience, filename: str = "experiences.jsonl"): + """ + 단일 Experience를 파일에 추가 저장 + + Args: + experience: 저장할 Experience + filename: 파일명 (기본: experiences.jsonl) + """ + file_path = self.data_dir / filename + + with open(file_path, "a", encoding="utf-8") as f: + json.dump(experience.to_dict(), f, ensure_ascii=False) + f.write("\n") + + def save_batch( + self, experiences: List[Experience], filename: str = "experiences.jsonl" + ): + """ + 여러 Experience를 배치로 저장 + + Args: + experiences: Experience 리스트 + filename: 파일명 + """ + file_path = self.data_dir / filename + + with open(file_path, "a", encoding="utf-8") as f: + for exp in experiences: + json.dump(exp.to_dict(), f, ensure_ascii=False) + f.write("\n") + + def load_all(self, filename: str = "experiences.jsonl") -> List[Experience]: + """ + 파일에서 모든 Experience 로드 + + Args: + filename: 파일명 + + Returns: + Experience 리스트 + """ + file_path = self.data_dir / filename + + if not file_path.exists(): + return [] + + experiences = [] + with open(file_path, "r", encoding="utf-8") as f: + for line in f: + if line.strip(): + data = json.loads(line) + experiences.append(Experience.from_dict(data)) + + return experiences + + def load_by_episode( + self, episode_id: str, filename: str = "experiences.jsonl" + ) -> List[Experience]: + """ + 특정 에피소드의 Experience들만 로드 + + Args: + episode_id: 에피소드 ID + filename: 파일명 + + Returns: + 해당 에피소드의 Experience 리스트 + """ + all_experiences = self.load_all(filename) + return [exp for exp in all_experiences if exp.episode_id == episode_id] + + def count(self, filename: str = "experiences.jsonl") -> int: + """ + 저장된 Experience 개수 + + Args: + filename: 파일명 + + Returns: + Experience 개수 + """ + file_path = self.data_dir / filename + + if not file_path.exists(): + return 0 + + with open(file_path, "r", encoding="utf-8") as f: + return sum(1 for line in f if line.strip()) + + def clear(self, filename: str = "experiences.jsonl"): + """ + 파일 내용 삭제 + + Args: + filename: 파일명 + """ + file_path = self.data_dir / filename + + if file_path.exists(): + file_path.unlink() + + def create_new_file(self, prefix: str = "exp") -> str: + """ + 타임스탬프가 포함된 새 파일 생성 + + Args: + prefix: 파일명 접두사 + + Returns: + 생성된 파일명 + + Example: + >>> repo.create_new_file("train") + 'train_20251029_163000.jsonl' + """ + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"{prefix}_{timestamp}.jsonl" + file_path = self.data_dir / filename + file_path.touch() + return filename + + def list_files(self) -> List[str]: + """ + 데이터 디렉토리의 모든 JSONL 파일 목록 + + Returns: + 파일명 리스트 + """ + return [f.name for f in self.data_dir.glob("*.jsonl")] + + def get_file_info(self, filename: str) -> dict: + """ + 파일 정보 조회 + + Args: + filename: 파일명 + + Returns: + { + 'filename': 'experiences.jsonl', + 'path': '/path/to/file', + 'size': 1024, # bytes + 'count': 100, # experience 개수 + 'created': '2025-10-29 16:30:00' + } + """ + file_path = self.data_dir / filename + + if not file_path.exists(): + return None + + stat = file_path.stat() + + return { + "filename": filename, + "path": str(file_path), + "size": stat.st_size, + "count": self.count(filename), + "created": datetime.fromtimestamp(stat.st_ctime).strftime( + "%Y-%m-%d %H:%M:%S" + ), + "modified": datetime.fromtimestamp(stat.st_mtime).strftime( + "%Y-%m-%d %H:%M:%S" + ), + } diff --git a/src/negotiation_agent/Q_Table/domain/service/reward_calculator.py b/src/negotiation_agent/Q_Table/domain/service/reward_calculator.py new file mode 100644 index 0000000..a230e31 --- /dev/null +++ b/src/negotiation_agent/Q_Table/domain/service/reward_calculator.py @@ -0,0 +1,166 @@ +"""Reward function aligned with the Q-Table design specification (Version 3.0).""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum +from typing import Optional + +from ..model.state import ( + AcceptanceRate, + DistributionStructure, + InputPriceZone, + PartnerType, + RevenueRange, +) + + +class NegotiationOutcome(Enum): + ONGOING = "ongoing" + SUCCESS = "success" + FAILURE = "failure" + + +@dataclass +class RewardConfig: + """보상 함수 설정""" + + beta: float = 0.2 + success_reward: float = 1.0 + ongoing_reward: float = 0.0 + failure_penalty: float = -0.5 + penalty_lambda: float = 0.02 + + # 동적 가중치 계수 (Version 3.0) + w1: float = 0.2 # 매출액 가격구간 + w2: float = 0.25 # 유통 구조 + w3: float = 0.2 # 파트너사 종류 + w4: float = 0.25 # 가격 수용률 + w5: float = 0.1 # 입력 금액 구간 + + min_weight: float = 0.2 + max_weight: float = 0.8 + + +@dataclass +class RewardBreakdown: + price_reward: float + end_reward: float + penalty: float + weight: float + total: float + + +def calculate_reward( + *, + revenue_range: RevenueRange, + distribution: DistributionStructure, + partner_type: PartnerType, + acceptance_rate: AcceptanceRate, + input_price_zone: InputPriceZone, + current_price: float, + anchor_price: float, + target_price: float, + round_number: int, + outcome: NegotiationOutcome, + config: Optional[RewardConfig] = None, +) -> RewardBreakdown: + """ + 보상 계산 (Version 3.0) + + R = W × R_price + (1-W) × R_end - λ × t + + W = clip(W_raw, min_weight, max_weight) + W_raw = w1×S_amount + w2×S_dist + w3×S_partner + w4×S_accept + w5×S_pricezone + """ + cfg = config or RewardConfig() + + price_reward = _calculate_price_reward( + current_price=current_price, + anchor_price=anchor_price, + target_price=target_price, + beta=cfg.beta, + ) + end_reward = _calculate_end_reward(outcome, cfg) + weight = _calculate_weight( + revenue_range=revenue_range, + distribution=distribution, + partner_type=partner_type, + acceptance_rate=acceptance_rate, + input_price_zone=input_price_zone, + config=cfg, + ) + penalty = _calculate_penalty(round_number, cfg) + + total = weight * price_reward + (1.0 - weight) * end_reward - penalty + return RewardBreakdown( + price_reward=price_reward, + end_reward=end_reward, + penalty=penalty, + weight=weight, + total=total, + ) + + +def _calculate_price_reward( + *, + current_price: float, + anchor_price: float, + target_price: float, + beta: float, +) -> float: + if anchor_price <= 0 or target_price <= 0: + raise ValueError("anchor_price and target_price must be positive") + if anchor_price > target_price: + raise ValueError("anchor_price must not exceed target_price") + + if current_price < anchor_price: + improvement = (anchor_price - current_price) / anchor_price + return 1.0 + beta * improvement + + if anchor_price <= current_price <= target_price: + span = target_price - anchor_price + if span <= 0: + return 0.0 + return (target_price - current_price) / span + + return 0.0 + + +def _calculate_end_reward(outcome: NegotiationOutcome, cfg: RewardConfig) -> float: + if outcome is NegotiationOutcome.SUCCESS: + return cfg.success_reward + if outcome is NegotiationOutcome.FAILURE: + return cfg.failure_penalty + return cfg.ongoing_reward + + +def _calculate_weight( + *, + revenue_range: RevenueRange, + distribution: DistributionStructure, + partner_type: PartnerType, + acceptance_rate: AcceptanceRate, + input_price_zone: InputPriceZone, + config: RewardConfig, +) -> float: + """ + 동적 가중치 계산 (Version 3.0) + + W_raw = w1×S_amount + w2×S_dist + w3×S_partner + w4×S_accept + w5×S_pricezone + W = clip(W_raw, min_weight, max_weight) + """ + w_raw = ( + config.w1 * revenue_range.weight + + config.w2 * distribution.weight + + config.w3 * partner_type.weight + + config.w4 * acceptance_rate.weight + + config.w5 * input_price_zone.weight + ) + return max(config.min_weight, min(config.max_weight, w_raw)) + + +def _calculate_penalty(round_number: int, cfg: RewardConfig) -> float: + if round_number < 0: + raise ValueError("round_number must be non-negative") + return cfg.penalty_lambda * float(round_number) diff --git a/src/negotiation_agent/Q_Table/domain/service/state_calculator.py b/src/negotiation_agent/Q_Table/domain/service/state_calculator.py new file mode 100644 index 0000000..8e4d870 --- /dev/null +++ b/src/negotiation_agent/Q_Table/domain/service/state_calculator.py @@ -0,0 +1,87 @@ +"""Utility helpers to build Q-Table states from negotiation snapshots (Version 3.0).""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +from ..model.state import ( + AcceptanceRate, + DistributionStructure, + InputPriceZone, + PartnerType, + RevenueRange, + State, +) + + +@dataclass(frozen=True) +class NegotiationSnapshot: + """협상 스냅샷 - 실제 비즈니스 데이터""" + + # 매출액 (만원 단위) + revenue_amount: float + + # 유통 구조 코드 ("M": 제조, "W": 총판, "R": 유통) + distribution_code: str + + # 파트너사 수 + partner_count: int + + # 가격 정보 + anchor_price: float + target_price: float + input_price: float + + # 가격 수용률 (0~1, 계산되거나 직접 입력) + acceptance_ratio: Optional[float] = None + + # 초기 가격 (수용률 계산용, acceptance_ratio가 없을 때) + initial_price: Optional[float] = None + + +def build_state(snapshot: NegotiationSnapshot) -> State: + """ + 협상 스냅샷으로부터 Q-Learning State 생성 (Version 3.0) + + Args: + snapshot: 협상 스냅샷 + + Returns: + 162-dimensional state + """ + # 1. 매출액 가격구간 + revenue_range = RevenueRange.from_amount(snapshot.revenue_amount) + + # 2. 유통 구조 + distribution = DistributionStructure.from_code(snapshot.distribution_code) + + # 3. 파트너사 종류 + partner_type = PartnerType.from_count(snapshot.partner_count) + + # 4. 가격 수용률 구간 + if snapshot.acceptance_ratio is not None: + acceptance_rate = AcceptanceRate.from_ratio(snapshot.acceptance_ratio) + elif snapshot.initial_price is not None: + # 초기가격으로부터 수용률 계산: (initial - input) / initial + if snapshot.initial_price <= 0: + raise ValueError("initial_price must be positive") + ratio = (snapshot.initial_price - snapshot.input_price) / snapshot.initial_price + acceptance_rate = AcceptanceRate.from_ratio(ratio) + else: + raise ValueError("Either acceptance_ratio or initial_price must be provided") + + # 5. 입력 금액 구간 + input_price_zone = InputPriceZone.from_prices( + input_price=snapshot.input_price, + anchor_price=snapshot.anchor_price, + target_price=snapshot.target_price, + ) + + return State( + revenue_range=revenue_range, + distribution=distribution, + partner_type=partner_type, + acceptance_rate=acceptance_rate, + input_price_zone=input_price_zone, + ) diff --git a/src/negotiation_agent/Q_Table/infra/repository/model_repository.py b/src/negotiation_agent/Q_Table/infra/repository/model_repository.py new file mode 100644 index 0000000..5914092 --- /dev/null +++ b/src/negotiation_agent/Q_Table/infra/repository/model_repository.py @@ -0,0 +1,46 @@ +import json +from pathlib import Path +from typing import Optional, Tuple + +from src.negotiation_agent.Q_Table.domain.model.q_table import QTable +from src.negotiation_agent.Q_Table.domain.model.visit_table import VisitTable + + +class ModelRepository: + """Simple repository to save/load Q-Table and VisitTable models.""" + + def __init__(self, model_dir: Optional[str] = None): + if model_dir is None: + # Current: src/negotiation_agent/Q_Table/infra/repository/model_repository.py + # Target Data: src/negotiation_agent/Q_Table/data/model + # Path: ../../../data/model + current_dir = Path(__file__).parent + model_dir = current_dir.parent.parent / "data" / "model" + + self.model_dir = Path(model_dir) + self.model_dir.mkdir(parents=True, exist_ok=True) + self.q_table_path = self.model_dir / "q_table.json" + self.visit_table_path = self.model_dir / "visit_table.json" + + def save(self, q_table: QTable, visit_table: VisitTable): + with open(self.q_table_path, "w", encoding="utf-8") as f: + json.dump(q_table.to_dict(), f, indent=2) + + with open(self.visit_table_path, "w", encoding="utf-8") as f: + json.dump(visit_table.to_dict(), f, indent=2) + + print(f"Models saved to {self.model_dir}") + + def load(self) -> Tuple[Optional[QTable], Optional[VisitTable]]: + q_table = None + visit_table = None + + if self.q_table_path.exists(): + with open(self.q_table_path, "r", encoding="utf-8") as f: + q_table = QTable.from_dict(json.load(f)) + + if self.visit_table_path.exists(): + with open(self.visit_table_path, "r", encoding="utf-8") as f: + visit_table = VisitTable.from_dict(json.load(f)) + + return q_table, visit_table diff --git a/src/negotiation_agent/Q_Table/main.pdf b/src/negotiation_agent/Q_Table/main.pdf new file mode 100644 index 0000000..a274719 Binary files /dev/null and b/src/negotiation_agent/Q_Table/main.pdf differ diff --git a/src/negotiation_agent/Q_Table/usecase/__init__.py b/src/negotiation_agent/Q_Table/usecase/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/negotiation_agent/Q_Table/usecase/collect_experience_usecase.py b/src/negotiation_agent/Q_Table/usecase/collect_experience_usecase.py new file mode 100644 index 0000000..4995e0d --- /dev/null +++ b/src/negotiation_agent/Q_Table/usecase/collect_experience_usecase.py @@ -0,0 +1,143 @@ +""" +CollectExperienceUsecase +협상 중 Experience 데이터를 수집하여 저장 +""" + +from typing import Optional +from datetime import datetime +from ..domain.model.state import State +from ..domain.model.experience import Experience +from ..domain.repository.experience_repository import ExperienceRepository + + +class CollectExperienceUsecase: + """ + 협상 과정에서 발생하는 Experience를 수집하여 저장 + + 오프라인 학습을 위한 데이터 수집 담당 + """ + + def __init__( + self, + experience_repository: ExperienceRepository, + filename: Optional[str] = None, + ): + """ + Args: + experience_repository: Experience 저장소 + filename: 저장할 파일명 (None이면 기본 파일 사용) + """ + self.experience_repository = experience_repository + self.filename = filename or "experiences.jsonl" + self.current_episode_id: Optional[str] = None + self.current_step: int = 0 + + def start_episode(self, episode_id: Optional[str] = None): + """ + 새로운 에피소드 시작 + + Args: + episode_id: 에피소드 ID (None이면 자동 생성) + """ + if episode_id is None: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") + episode_id = f"ep_{timestamp}" + + self.current_episode_id = episode_id + self.current_step = 0 + + def collect( + self, + state: State, + action_id: int, + reward: float, + next_state: Optional[State], + done: bool, + ) -> Experience: + """ + 단일 Experience 수집 및 저장 + + Args: + state: 현재 상태 + action_id: 선택한 액션 ID + reward: 받은 보상 + next_state: 다음 상태 (종료 시 None) + done: 에피소드 종료 여부 + + Returns: + 수집된 Experience + """ + if self.current_episode_id is None: + self.start_episode() + + # Experience 생성 + experience = Experience( + state=state, + action_id=action_id, + reward=reward, + next_state=next_state, + done=done, + episode_id=self.current_episode_id, + step=self.current_step, + timestamp=datetime.now().isoformat(), + ) + + # 저장 + self.experience_repository.save(experience, self.filename) + + # 스텝 증가 + self.current_step += 1 + + return experience + + def end_episode(self): + """에피소드 종료""" + self.current_episode_id = None + self.current_step = 0 + + def get_episode_count(self) -> int: + """수집된 에피소드 개수 (근사치)""" + # 파일에서 unique episode_id 개수 계산 + experiences = self.experience_repository.load_all(self.filename) + episode_ids = set(exp.episode_id for exp in experiences if exp.episode_id) + return len(episode_ids) + + def get_total_count(self) -> int: + """수집된 총 Experience 개수""" + return self.experience_repository.count(self.filename) + + def get_collection_info(self) -> dict: + """ + 수집 정보 조회 + + Returns: + { + 'filename': 'experiences.jsonl', + 'total_experiences': 1000, + 'episodes': 50, + 'current_episode': 'ep_001' or None, + 'current_step': 5 + } + """ + return { + "filename": self.filename, + "total_experiences": self.get_total_count(), + "episodes": self.get_episode_count(), + "current_episode": self.current_episode_id, + "current_step": self.current_step, + "file_info": self.experience_repository.get_file_info(self.filename), + } + + def create_new_collection_file(self, prefix: str = "exp") -> str: + """ + 새로운 수집 파일 생성 및 전환 + + Args: + prefix: 파일명 접두사 + + Returns: + 생성된 파일명 + """ + new_filename = self.experience_repository.create_new_file(prefix) + self.filename = new_filename + return new_filename diff --git a/src/negotiation_agent/Q_Table/usecase/evaluate_agent_usecase.py b/src/negotiation_agent/Q_Table/usecase/evaluate_agent_usecase.py new file mode 100644 index 0000000..6a59e38 --- /dev/null +++ b/src/negotiation_agent/Q_Table/usecase/evaluate_agent_usecase.py @@ -0,0 +1,72 @@ +""" +EvaluateAgentUsecase +학습된 Q-Table의 성능을 평가하는 Usecase +""" + +from typing import Optional +from ..domain.model.q_table import QTable +from ..domain.repository.experience_repository import ExperienceRepository + + +class EvaluateAgentUsecase: + """ + 저장된 Experience 데이터를 사용하여 현재 에이전트(Q-Table)의 성능을 평가 + """ + + def __init__( + self, + q_table: QTable, + experience_repository: ExperienceRepository, + ): + """ + Args: + q_table: 평가할 Q-Table + experience_repository: Experience 저장소 + """ + self.q_table = q_table + self.experience_repository = experience_repository + + def execute( + self, filename: str = "experiences.jsonl", sample_size: Optional[int] = None + ) -> dict: + """ + 현재 Q-Table의 성능 평가 + + Args: + filename: 평가용 Experience 파일 + sample_size: 샘플링 크기 (None이면 전체) + + Returns: + { + 'avg_q_value': 0.5, + 'avg_reward': 1.0, + 'total_samples': 100 + } + """ + experiences = self.experience_repository.load_all(filename) + + if not experiences: + return {"avg_q_value": 0.0, "avg_reward": 0.0, "total_samples": 0} + + # 샘플링 + if sample_size and sample_size < len(experiences): + import random + + experiences = random.sample(experiences, sample_size) + + total_q_value = 0.0 + total_reward = 0.0 + + for exp in experiences: + state_index = exp.state.to_index() + q_value = self.q_table.get_q_value(state_index, exp.action_id) + total_q_value += q_value + total_reward += exp.reward + + n = len(experiences) + + return { + "avg_q_value": total_q_value / n, + "avg_reward": total_reward / n, + "total_samples": n, + } diff --git a/src/negotiation_agent/Q_Table/usecase/get_best_action_usecase.py b/src/negotiation_agent/Q_Table/usecase/get_best_action_usecase.py new file mode 100644 index 0000000..2f5c127 --- /dev/null +++ b/src/negotiation_agent/Q_Table/usecase/get_best_action_usecase.py @@ -0,0 +1,179 @@ +""" +GetBestActionUsecase +State를 받아 최적의 협상 카드를 선택하는 Usecase +""" + +from typing import Dict, Optional +from ..domain.model.state import State +from ..domain.model.q_table import QTable +from ..domain.agents.policy import UCBPolicy +from ...integration.action_card_mapper import ActionCardMapper + + +class GetBestActionUsecase: + """ + 협상 추론 Usecase: State → action_id → card_id + + Q_Table은 추상적인 action_id만 반환하고, + ActionCardMapper를 통해 구체적인 card_id로 변환한다. + """ + + def __init__( + self, + q_table: QTable, + policy: UCBPolicy, + action_card_mapper: ActionCardMapper, + ): + """ + Args: + q_table: 학습된 Q-Table + policy: 액션 선택 정책 (UCB) + action_card_mapper: action_id ↔ card_id 매핑 + """ + self.q_table = q_table + self.policy = policy + self.action_card_mapper = action_card_mapper + + def execute(self, state: State, use_policy: bool = True) -> Dict: + """ + 현재 상태에서 최적의 협상 카드 선택 + + Args: + state: 현재 협상 상태 + use_policy: True면 Policy 사용 (UCB + 중복 방지) + False면 순수 Q-value 최대값만 사용 + + Returns: + { + 'action_id': 5, + 'card_id': 'no_5', + 'q_value': 0.85, + 'state_index': 12, + 'all_q_values': [0.1, 0.2, ..., 0.85, ...] # optional + } + + 사용 가능한 액션이 없으면: + { + 'action_id': None, + 'card_id': None, + 'q_value': None, + 'message': 'No available actions' + } + """ + # 1. State → state_index 변환 + state_index = state.to_index() + + # 2. Q-values 조회 + q_values = self.q_table.get_q_values_for_state(state_index) + + # 3. Action 선택 + if use_policy: + action_id = self.policy.select_action(state_index, q_values) + else: + # 순수 Q-value 최대값 + action_id = self.q_table.get_best_action(state_index) + + if action_id is None: + return { + "action_id": None, + "card_id": None, + "q_value": None, + "state_index": state_index, + "message": "No available actions (all actions used in this episode)", + } + + # 4. action_id → card_id 매핑 + card_id = self.action_card_mapper.get_card_id(action_id) + + if card_id is None: + return { + "action_id": action_id, + "card_id": None, + "q_value": float(q_values[action_id]), + "state_index": state_index, + "message": f"action_id {action_id} has no card mapping", + } + + # 5. Q-value 추출 + q_value = self.q_table.get_q_value(state_index, action_id) + + return { + "action_id": action_id, + "card_id": card_id, + "q_value": float(q_value), + "state_index": state_index, + "all_q_values": q_values.tolist(), # 디버깅용 + } + + def get_top_k_actions(self, state: State, k: int = 3) -> list: + """ + 상위 k개의 액션 추천 + + Args: + state: 현재 협상 상태 + k: 추천할 액션 개수 + + Returns: + [ + {'action_id': 5, 'card_id': 'no_5', 'q_value': 0.85}, + {'action_id': 3, 'card_id': 'no_3', 'q_value': 0.72}, + {'action_id': 1, 'card_id': 'no_1', 'q_value': 0.68} + ] + """ + state_index = state.to_index() + q_values = self.q_table.get_q_values_for_state(state_index) + + # Q-value 내림차순 정렬 + sorted_actions = sorted(enumerate(q_values), key=lambda x: x[1], reverse=True) + + # 상위 k개 추출 + top_k = sorted_actions[:k] + + results = [] + for action_id, q_value in top_k: + card_id = self.action_card_mapper.get_card_id(action_id) + results.append( + {"action_id": action_id, "card_id": card_id, "q_value": float(q_value)} + ) + + return results + + def reset_episode(self): + """ + 에피소드 종료 시 Policy 초기화 + (다음 협상 세션 시작 전 호출) + """ + self.policy.reset_episode() + + def get_available_actions(self) -> list: + """ + 현재 에피소드에서 아직 사용하지 않은 액션들 + + Returns: + [0, 2, 4, 5, ...] (사용하지 않은 action_id들) + """ + import numpy as np + + mask = self.policy.get_action_mask() + available_action_ids = np.where(mask > 0)[0].tolist() + return available_action_ids + + def get_available_cards(self) -> list: + """ + 현재 에피소드에서 아직 사용하지 않은 카드들 + + Returns: + [ + {'action_id': 0, 'card_id': 'no_0'}, + {'action_id': 2, 'card_id': 'no_2'}, + ... + ] + """ + available_action_ids = self.get_available_actions() + + results = [] + for action_id in available_action_ids: + card_id = self.action_card_mapper.get_card_id(action_id) + results.append({"action_id": action_id, "card_id": card_id}) + + return results diff --git a/src/negotiation_agent/Q_Table/usecase/train_offline_usecase.py b/src/negotiation_agent/Q_Table/usecase/train_offline_usecase.py new file mode 100644 index 0000000..edd11a7 --- /dev/null +++ b/src/negotiation_agent/Q_Table/usecase/train_offline_usecase.py @@ -0,0 +1,203 @@ +""" +TrainOfflineUsecase +저장된 Experience 데이터로 Q-Table을 오프라인 학습 +""" + +from typing import List, Optional +from ..domain.model.q_table import QTable +from ..domain.model.experience import Experience +from ..domain.model.visit_table import VisitTable +from ..domain.repository.experience_repository import ExperienceRepository + + +class TrainOfflineUsecase: + """ + 수집된 Experience 데이터를 사용하여 Q-Table을 오프라인으로 학습 + + 배치 학습 방식: + - 저장된 Experience를 로드 + - Q-Learning 알고리즘으로 Q-Table 업데이트 + - 여러 에포크 반복 가능 + """ + + def __init__( + self, + q_table: QTable, + experience_repository: ExperienceRepository, + visit_table: Optional[VisitTable] = None, + ): + """ + Args: + q_table: 학습할 Q-Table + experience_repository: Experience 저장소 + visit_table: 방문 횟수를 함께 추적할 N-Table (선택) + """ + self.q_table = q_table + self.experience_repository = experience_repository + self.visit_table = visit_table + + def train( + self, + filename: str = "experiences.jsonl", + epochs: int = 1, + batch_size: Optional[int] = None, + ) -> dict: + """ + 저장된 Experience로 Q-Table 학습 + + Args: + filename: Experience 파일명 + epochs: 학습 반복 횟수 + batch_size: 배치 크기 (None이면 전체 데이터 사용) + + Returns: + { + 'total_experiences': 1000, + 'epochs': 10, + 'updates': 10000, + 'avg_loss': 0.05 + } + """ + # Experience 로드 + experiences = self.experience_repository.load_all(filename) + + if not experiences: + return { + "total_experiences": 0, + "epochs": 0, + "updates": 0, + "avg_loss": 0.0, + "message": "No experiences found", + } + + total_updates = 0 + total_loss = 0.0 + + # 에포크 반복 + for epoch in range(epochs): + epoch_loss = 0.0 + + # 배치 처리 + if batch_size: + for i in range(0, len(experiences), batch_size): + batch = experiences[i : i + batch_size] + loss = self._train_batch(batch) + epoch_loss += loss + total_updates += len(batch) + else: + # 전체 데이터 한번에 + loss = self._train_batch(experiences) + epoch_loss += loss + total_updates += len(experiences) + + total_loss += epoch_loss + + avg_loss = total_loss / total_updates if total_updates > 0 else 0.0 + + return { + "total_experiences": len(experiences), + "epochs": epochs, + "updates": total_updates, + "avg_loss": avg_loss, + } + + def _train_batch(self, experiences: List[Experience]) -> float: + """ + 배치 Experience로 Q-Table 업데이트 + + Args: + experiences: Experience 리스트 + + Returns: + 평균 손실값 + """ + total_loss = 0.0 + + for exp in experiences: + # State → state_index + state_index = exp.state.to_index() + + # Q-value 업데이트 전 값 + old_q = self.q_table.get_q_value(state_index, exp.action_id) + + # Q-Table 업데이트 + if exp.next_state: + next_state_index = exp.next_state.to_index() + self.q_table.update( + state_index=state_index, + action_id=exp.action_id, + reward=exp.reward, + next_state_index=next_state_index, + done=exp.done, + ) + else: + # 종료 상태 + self.q_table.update( + state_index=state_index, + action_id=exp.action_id, + reward=exp.reward, + next_state_index=None, + done=True, + ) + + if self.visit_table is not None: + self.visit_table.increment(state_index, exp.action_id) + + # Q-value 업데이트 후 값 + new_q = self.q_table.get_q_value(state_index, exp.action_id) + + # 손실 계산 (업데이트 크기) + loss = abs(new_q - old_q) + total_loss += loss + + return total_loss / len(experiences) if experiences else 0.0 + + def train_by_episodes( + self, + episode_ids: List[str], + filename: str = "experiences.jsonl", + epochs: int = 1, + ) -> dict: + """ + 특정 에피소드들만 선택하여 학습 + + Args: + episode_ids: 학습할 에피소드 ID 리스트 + filename: Experience 파일명 + epochs: 학습 반복 횟수 + + Returns: + 학습 결과 + """ + # 해당 에피소드의 Experience만 로드 + all_experiences = self.experience_repository.load_all(filename) + filtered_experiences = [ + exp for exp in all_experiences if exp.episode_id in episode_ids + ] + + if not filtered_experiences: + return { + "total_experiences": 0, + "episodes": 0, + "epochs": 0, + "updates": 0, + "avg_loss": 0.0, + "message": "No matching episodes found", + } + + total_updates = 0 + total_loss = 0.0 + + for epoch in range(epochs): + loss = self._train_batch(filtered_experiences) + total_loss += loss * len(filtered_experiences) + total_updates += len(filtered_experiences) + + return { + "total_experiences": len(filtered_experiences), + "episodes": len(episode_ids), + "epochs": epochs, + "updates": total_updates, + "avg_loss": total_loss / total_updates if total_updates > 0 else 0.0, + } + diff --git a/src/negotiation_agent/integration/action_card_mapper.py b/src/negotiation_agent/integration/action_card_mapper.py new file mode 100644 index 0000000..c7f9e7a --- /dev/null +++ b/src/negotiation_agent/integration/action_card_mapper.py @@ -0,0 +1,152 @@ +""" +Action-Card 매핑 레이어 +Q_Table의 action_id와 Card_Management의 card_id를 연결 +""" + +import json +from pathlib import Path +from typing import Dict, Optional + + +class ActionCardMapper: + """ + action_id (0, 1, 2, ...) ↔ card_id ("no_0", "no_1", ...) + + Q_Table은 추상적인 action_id만 다루고, + Card_Management는 구체적인 card_id만 다룬다. + 이 클래스가 둘을 연결하는 단일 책임을 가진다. + + 매핑은 JSON 파일에서 로드되며, 카드 추가/삭제 시 수동으로 업데이트된다. + """ + + def __init__(self, mapping_file: Optional[str] = None): + """ + Args: + mapping_file: JSON 매핑 파일 경로 + (기본: integration/data/action_card_mapping.json) + """ + if mapping_file is None: + current_dir = Path(__file__).parent + mapping_file = current_dir / "data" / "action_card_mapping.json" + + self.mapping_file = Path(mapping_file) + self.action_to_card: Dict[int, str] = {} + self.card_to_action: Dict[str, int] = {} + + self._load_mapping() + + def _load_mapping(self): + """JSON 파일에서 매핑 로드""" + if not self.mapping_file.exists(): + raise FileNotFoundError( + f"Mapping file not found: {self.mapping_file}\n" + f"Please create action_card_mapping.json first." + ) + + with open(self.mapping_file, "r", encoding="utf-8") as f: + data = json.load(f) + + # action_to_card: {"0": "no_0", "1": "no_1", ...} + # JSON keys are strings, convert to int + self.action_to_card = {int(k): v for k, v in data["action_to_card"].items()} + + # card_to_action: reverse mapping for quick lookup + self.card_to_action = {v: int(k) for k, v in data["action_to_card"].items()} + + def get_card_id(self, action_id: int) -> Optional[str]: + """ + action_id → card_id 변환 + + Args: + action_id: Q_Table에서 사용하는 액션 ID (0, 1, 2, ...) + + Returns: + card_id: Card_Management에서 사용하는 카드 ID ("no_0", "no_1", ...) + 매핑이 없으면 None + + Example: + >>> mapper.get_card_id(5) + 'no_5' + """ + return self.action_to_card.get(action_id) + + def get_action_id(self, card_id: str) -> Optional[int]: + """ + card_id → action_id 변환 + + Args: + card_id: Card_Management에서 사용하는 카드 ID ("no_0", "no_1", ...) + + Returns: + action_id: Q_Table에서 사용하는 액션 ID (0, 1, 2, ...) + 매핑이 없으면 None + + Example: + >>> mapper.get_action_id('no_5') + 5 + """ + return self.card_to_action.get(card_id) + + def get_action_space_size(self) -> int: + """ + 현재 매핑된 액션 개수 (= 카드 개수) + + Returns: + 액션 공간 크기 + + Example: + >>> mapper.get_action_space_size() + 21 # 21개의 카드 + """ + return len(self.action_to_card) + + def get_all_mappings(self) -> Dict[int, str]: + """ + 전체 매핑 반환 + + Returns: + {action_id: card_id} 딕셔너리 + + Example: + >>> mapper.get_all_mappings() + {0: 'no_0', 1: 'no_1', ..., 20: 'no_20'} + """ + return self.action_to_card.copy() + + def get_all_action_ids(self) -> list: + """ + 모든 action_id 리스트 반환 (정렬됨) + + Returns: + [0, 1, 2, ..., 20] + """ + return sorted(self.action_to_card.keys()) + + def get_all_card_ids(self) -> list: + """ + 모든 card_id 리스트 반환 (action_id 순서) + + Returns: + ['no_0', 'no_1', ..., 'no_20'] + """ + return [self.action_to_card[i] for i in sorted(self.action_to_card.keys())] + + def is_valid_action_id(self, action_id: int) -> bool: + """action_id가 매핑에 존재하는지 확인""" + return action_id in self.action_to_card + + def is_valid_card_id(self, card_id: str) -> bool: + """card_id가 매핑에 존재하는지 확인""" + return card_id in self.card_to_action + + def reload(self): + """매핑 파일 다시 로드 (런타임 중 파일이 변경된 경우)""" + self._load_mapping() + + def __repr__(self) -> str: + return ( + f"ActionCardMapper(" + f"action_space_size={self.get_action_space_size()}, " + f"mapping_file={self.mapping_file}" + f")" + ) diff --git a/src/negotiation_agent/integration/data/action_card_mapping.json b/src/negotiation_agent/integration/data/action_card_mapping.json new file mode 100644 index 0000000..2278b88 --- /dev/null +++ b/src/negotiation_agent/integration/data/action_card_mapping.json @@ -0,0 +1,13 @@ +{ + "action_to_card": { + "0": "no_1", + "1": "no_2", + "2": "no_3", + "3": "no_5", + "4": "no_6", + "5": "no_8", + "6": "no_11", + "7": "no_13", + "8": "no_14" + } +} \ No newline at end of file diff --git a/train.py b/train.py new file mode 100644 index 0000000..eb142be --- /dev/null +++ b/train.py @@ -0,0 +1,90 @@ +import sys +import os +import argparse +import numpy as np +from pathlib import Path + +# Add project root to path to allow imports from src +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from src.negotiation_agent.Q_Table.domain.model.q_table import QTable +from src.negotiation_agent.Q_Table.domain.model.visit_table import VisitTable +from src.negotiation_agent.Q_Table.domain.repository.experience_repository import ExperienceRepository +from src.negotiation_agent.Q_Table.infra.repository.model_repository import ModelRepository +from src.negotiation_agent.Q_Table.usecase.train_offline_usecase import TrainOfflineUsecase +from src.negotiation_agent.integration.action_card_mapper import ActionCardMapper + +def main(): + parser = argparse.ArgumentParser(description="Train Q-Table Agent") + parser.add_argument("--epochs", type=int, default=10, help="Number of epochs") + parser.add_argument("--batch-size", type=int, default=32, help="Batch size") + parser.add_argument("--lr", type=float, default=0.1, help="Learning rate (only for new tables)") + parser.add_argument("--gamma", type=float, default=0.9, help="Discount factor (only for new tables)") + parser.add_argument("--data-file", type=str, default="experiences.jsonl", help="Experience file name inside data/experiences/") + args = parser.parse_args() + + print("=== KTC V2 Agent Training ===") + + # 1. Config + try: + mapper = ActionCardMapper() + ACTION_SIZE = mapper.get_action_space_size() + except Exception as e: + # Fallback if specific file error + print(f"Warning: Could not load Action mapping ({e}). Defaulting to 21.") + ACTION_SIZE = 21 + + STATE_SIZE = 162 + print(f"Configuration: State Size={STATE_SIZE}, Action Size={ACTION_SIZE}") + + # 2. Repository & Models + model_repo = ModelRepository() + + print("Loading models...") + q_table, visit_table = model_repo.load() + + if q_table is None: + print("[Info] No existing Q-Table found. Creating new one.") + q_table = QTable( + state_space_size=STATE_SIZE, + action_space_size=ACTION_SIZE, + learning_rate=args.lr, + discount_factor=args.gamma + ) + else: + print("[Info] Loaded existing Q-Table.") + + if visit_table is None: + print("[Info] No existing VisitTable found. Creating new one.") + visit_table = VisitTable(STATE_SIZE, ACTION_SIZE) + else: + print("[Info] Loaded existing VisitTable.") + + # 3. Data Repository + exp_repo = ExperienceRepository() + + # Check if data file exists + data_path = exp_repo.data_dir / args.data_file + if not data_path.exists(): + print(f"[Warning] Experience file not found at: {data_path}") + print("Please ensure the data file is synchronized from the main server.") + return + + # 4. Usecase + trainer = TrainOfflineUsecase(q_table, exp_repo, visit_table) + + # 5. Train + print(f"Starting training for {args.epochs} epochs with batch size {args.batch_size}...") + result = trainer.train(filename=args.data_file, epochs=args.epochs, batch_size=args.batch_size) + + print("\nTraining Result:") + for k, v in result.items(): + print(f" {k}: {v}") + + # 6. Save + print("Saving models...") + model_repo.save(q_table, visit_table) + print("Done.") + +if __name__ == "__main__": + main()