KT_Q_Table/negotiation_agent/action_space.py

108 lines
3.6 KiB
Python

from dataclasses import dataclass
from typing import Dict, List, Optional
import json
import os
from pathlib import Path
@dataclass
class ActionInfo:
id: int
name: str
description: str
category: str
strength: str
def __str__(self) -> str:
return f"{self.description} (ID: {self.id}, Category: {self.category}, Strength: {self.strength})"
class ActionSpace:
def __init__(self, config_path: Optional[str] = None):
if config_path is None:
config_path = os.path.join(
Path(__file__).parent.parent, "configs", "actions.json"
)
self._actions: Dict[int, ActionInfo] = {}
self._load_actions(config_path)
def _load_actions(self, config_path: str) -> None:
"""JSON 파일에서 액션 정보를 로드합니다."""
with open(config_path, "r", encoding="utf-8") as f:
data = json.load(f)
for action in data["actions"]:
self.add_action(
id=action["id"],
name=action["name"],
description=action["description"],
category=action["category"],
strength=action["strength"],
)
def add_action(
self, id: int, name: str, description: str, category: str, strength: str
) -> None:
"""새로운 액션을 추가합니다."""
if id in self._actions:
raise ValueError(f"Action with id {id} already exists")
self._actions[id] = ActionInfo(
id=id,
name=name,
description=description,
category=category,
strength=strength,
)
def remove_action(self, action_id: int) -> None:
"""지정된 ID의 액션을 제거합니다."""
if action_id not in self._actions:
raise ValueError(f"Action with id {action_id} does not exist")
del self._actions[action_id]
def get_action(self, action_id: int) -> ActionInfo:
"""액션 ID로 액션 정보를 조회합니다."""
if action_id not in self._actions:
raise ValueError(f"Invalid action id: {action_id}")
return self._actions[action_id]
def get_actions_by_category(self, category: str) -> List[ActionInfo]:
"""특정 카테고리의 모든 액션을 반환합니다."""
return [
action for action in self._actions.values() if action.category == category
]
def get_actions_by_strength(self, strength: str) -> List[ActionInfo]:
"""특정 강도의 모든 액션을 반환합니다."""
return [
action for action in self._actions.values() if action.strength == strength
]
def save_actions(self, file_path: str) -> None:
"""현재 액션 설정을 JSON 파일로 저장합니다."""
data = {
"actions": [
{
"id": action.id,
"name": action.name,
"description": action.description,
"category": action.category,
"strength": action.strength,
}
for action in self._actions.values()
]
}
with open(file_path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=4)
@property
def action_space_size(self) -> int:
"""현재 액션 공간의 크기를 반환합니다."""
return len(self._actions)
def list_actions(self) -> List[ActionInfo]:
"""모든 액션 정보를 리스트로 반환합니다."""
return list(self._actions.values())