41 lines
1.4 KiB
Python
41 lines
1.4 KiB
Python
import numpy as np
|
|
from typing import Type
|
|
from negotiation_agent.environment import NegotiationEnv
|
|
from negotiation_agent.agent import QLearningAgent
|
|
import config
|
|
|
|
|
|
class TrainAgentUseCase:
|
|
"""'에이전트 훈련'이라는 비즈니스 로직을 담당하는 클래스"""
|
|
|
|
def __init__(self, env: NegotiationEnv, agent: QLearningAgent):
|
|
self.env = env
|
|
self.agent = agent
|
|
|
|
def execute(self):
|
|
"""유스케이스를 실행합니다."""
|
|
print("--- [UseCase] 학습 시작 ---")
|
|
for episode in range(config.TOTAL_EPISODES):
|
|
state, info = self.env.reset()
|
|
terminated = False
|
|
|
|
while not terminated:
|
|
action = self.agent.get_action(state)
|
|
next_state, reward, terminated, truncated, info = self.env.step(action)
|
|
self.agent.learn(state, action, reward, next_state)
|
|
state = next_state
|
|
|
|
# Epsilon 값 업데이트 로직
|
|
self.agent.epsilon = config.EPSILON_END + (
|
|
config.EPSILON_START - config.EPSILON_END
|
|
) * np.exp(-config.EPSILON_DECAY_RATE * episode)
|
|
|
|
if (episode + 1) % 1000 == 0:
|
|
print(
|
|
f"Episode {episode + 1}/{config.TOTAL_EPISODES} | Epsilon: {self.agent.epsilon:.4f}"
|
|
)
|
|
|
|
print("\n✅ [UseCase] 학습 완료!")
|
|
self.agent.save_q_table(config.Q_TABLE_SAVE_PATH)
|
|
self.env.close()
|