import h5py import numpy as np import yaml import os from agents.offline_agent import QLearningAgent def main(): with open("configs/offline_env_config.yaml", "r") as f: config = yaml.safe_load(f) dataset_path = config["dataset_params"]["path"] batch_size = config["dataset_params"]["batch_size"] with h5py.File(dataset_path, 'r') as f: observations = f["observations"][:] actions = f["actions"][:] rewards = f["rewards"][:] next_observations = f["next_observations"][:] terminals = f["terminals"][:] state_size = len(np.unique(np.concatenate((observations, next_observations)))) action_size = len(np.unique(actions)) agent = QLearningAgent(config["agent_params"], state_size, action_size) num_epochs = 10 for epoch in range(num_epochs): for i in range(0, len(observations), batch_size): batch_indices = np.arange(i, min(i + batch_size, len(observations))) batch = { "observations": observations[batch_indices], "actions": actions[batch_indices], "rewards": rewards[batch_indices], "next_observations": next_observations[batch_indices], "terminals": terminals[batch_indices], } agent.learn(batch) # Save the model saved_models_dir = "saved_models" os.makedirs(saved_models_dir, exist_ok=True) model_path = os.path.join(saved_models_dir, "q_table.npy") agent.save_model(model_path) print(f"Model saved to {model_path}") if __name__ == "__main__": main()