from gymnasium import spaces from typing import Dict, List, Any from dataclasses import dataclass from enum import Enum, auto from negotiation_agent.action_space import ActionSpace, ActionInfo class Scenario(Enum): HIGH_INTENTION = 0 MEDIUM_INTENTION = 1 LOW_INTENTION = 2 VERY_LOW_INTENTION = 3 @property def description(self) -> str: return { self.HIGH_INTENTION: "높은 구매 의지", self.MEDIUM_INTENTION: "중간 구매 의지", self.LOW_INTENTION: "낮은 구매 의지", self.VERY_LOW_INTENTION: "매우 낮은 구매 의지", }[self] class PriceZone(Enum): BELOW_TARGET = 0 BETWEEN_TARGET_AND_THRESHOLD = 1 ABOVE_THRESHOLD = 2 @property def description(self) -> str: return { self.BELOW_TARGET: "목표가격 이하", self.BETWEEN_TARGET_AND_THRESHOLD: "목표가격~임계가격", self.ABOVE_THRESHOLD: "임계가격 초과", }[self] class AcceptanceRate(Enum): LOW = 0 MEDIUM = 1 HIGH = 2 @property def description(self) -> str: return { self.LOW: "낮음 (<10%)", self.MEDIUM: "중간 (10-25%)", self.HIGH: "높음 (>25%)", }[self] @dataclass class State: scenario: Scenario price_zone: PriceZone acceptance_rate: AcceptanceRate def to_array(self) -> List[int]: return [self.scenario.value, self.price_zone.value, self.acceptance_rate.value] @classmethod def from_array(cls, arr: List[int]) -> "State": return cls( scenario=Scenario(arr[0]), price_zone=PriceZone(arr[1]), acceptance_rate=AcceptanceRate(arr[2]), ) def __str__(self) -> str: return ( f"State(scenario={self.scenario.description}, " f"price_zone={self.price_zone.description}, " f"acceptance_rate={self.acceptance_rate.description})" ) class NegotiationSpaces: def __init__(self): self._action_space = ActionSpace() @property def observation_space(self) -> spaces.MultiDiscrete: return spaces.MultiDiscrete( [len(Scenario), len(PriceZone), len(AcceptanceRate)] ) @property def action_space(self) -> spaces.Discrete: return spaces.Discrete(self._action_space.action_space_size) def decode_action(self, action_id: int) -> ActionInfo: return self._action_space.get_action(action_id) def encode_state(self, state: State) -> List[int]: return state.to_array() def decode_state(self, state_array: List[int]) -> State: return State.from_array(state_array) def get_action_description(self, action_id: int) -> str: return self.decode_action(action_id).description def get_state_description(self, state_array: List[int]) -> str: return str(self.decode_state(state_array)) def get_actions_by_category(self, category: str) -> List[ActionInfo]: return self._action_space.get_actions_by_category(category) def get_actions_by_strength(self, strength: str) -> List[ActionInfo]: return self._action_space.get_actions_by_strength(strength) def list_all_actions(self) -> List[ActionInfo]: return self._action_space.list_actions()