KT_Q_Table/data_collector.py

75 lines
2.4 KiB
Python

import numpy as np
import yaml
import json
import os
from datetime import datetime
from negotiation_agent.environment import NegotiationEnv
from negotiation_agent.spaces import NegotiationSpaces
def main():
with open("configs/offline_env_config.yaml", "r") as f:
config = yaml.safe_load(f)
env = NegotiationEnv()
spaces = NegotiationSpaces()
num_episodes = 10
max_steps_per_episode = 100
# 데이터를 저장할 리스트
episodes_data = []
for episode in range(num_episodes):
episode_data = {
"episode_id": episode,
"timestamp": datetime.now().isoformat(),
"steps": []
}
obs, _ = env.reset()
episode_reward = 0
for step in range(max_steps_per_episode):
# 행동 선택 및 환경과 상호작용
action = env.action_space.sample()
next_obs, reward, terminated, _, _ = env.step(action)
episode_reward += reward
# 스텝 데이터 저장
step_data = {
"step": step,
"state": spaces.get_state_description(obs),
"action": spaces.get_action_description(action),
"reward": float(reward),
"next_state": spaces.get_state_description(next_obs),
"current_price": float(env.current_price),
"terminated": terminated
}
episode_data["steps"].append(step_data)
obs = next_obs
if terminated:
break
episode_data["total_reward"] = float(episode_reward)
episode_data["num_steps"] = len(episode_data["steps"])
episodes_data.append(episode_data)
# JSON 파일로 저장
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
json_path = f"logs/collected_data_{timestamp}.json"
os.makedirs("logs", exist_ok=True)
with open(json_path, 'w', encoding='utf-8') as f:
json.dump(episodes_data, f, ensure_ascii=False, indent=2)
print(f"Data collected and saved to {json_path}")
print(f"Total episodes: {len(episodes_data)}")
print(f"Average steps per episode: {sum(ep['num_steps'] for ep in episodes_data) / len(episodes_data):.2f}")
print(f"Average reward per episode: {sum(ep['total_reward'] for ep in episodes_data) / len(episodes_data):.2f}")
if __name__ == "__main__":
main()