Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
from src.infuzu import (create_chat_completion, ChatCompletionsHandlerRequestMessage, ChatCompletionsObject)
from src.infuzu import (
create_chat_completion, ChatCompletionsHandlerRequestMessage, ChatCompletionsObject
)
from dotenv import load_dotenv


load_dotenv()
if __name__ == "__main__":
load_dotenv()

messages: list[ChatCompletionsHandlerRequestMessage] = [
ChatCompletionsHandlerRequestMessage(role="system", content="You are a helpful assistant."),
ChatCompletionsHandlerRequestMessage(role="user", content="What is the capital of France?"),
]

messages: list[ChatCompletionsHandlerRequestMessage] = [
ChatCompletionsHandlerRequestMessage(role="system", content="You are a helpful assistant."),
ChatCompletionsHandlerRequestMessage(role="user", content="What is the capital of France?"),
]


try:
response: ChatCompletionsObject = create_chat_completion(messages=messages)
print(response)
except Exception as e:
print(f"Error: {e}")
try:
response: ChatCompletionsObject = create_chat_completion(messages=messages)
print(response)
except Exception as e:
print(f"Error: {e}")
3 changes: 3 additions & 0 deletions src/infuzu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
ChatCompletionsObject,
)
from .errors import (InfuzuAPIError, APIWarning, APIError)
from .utils import get_version


__all__: list[str] = [
Expand All @@ -40,4 +41,6 @@
"InfuzuAPIError",
"APIWarning",
"APIError",

"get_version",
]
139 changes: 64 additions & 75 deletions src/infuzu/api_client.py
Original file line number Diff line number Diff line change
@@ -1,168 +1,155 @@
import platform
import time
import uuid
import httpx
import os
from typing import (Optional, Dict, Union, List)
from pydantic import (BaseModel, validator, Field)
from pydantic import (BaseModel, Field, ConfigDict, model_validator)
from .utils import get_version
from .errors import InfuzuAPIError


class ModelWeights(BaseModel):
model_config = ConfigDict(extra="allow")

price: Optional[float] = None
error: Optional[float] = None
start_latency: Optional[float] = None
end_latency: Optional[float] = None

class Config:
extra: str = "allow"


class InfuzuModelParams(BaseModel):
model_config = ConfigDict(extra="allow")

llms: Optional[List[str]] = None
exclude_llms: Optional[List[str]] = None
weights: Optional[ModelWeights] = None
imsn: Optional[int] = None
max_input_cost: Optional[float] = None
max_output_cost: Optional[float] = None

class Config:
extra: str = "allow"


class ChatCompletionsRequestContentPart(BaseModel):
model_config = ConfigDict(extra="allow")

type: str
text: Optional[str] = None
image_url: Optional[str] = None
input_audio: Optional[str] = None

class Config:
extra: str = "allow"

@validator("text", always=True)
def check_content_fields(cls, value, values):
if "type" in values:
content_type = values["type"]
if content_type == "text" and value is None:
raise ValueError("Text must be provided when type is 'text'")
if content_type != "text" and value is not None:
raise ValueError("Text cannot be provided when type is not 'text'")
return value
@model_validator(mode='after')
def check_content_fields(self) -> 'ChatCompletionsRequestContentPart':
if self.type == "text" and self.text is None:
raise ValueError("Text must be provided when type is 'text'")
if self.type != "text" and self.text is not None:
raise ValueError("Text cannot be provided when type is not 'text'")
return self


class ChatCompletionsHandlerRequestMessage(BaseModel):
model_config = ConfigDict(extra="allow")

content: Union[str, List[ChatCompletionsRequestContentPart]]
role: str
name: Optional[str] = None

class Config:
extra: str = "allow"

@validator('role')
def role_must_be_valid(cls, v):
if v not in ('system', 'user', 'assistant'):
@model_validator(mode='after')
def role_must_be_valid(self) -> 'ChatCompletionsHandlerRequestMessage':
if self.role not in ('system', 'user', 'assistant'):
raise ValueError('Role must be one of: system, user, assistant')
return v
return self


class ChatCompletionsChoiceMessageAudioObject(BaseModel):
model_config = ConfigDict(extra="allow")

id: Optional[str] = None
expired_at: Optional[int] = None
data: Optional[str] = None
transcript: Optional[str] = None

class Config:
extra: str = "allow"


class ChatCompletionsChoiceMessageFunctionCallObject(BaseModel):
model_config = ConfigDict(extra="allow")

name: Optional[str] = None
arguments: Optional[str] = None

class Config:
extra: str = "allow"


class ChatCompletionsChoiceMessageToolCallFunctionObject(BaseModel):
model_config = ConfigDict(extra="allow")

name: Optional[str] = None
arguments: Optional[str] = None

class Config:
extra: str = "allow"


class chatCompletionsChoiceMessageToolCallObject(BaseModel):
model_config = ConfigDict(extra="allow")

id: Optional[str] = None
type: Optional[str] = None
function: Optional[ChatCompletionsChoiceMessageToolCallFunctionObject] = None

class Config:
extra: str = "allow"


class ChatCompletionsChoiceMessageObject(BaseModel):
model_config = ConfigDict(extra="allow")

content: Optional[str] = None
refusal: Optional[str] = None
tool_calls: Optional[List[chatCompletionsChoiceMessageToolCallObject]] = None
role: Optional[str] = None
function_call: Optional[ChatCompletionsChoiceMessageFunctionCallObject] = None
audio: Optional[ChatCompletionsChoiceMessageAudioObject] = None

