KT_Q_Table/negotiation_agent/spaces.py

114 lines
3.3 KiB
Python

from gymnasium import spaces
from typing import Dict, List, Any
from dataclasses import dataclass
from enum import Enum, auto
from negotiation_agent.action_space import ActionSpace, ActionInfo
class Scenario(Enum):
HIGH_INTENTION = 0
MEDIUM_INTENTION = 1
LOW_INTENTION = 2
VERY_LOW_INTENTION = 3
@property
def description(self) -> str:
return {
self.HIGH_INTENTION: "높은 구매 의지",
self.MEDIUM_INTENTION: "중간 구매 의지",
self.LOW_INTENTION: "낮은 구매 의지",
self.VERY_LOW_INTENTION: "매우 낮은 구매 의지",
}[self]
class PriceZone(Enum):
BELOW_TARGET = 0
BETWEEN_TARGET_AND_THRESHOLD = 1
ABOVE_THRESHOLD = 2
@property
def description(self) -> str:
return {
self.BELOW_TARGET: "목표가격 이하",
self.BETWEEN_TARGET_AND_THRESHOLD: "목표가격~임계가격",
self.ABOVE_THRESHOLD: "임계가격 초과",
}[self]
class AcceptanceRate(Enum):
LOW = 0
MEDIUM = 1
HIGH = 2
@property
def description(self) -> str:
return {
self.LOW: "낮음 (<10%)",
self.MEDIUM: "중간 (10-25%)",
self.HIGH: "높음 (>25%)",
}[self]
@dataclass
class State:
scenario: Scenario
price_zone: PriceZone
acceptance_rate: AcceptanceRate
def to_array(self) -> List[int]:
return [self.scenario.value, self.price_zone.value, self.acceptance_rate.value]
@classmethod
def from_array(cls, arr: List[int]) -> "State":
return cls(
scenario=Scenario(arr[0]),
price_zone=PriceZone(arr[1]),
acceptance_rate=AcceptanceRate(arr[2]),
)
def __str__(self) -> str:
return (
f"State(scenario={self.scenario.description}, "
f"price_zone={self.price_zone.description}, "
f"acceptance_rate={self.acceptance_rate.description})"
)
class NegotiationSpaces:
def __init__(self):
self._action_space = ActionSpace()
@property
def observation_space(self) -> spaces.MultiDiscrete:
return spaces.MultiDiscrete(
[len(Scenario), len(PriceZone), len(AcceptanceRate)]
)
@property
def action_space(self) -> spaces.Discrete:
return spaces.Discrete(self._action_space.action_space_size)
def decode_action(self, action_id: int) -> ActionInfo:
return self._action_space.get_action(action_id)
def encode_state(self, state: State) -> List[int]:
return state.to_array()
def decode_state(self, state_array: List[int]) -> State:
return State.from_array(state_array)
def get_action_description(self, action_id: int) -> str:
return self.decode_action(action_id).description
def get_state_description(self, state_array: List[int]) -> str:
return str(self.decode_state(state_array))
def get_actions_by_category(self, category: str) -> List[ActionInfo]:
return self._action_space.get_actions_by_category(category)
def get_actions_by_strength(self, strength: str) -> List[ActionInfo]:
return self._action_space.get_actions_by_strength(strength)
def list_all_actions(self) -> List[ActionInfo]:
return self._action_space.list_actions()