26 lines
1013 B
Python
26 lines
1013 B
Python
import unittest
|
|
|
|
from agents.offline_agent import QLearningAgent
|
|
from envs.my_custom_env import MyCustomEnv
|
|
from usecases.evaluate_agent_usecase import EvaluateAgentUseCase
|
|
|
|
class TestEvaluateAgentUseCase(unittest.TestCase):
|
|
def setUp(self):
|
|
self.agent_params = {'learning_rate': 0.1, 'discount_factor': 0.9, 'epsilon': 0.0} # Epsilon 0 for deterministic actions
|
|
self.state_size = 10
|
|
self.action_size = 2
|
|
self.agent = QLearningAgent(self.agent_params, self.state_size, self.action_size)
|
|
self.env = MyCustomEnv()
|
|
self.use_case = EvaluateAgentUseCase()
|
|
|
|
def test_execute(self):
|
|
# Execute the use case
|
|
average_reward = self.use_case.execute(self.agent, self.env, num_episodes=10)
|
|
|
|
# Assert that the average reward is within a reasonable range
|
|
# In our simple environment, the reward is always 1.0, so the average should be 1.0
|
|
self.assertGreaterEqual(average_reward, 0.0)
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|