refactor: 프로젝트 구조 개선
- 기존 envs/ 디렉토리를 negotiation_agent/로 이동 및 리팩토링 - config.py를 configs/ 디렉토리로 이동 및 yaml 형식으로 변경 - Offline_RL.md를 README.md로 통합 - 불필요한 train.py 제거main
parent
0c2ec47c6b
commit
26442ca9c1
|
|
@ -1,80 +0,0 @@
|
|||
**오프라인 RL 프로젝트 구조**
|
||||
|
||||
가장 큰 변화는 데이터셋을 관리하는 `datasets/` 디렉토리와 데이터 수집을 위한 별도 스크립트(`data_collector.py`)의 추가입니다.
|
||||
|
||||
`my_rl_project/
|
||||
├── configs/
|
||||
│ └── offline_env_config.yaml
|
||||
├── datasets/
|
||||
│ └── collected_data.h5
|
||||
├── envs/
|
||||
│ ├── __init__.py
|
||||
│ └── my_custom_env.py
|
||||
├── agents/
|
||||
│ ├── __init__.py
|
||||
│ └── offline_agent.py
|
||||
├── data_collector.py
|
||||
└── train_offline.py`
|
||||
|
||||
---
|
||||
|
||||
### **구성 요소별 변경 사항 및 관리 방안**
|
||||
|
||||
### **1. `datasets/` (신규)**
|
||||
|
||||
- **Purpose**: 사전 수집된 고정 데이터셋을 저장하는 디렉토리입니다. 오프라인 학습의 가장 핵심적인 자산입니다.
|
||||
- **Format**: 대용량 데이터를 효율적으로 다루기 위해 HDF5(`.h5`), Parquet, 또는 NumPy Archive(`.npz`) 형식을 사용하는 것이 일반적입니다.
|
||||
- **Content**: `(observation, action, reward, next_observation, terminated)` 튜플들의 집합으로 구성된 테이블 또는 배열.
|
||||
|
||||
### **2. `data_collector.py` (신규)**
|
||||
|
||||
- **Purpose**: 오프라인 학습에 사용할 데이터셋을 생성하는 스크립트입니다. 이 스크립트는 **온라인(online) 환경에서 실행**됩니다.
|
||||
- **Implementation**:
|
||||
- `MyCustomEnv`를 인스턴스화합니다.
|
||||
- 미리 정의된 정책(예: 랜덤 정책, 전문가 정책, 기존에 학습된 온라인 RL 정책)을 사용하여 환경과 상호작용(`env.step()`)합니다.
|
||||
- 수집된 `(s, a, r, s', d)` 튜플들을 `datasets/` 디렉토리에 지정된 파일 형식으로 저장합니다.
|
||||
- **데이터 수집 단계와 오프라인 학습 단계를 명확히 분리**하는 역할을 합니다.
|
||||
|
||||
### **3. `configs/offline_env_config.yaml`**
|
||||
|
||||
- **Purpose**: 오프라인 학습에 특화된 설정을 관리합니다.
|
||||
- **Key Definitions (변경점)**:
|
||||
- **`dataset_params`**:
|
||||
- `path`: 사용할 데이터셋 파일의 경로 (`datasets/collected_data.h5`).
|
||||
- `batch_size`: 학습 시 데이터셋에서 샘플링할 배치의 크기.
|
||||
- **`agent_params`**: 오프라인 RL 알고리즘(예: CQL, BCQ)에 특화된 하이퍼파라미터를 정의.
|
||||
|
||||
### **4. `envs/my_custom_env.py`**
|
||||
|
||||
- **Purpose (역할 변경)**:
|
||||
- **학습 단계**: 더 이상 학습 과정에서 실시간으로 상호작용하지 않습니다.
|
||||
- **평가(Evaluation) 단계**: **오프라인 학습이 완료된 후, 학습된 정책의 성능을 검증하는 용도**로 사용됩니다. 즉, 학습된 에이전트를 실제 환경에서 실행해보기 위한 '테스트베드'의 역할이 주가 됩니다.
|
||||
|
||||
### **5. `agents/offline_agent.py`**
|
||||
|
||||
- **Purpose**: 오프라인 데이터셋만으로 정책을 학습하는 알고리즘을 구현합니다.
|
||||
- **Key Definitions (변경점)**:
|
||||
- `class OfflineAgent`: CQL, IQL, BCQ 등 오프라인 RL 알고리즘을 구현.
|
||||
- `__init__(...)`: 온라인 에이전트와 유사하게 모델과 하이퍼파라미터를 초기화.
|
||||
- `get_action(self, state)`: 역할은 동일하나, 주로 학습 후 **평가 단계**에서 사용됩니다.
|
||||
- `learn(self, batch)`:
|
||||
- 메서드의 인자가 단일 경험이 아닌, \**데이터셋에서 샘플링된 `batch`*가 됩니다.
|
||||
- 이 `batch` 데이터를 사용하여 오프라인 RL 알고리즘의 손실 함수를 계산하고 정책을 업데이트합니다. **환경과의 상호작용(`env.step()`)이 전혀 없습니다.**
|
||||
- **Policy & Reward Function Management**:
|
||||
- **정책(Policy)**: 이 파일 내에서 관리되며, 주어진 데이터셋 내의 행동 분포를 모방하거나 보수적으로(conservative) 개선하는 방향으로 학습됩니다.
|
||||
- **보상 함수(Reward Function)**: 에이전트는 보상 함수를 직접 참조하지 않습니다. 대신, 데이터셋에 **기록된 `reward` 값**을 학습의 유일한 감독(supervision) 신호로 사용합니다.
|
||||
|
||||
### **6. `train_offline.py`**
|
||||
|
||||
- **Purpose**: 오프라인 데이터셋을 로드하여 에이전트를 학습시키는 메인 스크립트.
|
||||
- **Key Definitions (핵심 변경점)**:
|
||||
- **Setup**: 설정 파일을 로드하고, 데이터셋 로더와 `OfflineAgent`를 인스턴스화합니다. (`MyCustomEnv`는 이 단계에서 필요 없을 수 있습니다.)
|
||||
- **Training Loop**:
|
||||
- **환경과의 상호작용 루프가 사라집니다.**
|
||||
- 대신, 지도 학습(supervised learning)과 유사한 루프를 가집니다.
|
||||
- `for epoch in range(num_epochs):`
|
||||
- `batch = dataset.sample(batch_size)`
|
||||
- `agent.learn(batch)`
|
||||
- **Evaluation**:
|
||||
- 학습 루프가 끝난 후, `MyCustomEnv`를 인스턴스화합니다.
|
||||
- 학습된 `agent.get_action()`을 사용하여 환경과 상호작용하며 성능을 측정하는 별도의 평가 루프를 실행합니다.
|
||||
19
config.py
19
config.py
|
|
@ -1,19 +0,0 @@
|
|||
# config.py
|
||||
|
||||
# --- Training Hyperparameters ---
|
||||
LEARNING_RATE = 0.1
|
||||
GAMMA = 0.99
|
||||
TOTAL_EPISODES = 10000
|
||||
|
||||
# --- Epsilon Parameters ---
|
||||
EPSILON_START = 1.0
|
||||
EPSILON_END = 0.01
|
||||
EPSILON_DECAY_RATE = 0.0005
|
||||
|
||||
# --- Environment Parameters ---
|
||||
SCENARIO = 0 # 0: A, 1: B, 2: C, 3: D
|
||||
TARGET_PRICE = 100
|
||||
THRESHOLD_PRICE = 120
|
||||
|
||||
# --- File Paths ---
|
||||
Q_TABLE_SAVE_PATH = "saved_models/q_table.npy"
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -1,24 +0,0 @@
|
|||
import gymnasium as gym
|
||||
|
||||
class MyCustomEnv(gym.Env):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.action_space = gym.spaces.Discrete(2)
|
||||
self.observation_space = gym.spaces.Discrete(10)
|
||||
|
||||
def step(self, action):
|
||||
observation = self.observation_space.sample()
|
||||
reward = 1.0
|
||||
terminated = False
|
||||
truncated = False
|
||||
info = {}
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
super().reset(seed=seed)
|
||||
observation = self.observation_space.sample()
|
||||
info = {}
|
||||
return observation, info
|
||||
|
||||
def render(self, mode='human'):
|
||||
pass
|
||||
|
|
@ -1,63 +0,0 @@
|
|||
import numpy as np
|
||||
import random
|
||||
import os
|
||||
|
||||
|
||||
class QLearningAgent:
|
||||
"""Q-Table을 기반으로 행동하고 학습하는 에이전트"""
|
||||
|
||||
def __init__(self, state_dims, action_size, learning_rate, gamma, epsilon):
|
||||
self.state_dims = state_dims # [4, 3, 3]
|
||||
self.action_size = action_size
|
||||
self.lr = learning_rate
|
||||
self.gamma = gamma
|
||||
self.epsilon = epsilon
|
||||
|
||||
# Q-Table 초기화: 36x9 크기의 테이블을 0으로 초기화
|
||||
num_states = np.prod(state_dims) # 4 * 3 * 3 = 36
|
||||
self.q_table = np.zeros((num_states, action_size))
|
||||
|
||||
def _state_to_index(self, state):
|
||||
"""MultiDiscrete 상태 [s0, s1, s2]를 단일 정수 인덱스로 변환"""
|
||||
idx = (
|
||||
state[0] * (self.state_dims[1] * self.state_dims[2])
|
||||
+ state[1] * self.state_dims[2]
|
||||
+ state[2]
|
||||
)
|
||||
return int(idx)
|
||||
|
||||
def get_action(self, state):
|
||||
"""Epsilon-Greedy 정책에 따라 행동 선택"""
|
||||
if random.uniform(0, 1) < self.epsilon:
|
||||
return random.randint(0, self.action_size - 1) # 탐험 (무작위 행동)
|
||||
else:
|
||||
state_idx = self._state_to_index(state)
|
||||
return np.argmax(self.q_table[state_idx, :]) # 활용 (Q값이 가장 높은 행동)
|
||||
|
||||
# ===================================================================
|
||||
# ## 3. 학습 알고리즘 (Learning Algorithm): Q-Table 업데이트 규칙
|
||||
# ===================================================================
|
||||
def learn(self, state, action, reward, next_state):
|
||||
"""경험 데이터를 바탕으로 Q-Table을 업데이트 (Q-러닝 공식)"""
|
||||
state_idx = self._state_to_index(state)
|
||||
next_state_idx = self._state_to_index(next_state)
|
||||
|
||||
old_value = self.q_table[state_idx, action]
|
||||
next_max = np.max(self.q_table[next_state_idx, :])
|
||||
|
||||
# Bellman Equation
|
||||
new_value = old_value + self.lr * (reward + self.gamma * next_max - old_value)
|
||||
self.q_table[state_idx, action] = new_value
|
||||
|
||||
def save_q_table(self, file_path):
|
||||
"""Q-Table을 파일로 저장합니다."""
|
||||
np.save(file_path, self.q_table)
|
||||
print(f"Q-Table saved to {file_path}")
|
||||
|
||||
def load_q_table(self, file_path):
|
||||
"""파일로부터 Q-Table을 불러옵니다."""
|
||||
if os.path.exists(file_path):
|
||||
self.q_table = np.load(file_path)
|
||||
print(f"Q-Table loaded from {file_path}")
|
||||
else:
|
||||
print(f"Error: No Q-Table found at {file_path}")
|
||||
29
train.py
29
train.py
|
|
@ -1,29 +0,0 @@
|
|||
from negotiation_agent.environment import NegotiationEnv
|
||||
from negotiation_agent.agent import QLearningAgent
|
||||
from usecases.train_agent_usecase import TrainAgentUseCase # 유스케이스 임포트
|
||||
import config
|
||||
|
||||
|
||||
def main():
|
||||
# 1. 의존성(객체) 생성: 필요한 모든 '재료'를 준비합니다.
|
||||
env = NegotiationEnv(
|
||||
scenario=config.SCENARIO,
|
||||
target_price=config.TARGET_PRICE,
|
||||
threshold_price=config.THRESHOLD_PRICE,
|
||||
)
|
||||
|
||||
agent = QLearningAgent(
|
||||
state_dims=env.observation_space.nvec,
|
||||
action_size=env.action_space.n,
|
||||
learning_rate=config.LEARNING_RATE,
|
||||
gamma=config.GAMMA,
|
||||
epsilon=config.EPSILON_START,
|
||||
)
|
||||
|
||||
# 2. 유스케이스 생성 및 실행: 준비된 재료로 '요리사'에게 '요리'를 지시합니다.
|
||||
train_use_case = TrainAgentUseCase(env=env, agent=agent)
|
||||
train_use_case.execute()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue