from pydantic import BaseModel from openai import AsyncOpenAI from common.utils import get_env from .prompt import Prompt class LLMResponseError(Exception): def __init__(self, status: str, code: str = None, message: str = None): self.status = status self.code = code self.message = message super().__init__(f"LLM response failed: status={status}, code={code}, message={message}") class LLMService: def __init__(self, provider: str = "openai", max_retries: int = 2): self.max_retries = max_retries match provider: case "openai": self.client = AsyncOpenAI(api_key=get_env("OPENAI_API_KEY")) case "perplexity": self.client = AsyncOpenAI( api_key=get_env("PERPLEXITY_API_KEY"), base_url="https://api.perplexity.ai", ) case "gemini": self.client = AsyncOpenAI( api_key=get_env("GEMINI_API_KEY"), base_url="https://generativelanguage.googleapis.com/v1beta/openai/", ) case _: raise NotImplementedError(f"Unknown provider: {provider}") async def generate( self, prompt: Prompt, input_data: dict, ) -> BaseModel: prompt_text = prompt.build(input_data) last_error = None for attempt in range(self.max_retries + 1): response = await self.client.beta.chat.completions.parse( model=prompt.model, messages=[{"role": "user", "content": prompt_text}], response_format=prompt.output_class, ) choice = response.choices[0] finish_reason = choice.finish_reason if finish_reason == "stop": return choice.message.parsed if finish_reason == "length": last_error = LLMResponseError("incomplete", finish_reason, "max tokens reached") elif finish_reason == "content_filter": last_error = LLMResponseError("failed", finish_reason, "blocked by content filter") else: last_error = LLMResponseError("failed", finish_reason, f"unexpected finish_reason: {finish_reason}") raise last_error