o2o-infinith-backend/app/integrations/llm/prompt.py

109 lines
3.2 KiB
Python

import os
from pydantic import BaseModel
from common.utils import get_env
from integrations.llm.schemas.report import (
ReportInput, ReportOutput,
CriticalIssuesInput, CriticalIssuesOutput,
YouTubeDiagnosisInput, YouTubeDiagnosisOutput,
BrandConsistencyInput, BrandConsistencyOutput,
)
from integrations.llm.schemas.plan import PlanInput, PlanOutput
from integrations.llm.schemas.market import (
MarketCompetitorsInput, MarketCompetitorsOutput,
MarketKeywordsInput, MarketKeywordsOutput,
MarketTrendInput, MarketTrendOutput,
MarketTargetAudienceInput, MarketTargetAudienceOutput,
)
_PROMPT_DIR = os.path.join(os.path.dirname(__file__), "temp-prompt")
class Prompt:
file_name: str
prompt_model: str
input_class: type[BaseModel]
output_class: type[BaseModel]
def __init__(self, file_name: str, prompt_model: str, input_class: type[BaseModel], output_class: type[BaseModel]):
self.file_name = file_name
self.prompt_model = prompt_model
self.input_class = input_class
self.output_class = output_class
self.template, self.model = self._load_prompt()
def _load_prompt(self) -> tuple[str, str]:
with open(os.path.join(_PROMPT_DIR, self.file_name), encoding="utf-8") as f:
template = f.read()
return template, get_env(self.prompt_model)
def _reload_prompt(self):
self.template, self.model = self._load_prompt()
def build(self, input_data: dict) -> str:
verified = self.input_class(**input_data)
return self.template.format(**verified.model_dump())
report_prompt = Prompt(
file_name="report_prompt.txt",
prompt_model="REPORT_MODEL",
input_class=ReportInput,
output_class=ReportOutput,
)
plan_prompt = Prompt(
file_name="plan_prompt.txt",
prompt_model="PLAN_MODEL",
input_class=PlanInput,
output_class=PlanOutput,
)
market_competitors_prompt = Prompt(
file_name="market_competitors_prompt.txt",
prompt_model="MARKET_MODEL",
input_class=MarketCompetitorsInput,
output_class=MarketCompetitorsOutput,
)
market_keywords_prompt = Prompt(
file_name="market_keywords_prompt.txt",
prompt_model="MARKET_MODEL",
input_class=MarketKeywordsInput,
output_class=MarketKeywordsOutput,
)
market_trend_prompt = Prompt(
file_name="market_trend_prompt.txt",
prompt_model="MARKET_MODEL",
input_class=MarketTrendInput,
output_class=MarketTrendOutput,
)
market_target_audience_prompt = Prompt(
file_name="market_target_audience_prompt.txt",
prompt_model="MARKET_MODEL",
input_class=MarketTargetAudienceInput,
output_class=MarketTargetAudienceOutput,
)
youtube_diagnosis_prompt = Prompt(
file_name="youtube_diagnosis_prompt.txt",
prompt_model="REPORT_MODEL",
input_class=YouTubeDiagnosisInput,
output_class=YouTubeDiagnosisOutput,
)
brand_consistency_prompt = Prompt(
file_name="brand_consistency_prompt.txt",
prompt_model="REPORT_MODEL",
input_class=BrandConsistencyInput,
output_class=BrandConsistencyOutput,
)
critical_issues_prompt = Prompt(
file_name="critical_issues_prompt.txt",
prompt_model="REPORT_MODEL",
input_class=CriticalIssuesInput,
output_class=CriticalIssuesOutput,
)