o2o-infinith-backend/app/integrations/llm/llm_service.py

77 lines
3.2 KiB
Python

from pydantic import BaseModel
from openai import AsyncOpenAI
from common.utils import get_env
from .prompt import Prompt
class LLMResponseError(Exception):
def __init__(self, status: str, code: str = None, message: str = None):
self.status = status
self.code = code
self.message = message
super().__init__(f"LLM response failed: status={status}, code={code}, message={message}")
class LLMService:
def __init__(self, provider: str = "openai", max_retries: int = 2):
self.provider = provider
self.max_retries = max_retries
match provider:
case "openai":
self.client = AsyncOpenAI(api_key=get_env("OPENAI_API_KEY"))
case "perplexity":
self.client = AsyncOpenAI(
api_key=get_env("PERPLEXITY_API_KEY"),
base_url="https://api.perplexity.ai",
)
case "gemini":
self.client = AsyncOpenAI(
api_key=get_env("GEMINI_API_KEY"),
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
)
case _:
raise NotImplementedError(f"Unknown provider: {provider}")
async def generate(
self,
prompt: Prompt,
input_data: dict,
) -> BaseModel:
prompt_text = prompt.build(input_data)
last_error = None
for attempt in range(self.max_retries + 1):
if self.provider == "perplexity":
response = await self.client.chat.completions.create(
model=prompt.model,
messages=[{"role": "user", "content": prompt_text}],
response_format={
"type": "json_schema",
"json_schema": {"name": prompt.output_class.__name__, "schema": prompt.output_class.model_json_schema()},
},
)
choice = response.choices[0]
if choice.finish_reason == "stop":
return prompt.output_class.model_validate_json(choice.message.content)
last_error = LLMResponseError("failed", choice.finish_reason, f"unexpected finish_reason: {choice.finish_reason}")
else:
response = await self.client.beta.chat.completions.parse(
model=prompt.model,
messages=[{"role": "user", "content": prompt_text}],
response_format=prompt.output_class,
)
choice = response.choices[0]
finish_reason = choice.finish_reason
if finish_reason == "stop":
return choice.message.parsed
if finish_reason == "length":
last_error = LLMResponseError("incomplete", finish_reason, "max tokens reached")
elif finish_reason == "content_filter":
last_error = LLMResponseError("failed", finish_reason, "blocked by content filter")
else:
last_error = LLMResponseError("failed", finish_reason, f"unexpected finish_reason: {finish_reason}")
raise last_error