28 lines
875 B
Python
28 lines
875 B
Python
import unittest
|
|
|
|
from agents.offline_agent import QLearningAgent
|
|
from usecases.update_q_table_usecase import UpdateQTableUseCase
|
|
|
|
class TestUpdateQTableUseCase(unittest.TestCase):
|
|
def setUp(self):
|
|
self.agent_params = {'learning_rate': 0.1, 'discount_factor': 0.9}
|
|
self.state_size = 10
|
|
self.action_size = 2
|
|
self.agent = QLearningAgent(self.agent_params, self.state_size, self.action_size)
|
|
self.use_case = UpdateQTableUseCase()
|
|
|
|
def test_execute(self):
|
|
# Define the update parameters
|
|
state = 4
|
|
action = 0
|
|
new_value = 0.99
|
|
|
|
# Execute the use case
|
|
self.use_case.execute(self.agent, state, action, new_value)
|
|
|
|
# Assert that the Q-value is updated correctly
|
|
self.assertEqual(self.agent.q_table[state, action], new_value)
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|