KTC_v2_Negotiation_agent_train/train.py

91 lines
3.3 KiB
Python

import sys
import os
import argparse
import numpy as np
from pathlib import Path
# Add project root to path to allow imports from src
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from src.negotiation_agent.Q_Table.domain.model.q_table import QTable
from src.negotiation_agent.Q_Table.domain.model.visit_table import VisitTable
from src.negotiation_agent.Q_Table.domain.repository.experience_repository import ExperienceRepository
from src.negotiation_agent.Q_Table.infra.repository.model_repository import ModelRepository
from src.negotiation_agent.Q_Table.usecase.train_offline_usecase import TrainOfflineUsecase
from src.negotiation_agent.integration.action_card_mapper import ActionCardMapper
def main():
parser = argparse.ArgumentParser(description="Train Q-Table Agent")
parser.add_argument("--epochs", type=int, default=10, help="Number of epochs")
parser.add_argument("--batch-size", type=int, default=32, help="Batch size")
parser.add_argument("--lr", type=float, default=0.1, help="Learning rate (only for new tables)")
parser.add_argument("--gamma", type=float, default=0.9, help="Discount factor (only for new tables)")
parser.add_argument("--data-file", type=str, default="experiences.jsonl", help="Experience file name inside data/experiences/")
args = parser.parse_args()
print("=== KTC V2 Agent Training ===")
# 1. Config
try:
mapper = ActionCardMapper()
ACTION_SIZE = mapper.get_action_space_size()
except Exception as e:
# Fallback if specific file error
print(f"Warning: Could not load Action mapping ({e}). Defaulting to 21.")
ACTION_SIZE = 21
STATE_SIZE = 162
print(f"Configuration: State Size={STATE_SIZE}, Action Size={ACTION_SIZE}")
# 2. Repository & Models
model_repo = ModelRepository()
print("Loading models...")
q_table, visit_table = model_repo.load()
if q_table is None:
print("[Info] No existing Q-Table found. Creating new one.")
q_table = QTable(
state_space_size=STATE_SIZE,
action_space_size=ACTION_SIZE,
learning_rate=args.lr,
discount_factor=args.gamma
)
else:
print("[Info] Loaded existing Q-Table.")
if visit_table is None:
print("[Info] No existing VisitTable found. Creating new one.")
visit_table = VisitTable(STATE_SIZE, ACTION_SIZE)
else:
print("[Info] Loaded existing VisitTable.")
# 3. Data Repository
exp_repo = ExperienceRepository()
# Check if data file exists
data_path = exp_repo.data_dir / args.data_file
if not data_path.exists():
print(f"[Warning] Experience file not found at: {data_path}")
print("Please ensure the data file is synchronized from the main server.")
return
# 4. Usecase
trainer = TrainOfflineUsecase(q_table, exp_repo, visit_table)
# 5. Train
print(f"Starting training for {args.epochs} epochs with batch size {args.batch_size}...")
result = trainer.train(filename=args.data_file, epochs=args.epochs, batch_size=args.batch_size)
print("\nTraining Result:")
for k, v in result.items():
print(f" {k}: {v}")
# 6. Save
print("Saving models...")
model_repo.save(q_table, visit_table)
print("Done.")
if __name__ == "__main__":
main()