64 lines
2.5 KiB
Python
64 lines
2.5 KiB
Python
import numpy as np
|
|
import random
|
|
import os
|
|
|
|
|
|
class QLearningAgent:
|
|
"""Q-Table을 기반으로 행동하고 학습하는 에이전트"""
|
|
|
|
def __init__(self, state_dims, action_size, learning_rate, gamma, epsilon):
|
|
self.state_dims = state_dims # [4, 3, 3]
|
|
self.action_size = action_size
|
|
self.lr = learning_rate
|
|
self.gamma = gamma
|
|
self.epsilon = epsilon
|
|
|
|
# Q-Table 초기화: 36x9 크기의 테이블을 0으로 초기화
|
|
num_states = np.prod(state_dims) # 4 * 3 * 3 = 36
|
|
self.q_table = np.zeros((num_states, action_size))
|
|
|
|
def _state_to_index(self, state):
|
|
"""MultiDiscrete 상태 [s0, s1, s2]를 단일 정수 인덱스로 변환"""
|
|
idx = (
|
|
state[0] * (self.state_dims[1] * self.state_dims[2])
|
|
+ state[1] * self.state_dims[2]
|
|
+ state[2]
|
|
)
|
|
return int(idx)
|
|
|
|
def get_action(self, state):
|
|
"""Epsilon-Greedy 정책에 따라 행동 선택"""
|
|
if random.uniform(0, 1) < self.epsilon:
|
|
return random.randint(0, self.action_size - 1) # 탐험 (무작위 행동)
|
|
else:
|
|
state_idx = self._state_to_index(state)
|
|
return np.argmax(self.q_table[state_idx, :]) # 활용 (Q값이 가장 높은 행동)
|
|
|
|
# ===================================================================
|
|
# ## 3. 학습 알고리즘 (Learning Algorithm): Q-Table 업데이트 규칙
|
|
# ===================================================================
|
|
def learn(self, state, action, reward, next_state):
|
|
"""경험 데이터를 바탕으로 Q-Table을 업데이트 (Q-러닝 공식)"""
|
|
state_idx = self._state_to_index(state)
|
|
next_state_idx = self._state_to_index(next_state)
|
|
|
|
old_value = self.q_table[state_idx, action]
|
|
next_max = np.max(self.q_table[next_state_idx, :])
|
|
|
|
# Bellman Equation
|
|
new_value = old_value + self.lr * (reward + self.gamma * next_max - old_value)
|
|
self.q_table[state_idx, action] = new_value
|
|
|
|
def save_q_table(self, file_path):
|
|
"""Q-Table을 파일로 저장합니다."""
|
|
np.save(file_path, self.q_table)
|
|
print(f"Q-Table saved to {file_path}")
|
|
|
|
def load_q_table(self, file_path):
|
|
"""파일로부터 Q-Table을 불러옵니다."""
|
|
if os.path.exists(file_path):
|
|
self.q_table = np.load(file_path)
|
|
print(f"Q-Table loaded from {file_path}")
|
|
else:
|
|
print(f"Error: No Q-Table found at {file_path}")
|