KT_Q_Table/tests/test_get_q_value_usecase.py

30 lines
923 B
Python

import unittest
import numpy as np
from agents.offline_agent import QLearningAgent
from usecases.get_q_value_usecase import GetQValueUseCase
class TestGetQValueUseCase(unittest.TestCase):
def setUp(self):
self.agent_params = {'learning_rate': 0.1, 'discount_factor': 0.9}
self.state_size = 10
self.action_size = 2
self.agent = QLearningAgent(self.agent_params, self.state_size, self.action_size)
self.use_case = GetQValueUseCase()
def test_execute(self):
# Set a specific Q-value
state = 3
action = 1
expected_q_value = 0.75
self.agent.q_table[state, action] = expected_q_value
# Execute the use case
q_value = self.use_case.execute(self.agent, state, action)
# Assert that the returned Q-value is correct
self.assertEqual(q_value, expected_q_value)
if __name__ == '__main__':
unittest.main()