import numpy as np import random import os class QLearningAgent: """Q-Table을 기반으로 행동하고 학습하는 에이전트""" def __init__(self, state_dims, action_size, learning_rate, gamma, epsilon): self.state_dims = state_dims # [4, 3, 3] self.action_size = action_size self.lr = learning_rate self.gamma = gamma self.epsilon = epsilon # Q-Table 초기화: 36x9 크기의 테이블을 0으로 초기화 num_states = np.prod(state_dims) # 4 * 3 * 3 = 36 self.q_table = np.zeros((num_states, action_size)) def _state_to_index(self, state): """MultiDiscrete 상태 [s0, s1, s2]를 단일 정수 인덱스로 변환""" idx = ( state[0] * (self.state_dims[1] * self.state_dims[2]) + state[1] * self.state_dims[2] + state[2] ) return int(idx) def get_action(self, state): """Epsilon-Greedy 정책에 따라 행동 선택""" if random.uniform(0, 1) < self.epsilon: return random.randint(0, self.action_size - 1) # 탐험 (무작위 행동) else: state_idx = self._state_to_index(state) return np.argmax(self.q_table[state_idx, :]) # 활용 (Q값이 가장 높은 행동) # =================================================================== # ## 3. 학습 알고리즘 (Learning Algorithm): Q-Table 업데이트 규칙 # =================================================================== def learn(self, state, action, reward, next_state): """경험 데이터를 바탕으로 Q-Table을 업데이트 (Q-러닝 공식)""" state_idx = self._state_to_index(state) next_state_idx = self._state_to_index(next_state) old_value = self.q_table[state_idx, action] next_max = np.max(self.q_table[next_state_idx, :]) # Bellman Equation new_value = old_value + self.lr * (reward + self.gamma * next_max - old_value) self.q_table[state_idx, action] = new_value def save_q_table(self, file_path): """Q-Table을 파일로 저장합니다.""" np.save(file_path, self.q_table) print(f"Q-Table saved to {file_path}") def load_q_table(self, file_path): """파일로부터 Q-Table을 불러옵니다.""" if os.path.exists(file_path): self.q_table = np.load(file_path) print(f"Q-Table loaded from {file_path}") else: print(f"Error: No Q-Table found at {file_path}")