KT_Q_Table/data_collector.py

50 lines
1.4 KiB
Python

import h5py
import numpy as np
import yaml
from envs.my_custom_env import MyCustomEnv
def main():
with open("configs/offline_env_config.yaml", "r") as f:
config = yaml.safe_load(f)
env = MyCustomEnv()
dataset_path = config["dataset_params"]["path"]
num_episodes = 10
max_steps_per_episode = 100
with h5py.File(dataset_path, 'w') as f:
observations = []
actions = []
rewards = []
next_observations = []
terminals = []
for episode in range(num_episodes):
obs, _ = env.reset()
for step in range(max_steps_per_episode):
action = env.action_space.sample()
next_obs, reward, terminated, _, _ = env.step(action)
observations.append(obs)
actions.append(action)
rewards.append(reward)
next_observations.append(next_obs)
terminals.append(terminated)
obs = next_obs
if terminated:
break
f.create_dataset("observations", data=np.array(observations))
f.create_dataset("actions", data=np.array(actions))
f.create_dataset("rewards", data=np.array(rewards))
f.create_dataset("next_observations", data=np.array(next_observations))
f.create_dataset("terminals", data=np.array(terminals))
if __name__ == "__main__":
main()