From 1bf179bbaa57f071947e2ae8699c586493c669e0 Mon Sep 17 00:00:00 2001 From: mgjeon Date: Mon, 22 Sep 2025 16:35:43 +0900 Subject: [PATCH] =?UTF-8?q?feat:=20=ED=98=91=EC=83=81=20=EC=97=90=EC=9D=B4?= =?UTF-8?q?=EC=A0=84=ED=8A=B8=20=EA=B5=AC=ED=98=84=20=EA=B0=9C=EC=84=A0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - action_space.py: 행동 공간 관리 로직 추가 - constants.py: 상수값 분리 및 관리 - spaces.py: 상태 및 행동 공간 정의 추가 - environment.py: 협상 환경 구현 개선 --- .../__pycache__/__init__.cpython-39.pyc | Bin 0 -> 215 bytes .../__pycache__/action_space.cpython-39.pyc | Bin 0 -> 4591 bytes .../__pycache__/environment.cpython-39.pyc | Bin 0 -> 2946 bytes .../__pycache__/spaces.cpython-39.pyc | Bin 0 -> 5208 bytes negotiation_agent/action_space.py | 107 +++++++++++++++++ negotiation_agent/constants.py | 41 +++++++ negotiation_agent/environment.py | 34 ++++-- negotiation_agent/spaces.py | 113 ++++++++++++++++++ 8 files changed, 285 insertions(+), 10 deletions(-) create mode 100644 negotiation_agent/__pycache__/__init__.cpython-39.pyc create mode 100644 negotiation_agent/__pycache__/action_space.cpython-39.pyc create mode 100644 negotiation_agent/__pycache__/environment.cpython-39.pyc create mode 100644 negotiation_agent/__pycache__/spaces.cpython-39.pyc create mode 100644 negotiation_agent/action_space.py create mode 100644 negotiation_agent/constants.py create mode 100644 negotiation_agent/spaces.py diff --git a/negotiation_agent/__pycache__/__init__.cpython-39.pyc b/negotiation_agent/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3fb2f54612354deb1f958da89e80198679cec301 GIT binary patch literal 215 zcmYe~<>g`kf|}w(86f&Gh(HF6K#l_t7qb9~6oz01O-8?!3`HPe1o5jzKeRZts8~NM zH9s#mGcUceRNp5vsVK3iQr|fzzceMdB)=#zJyqX5KR-Pu)ukx2ELArvJ2yE$H__Z7 zJvT8kM-Rx>_pMX_%jx?Y`RjX!IO_+-ha@KDr0VCTrstPrCYEI8=fwlf$t%&1kI&4@ bEQycTE2zB1VUwGmQks)$2XfVCAZ7pn2yZ)x literal 0 HcmV?d00001 diff --git a/negotiation_agent/__pycache__/action_space.cpython-39.pyc b/negotiation_agent/__pycache__/action_space.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aea8c29abe29aaeeac771e41d6ada6c9434464b8 GIT binary patch literal 4591 zcmcgw-EZ6073U=>ik9VAaoi+Mn@(touG&&7%|6lvY343o+hm!t&CP}hD>z1%cIZ-+ zb4jJMAOl&gVnLdq9lBs?XBRe*Zu5`_56OlW=;Qu?J@a)>wdI#R4GFO9cP=GMvK=5# z6Tx$NFE8);_?>g^xv~caQVM>{|NiI2GrJV!U)1UT>FA8%4t@#370x`R%!-WBv+Aj3 zt*DjtqFzoE6LL)RjG}>enwKml(dwRA?kn~&#bK{0Ji(2I!i{-#o)uH*C%K8fIjt-OaGA_UM$fk-9HVT zG2Fp8h@%u4SBmO+#nBv{vlEJwK&!S|t<~zSHbE^n7MK)Ghbny#D;1?)DG}S#{$wRm zT^{MaU3Q{=?gS;_N{J$~WQWdFMa)KK5DLei3NNm($Pi9g6@H09miCXHP28t&2hW1k zm1*SezX-^}3KtMy6$v!7VFFSys8i zs~$b4E$dR%_F8W=*AJsX%d&mH5=y=S=4fCchOxHPx{K_gu_OqYDIg1_RFkzXRleEg zSu}l&rvC$^?OFV3@~jcS2Cvb$-asgjr}4ys$_<`uz>A!uqsJG<%>{TBPhs42lIZti zUYZY}--mfA+26$n(eKAMX(xrXGnliR4`FOzb8HXBhWQg1+a>w!MSmaPkN%*O!OA0$ z=Kz1Qp(EIKx7MMtr}#n48R861^_8eUw)h00LNV1}s`!)clr=-4B~gK{^rPnlZt1n( z#6ul{RaHbbgx4;qp%#MCbYeOImDFWzj=At_%cv_$$vRt>83PQpLNOV02x{rrBZATC zpyKDuNUa2sPJAPMy5jo6026|K(v@%lnq(!C25@dzV_!a1cWD97`wHw;| z2j7pLtOUMvo{=UOxw|YdkTcj4N**p3^DT>pC|^6wsuk zrlft6k`wgbK#sjM#&y^+UHt|apP4vU$ZlM}-dy~m`TGx=jfHIU=5q7hTiNE)&GozY z)<69syS}`z{^7#L&6V}{uCBj#do*I=exe#qzVH$)E?_)TMj>G&mumm$nIc*`^tbU)rv$8AR>Xl2aMooiUexBHrT z?tr2}(Nip=s%-OmP}|#W$S%0ls4vFGa(QgGt%*R#deUys;h|+Ww%u;AonFRP2A9>V zq+Ih)SI9X(Sje6pZ!W@9rNc^VEiJ75b+uzHaTL>GAzpEUEOL5mqQZ(QjU&ka@hn50 z($N;>qXFTRE0>*Cb~uh1fka=08Cb6m^~~%Eb<$hIEp;VKHg(xsk0FU zda6~jdn*u6kHagrM`d$)b>rjf9bIdCPWzW_4=xcK9?o$*E53zw1bOCG-BMFd*ivi^ zuQwIj#dd z_cq?WzIJzM{f~DD#`VAadHutsE;QFyZ*6>ZtD|901w>5Ag(x=%2SBEgW;iB0dn`AtJN9Oz^1_9^wZT$;`>uX;I>U#D1*X z!#maptP8VNyIvbT;Z;hu7o5aHdySTpMFv2rIsvRF{Y;urGL`Hl=Y5Y@_Ei$2?Q!PA z+M&k~1$?Ly**WG2p~pdq!ls4N+Um{q_Z$Bge5BkIWnzrTCNkm#x&_>i14+AZ-U$-% zG8SnAA-PE&V;B`X0?}UVEHQsfAll#LkOF1dLFLGHILLQ5D}Z016Yu^FS;@CU#$~7k zuyJsm&ON9Mo|)4^ZCa;|gpL=y?}pA4R0ye_3Zc+lA@CFeufKW>EH-X6nxEX+OeKxm z2t!JnT{VDAHA`2TpM2VhJCVb;+k%gTpTHeF2O@*-={WchfE0uf8|4U_*158fhyyXl z{P0BHbsovMIZOiLH?QpmA2li%K4t7F1Vs8tQOMDgN)i24|z$5MW zK~*512PM~y4AsD zmxr2Zn*Y%s>I2v;_CHE`F@X6!#-+0z4Dx z(l4pY@fN_Uv?cZA`tDK_Tfw#48`nM~Kir!C=^Rh#wF|DDR7umEc17=QYN7UG&^;-~ z3b{Rk>Uy!4$UY)eq6^ws#Ir;W5uwdZj1r-6k_DBZ9ZisfiSH75g$V5*kvSt!(TZ?Z zoS`wwAzfDQIQ^4&L4~T^#Y{GAW(HGC{cn0OV`PkBT$v%2WpqLFa%ucy(EeY;2?Y5` zWQ4ObD8L1oJ(fhuPF^gi433hNf;=3}DTT#S(_%YG6PYLC!|_RSLm7u@HY^XdX$@lZ F{|~rapojnf literal 0 HcmV?d00001 diff --git a/negotiation_agent/__pycache__/environment.cpython-39.pyc b/negotiation_agent/__pycache__/environment.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad9240ffc56443ac8d5fe1fb0aba28be12b5452d GIT binary patch literal 2946 zcmZWr|8FBl6`$E}UOTpv^jv!109Qi28nv+mh@uLS-WQV;?i_kad(u*Aw70YI*|%TJ z>^fY+T8PG|he`#AyOOJNl_QmkN)-ydLjtM%5B+(60`Wt+*g*(>LwtL0Y$tbNR`d4F zo0+$>@5|?nT`1%Tw155m<7?l}6Y^&qJlr%GyZ}9_1L1^Ik60Amibs6KQs_vgS{ley zPxlSW@H1A%H!ai8T3J75J^$AQQLOhsG_H{ zb+-}5O*?Le!Lk&l`f_Z?Zkqj;XgaPH25wqB>o{&Zwgbn#j8o>a;|8{9h7KxxxZ&}! zMTtvO3og2L@96rO%9$yuUgXp8HeV+iv9fUR z$I{|M$E*cwof@b=^U|{`_Epa<4}0(K-RqSH_wMiA>koU|XUcoO+a2s~4exH1_qTpF z-26d#|IXIlhkq!a??iFvmq9~Qgn{n{v7n&o0Y`EJh%@;C;Nk0psN){J0!&O|nvhMU zuj~*?Vs(=wglP$BX)PT%=)f7k8F40|o3NIdi2^j6s6cZZ*H1~Bm$bkM{2r}Wim4uT z+O7~Ld?qy+YX;4jF@c{K3J_~Hyc)S;O*)E6_oaE;LAQ-YB7-|;a9`8*STy!fnvZSK zaAVd+Z>H1nHQ`3rLXVFH1*anf*kLl9>>n-HBMgO#mTHmft)}KgtC+@f%>a3yrdh@u z&yFI--Y5V3e)VNIW>jstVBcoYSnpIXHm?d>tXI!_VTUirp|Bfnbs-EJo_j$w*W71Y z*L^4S?dM)-_;%Bq1HO7`y)5UdOJ7*3eskq~^(7{KT@4Pym4TihuDXZLoNKQONGgH{ zi%LKUEh;9>$;Y6hXXA-SV`*+;YzISO(sW?3pNAg32LybJonEJ+6e}&YOM_={jB*Xa z^z-uu7e4tShn6OHRRR+9`F0RYq+T*AtbBc781QG*wU5(9FwyOqDahfPn6BTP* zftlCFddyZ1<_%ak)m7s>>@OrVQMvl2+STG>e>&0nGdpAlb6o3u2@boz)f>L^L3!A} z4M{Qh!QNy|VaNVAr4Z6<6Z z>_p;II4y04<`Q^<=b%Tj!YM_HmdeqM(h*z8PdNx`llDlmy$aG8Y|w}F;K~lbXG<+% zx*NyQ22*QlFl;6_MDCf$-{*EU0ZN*Xy;M*rp5@_ zt#RgY5w_D@>etD_${*pKVmTG!wP@RVn^)ovy+67>X0z`iHQ6NH-pYi44C#bu#{*v)P# zr$~wLY+?fNDZ=8}K;gL(tnOwLtnE`0unubXxO7hpzkmBz!<$>>r}uVygLi(rf9KZV zrx42@_J;jlrGfuQ$lvD2%wHzVALqp>n5bx@;w~#ZhT8+oGI0{8R5>LoDERyVkwaV# z^H||h)wgd**Hv<(1duR+IAG-`wU9OW4XbzvkvL8%voaA2m~Xc&v%N9mQhG(4!H+%a zHt{)-eFJ)gb&DuDim6kgCu#1#hPqKc;<^6|pUKZ;VY5ItR+M_ZlEYg>VEZm(0A$P$ zdB;OO&)9c5wl|*1Fvdd%E*cmxqd3Gg7EdF=bESoGb$49PgLQ{Tm)S29AMY#$=rZa| zswu^Sd9rAlQ}Q-bDOlOYx*yn4v*TMTu<}Ba>OsecdX4E~9S@F{^|{dvb#wuen2*6J ktAf1gSSnoNr2dDFbbv*v`PI=6e+|7E;Q>tvN2jy@1M+zu#sB~S literal 0 HcmV?d00001 diff --git a/negotiation_agent/__pycache__/spaces.cpython-39.pyc b/negotiation_agent/__pycache__/spaces.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f4a2093205c8795831a470cec7e4e33c6047f31 GIT binary patch literal 5208 zcmb_g+ix6K8K2wU@4h8Y({$TX>V?LeMnJhFAx^xtP3<*}okk22qsg9=&6v5EGqaI) zWuXXLhzEEm5>f>TN<^f{2%(6UH~t93_kw1n?Fv$H$iounXTR`bm{-}%m& z-#OoPM=unz0$hLi*WYhW76jp+*oj^WbS}c}-UC7eBAS9Fs-lQ($&@Tvl`Tb8tc052 zJlRZIDK*90ikY@DYR1Z{S>8{WIV-Q`t%6z*g(ZO`Nor3ZDVo`p)gtuMBm?~n@0XyT zB{}Hlc)v_jz{!&Wa02B*S>eggmSu|Rn zudEub=gSLr7X~v#_w<&jyKYVLm4@B1d`0hg4)k&hEzfZ54V>*~gU*V5!)f8V(Q9xn z^T4HoDiT4Jh^Wd$QWYYr38JV;nxbivxFhnNDu6jd(ttAKr#D*E)|ufj5vG64iG(c~ z$UeYl)ax+2g#+P*?KIq%;dZM)c7!d^2T+knKqVprm5E{~JHo60?@2T}Qy};F=f6I@ zd%Jq@*FQae_@H`t_k+WSxB0-~gZB>}?Zpf{{@I_$XS{$*J$!KZ+eg*7@tWcnmsggS zwUy>(qq(`V-t^0Bjm4F#YY|Rib^ThTJMmuQ$`7<7JlMSAQuD^Xz>dKSKke18x|F%~ zE$Y~oVYj=T`l@lA>8xA7Y&spX;W+fot`gN1mZQ6D`Lr0}+u#WB31IwA;?)o`G-4-*paj*zrTqJ^= z6#4HnDIxv4Hx+7fJ|f(#W&O0K*}6qF&ChC@<&chv?YyS_q@$a`2&=$yYy!z75>%F- z-e%4=WnR~pU5~L@9Ke%^kmK0R0YP(c^VkrT;3r8UA9=dHeg(wi?3IGG5(R8EL6UoD zt7(z~8%q0`?=z!CRmY~>SQ);LL8uT|3CNg2U7uP*wfg-h%2HHHi}4KE!%V+8Ud-29CDivL_v zq|bWOgF`v0_YA{rGxQa95=j-w%Sc{9f~&DpNS>|$W{l%(JQ+5+;xiv!o$Qbt8?exJo8$ugVxfIICF2FIeqr)z~fRFoF~Nq6Y!ManS`edPskXv!q5-Z49)g=kSXs-JV&B;=4hTvVUASP0?C7~7W~A92fFmr zZa=;F*==-Qtq0!gm-K@wuv!<{)821ym zbhATQ4cZJNOPs3wwC89#V|w@7AmL)36U1JTL!A$YAot23!^$w&;vy&p7q$m61&@74 z@Pr+ah_k}3=!wvhcuRtoj4h)43NS+Q8wpo6nL{381qKFD?pO?!hZtI>%P^&ngaSW% zgE>~Pi}NtdbuWNhX7TrVuN>@VC^QxqQ!tC4D;gY6+igqUZmCZ%bRWi+jUd&0;rNO>{&M3IvJ zNj~Wnf?f4xnziC+bi?!WOPZ$JwgZ+2dD7Li;ej%&l6={)J$4#fIbL*G)Vt{rb_O^> zC_Ia;xE%L7415i4o_MoTT7;i8Axq*+is4ub?jzTrD*FepLvJHsB71@g9!C>fNg_jz zQ%Hj6xD@B5IS;E*UZa7>{T99ken;FU-jcx9MZ z;qxkF0{RnunF^PgBvZhfq!r-iV6ADGH$#pAZ)(6h4!l`%0(jGXDa85$ISHH@@*eCZ;!DBQiPd5aH@v0l#925dXDexUBr$U1W6I6p7dTF$3He_>_}yp&W{u;K->d&+6ZbT zK%EW&v$sG{twfJlbB!L21)YzIU4x0iSxdcB<5;7j5U(=*zCq$6=Kwg`0!DLyE)EY{ z3UFrI)a#qWzQi!_vkORuK8Xd1Um&#Q5ZaggA+SQDcxfEy5m`W*a4|K-qa0W)7PLF! zJ>mYqT16f?z-T$Dg&0(MYI9V4eax~UsF>4>iUOt_pA$s&cw~d~!(xmp{>bTJ5aNL> zJT^3v(PPW^kAlBUe>T9V=vW#YfC-d&502`-xHa$`iCx6E^TP7GF*s(1TGk?eR!9WV z+_)a&Cyiml$cvljp&Al&-_Uj==y=(SiH0pA88NPDd{Ll}2)H&biRJ+n!|U;FO>t~wSe2d<0`69k;0Gxx(y-h_`a{V3wbM#OR;VVa@Tb_Crk z<7YZ5-mD!PEmgh=D+JX&LS)}U@*O1SkYG*6u+(6PJ8&fK+-TSA5)y8?+=k!A?h+DI zH^b9ot4M|>&0G9``Zr);4Q?0z7()F5)yqTzZn-Q^!7rdwGDEp$ZPL%QyOyoHM#lsc&d=*9mVg#~n+U&QQI}zQ^A!povRDemFpR3F>|v!4eA<|gCJ$Z(1iQs5 sE>g7deIuN>{rKr~!N-gX_&v(SPZoJ`O3aI8u`HL% str: + return f"{self.description} (ID: {self.id}, Category: {self.category}, Strength: {self.strength})" + + +class ActionSpace: + def __init__(self, config_path: Optional[str] = None): + if config_path is None: + config_path = os.path.join( + Path(__file__).parent.parent, "configs", "actions.json" + ) + self._actions: Dict[int, ActionInfo] = {} + self._load_actions(config_path) + + def _load_actions(self, config_path: str) -> None: + """JSON 파일에서 액션 정보를 로드합니다.""" + with open(config_path, "r", encoding="utf-8") as f: + data = json.load(f) + + for action in data["actions"]: + self.add_action( + id=action["id"], + name=action["name"], + description=action["description"], + category=action["category"], + strength=action["strength"], + ) + + def add_action( + self, id: int, name: str, description: str, category: str, strength: str + ) -> None: + """새로운 액션을 추가합니다.""" + if id in self._actions: + raise ValueError(f"Action with id {id} already exists") + + self._actions[id] = ActionInfo( + id=id, + name=name, + description=description, + category=category, + strength=strength, + ) + + def remove_action(self, action_id: int) -> None: + """지정된 ID의 액션을 제거합니다.""" + if action_id not in self._actions: + raise ValueError(f"Action with id {action_id} does not exist") + del self._actions[action_id] + + def get_action(self, action_id: int) -> ActionInfo: + """액션 ID로 액션 정보를 조회합니다.""" + if action_id not in self._actions: + raise ValueError(f"Invalid action id: {action_id}") + return self._actions[action_id] + + def get_actions_by_category(self, category: str) -> List[ActionInfo]: + """특정 카테고리의 모든 액션을 반환합니다.""" + return [ + action for action in self._actions.values() if action.category == category + ] + + def get_actions_by_strength(self, strength: str) -> List[ActionInfo]: + """특정 강도의 모든 액션을 반환합니다.""" + return [ + action for action in self._actions.values() if action.strength == strength + ] + + def save_actions(self, file_path: str) -> None: + """현재 액션 설정을 JSON 파일로 저장합니다.""" + data = { + "actions": [ + { + "id": action.id, + "name": action.name, + "description": action.description, + "category": action.category, + "strength": action.strength, + } + for action in self._actions.values() + ] + } + + with open(file_path, "w", encoding="utf-8") as f: + json.dump(data, f, ensure_ascii=False, indent=4) + + @property + def action_space_size(self) -> int: + """현재 액션 공간의 크기를 반환합니다.""" + return len(self._actions) + + def list_actions(self) -> List[ActionInfo]: + """모든 액션 정보를 리스트로 반환합니다.""" + return list(self._actions.values()) diff --git a/negotiation_agent/constants.py b/negotiation_agent/constants.py new file mode 100644 index 0000000..208e7f4 --- /dev/null +++ b/negotiation_agent/constants.py @@ -0,0 +1,41 @@ +from gymnasium import spaces + +# Observation Space Constants +SCENARIO_SPACE_SIZE = 4 # 시나리오 상태 수 (0-3) +PRICE_ZONE_SIZE = 3 # 가격 구간 수 (0-2) +ACCEPTANCE_RATE_SIZE = 3 # 수락률 레벨 수 (0-2) + +# Observation Space Mappings +SCENARIO_MAPPING = { + 0: "높은 구매 의지", + 1: "중간 구매 의지", + 2: "낮은 구매 의지", + 3: "매우 낮은 구매 의지", +} + +PRICE_ZONE_MAPPING = {0: "목표가격 이하", 1: "목표가격~임계가격", 2: "임계가격 초과"} + +ACCEPTANCE_RATE_MAPPING = {0: "낮음 (<10%)", 1: "중간 (10-25%)", 2: "높음 (>25%)"} + +# Action Space Constants +ACTION_SPACE_SIZE = 9 + +# Action Space Mappings +ACTION_MAPPING = { + 0: "강한 수락", + 1: "중간 수락", + 2: "약한 수락", + 3: "강한 거절", + 4: "중간 거절", + 5: "약한 거절", + 6: "강한 가격 제안", + 7: "중간 가격 제안", + 8: "약한 가격 제안", +} + +# Spaces Definition +OBSERVATION_SPACE = spaces.MultiDiscrete( + [SCENARIO_SPACE_SIZE, PRICE_ZONE_SIZE, ACCEPTANCE_RATE_SIZE] +) + +ACTION_SPACE = spaces.Discrete(ACTION_SPACE_SIZE) diff --git a/negotiation_agent/environment.py b/negotiation_agent/environment.py index cbc7bc3..4ab3f23 100644 --- a/negotiation_agent/environment.py +++ b/negotiation_agent/environment.py @@ -1,6 +1,13 @@ import gymnasium as gym from gymnasium import spaces import numpy as np +from negotiation_agent.spaces import ( + NegotiationSpaces, + State, + PriceZone, + AcceptanceRate, + Scenario, +) class NegotiationEnv(gym.Env): @@ -8,9 +15,11 @@ class NegotiationEnv(gym.Env): def __init__(self, scenario=0, target_price=100, threshold_price=120): super(NegotiationEnv, self).__init__() - self.observation_space = spaces.MultiDiscrete([4, 3, 3]) - self.action_space = spaces.Discrete(9) - self.initial_scenario = scenario + + self.spaces = NegotiationSpaces() + self.observation_space = self.spaces.observation_space + self.action_space = self.spaces.action_space + self.initial_scenario = Scenario(scenario) self.target_price = target_price self.threshold_price = threshold_price self.current_price = None @@ -20,23 +29,28 @@ class NegotiationEnv(gym.Env): def _get_state(self): """현재 정보를 바탕으로 State 배열을 계산""" if self.current_price <= self.target_price: - price_zone = 0 + price_zone = PriceZone.BELOW_TARGET elif self.target_price < self.current_price <= self.threshold_price: - price_zone = 1 + price_zone = PriceZone.BETWEEN_TARGET_AND_THRESHOLD else: - price_zone = 2 + price_zone = PriceZone.ABOVE_THRESHOLD acceptance_rate_val = ( self.initial_price - self.current_price ) / self.initial_price if acceptance_rate_val < 0.1: - acceptance_rate_level = 0 + acceptance_rate_level = AcceptanceRate.LOW elif 0.1 <= acceptance_rate_val < 0.25: - acceptance_rate_level = 1 + acceptance_rate_level = AcceptanceRate.MEDIUM else: - acceptance_rate_level = 2 + acceptance_rate_level = AcceptanceRate.HIGH - return np.array([self.initial_scenario, price_zone, acceptance_rate_level]) + state = State( + scenario=self.initial_scenario, + price_zone=price_zone, + acceptance_rate=acceptance_rate_level, + ) + return np.array(state.to_array()) def reset(self, seed=None, options=None): """환경을 초기 상태로 리셋""" diff --git a/negotiation_agent/spaces.py b/negotiation_agent/spaces.py new file mode 100644 index 0000000..15f5a1a --- /dev/null +++ b/negotiation_agent/spaces.py @@ -0,0 +1,113 @@ +from gymnasium import spaces +from typing import Dict, List, Any +from dataclasses import dataclass +from enum import Enum, auto +from negotiation_agent.action_space import ActionSpace, ActionInfo + + +class Scenario(Enum): + HIGH_INTENTION = 0 + MEDIUM_INTENTION = 1 + LOW_INTENTION = 2 + VERY_LOW_INTENTION = 3 + + @property + def description(self) -> str: + return { + self.HIGH_INTENTION: "높은 구매 의지", + self.MEDIUM_INTENTION: "중간 구매 의지", + self.LOW_INTENTION: "낮은 구매 의지", + self.VERY_LOW_INTENTION: "매우 낮은 구매 의지", + }[self] + + +class PriceZone(Enum): + BELOW_TARGET = 0 + BETWEEN_TARGET_AND_THRESHOLD = 1 + ABOVE_THRESHOLD = 2 + + @property + def description(self) -> str: + return { + self.BELOW_TARGET: "목표가격 이하", + self.BETWEEN_TARGET_AND_THRESHOLD: "목표가격~임계가격", + self.ABOVE_THRESHOLD: "임계가격 초과", + }[self] + + +class AcceptanceRate(Enum): + LOW = 0 + MEDIUM = 1 + HIGH = 2 + + @property + def description(self) -> str: + return { + self.LOW: "낮음 (<10%)", + self.MEDIUM: "중간 (10-25%)", + self.HIGH: "높음 (>25%)", + }[self] + + +@dataclass +class State: + scenario: Scenario + price_zone: PriceZone + acceptance_rate: AcceptanceRate + + def to_array(self) -> List[int]: + return [self.scenario.value, self.price_zone.value, self.acceptance_rate.value] + + @classmethod + def from_array(cls, arr: List[int]) -> "State": + return cls( + scenario=Scenario(arr[0]), + price_zone=PriceZone(arr[1]), + acceptance_rate=AcceptanceRate(arr[2]), + ) + + def __str__(self) -> str: + return ( + f"State(scenario={self.scenario.description}, " + f"price_zone={self.price_zone.description}, " + f"acceptance_rate={self.acceptance_rate.description})" + ) + + +class NegotiationSpaces: + def __init__(self): + self._action_space = ActionSpace() + + @property + def observation_space(self) -> spaces.MultiDiscrete: + return spaces.MultiDiscrete( + [len(Scenario), len(PriceZone), len(AcceptanceRate)] + ) + + @property + def action_space(self) -> spaces.Discrete: + return spaces.Discrete(self._action_space.action_space_size) + + def decode_action(self, action_id: int) -> ActionInfo: + return self._action_space.get_action(action_id) + + def encode_state(self, state: State) -> List[int]: + return state.to_array() + + def decode_state(self, state_array: List[int]) -> State: + return State.from_array(state_array) + + def get_action_description(self, action_id: int) -> str: + return self.decode_action(action_id).description + + def get_state_description(self, state_array: List[int]) -> str: + return str(self.decode_state(state_array)) + + def get_actions_by_category(self, category: str) -> List[ActionInfo]: + return self._action_space.get_actions_by_category(category) + + def get_actions_by_strength(self, strength: str) -> List[ActionInfo]: + return self._action_space.get_actions_by_strength(strength) + + def list_all_actions(self) -> List[ActionInfo]: + return self._action_space.list_actions()