o2o-infinith-backend/app/integrations/llm/prompt.py

49 lines
1.5 KiB
Python

import os
from pydantic import BaseModel
from common.utils import get_env
from integrations.llm.schemas.report import ReportInput, ReportOutput
from integrations.llm.schemas.plan import PlanInput, PlanOutput
_PROMPT_DIR = os.path.join(os.path.dirname(__file__), "temp-prompt")
class Prompt:
file_name: str
prompt_model: str
input_class: type[BaseModel]
output_class: type[BaseModel]
def __init__(self, file_name: str, prompt_model: str, input_class: type[BaseModel], output_class: type[BaseModel]):
self.file_name = file_name
self.prompt_model = prompt_model
self.input_class = input_class
self.output_class = output_class
self.template, self.model = self._load_prompt()
def _load_prompt(self) -> tuple[str, str]:
with open(os.path.join(_PROMPT_DIR, self.file_name), encoding="utf-8") as f:
template = f.read()
return template, get_env(self.prompt_model)
def _reload_prompt(self):
self.template, self.model = self._load_prompt()
def build(self, input_data: dict) -> str:
verified = self.input_class(**input_data)
return self.template.format(**verified.model_dump())
report_prompt = Prompt(
file_name="report_prompt.txt",
prompt_model="REPORT_MODEL",
input_class=ReportInput,
output_class=ReportOutput,
)
plan_prompt = Prompt(
file_name="plan_prompt.txt",
prompt_model="PLAN_MODEL",
input_class=PlanInput,
output_class=PlanOutput,
)