o2o-castad-backend/app/home/worker/main_task.py

88 lines
3.1 KiB
Python

import asyncio
from sqlalchemy import select
from app.database.session import get_worker_session
from app.home.schemas.home import GenerateRequest
from app.lyric.models import Lyric
from app.utils.chatgpt_prompt import ChatgptService
async def _save_lyric(task_id: str, project_id: int, lyric_prompt: str) -> int:
"""Lyric 레코드를 DB에 저장 (status=processing, lyric_result=null)"""
async with get_worker_session() as session:
lyric = Lyric(
task_id=task_id,
project_id=project_id,
status="processing",
lyric_prompt=lyric_prompt,
lyric_result=None,
)
session.add(lyric)
await session.commit()
await session.refresh(lyric)
print(f"Lyric saved: id={lyric.id}, task_id={task_id}, status=processing")
return lyric.id
async def _update_lyric_status(lyric_id: int, status: str, lyric_result: str | None = None) -> None:
"""Lyric 레코드의 status와 lyric_result를 업데이트"""
async with get_worker_session() as session:
result = await session.execute(select(Lyric).where(Lyric.id == lyric_id))
lyric = result.scalar_one_or_none()
if lyric:
lyric.status = status
if lyric_result is not None:
lyric.lyric_result = lyric_result
await session.commit()
print(f"Lyric updated: id={lyric_id}, status={status}")
async def lyric_task(
task_id: str,
project_id: int,
customer_name: str,
region: str,
detail_region_info: str,
) -> None:
"""가사 생성 작업: ChatGPT로 가사 생성 및 Lyric 테이블 저장/업데이트"""
service = ChatgptService(
customer_name=customer_name,
region=region,
detail_region_info=detail_region_info,
)
# Lyric 레코드 저장 (status=processing, lyric_result=null)
lyric_prompt = service.build_lyrics_prompt()
lyric_id = await _save_lyric(task_id, project_id, lyric_prompt)
# GPT 호출
result = await service.generate_lyrics(prompt=lyric_prompt)
print(f"GPT Response:\n{result}")
# 결과에 ERROR가 포함되어 있으면 status를 failed로 업데이트
if "ERROR:" in result:
await _update_lyric_status(lyric_id, "failed", lyric_result=result)
else:
await _update_lyric_status(lyric_id, "completed", lyric_result=result)
async def _task_process_async(request_body: GenerateRequest, task_id: str, project_id: int) -> None:
"""백그라운드 작업 처리 (async 버전)"""
customer_name = request_body.customer_name
region = request_body.region
detail_region_info = request_body.detail_region_info or ""
print(f"customer_name: {customer_name}")
print(f"region: {region}")
print(f"detail_region_info: {detail_region_info}")
# 가사 생성 작업
await lyric_task(task_id, project_id, customer_name, region, detail_region_info)
def task_process(request_body: GenerateRequest, task_id: str, project_id: int) -> None:
"""백그라운드 작업 처리 함수 (sync wrapper)"""
asyncio.run(_task_process_async(request_body, task_id, project_id))