79 lines
2.4 KiB
Python
79 lines
2.4 KiB
Python
import os, json
|
|
from pydantic import BaseModel
|
|
from config import prompt_settings
|
|
from app.utils.logger import get_logger
|
|
from app.utils.prompts.schemas import *
|
|
from functools import lru_cache
|
|
import openpyxl
|
|
|
|
logger = get_logger("prompt")
|
|
|
|
class Prompt():
|
|
sheet_name: str
|
|
prompt_template: str
|
|
prompt_model: str
|
|
|
|
prompt_input_class = BaseModel
|
|
prompt_output_class = BaseModel
|
|
|
|
def __init__(self, sheet_name, prompt_input_class, prompt_output_class):
|
|
self.sheet_name = sheet_name
|
|
self.prompt_input_class = prompt_input_class
|
|
self.prompt_output_class = prompt_output_class
|
|
self.prompt_template, self.prompt_model = self._read_from_excel()
|
|
|
|
def _read_from_excel(self) -> tuple[str, str]:
|
|
wb = openpyxl.load_workbook(prompt_settings.PROMPT_EXCEL_FILE, read_only=True)
|
|
try:
|
|
ws = wb[self.sheet_name]
|
|
data = {}
|
|
for row in ws.iter_rows(min_row=2, values_only=True):
|
|
key, value = row[0], row[1]
|
|
if key and value:
|
|
data[key] = value
|
|
finally:
|
|
wb.close()
|
|
|
|
return data["input"], data["model"]
|
|
|
|
def _reload_prompt(self):
|
|
self.prompt_template, self.prompt_model = self._read_from_excel()
|
|
|
|
def build_prompt(self, input_data: dict) -> str:
|
|
verified_input = self.prompt_input_class(**input_data)
|
|
build_template = self.prompt_template.format(**verified_input.model_dump())
|
|
logger.debug(f"build_template: {build_template}")
|
|
logger.debug(f"input_data: {input_data}")
|
|
return build_template
|
|
|
|
marketing_prompt = Prompt(
|
|
sheet_name="marketing",
|
|
prompt_input_class=MarketingPromptInput,
|
|
prompt_output_class=MarketingPromptOutput,
|
|
)
|
|
|
|
lyric_prompt = Prompt(
|
|
sheet_name="lyric",
|
|
prompt_input_class=LyricPromptInput,
|
|
prompt_output_class=LyricPromptOutput,
|
|
)
|
|
|
|
yt_upload_prompt = Prompt(
|
|
sheet_name="yt_upload",
|
|
prompt_input_class=YTUploadPromptInput,
|
|
prompt_output_class=YTUploadPromptOutput,
|
|
)
|
|
|
|
@lru_cache()
|
|
def create_dynamic_subtitle_prompt(length: int) -> Prompt:
|
|
return Prompt(
|
|
sheet_name="subtitle",
|
|
prompt_input_class=SubtitlePromptInput,
|
|
prompt_output_class=SubtitlePromptOutput[length],
|
|
)
|
|
|
|
|
|
def reload_all_prompt():
|
|
marketing_prompt._reload_prompt()
|
|
lyric_prompt._reload_prompt()
|
|
yt_upload_prompt._reload_prompt() |