o2o-castad-backend/app/utils/prompts/prompts.py

79 lines
2.4 KiB
Python

import os, json
from pydantic import BaseModel
from config import prompt_settings
from app.utils.logger import get_logger
from app.utils.prompts.schemas import *
from functools import lru_cache
import openpyxl
logger = get_logger("prompt")
class Prompt():
sheet_name: str
prompt_template: str
prompt_model: str
prompt_input_class = BaseModel
prompt_output_class = BaseModel
def __init__(self, sheet_name, prompt_input_class, prompt_output_class):
self.sheet_name = sheet_name
self.prompt_input_class = prompt_input_class
self.prompt_output_class = prompt_output_class
self.prompt_template, self.prompt_model = self._read_from_excel()
def _read_from_excel(self) -> tuple[str, str]:
wb = openpyxl.load_workbook(prompt_settings.PROMPT_EXCEL_FILE, read_only=True)
try:
ws = wb[self.sheet_name]
data = {}
for row in ws.iter_rows(min_row=2, values_only=True):
key, value = row[0], row[1]
if key and value:
data[key] = value
finally:
wb.close()
return data["input"], data["model"]
def _reload_prompt(self):
self.prompt_template, self.prompt_model = self._read_from_excel()
def build_prompt(self, input_data: dict) -> str:
verified_input = self.prompt_input_class(**input_data)
build_template = self.prompt_template.format(**verified_input.model_dump())
logger.debug(f"build_template: {build_template}")
logger.debug(f"input_data: {input_data}")
return build_template
marketing_prompt = Prompt(
sheet_name="marketing",
prompt_input_class=MarketingPromptInput,
prompt_output_class=MarketingPromptOutput,
)
lyric_prompt = Prompt(
sheet_name="lyric",
prompt_input_class=LyricPromptInput,
prompt_output_class=LyricPromptOutput,
)
yt_upload_prompt = Prompt(
sheet_name="yt_upload",
prompt_input_class=YTUploadPromptInput,
prompt_output_class=YTUploadPromptOutput,
)
@lru_cache()
def create_dynamic_subtitle_prompt(length: int) -> Prompt:
return Prompt(
sheet_name="subtitle",
prompt_input_class=SubtitlePromptInput,
prompt_output_class=SubtitlePromptOutput[length],
)
def reload_all_prompt():
marketing_prompt._reload_prompt()
lyric_prompt._reload_prompt()
yt_upload_prompt._reload_prompt()