89 lines
2.8 KiB
Python
89 lines
2.8 KiB
Python
import gspread
|
|
from pydantic import BaseModel
|
|
from google.oauth2.service_account import Credentials
|
|
from config import prompt_settings
|
|
from app.utils.logger import get_logger
|
|
from app.utils.prompts.schemas import *
|
|
from functools import lru_cache
|
|
|
|
logger = get_logger("prompt")
|
|
|
|
_SCOPES = [
|
|
"https://www.googleapis.com/auth/spreadsheets.readonly",
|
|
"https://www.googleapis.com/auth/drive.readonly"
|
|
]
|
|
|
|
class Prompt():
|
|
sheet_name: str
|
|
prompt_template: str
|
|
prompt_model: str
|
|
|
|
prompt_input_class = BaseModel
|
|
prompt_output_class = BaseModel
|
|
|
|
def __init__(self, sheet_name, prompt_input_class, prompt_output_class):
|
|
self.sheet_name = sheet_name
|
|
self.prompt_input_class = prompt_input_class
|
|
self.prompt_output_class = prompt_output_class
|
|
self.prompt_template, self.prompt_model = self._read_from_sheets()
|
|
|
|
def _read_from_sheets(self) -> tuple[str, str]:
|
|
creds = Credentials.from_service_account_file(
|
|
prompt_settings.GOOGLE_SERVICE_ACCOUNT_JSON, scopes=_SCOPES
|
|
)
|
|
gc = gspread.authorize(creds)
|
|
ws = gc.open_by_key(prompt_settings.PROMPT_SPREADSHEET).worksheet(self.sheet_name)
|
|
model = ws.cell(2, 2).value
|
|
input_text = ws.cell(3, 2).value
|
|
return input_text, model
|
|
|
|
def _reload_prompt(self):
|
|
self.prompt_template, self.prompt_model = self._read_from_sheets()
|
|
|
|
def build_prompt(self, input_data:dict, silent:bool = False) -> str:
|
|
verified_input = self.prompt_input_class(**input_data)
|
|
build_template = self.prompt_template
|
|
build_template = build_template.format(**verified_input.model_dump())
|
|
if not silent:
|
|
logger.debug(f"build_template: {build_template}")
|
|
logger.debug(f"input_data: {input_data}")
|
|
return build_template
|
|
|
|
marketing_prompt = Prompt(
|
|
sheet_name="marketing",
|
|
prompt_input_class=MarketingPromptInput,
|
|
prompt_output_class=MarketingPromptOutput,
|
|
)
|
|
|
|
lyric_prompt = Prompt(
|
|
sheet_name="lyric",
|
|
prompt_input_class=LyricPromptInput,
|
|
prompt_output_class=LyricPromptOutput,
|
|
)
|
|
|
|
yt_upload_prompt = Prompt(
|
|
sheet_name="yt_upload",
|
|
prompt_input_class=YTUploadPromptInput,
|
|
prompt_output_class=YTUploadPromptOutput,
|
|
)
|
|
|
|
image_autotag_prompt = Prompt(
|
|
sheet_name="image_tag",
|
|
prompt_input_class=ImageTagPromptInput,
|
|
prompt_output_class=ImageTagPromptOutput,
|
|
)
|
|
|
|
@lru_cache()
|
|
def create_dynamic_subtitle_prompt(length: int) -> Prompt:
|
|
return Prompt(
|
|
sheet_name="subtitle",
|
|
prompt_input_class=SubtitlePromptInput,
|
|
prompt_output_class=SubtitlePromptOutput[length],
|
|
)
|
|
|
|
|
|
def reload_all_prompt():
|
|
marketing_prompt._reload_prompt()
|
|
lyric_prompt._reload_prompt()
|
|
yt_upload_prompt._reload_prompt()
|
|
image_autotag_prompt._reload_prompt() |