class Config:
extra: str = "allow"


class ChatCompletionsChoiceLogprobsItemTopLogprobObject(BaseModel):
model_config = ConfigDict(extra="allow")

token: Optional[str] = None
logprob: Optional[int] = None
bytes: Optional[List[int]] = None

class Config:
extra: str = "allow"


class ChatCompletionsLogprobsItemObject(BaseModel):
model_config = ConfigDict(extra="allow")

token: Optional[str] = None
logprob: Optional[int] = None
bytes: Optional[List[int]] = None
content: Optional[List[ChatCompletionsChoiceLogprobsItemTopLogprobObject]] = None

class Config:
extra: str = "allow"


class ChatCompletionsChoiceLogprobsObject(BaseModel):
model_config = ConfigDict(extra="allow")

content: Optional[List[ChatCompletionsLogprobsItemObject]] = None
refusal: Optional[List[ChatCompletionsLogprobsItemObject]] = None

class Config:
extra: str = "allow"


class ChatCompletionsChoiceModelObject(BaseModel):
model_config = ConfigDict(extra="allow")

ref: Optional[str] = None
rank: Optional[int] = None

class Config:
extra: str = "allow"


class ChatCompletionsChoiceErrorObject(BaseModel):
model_config = ConfigDict(extra="allow")

message: Optional[str] = None
code: Optional[str] = None

class Config:
extra: str = "allow"


class ChatCompletionsChoiceLatencyObject(BaseModel):
model_config = ConfigDict(extra="allow")

start: Optional[int] = Field(None, alias='start_latency')
end: Optional[int] = Field(None, alias='end_latency')

class Config:
extra: str = "allow"


class ChatCompletionsChoiceObject(BaseModel):
model_config = ConfigDict(extra="allow")

finish_reason: Optional[str] = None
index: Optional[int] = None
message: Optional[ChatCompletionsChoiceMessageObject] = None
Expand All @@ -171,11 +158,10 @@ class ChatCompletionsChoiceObject(BaseModel):
error: Optional[ChatCompletionsChoiceErrorObject] = None
latency: Optional[ChatCompletionsChoiceLatencyObject] = None

class Config:
extra: str = "allow"


class ChatCompletionsObject(BaseModel):
model_config = ConfigDict(extra="allow")

id: Optional[str] = None
choices: Optional[List[ChatCompletionsChoiceObject]] = None
created: Optional[int] = None
Expand All @@ -185,11 +171,8 @@ class ChatCompletionsObject(BaseModel):
object: Optional[str] = None
usage: Optional[Dict[str, int]] = None

class Config:
extra: str = "allow"


API_BASE_URL = "https://chat.infuzu.com/api"
API_BASE_URL: str = "https://chat.infuzu.com/api"


def create_chat_completion(
Expand All @@ -204,43 +187,49 @@ def create_chat_completion(
messages: A list of message objects.
api_key: Your Infuzu API key. If not provided, it will be read from the
INFUZU_API_KEY environment variable.
model: The model to use for the chat completion. Can be a string (model name)
model: The model to use for the chat completion. Can be a string (model name)
or a InfuzuModelParams object for more advanced configuration.

Returns:
A dictionary containing the JSON response from the API.
The ChatCompletionsObject Object

Raises:
ValueError: If the API key is not provided and the INFUZU_API_KEY
environment variable is not set.
httpx.HTTPStatusError: If the API request returns an error status code.
InfuzuAPIError: If the API request returns an error status code.
"""

if api_key is None:
api_key = os.environ.get("INFUZU_API_KEY")
api_key: str | None = os.environ.get("INFUZU_API_KEY")
if api_key is None:
raise ValueError(
"API key not provided and INFUZU_API_KEY environment variable not set."
)

headers = {
headers: dict[str, str] = {
"Content-Type": "application/json",
"Infuzu-API-Key": api_key,
"User-Agent": (
f"infuzu-python/{get_version()} "
f"(Python {platform.python_version()}; "
f"httpx/{httpx.__version__}; "
f"{platform.system()} {platform.release()})"
)
}

payload = {
"messages": [message.dict(by_alias=True) for message in messages],
payload: dict[str, any] = {
"messages": [message.model_dump(by_alias=True) for message in messages],
}

if model:
if isinstance(model, str):
payload["model"] = model
else:
payload["model"] = model.dict(by_alias=True)
payload["model"] = model.model_dump(by_alias=True)

try:
with httpx.Client() as client:
response = client.post(
response: httpx.Response = client.post(
f"{API_BASE_URL}/v1/chat/completions",
headers=headers,
json=payload,
Expand Down
12 changes: 5 additions & 7 deletions src/infuzu/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,25 @@
from json import JSONDecodeError
from typing import Optional
import httpx
from pydantic import BaseModel
from pydantic import (BaseModel, ConfigDict)


logger: logging.Logger = logging.getLogger(__name__)


class APIError(BaseModel):
model_config = ConfigDict(extra="allow")

code: Optional[str] = None
message: Optional[str] = None

class Config:
extra: str = "allow"


class APIWarning(BaseModel):
model_config = ConfigDict(extra="allow")

code: Optional[str] = None
message: Optional[str] = None

class Config:
extra: str = "allow"


class InfuzuAPIError(httpx.HTTPStatusError):
def __init__(self, base_error: httpx.HTTPStatusError) -> None:
Expand Down
Loading