78 lines
3.2 KiB
Python
78 lines
3.2 KiB
Python
from pydantic import BaseModel
|
|
from openai import AsyncOpenAI
|
|
from common.utils import get_env
|
|
from .prompt import Prompt
|
|
|
|
|
|
class LLMResponseError(Exception):
|
|
def __init__(self, status: str, code: str = None, message: str = None):
|
|
self.status = status
|
|
self.code = code
|
|
self.message = message
|
|
super().__init__(f"LLM response failed: status={status}, code={code}, message={message}")
|
|
|
|
|
|
class LLMService:
|
|
def __init__(self, provider: str = "openai", max_retries: int = 2):
|
|
self.provider = provider
|
|
self.max_retries = max_retries
|
|
match provider:
|
|
case "openai":
|
|
self.client = AsyncOpenAI(api_key=get_env("OPENAI_API_KEY"))
|
|
case "perplexity":
|
|
self.client = AsyncOpenAI(
|
|
api_key=get_env("PERPLEXITY_API_KEY"),
|
|
base_url="https://api.perplexity.ai",
|
|
)
|
|
case "gemini":
|
|
self.client = AsyncOpenAI(
|
|
api_key=get_env("GEMINI_API_KEY"),
|
|
base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
|
|
)
|
|
case _:
|
|
raise NotImplementedError(f"Unknown provider: {provider}")
|
|
|
|
async def generate(
|
|
self,
|
|
prompt: Prompt,
|
|
input_data: dict,
|
|
) -> BaseModel:
|
|
prompt_text = prompt.build(input_data)
|
|
last_error = None
|
|
|
|
for attempt in range(self.max_retries + 1):
|
|
if self.provider == "perplexity":
|
|
response = await self.client.chat.completions.create(
|
|
model=prompt.model,
|
|
messages=[{"role": "user", "content": prompt_text}],
|
|
max_tokens=16000,
|
|
response_format={
|
|
"type": "json_schema",
|
|
"json_schema": {"name": prompt.output_class.__name__, "schema": prompt.output_class.model_json_schema()},
|
|
},
|
|
)
|
|
choice = response.choices[0]
|
|
if choice.finish_reason == "stop":
|
|
return prompt.output_class.model_validate_json(choice.message.content)
|
|
last_error = LLMResponseError("failed", choice.finish_reason, f"unexpected finish_reason: {choice.finish_reason}")
|
|
else:
|
|
response = await self.client.beta.chat.completions.parse(
|
|
model=prompt.model,
|
|
messages=[{"role": "user", "content": prompt_text}],
|
|
response_format=prompt.output_class,
|
|
)
|
|
choice = response.choices[0]
|
|
finish_reason = choice.finish_reason
|
|
|
|
if finish_reason == "stop":
|
|
return choice.message.parsed
|
|
|
|
if finish_reason == "length":
|
|
last_error = LLMResponseError("incomplete", finish_reason, "max tokens reached")
|
|
elif finish_reason == "content_filter":
|
|
last_error = LLMResponseError("failed", finish_reason, "blocked by content filter")
|
|
else:
|
|
last_error = LLMResponseError("failed", finish_reason, f"unexpected finish_reason: {finish_reason}")
|
|
|
|
raise last_error
|