o2o-castad-backend/app/utils/prompts/prompts.py

80 lines
2.5 KiB
Python

import gspread
from pydantic import BaseModel
from google.oauth2.service_account import Credentials
from config import prompt_settings
from app.utils.logger import get_logger
from app.utils.prompts.schemas import *
from functools import lru_cache
logger = get_logger("prompt")
_SCOPES = [
"https://www.googleapis.com/auth/spreadsheets.readonly",
"https://www.googleapis.com/auth/drive.readonly"
]
class Prompt():
sheet_name: str
prompt_template: str
prompt_model: str
prompt_input_class = BaseModel
prompt_output_class = BaseModel
def __init__(self, sheet_name, prompt_input_class, prompt_output_class):
self.sheet_name = sheet_name
self.prompt_input_class = prompt_input_class
self.prompt_output_class = prompt_output_class
self.prompt_template, self.prompt_model = self._read_from_sheets()
def _read_from_sheets(self) -> tuple[str, str]:
creds = Credentials.from_service_account_file(
prompt_settings.GOOGLE_SERVICE_ACCOUNT_JSON, scopes=_SCOPES
)
gc = gspread.authorize(creds)
ws = gc.open_by_key(prompt_settings.PROMPT_SPREADSHEET).worksheet(self.sheet_name)
model = ws.cell(2, 2).value
input_text = ws.cell(3, 2).value
return input_text, model
def _reload_prompt(self):
self.prompt_template, self.prompt_model = self._read_from_sheets()
def build_prompt(self, input_data: dict) -> str:
verified_input = self.prompt_input_class(**input_data)
build_template = self.prompt_template.format(**verified_input.model_dump())
logger.debug(f"build_template: {build_template}")
logger.debug(f"input_data: {input_data}")
return build_template
marketing_prompt = Prompt(
sheet_name="marketing",
prompt_input_class=MarketingPromptInput,
prompt_output_class=MarketingPromptOutput,
)
lyric_prompt = Prompt(
sheet_name="lyric",
prompt_input_class=LyricPromptInput,
prompt_output_class=LyricPromptOutput,
)
yt_upload_prompt = Prompt(
sheet_name="yt_upload",
prompt_input_class=YTUploadPromptInput,
prompt_output_class=YTUploadPromptOutput,
)
@lru_cache()
def create_dynamic_subtitle_prompt(length: int) -> Prompt:
return Prompt(
sheet_name="subtitle",
prompt_input_class=SubtitlePromptInput,
prompt_output_class=SubtitlePromptOutput[length],
)
def reload_all_prompt():
marketing_prompt._reload_prompt()
lyric_prompt._reload_prompt()
yt_upload_prompt._reload_prompt()