67 lines
2.3 KiB
Python
67 lines
2.3 KiB
Python
from negotiation_agent.environment import NegotiationEnv
|
|
from agents.offline_agent import QLearningAgent
|
|
import yaml
|
|
import numpy as np
|
|
|
|
|
|
def main():
|
|
# 환경 설정 로드
|
|
with open('configs/offline_env_config.yaml', 'r') as f:
|
|
config = yaml.safe_load(f)
|
|
|
|
# 환경 초기화
|
|
env = NegotiationEnv(
|
|
scenario=config['env']['scenario'],
|
|
target_price=config['env']['target_price'],
|
|
threshold_price=config['env']['threshold_price']
|
|
)
|
|
|
|
# 에이전트 초기화 및 Q-table 로드
|
|
state_dims = env.observation_space.nvec
|
|
state_size = np.prod(state_dims) # 전체 상태 공간 크기
|
|
action_size = env.action_space.n
|
|
agent = QLearningAgent(config['agent'], state_size, action_size)
|
|
agent.load_q_table('saved_models/q_table.npy')
|
|
|
|
print(f"State space size: {state_size}")
|
|
print(f"Action space size: {action_size}")
|
|
print(f"Q-table shape: {agent.q_table.shape}")
|
|
|
|
# 평가 실행
|
|
num_episodes = 10
|
|
total_rewards = []
|
|
|
|
for episode in range(num_episodes):
|
|
state, _ = env.reset()
|
|
episode_reward = 0
|
|
done = False
|
|
|
|
while not done:
|
|
# 상태를 인덱스로 변환
|
|
state_idx = np.ravel_multi_index(tuple(state), env.observation_space.nvec)
|
|
# 최적의 행동 선택
|
|
action = np.argmax(agent.q_table[state_idx])
|
|
|
|
# 환경에서 한 스텝 진행
|
|
next_state, reward, done, _, _ = env.step(action)
|
|
episode_reward += reward
|
|
state = next_state
|
|
|
|
# 현재 상태 출력
|
|
print(f"Episode {episode + 1}")
|
|
print(f"State: {env.spaces.get_state_description(state)}")
|
|
print(f"Action: {env.spaces.get_action_description(action)}")
|
|
print(f"Reward: {reward:.2f}")
|
|
print(f"Current Price: {env.current_price:.2f}")
|
|
print("--------------------")
|
|
|
|
total_rewards.append(episode_reward)
|
|
print(f"Episode {episode + 1} finished with total reward: {episode_reward:.2f}")
|
|
print("========================================")
|
|
|
|
print(f"Average reward over {num_episodes} episodes: {np.mean(total_rewards):.2f}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|