30 lines
923 B
Python
30 lines
923 B
Python
import unittest
|
|
import numpy as np
|
|
|
|
from agents.offline_agent import QLearningAgent
|
|
from usecases.get_q_value_usecase import GetQValueUseCase
|
|
|
|
class TestGetQValueUseCase(unittest.TestCase):
|
|
def setUp(self):
|
|
self.agent_params = {'learning_rate': 0.1, 'discount_factor': 0.9}
|
|
self.state_size = 10
|
|
self.action_size = 2
|
|
self.agent = QLearningAgent(self.agent_params, self.state_size, self.action_size)
|
|
self.use_case = GetQValueUseCase()
|
|
|
|
def test_execute(self):
|
|
# Set a specific Q-value
|
|
state = 3
|
|
action = 1
|
|
expected_q_value = 0.75
|
|
self.agent.q_table[state, action] = expected_q_value
|
|
|
|
# Execute the use case
|
|
q_value = self.use_case.execute(self.agent, state, action)
|
|
|
|
# Assert that the returned Q-value is correct
|
|
self.assertEqual(q_value, expected_q_value)
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|