KT_Q_Table/negotiation_agent/agent.py

64 lines
2.5 KiB
Python

import numpy as np
import random
import os
class QLearningAgent:
"""Q-Table을 기반으로 행동하고 학습하는 에이전트"""
def __init__(self, state_dims, action_size, learning_rate, gamma, epsilon):
self.state_dims = state_dims # [4, 3, 3]
self.action_size = action_size
self.lr = learning_rate
self.gamma = gamma
self.epsilon = epsilon
# Q-Table 초기화: 36x9 크기의 테이블을 0으로 초기화
num_states = np.prod(state_dims) # 4 * 3 * 3 = 36
self.q_table = np.zeros((num_states, action_size))
def _state_to_index(self, state):
"""MultiDiscrete 상태 [s0, s1, s2]를 단일 정수 인덱스로 변환"""
idx = (
state[0] * (self.state_dims[1] * self.state_dims[2])
+ state[1] * self.state_dims[2]
+ state[2]
)
return int(idx)
def get_action(self, state):
"""Epsilon-Greedy 정책에 따라 행동 선택"""
if random.uniform(0, 1) < self.epsilon:
return random.randint(0, self.action_size - 1) # 탐험 (무작위 행동)
else:
state_idx = self._state_to_index(state)
return np.argmax(self.q_table[state_idx, :]) # 활용 (Q값이 가장 높은 행동)
# ===================================================================
# ## 3. 학습 알고리즘 (Learning Algorithm): Q-Table 업데이트 규칙
# ===================================================================
def learn(self, state, action, reward, next_state):
"""경험 데이터를 바탕으로 Q-Table을 업데이트 (Q-러닝 공식)"""
state_idx = self._state_to_index(state)
next_state_idx = self._state_to_index(next_state)
old_value = self.q_table[state_idx, action]
next_max = np.max(self.q_table[next_state_idx, :])
# Bellman Equation
new_value = old_value + self.lr * (reward + self.gamma * next_max - old_value)
self.q_table[state_idx, action] = new_value
def save_q_table(self, file_path):
"""Q-Table을 파일로 저장합니다."""
np.save(file_path, self.q_table)
print(f"Q-Table saved to {file_path}")
def load_q_table(self, file_path):
"""파일로부터 Q-Table을 불러옵니다."""
if os.path.exists(file_path):
self.q_table = np.load(file_path)
print(f"Q-Table loaded from {file_path}")
else:
print(f"Error: No Q-Table found at {file_path}")