91 lines
3.3 KiB
Python
91 lines
3.3 KiB
Python
import sys
|
|
import os
|
|
import argparse
|
|
import numpy as np
|
|
from pathlib import Path
|
|
|
|
# Add project root to path to allow imports from src
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
from src.negotiation_agent.Q_Table.domain.model.q_table import QTable
|
|
from src.negotiation_agent.Q_Table.domain.model.visit_table import VisitTable
|
|
from src.negotiation_agent.Q_Table.domain.repository.experience_repository import ExperienceRepository
|
|
from src.negotiation_agent.Q_Table.infra.repository.model_repository import ModelRepository
|
|
from src.negotiation_agent.Q_Table.usecase.train_offline_usecase import TrainOfflineUsecase
|
|
from src.negotiation_agent.integration.action_card_mapper import ActionCardMapper
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Train Q-Table Agent")
|
|
parser.add_argument("--epochs", type=int, default=10, help="Number of epochs")
|
|
parser.add_argument("--batch-size", type=int, default=32, help="Batch size")
|
|
parser.add_argument("--lr", type=float, default=0.1, help="Learning rate (only for new tables)")
|
|
parser.add_argument("--gamma", type=float, default=0.9, help="Discount factor (only for new tables)")
|
|
parser.add_argument("--data-file", type=str, default="experiences.jsonl", help="Experience file name inside data/experiences/")
|
|
args = parser.parse_args()
|
|
|
|
print("=== KTC V2 Agent Training ===")
|
|
|
|
# 1. Config
|
|
try:
|
|
mapper = ActionCardMapper()
|
|
ACTION_SIZE = mapper.get_action_space_size()
|
|
except Exception as e:
|
|
# Fallback if specific file error
|
|
print(f"Warning: Could not load Action mapping ({e}). Defaulting to 21.")
|
|
ACTION_SIZE = 21
|
|
|
|
STATE_SIZE = 162
|
|
print(f"Configuration: State Size={STATE_SIZE}, Action Size={ACTION_SIZE}")
|
|
|
|
# 2. Repository & Models
|
|
model_repo = ModelRepository()
|
|
|
|
print("Loading models...")
|
|
q_table, visit_table = model_repo.load()
|
|
|
|
if q_table is None:
|
|
print("[Info] No existing Q-Table found. Creating new one.")
|
|
q_table = QTable(
|
|
state_space_size=STATE_SIZE,
|
|
action_space_size=ACTION_SIZE,
|
|
learning_rate=args.lr,
|
|
discount_factor=args.gamma
|
|
)
|
|
else:
|
|
print("[Info] Loaded existing Q-Table.")
|
|
|
|
if visit_table is None:
|
|
print("[Info] No existing VisitTable found. Creating new one.")
|
|
visit_table = VisitTable(STATE_SIZE, ACTION_SIZE)
|
|
else:
|
|
print("[Info] Loaded existing VisitTable.")
|
|
|
|
# 3. Data Repository
|
|
exp_repo = ExperienceRepository()
|
|
|
|
# Check if data file exists
|
|
data_path = exp_repo.data_dir / args.data_file
|
|
if not data_path.exists():
|
|
print(f"[Warning] Experience file not found at: {data_path}")
|
|
print("Please ensure the data file is synchronized from the main server.")
|
|
return
|
|
|
|
# 4. Usecase
|
|
trainer = TrainOfflineUsecase(q_table, exp_repo, visit_table)
|
|
|
|
# 5. Train
|
|
print(f"Starting training for {args.epochs} epochs with batch size {args.batch_size}...")
|
|
result = trainer.train(filename=args.data_file, epochs=args.epochs, batch_size=args.batch_size)
|
|
|
|
print("\nTraining Result:")
|
|
for k, v in result.items():
|
|
print(f" {k}: {v}")
|
|
|
|
# 6. Save
|
|
print("Saving models...")
|
|
model_repo.save(q_table, visit_table)
|
|
print("Done.")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|