feat: 협상 에이전트 구현 개선
- action_space.py: 행동 공간 관리 로직 추가 - constants.py: 상수값 분리 및 관리 - spaces.py: 상태 및 행동 공간 정의 추가 - environment.py: 협상 환경 구현 개선main
parent
26442ca9c1
commit
1bf179bbaa
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -0,0 +1,107 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionInfo:
|
||||
id: int
|
||||
name: str
|
||||
description: str
|
||||
category: str
|
||||
strength: str
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.description} (ID: {self.id}, Category: {self.category}, Strength: {self.strength})"
|
||||
|
||||
|
||||
class ActionSpace:
|
||||
def __init__(self, config_path: Optional[str] = None):
|
||||
if config_path is None:
|
||||
config_path = os.path.join(
|
||||
Path(__file__).parent.parent, "configs", "actions.json"
|
||||
)
|
||||
self._actions: Dict[int, ActionInfo] = {}
|
||||
self._load_actions(config_path)
|
||||
|
||||
def _load_actions(self, config_path: str) -> None:
|
||||
"""JSON 파일에서 액션 정보를 로드합니다."""
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
for action in data["actions"]:
|
||||
self.add_action(
|
||||
id=action["id"],
|
||||
name=action["name"],
|
||||
description=action["description"],
|
||||
category=action["category"],
|
||||
strength=action["strength"],
|
||||
)
|
||||
|
||||
def add_action(
|
||||
self, id: int, name: str, description: str, category: str, strength: str
|
||||
) -> None:
|
||||
"""새로운 액션을 추가합니다."""
|
||||
if id in self._actions:
|
||||
raise ValueError(f"Action with id {id} already exists")
|
||||
|
||||
self._actions[id] = ActionInfo(
|
||||
id=id,
|
||||
name=name,
|
||||
description=description,
|
||||
category=category,
|
||||
strength=strength,
|
||||
)
|
||||
|
||||
def remove_action(self, action_id: int) -> None:
|
||||
"""지정된 ID의 액션을 제거합니다."""
|
||||
if action_id not in self._actions:
|
||||
raise ValueError(f"Action with id {action_id} does not exist")
|
||||
del self._actions[action_id]
|
||||
|
||||
def get_action(self, action_id: int) -> ActionInfo:
|
||||
"""액션 ID로 액션 정보를 조회합니다."""
|
||||
if action_id not in self._actions:
|
||||
raise ValueError(f"Invalid action id: {action_id}")
|
||||
return self._actions[action_id]
|
||||
|
||||
def get_actions_by_category(self, category: str) -> List[ActionInfo]:
|
||||
"""특정 카테고리의 모든 액션을 반환합니다."""
|
||||
return [
|
||||
action for action in self._actions.values() if action.category == category
|
||||
]
|
||||
|
||||
def get_actions_by_strength(self, strength: str) -> List[ActionInfo]:
|
||||
"""특정 강도의 모든 액션을 반환합니다."""
|
||||
return [
|
||||
action for action in self._actions.values() if action.strength == strength
|
||||
]
|
||||
|
||||
def save_actions(self, file_path: str) -> None:
|
||||
"""현재 액션 설정을 JSON 파일로 저장합니다."""
|
||||
data = {
|
||||
"actions": [
|
||||
{
|
||||
"id": action.id,
|
||||
"name": action.name,
|
||||
"description": action.description,
|
||||
"category": action.category,
|
||||
"strength": action.strength,
|
||||
}
|
||||
for action in self._actions.values()
|
||||
]
|
||||
}
|
||||
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=4)
|
||||
|
||||
@property
|
||||
def action_space_size(self) -> int:
|
||||
"""현재 액션 공간의 크기를 반환합니다."""
|
||||
return len(self._actions)
|
||||
|
||||
def list_actions(self) -> List[ActionInfo]:
|
||||
"""모든 액션 정보를 리스트로 반환합니다."""
|
||||
return list(self._actions.values())
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
from gymnasium import spaces
|
||||
|
||||
# Observation Space Constants
|
||||
SCENARIO_SPACE_SIZE = 4 # 시나리오 상태 수 (0-3)
|
||||
PRICE_ZONE_SIZE = 3 # 가격 구간 수 (0-2)
|
||||
ACCEPTANCE_RATE_SIZE = 3 # 수락률 레벨 수 (0-2)
|
||||
|
||||
# Observation Space Mappings
|
||||
SCENARIO_MAPPING = {
|
||||
0: "높은 구매 의지",
|
||||
1: "중간 구매 의지",
|
||||
2: "낮은 구매 의지",
|
||||
3: "매우 낮은 구매 의지",
|
||||
}
|
||||
|
||||
PRICE_ZONE_MAPPING = {0: "목표가격 이하", 1: "목표가격~임계가격", 2: "임계가격 초과"}
|
||||
|
||||
ACCEPTANCE_RATE_MAPPING = {0: "낮음 (<10%)", 1: "중간 (10-25%)", 2: "높음 (>25%)"}
|
||||
|
||||
# Action Space Constants
|
||||
ACTION_SPACE_SIZE = 9
|
||||
|
||||
# Action Space Mappings
|
||||
ACTION_MAPPING = {
|
||||
0: "강한 수락",
|
||||
1: "중간 수락",
|
||||
2: "약한 수락",
|
||||
3: "강한 거절",
|
||||
4: "중간 거절",
|
||||
5: "약한 거절",
|
||||
6: "강한 가격 제안",
|
||||
7: "중간 가격 제안",
|
||||
8: "약한 가격 제안",
|
||||
}
|
||||
|
||||
# Spaces Definition
|
||||
OBSERVATION_SPACE = spaces.MultiDiscrete(
|
||||
[SCENARIO_SPACE_SIZE, PRICE_ZONE_SIZE, ACCEPTANCE_RATE_SIZE]
|
||||
)
|
||||
|
||||
ACTION_SPACE = spaces.Discrete(ACTION_SPACE_SIZE)
|
||||
|
|
@ -1,6 +1,13 @@
|
|||
import gymnasium as gym
|
||||
from gymnasium import spaces
|
||||
import numpy as np
|
||||
from negotiation_agent.spaces import (
|
||||
NegotiationSpaces,
|
||||
State,
|
||||
PriceZone,
|
||||
AcceptanceRate,
|
||||
Scenario,
|
||||
)
|
||||
|
||||
|
||||
class NegotiationEnv(gym.Env):
|
||||
|
|
@ -8,9 +15,11 @@ class NegotiationEnv(gym.Env):
|
|||
|
||||
def __init__(self, scenario=0, target_price=100, threshold_price=120):
|
||||
super(NegotiationEnv, self).__init__()
|
||||
self.observation_space = spaces.MultiDiscrete([4, 3, 3])
|
||||
self.action_space = spaces.Discrete(9)
|
||||
self.initial_scenario = scenario
|
||||
|
||||
self.spaces = NegotiationSpaces()
|
||||
self.observation_space = self.spaces.observation_space
|
||||
self.action_space = self.spaces.action_space
|
||||
self.initial_scenario = Scenario(scenario)
|
||||
self.target_price = target_price
|
||||
self.threshold_price = threshold_price
|
||||
self.current_price = None
|
||||
|
|
@ -20,23 +29,28 @@ class NegotiationEnv(gym.Env):
|
|||
def _get_state(self):
|
||||
"""현재 정보를 바탕으로 State 배열을 계산"""
|
||||
if self.current_price <= self.target_price:
|
||||
price_zone = 0
|
||||
price_zone = PriceZone.BELOW_TARGET
|
||||
elif self.target_price < self.current_price <= self.threshold_price:
|
||||
price_zone = 1
|
||||
price_zone = PriceZone.BETWEEN_TARGET_AND_THRESHOLD
|
||||
else:
|
||||
price_zone = 2
|
||||
price_zone = PriceZone.ABOVE_THRESHOLD
|
||||
|
||||
acceptance_rate_val = (
|
||||
self.initial_price - self.current_price
|
||||
) / self.initial_price
|
||||
if acceptance_rate_val < 0.1:
|
||||
acceptance_rate_level = 0
|
||||
acceptance_rate_level = AcceptanceRate.LOW
|
||||
elif 0.1 <= acceptance_rate_val < 0.25:
|
||||
acceptance_rate_level = 1
|
||||
acceptance_rate_level = AcceptanceRate.MEDIUM
|
||||
else:
|
||||
acceptance_rate_level = 2
|
||||
acceptance_rate_level = AcceptanceRate.HIGH
|
||||
|
||||
return np.array([self.initial_scenario, price_zone, acceptance_rate_level])
|
||||
state = State(
|
||||
scenario=self.initial_scenario,
|
||||
price_zone=price_zone,
|
||||
acceptance_rate=acceptance_rate_level,
|
||||
)
|
||||
return np.array(state.to_array())
|
||||
|
||||
def reset(self, seed=None, options=None):
|
||||
"""환경을 초기 상태로 리셋"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,113 @@
|
|||
from gymnasium import spaces
|
||||
from typing import Dict, List, Any
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from negotiation_agent.action_space import ActionSpace, ActionInfo
|
||||
|
||||
|
||||
class Scenario(Enum):
|
||||
HIGH_INTENTION = 0
|
||||
MEDIUM_INTENTION = 1
|
||||
LOW_INTENTION = 2
|
||||
VERY_LOW_INTENTION = 3
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return {
|
||||
self.HIGH_INTENTION: "높은 구매 의지",
|
||||
self.MEDIUM_INTENTION: "중간 구매 의지",
|
||||
self.LOW_INTENTION: "낮은 구매 의지",
|
||||
self.VERY_LOW_INTENTION: "매우 낮은 구매 의지",
|
||||
}[self]
|
||||
|
||||
|
||||
class PriceZone(Enum):
|
||||
BELOW_TARGET = 0
|
||||
BETWEEN_TARGET_AND_THRESHOLD = 1
|
||||
ABOVE_THRESHOLD = 2
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return {
|
||||
self.BELOW_TARGET: "목표가격 이하",
|
||||
self.BETWEEN_TARGET_AND_THRESHOLD: "목표가격~임계가격",
|
||||
self.ABOVE_THRESHOLD: "임계가격 초과",
|
||||
}[self]
|
||||
|
||||
|
||||
class AcceptanceRate(Enum):
|
||||
LOW = 0
|
||||
MEDIUM = 1
|
||||
HIGH = 2
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return {
|
||||
self.LOW: "낮음 (<10%)",
|
||||
self.MEDIUM: "중간 (10-25%)",
|
||||
self.HIGH: "높음 (>25%)",
|
||||
}[self]
|
||||
|
||||
|
||||
@dataclass
|
||||
class State:
|
||||
scenario: Scenario
|
||||
price_zone: PriceZone
|
||||
acceptance_rate: AcceptanceRate
|
||||
|
||||
def to_array(self) -> List[int]:
|
||||
return [self.scenario.value, self.price_zone.value, self.acceptance_rate.value]
|
||||
|
||||
@classmethod
|
||||
def from_array(cls, arr: List[int]) -> "State":
|
||||
return cls(
|
||||
scenario=Scenario(arr[0]),
|
||||
price_zone=PriceZone(arr[1]),
|
||||
acceptance_rate=AcceptanceRate(arr[2]),
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return (
|
||||
f"State(scenario={self.scenario.description}, "
|
||||
f"price_zone={self.price_zone.description}, "
|
||||
f"acceptance_rate={self.acceptance_rate.description})"
|
||||
)
|
||||
|
||||
|
||||
class NegotiationSpaces:
|
||||
def __init__(self):
|
||||
self._action_space = ActionSpace()
|
||||
|
||||
@property
|
||||
def observation_space(self) -> spaces.MultiDiscrete:
|
||||
return spaces.MultiDiscrete(
|
||||
[len(Scenario), len(PriceZone), len(AcceptanceRate)]
|
||||
)
|
||||
|
||||
@property
|
||||
def action_space(self) -> spaces.Discrete:
|
||||
return spaces.Discrete(self._action_space.action_space_size)
|
||||
|
||||
def decode_action(self, action_id: int) -> ActionInfo:
|
||||
return self._action_space.get_action(action_id)
|
||||
|
||||
def encode_state(self, state: State) -> List[int]:
|
||||
return state.to_array()
|
||||
|
||||
def decode_state(self, state_array: List[int]) -> State:
|
||||
return State.from_array(state_array)
|
||||
|
||||
def get_action_description(self, action_id: int) -> str:
|
||||
return self.decode_action(action_id).description
|
||||
|
||||
def get_state_description(self, state_array: List[int]) -> str:
|
||||
return str(self.decode_state(state_array))
|
||||
|
||||
def get_actions_by_category(self, category: str) -> List[ActionInfo]:
|
||||
return self._action_space.get_actions_by_category(category)
|
||||
|
||||
def get_actions_by_strength(self, strength: str) -> List[ActionInfo]:
|
||||
return self._action_space.get_actions_by_strength(strength)
|
||||
|
||||
def list_all_actions(self) -> List[ActionInfo]:
|
||||
return self._action_space.list_actions()
|
||||
Loading…
Reference in New Issue