50 lines
1.4 KiB
Python
50 lines
1.4 KiB
Python
import h5py
|
|
import numpy as np
|
|
import yaml
|
|
|
|
from envs.my_custom_env import MyCustomEnv
|
|
|
|
|
|
def main():
|
|
with open("configs/offline_env_config.yaml", "r") as f:
|
|
config = yaml.safe_load(f)
|
|
|
|
env = MyCustomEnv()
|
|
dataset_path = config["dataset_params"]["path"]
|
|
|
|
num_episodes = 10
|
|
max_steps_per_episode = 100
|
|
|
|
with h5py.File(dataset_path, 'w') as f:
|
|
observations = []
|
|
actions = []
|
|
rewards = []
|
|
next_observations = []
|
|
terminals = []
|
|
|
|
for episode in range(num_episodes):
|
|
obs, _ = env.reset()
|
|
for step in range(max_steps_per_episode):
|
|
action = env.action_space.sample()
|
|
next_obs, reward, terminated, _, _ = env.step(action)
|
|
|
|
observations.append(obs)
|
|
actions.append(action)
|
|
rewards.append(reward)
|
|
next_observations.append(next_obs)
|
|
terminals.append(terminated)
|
|
|
|
obs = next_obs
|
|
|
|
if terminated:
|
|
break
|
|
|
|
f.create_dataset("observations", data=np.array(observations))
|
|
f.create_dataset("actions", data=np.array(actions))
|
|
f.create_dataset("rewards", data=np.array(rewards))
|
|
f.create_dataset("next_observations", data=np.array(next_observations))
|
|
f.create_dataset("terminals", data=np.array(terminals))
|
|
|
|
if __name__ == "__main__":
|
|
main()
|