KT_Q_Table/train_offline.py

50 lines
1.6 KiB
Python

import h5py
import numpy as np
import yaml
import os
from agents.offline_agent import QLearningAgent
def main():
with open("configs/offline_env_config.yaml", "r") as f:
config = yaml.safe_load(f)
dataset_path = config["dataset_params"]["path"]
batch_size = config["dataset_params"]["batch_size"]
with h5py.File(dataset_path, 'r') as f:
observations = f["observations"][:]
actions = f["actions"][:]
rewards = f["rewards"][:]
next_observations = f["next_observations"][:]
terminals = f["terminals"][:]
state_size = len(np.unique(np.concatenate((observations, next_observations))))
action_size = len(np.unique(actions))
agent = QLearningAgent(config["agent_params"], state_size, action_size)
num_epochs = 10
for epoch in range(num_epochs):
for i in range(0, len(observations), batch_size):
batch_indices = np.arange(i, min(i + batch_size, len(observations)))
batch = {
"observations": observations[batch_indices],
"actions": actions[batch_indices],
"rewards": rewards[batch_indices],
"next_observations": next_observations[batch_indices],
"terminals": terminals[batch_indices],
}
agent.learn(batch)
# Save the model
saved_models_dir = "saved_models"
os.makedirs(saved_models_dir, exist_ok=True)
model_path = os.path.join(saved_models_dir, "q_table.npy")
agent.save_model(model_path)
print(f"Model saved to {model_path}")
if __name__ == "__main__":
main()