105 lines
3.6 KiB
Python
105 lines
3.6 KiB
Python
import h5py
|
|
import numpy as np
|
|
import yaml
|
|
import os
|
|
import json
|
|
from datetime import datetime
|
|
|
|
from agents.offline_agent import QLearningAgent
|
|
from negotiation_agent.spaces import NegotiationSpaces
|
|
from negotiation_agent.environment import NegotiationEnv
|
|
|
|
def main():
|
|
with open("configs/offline_env_config.yaml", "r") as f:
|
|
config = yaml.safe_load(f)
|
|
|
|
dataset_path = config["dataset_params"]["path"]
|
|
batch_size = config["dataset_params"]["batch_size"]
|
|
|
|
with h5py.File(dataset_path, 'r') as f:
|
|
observations = f["observations"][:]
|
|
actions = f["actions"][:]
|
|
rewards = f["rewards"][:]
|
|
next_observations = f["next_observations"][:]
|
|
terminals = f["terminals"][:]
|
|
|
|
from negotiation_agent.environment import NegotiationEnv
|
|
env = NegotiationEnv()
|
|
state_size = np.prod(env.observation_space.nvec) # 4 * 3 * 3 = 36
|
|
action_size = env.action_space.n # 9
|
|
|
|
agent = QLearningAgent(config["agent"], state_size, action_size) # config["agent"]로 수정
|
|
|
|
num_epochs = 10
|
|
for epoch in range(num_epochs):
|
|
for i in range(0, len(observations), batch_size):
|
|
batch_indices = np.arange(i, min(i + batch_size, len(observations)))
|
|
|
|
batch = {
|
|
"observations": observations[batch_indices],
|
|
"actions": actions[batch_indices],
|
|
"rewards": rewards[batch_indices],
|
|
"next_observations": next_observations[batch_indices],
|
|
"terminals": terminals[batch_indices],
|
|
}
|
|
agent.learn(batch)
|
|
|
|
# 모델 저장 (npy 형식)
|
|
saved_models_dir = "saved_models"
|
|
os.makedirs(saved_models_dir, exist_ok=True)
|
|
model_path = os.path.join(saved_models_dir, "q_table.npy")
|
|
np.save(model_path, agent.q_table)
|
|
|
|
# Q-table을 JSON 형식으로도 저장
|
|
spaces = NegotiationSpaces()
|
|
q_table_data = {
|
|
"metadata": {
|
|
"state_size": int(state_size),
|
|
"action_size": int(action_size),
|
|
"timestamp": datetime.now().isoformat(),
|
|
"training_episodes": int(num_epochs)
|
|
},
|
|
"q_values": []
|
|
}
|
|
|
|
# 각 상태에 대한 Q-값을 저장
|
|
for state_idx in range(state_size):
|
|
state_indices = np.unravel_index(state_idx, env.observation_space.nvec)
|
|
state_data = {
|
|
"state_idx": int(state_idx),
|
|
"state_desc": spaces.get_state_description(
|
|
[int(idx) for idx in state_indices]
|
|
),
|
|
"actions": []
|
|
}
|
|
|
|
# 각 행동에 대한 Q-값을 저장
|
|
for action_idx in range(action_size):
|
|
action_data = {
|
|
"action_idx": int(action_idx),
|
|
"action_desc": spaces.get_action_description(action_idx),
|
|
"q_value": float(agent.q_table[state_idx, action_idx])
|
|
}
|
|
state_data["actions"].append(action_data)
|
|
|
|
# 최적 행동 정보 추가
|
|
optimal_action_idx = int(np.argmax(agent.q_table[state_idx]))
|
|
state_data["optimal_action"] = {
|
|
"action_idx": optimal_action_idx,
|
|
"action_desc": spaces.get_action_description(optimal_action_idx),
|
|
"q_value": float(agent.q_table[state_idx, optimal_action_idx])
|
|
}
|
|
|
|
q_table_data["q_values"].append(state_data)
|
|
|
|
# JSON 파일로 저장
|
|
json_path = os.path.join(saved_models_dir, "q_table.json")
|
|
with open(json_path, 'w', encoding='utf-8') as f:
|
|
json.dump(q_table_data, f, ensure_ascii=False, indent=2)
|
|
|
|
print(f"Model saved to {model_path}")
|
|
print(f"Q-table JSON saved to {json_path}")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|