34 lines
1.1 KiB
Python
34 lines
1.1 KiB
Python
import unittest
|
|
import numpy as np
|
|
import os
|
|
|
|
from agents.offline_agent import QLearningAgent
|
|
from usecases.load_q_table_usecase import LoadQTableUseCase
|
|
|
|
class TestLoadQTableUseCase(unittest.TestCase):
|
|
def setUp(self):
|
|
self.agent_params = {'learning_rate': 0.1, 'discount_factor': 0.9}
|
|
self.state_size = 10
|
|
self.action_size = 2
|
|
self.agent = QLearningAgent(self.agent_params, self.state_size, self.action_size)
|
|
self.use_case = LoadQTableUseCase()
|
|
self.test_q_table_path = 'test_q_table.npy'
|
|
|
|
def tearDown(self):
|
|
if os.path.exists(self.test_q_table_path):
|
|
os.remove(self.test_q_table_path)
|
|
|
|
def test_execute(self):
|
|
# Create a dummy Q-table file
|
|
dummy_q_table = np.random.rand(self.state_size, self.action_size)
|
|
np.save(self.test_q_table_path, dummy_q_table)
|
|
|
|
# Execute the use case
|
|
self.use_case.execute(self.agent, self.test_q_table_path)
|
|
|
|
# Assert that the Q-table is loaded correctly
|
|
self.assertTrue(np.array_equal(self.agent.q_table, dummy_q_table))
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|