114 lines
3.3 KiB
Python
114 lines
3.3 KiB
Python
from gymnasium import spaces
|
|
from typing import Dict, List, Any
|
|
from dataclasses import dataclass
|
|
from enum import Enum, auto
|
|
from negotiation_agent.action_space import ActionSpace, ActionInfo
|
|
|
|
|
|
class Scenario(Enum):
|
|
HIGH_INTENTION = 0
|
|
MEDIUM_INTENTION = 1
|
|
LOW_INTENTION = 2
|
|
VERY_LOW_INTENTION = 3
|
|
|
|
@property
|
|
def description(self) -> str:
|
|
return {
|
|
self.HIGH_INTENTION: "높은 구매 의지",
|
|
self.MEDIUM_INTENTION: "중간 구매 의지",
|
|
self.LOW_INTENTION: "낮은 구매 의지",
|
|
self.VERY_LOW_INTENTION: "매우 낮은 구매 의지",
|
|
}[self]
|
|
|
|
|
|
class PriceZone(Enum):
|
|
BELOW_TARGET = 0
|
|
BETWEEN_TARGET_AND_THRESHOLD = 1
|
|
ABOVE_THRESHOLD = 2
|
|
|
|
@property
|
|
def description(self) -> str:
|
|
return {
|
|
self.BELOW_TARGET: "목표가격 이하",
|
|
self.BETWEEN_TARGET_AND_THRESHOLD: "목표가격~임계가격",
|
|
self.ABOVE_THRESHOLD: "임계가격 초과",
|
|
}[self]
|
|
|
|
|
|
class AcceptanceRate(Enum):
|
|
LOW = 0
|
|
MEDIUM = 1
|
|
HIGH = 2
|
|
|
|
@property
|
|
def description(self) -> str:
|
|
return {
|
|
self.LOW: "낮음 (<10%)",
|
|
self.MEDIUM: "중간 (10-25%)",
|
|
self.HIGH: "높음 (>25%)",
|
|
}[self]
|
|
|
|
|
|
@dataclass
|
|
class State:
|
|
scenario: Scenario
|
|
price_zone: PriceZone
|
|
acceptance_rate: AcceptanceRate
|
|
|
|
def to_array(self) -> List[int]:
|
|
return [self.scenario.value, self.price_zone.value, self.acceptance_rate.value]
|
|
|
|
@classmethod
|
|
def from_array(cls, arr: List[int]) -> "State":
|
|
return cls(
|
|
scenario=Scenario(arr[0]),
|
|
price_zone=PriceZone(arr[1]),
|
|
acceptance_rate=AcceptanceRate(arr[2]),
|
|
)
|
|
|
|
def __str__(self) -> str:
|
|
return (
|
|
f"State(scenario={self.scenario.description}, "
|
|
f"price_zone={self.price_zone.description}, "
|
|
f"acceptance_rate={self.acceptance_rate.description})"
|
|
)
|
|
|
|
|
|
class NegotiationSpaces:
|
|
def __init__(self):
|
|
self._action_space = ActionSpace()
|
|
|
|
@property
|
|
def observation_space(self) -> spaces.MultiDiscrete:
|
|
return spaces.MultiDiscrete(
|
|
[len(Scenario), len(PriceZone), len(AcceptanceRate)]
|
|
)
|
|
|
|
@property
|
|
def action_space(self) -> spaces.Discrete:
|
|
return spaces.Discrete(self._action_space.action_space_size)
|
|
|
|
def decode_action(self, action_id: int) -> ActionInfo:
|
|
return self._action_space.get_action(action_id)
|
|
|
|
def encode_state(self, state: State) -> List[int]:
|
|
return state.to_array()
|
|
|
|
def decode_state(self, state_array: List[int]) -> State:
|
|
return State.from_array(state_array)
|
|
|
|
def get_action_description(self, action_id: int) -> str:
|
|
return self.decode_action(action_id).description
|
|
|
|
def get_state_description(self, state_array: List[int]) -> str:
|
|
return str(self.decode_state(state_array))
|
|
|
|
def get_actions_by_category(self, category: str) -> List[ActionInfo]:
|
|
return self._action_space.get_actions_by_category(category)
|
|
|
|
def get_actions_by_strength(self, strength: str) -> List[ActionInfo]:
|
|
return self._action_space.get_actions_by_strength(strength)
|
|
|
|
def list_all_actions(self) -> List[ActionInfo]:
|
|
return self._action_space.list_actions()
|