KT_Q_Table/usecases/train_agent_usecase.py

41 lines
1.4 KiB
Python

import numpy as np
from typing import Type
from negotiation_agent.environment import NegotiationEnv
from negotiation_agent.agent import QLearningAgent
import config
class TrainAgentUseCase:
"""'에이전트 훈련'이라는 비즈니스 로직을 담당하는 클래스"""
def __init__(self, env: NegotiationEnv, agent: QLearningAgent):
self.env = env
self.agent = agent
def execute(self):
"""유스케이스를 실행합니다."""
print("--- [UseCase] 학습 시작 ---")
for episode in range(config.TOTAL_EPISODES):
state, info = self.env.reset()
terminated = False
while not terminated:
action = self.agent.get_action(state)
next_state, reward, terminated, truncated, info = self.env.step(action)
self.agent.learn(state, action, reward, next_state)
state = next_state
# Epsilon 값 업데이트 로직
self.agent.epsilon = config.EPSILON_END + (
config.EPSILON_START - config.EPSILON_END
) * np.exp(-config.EPSILON_DECAY_RATE * episode)
if (episode + 1) % 1000 == 0:
print(
f"Episode {episode + 1}/{config.TOTAL_EPISODES} | Epsilon: {self.agent.epsilon:.4f}"
)
print("\n✅ [UseCase] 학습 완료!")
self.agent.save_q_table(config.Q_TABLE_SAVE_PATH)
self.env.close()