50 lines
1.6 KiB
Python
50 lines
1.6 KiB
Python
import h5py
|
|
import numpy as np
|
|
import yaml
|
|
import os
|
|
|
|
from agents.offline_agent import QLearningAgent
|
|
|
|
def main():
|
|
with open("configs/offline_env_config.yaml", "r") as f:
|
|
config = yaml.safe_load(f)
|
|
|
|
dataset_path = config["dataset_params"]["path"]
|
|
batch_size = config["dataset_params"]["batch_size"]
|
|
|
|
with h5py.File(dataset_path, 'r') as f:
|
|
observations = f["observations"][:]
|
|
actions = f["actions"][:]
|
|
rewards = f["rewards"][:]
|
|
next_observations = f["next_observations"][:]
|
|
terminals = f["terminals"][:]
|
|
|
|
state_size = len(np.unique(np.concatenate((observations, next_observations))))
|
|
action_size = len(np.unique(actions))
|
|
|
|
agent = QLearningAgent(config["agent_params"], state_size, action_size)
|
|
|
|
num_epochs = 10
|
|
for epoch in range(num_epochs):
|
|
for i in range(0, len(observations), batch_size):
|
|
batch_indices = np.arange(i, min(i + batch_size, len(observations)))
|
|
|
|
batch = {
|
|
"observations": observations[batch_indices],
|
|
"actions": actions[batch_indices],
|
|
"rewards": rewards[batch_indices],
|
|
"next_observations": next_observations[batch_indices],
|
|
"terminals": terminals[batch_indices],
|
|
}
|
|
agent.learn(batch)
|
|
|
|
# Save the model
|
|
saved_models_dir = "saved_models"
|
|
os.makedirs(saved_models_dir, exist_ok=True)
|
|
model_path = os.path.join(saved_models_dir, "q_table.npy")
|
|
agent.save_model(model_path)
|
|
print(f"Model saved to {model_path}")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|