import sys import os import argparse import numpy as np from pathlib import Path # Add project root to path to allow imports from src sys.path.append(os.path.dirname(os.path.abspath(__file__))) from src.negotiation_agent.Q_Table.domain.model.q_table import QTable from src.negotiation_agent.Q_Table.domain.model.visit_table import VisitTable from src.negotiation_agent.Q_Table.domain.repository.experience_repository import ExperienceRepository from src.negotiation_agent.Q_Table.infra.repository.model_repository import ModelRepository from src.negotiation_agent.Q_Table.usecase.train_offline_usecase import TrainOfflineUsecase from src.negotiation_agent.integration.action_card_mapper import ActionCardMapper def main(): parser = argparse.ArgumentParser(description="Train Q-Table Agent") parser.add_argument("--epochs", type=int, default=10, help="Number of epochs") parser.add_argument("--batch-size", type=int, default=32, help="Batch size") parser.add_argument("--lr", type=float, default=0.1, help="Learning rate (only for new tables)") parser.add_argument("--gamma", type=float, default=0.9, help="Discount factor (only for new tables)") parser.add_argument("--data-file", type=str, default="experiences.jsonl", help="Experience file name inside data/experiences/") args = parser.parse_args() print("=== KTC V2 Agent Training ===") # 1. Config try: mapper = ActionCardMapper() ACTION_SIZE = mapper.get_action_space_size() except Exception as e: # Fallback if specific file error print(f"Warning: Could not load Action mapping ({e}). Defaulting to 21.") ACTION_SIZE = 21 STATE_SIZE = 162 print(f"Configuration: State Size={STATE_SIZE}, Action Size={ACTION_SIZE}") # 2. Repository & Models model_repo = ModelRepository() print("Loading models...") q_table, visit_table = model_repo.load() if q_table is None: print("[Info] No existing Q-Table found. Creating new one.") q_table = QTable( state_space_size=STATE_SIZE, action_space_size=ACTION_SIZE, learning_rate=args.lr, discount_factor=args.gamma ) else: print("[Info] Loaded existing Q-Table.") if visit_table is None: print("[Info] No existing VisitTable found. Creating new one.") visit_table = VisitTable(STATE_SIZE, ACTION_SIZE) else: print("[Info] Loaded existing VisitTable.") # 3. Data Repository exp_repo = ExperienceRepository() # Check if data file exists data_path = exp_repo.data_dir / args.data_file if not data_path.exists(): print(f"[Warning] Experience file not found at: {data_path}") print("Please ensure the data file is synchronized from the main server.") return # 4. Usecase trainer = TrainOfflineUsecase(q_table, exp_repo, visit_table) # 5. Train print(f"Starting training for {args.epochs} epochs with batch size {args.batch_size}...") result = trainer.train(filename=args.data_file, epochs=args.epochs, batch_size=args.batch_size) print("\nTraining Result:") for k, v in result.items(): print(f" {k}: {v}") # 6. Save print("Saving models...") model_repo.save(q_table, visit_table) print("Done.") if __name__ == "__main__": main()