o2o-castad-backend/app/utils/prompts/prompts.py

65 lines
2.4 KiB
Python

import os, json
from pydantic import BaseModel
from config import prompt_settings
from app.utils.logger import get_logger
from app.utils.prompts.schemas import *
logger = get_logger("prompt")
class Prompt():
prompt_template_path : str #프롬프트 경로
prompt_template : str # fstring 포맷
prompt_model : str
prompt_input_class = BaseModel # pydantic class 자체를(instance 아님) 변수로 가짐
prompt_output_class = BaseModel
def __init__(self, prompt_template_path, prompt_input_class, prompt_output_class, prompt_model):
self.prompt_template_path = prompt_template_path
self.prompt_input_class = prompt_input_class
self.prompt_output_class = prompt_output_class
self.prompt_template = self.read_prompt()
self.prompt_model = prompt_model
def _reload_prompt(self):
self.prompt_template = self.read_prompt()
def read_prompt(self) -> tuple[str, dict]:
with open(self.prompt_template_path, "r") as fp:
prompt_template = fp.read()
return prompt_template
def build_prompt(self, input_data:dict) -> str:
verified_input = self.prompt_input_class(**input_data)
build_template = self.prompt_template
build_template = build_template.format(**verified_input.model_dump())
logger.debug(f"build_template: {build_template}")
logger.debug(f"input_data: {input_data}")
return build_template
marketing_prompt = Prompt(
prompt_template_path = os.path.join(prompt_settings.PROMPT_FOLDER_ROOT, prompt_settings.MARKETING_PROMPT_FILE_NAME),
prompt_input_class = MarketingPromptInput,
prompt_output_class = MarketingPromptOutput,
prompt_model = prompt_settings.MARKETING_PROMPT_MODEL
)
lyric_prompt = Prompt(
prompt_template_path=os.path.join(prompt_settings.PROMPT_FOLDER_ROOT, prompt_settings.LYRIC_PROMPT_FILE_NAME),
prompt_input_class = LyricPromptInput,
prompt_output_class = LyricPromptOutput,
prompt_model = prompt_settings.LYRIC_PROMPT_MODEL
)
yt_upload_prompt = Prompt(
prompt_template_path=os.path.join(prompt_settings.PROMPT_FOLDER_ROOT, prompt_settings.YOUTUBE_PROMPT_FILE_NAME),
prompt_input_class = YTUploadPromptInput,
prompt_output_class = YTUploadPromptOutput,
prompt_model = prompt_settings.YOUTUBE_PROMPT_MODEL
)
def reload_all_prompt():
marketing_prompt._reload_prompt()
lyric_prompt._reload_prompt()
yt_upload_prompt._reload_prompt()