KT_Q_Table/train_offline.py

105 lines
3.6 KiB
Python

import h5py
import numpy as np
import yaml
import os
import json
from datetime import datetime
from agents.offline_agent import QLearningAgent
from negotiation_agent.spaces import NegotiationSpaces
from negotiation_agent.environment import NegotiationEnv
def main():
with open("configs/offline_env_config.yaml", "r") as f:
config = yaml.safe_load(f)
dataset_path = config["dataset_params"]["path"]
batch_size = config["dataset_params"]["batch_size"]
with h5py.File(dataset_path, 'r') as f:
observations = f["observations"][:]
actions = f["actions"][:]
rewards = f["rewards"][:]
next_observations = f["next_observations"][:]
terminals = f["terminals"][:]
from negotiation_agent.environment import NegotiationEnv
env = NegotiationEnv()
state_size = np.prod(env.observation_space.nvec) # 4 * 3 * 3 = 36
action_size = env.action_space.n # 9
agent = QLearningAgent(config["agent"], state_size, action_size) # config["agent"]로 수정
num_epochs = 10
for epoch in range(num_epochs):
for i in range(0, len(observations), batch_size):
batch_indices = np.arange(i, min(i + batch_size, len(observations)))
batch = {
"observations": observations[batch_indices],
"actions": actions[batch_indices],
"rewards": rewards[batch_indices],
"next_observations": next_observations[batch_indices],
"terminals": terminals[batch_indices],
}
agent.learn(batch)
# 모델 저장 (npy 형식)
saved_models_dir = "saved_models"
os.makedirs(saved_models_dir, exist_ok=True)
model_path = os.path.join(saved_models_dir, "q_table.npy")
np.save(model_path, agent.q_table)
# Q-table을 JSON 형식으로도 저장
spaces = NegotiationSpaces()
q_table_data = {
"metadata": {
"state_size": int(state_size),
"action_size": int(action_size),
"timestamp": datetime.now().isoformat(),
"training_episodes": int(num_epochs)
},
"q_values": []
}
# 각 상태에 대한 Q-값을 저장
for state_idx in range(state_size):
state_indices = np.unravel_index(state_idx, env.observation_space.nvec)
state_data = {
"state_idx": int(state_idx),
"state_desc": spaces.get_state_description(
[int(idx) for idx in state_indices]
),
"actions": []
}
# 각 행동에 대한 Q-값을 저장
for action_idx in range(action_size):
action_data = {
"action_idx": int(action_idx),
"action_desc": spaces.get_action_description(action_idx),
"q_value": float(agent.q_table[state_idx, action_idx])
}
state_data["actions"].append(action_data)
# 최적 행동 정보 추가
optimal_action_idx = int(np.argmax(agent.q_table[state_idx]))
state_data["optimal_action"] = {
"action_idx": optimal_action_idx,
"action_desc": spaces.get_action_description(optimal_action_idx),
"q_value": float(agent.q_table[state_idx, optimal_action_idx])
}
q_table_data["q_values"].append(state_data)
# JSON 파일로 저장
json_path = os.path.join(saved_models_dir, "q_table.json")
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(q_table_data, f, ensure_ascii=False, indent=2)
print(f"Model saved to {model_path}")
print(f"Q-table JSON saved to {json_path}")
if __name__ == "__main__":
main()