KT_Q_Table/tests/test_load_q_table_usecase.py

34 lines
1.1 KiB
Python

import unittest
import numpy as np
import os
from agents.offline_agent import QLearningAgent
from usecases.load_q_table_usecase import LoadQTableUseCase
class TestLoadQTableUseCase(unittest.TestCase):
def setUp(self):
self.agent_params = {'learning_rate': 0.1, 'discount_factor': 0.9}
self.state_size = 10
self.action_size = 2
self.agent = QLearningAgent(self.agent_params, self.state_size, self.action_size)
self.use_case = LoadQTableUseCase()
self.test_q_table_path = 'test_q_table.npy'
def tearDown(self):
if os.path.exists(self.test_q_table_path):
os.remove(self.test_q_table_path)
def test_execute(self):
# Create a dummy Q-table file
dummy_q_table = np.random.rand(self.state_size, self.action_size)
np.save(self.test_q_table_path, dummy_q_table)
# Execute the use case
self.use_case.execute(self.agent, self.test_q_table_path)
# Assert that the Q-table is loaded correctly
self.assertTrue(np.array_equal(self.agent.q_table, dummy_q_table))
if __name__ == '__main__':
unittest.main()