Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/sdk_generation_mistralai_azure_sdk.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
mode: pr
set_version: ${{ github.event.inputs.set_version }}
speakeasy_version: latest
target: mistral-python-sdk-azure
target: mistralai-azure-sdk
secrets:
github_access_token: ${{ secrets.GITHUB_TOKEN }}
pypi_token: ${{ secrets.PYPI_TOKEN }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/sdk_generation_mistralai_gcp_sdk.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
mode: pr
set_version: ${{ github.event.inputs.set_version }}
speakeasy_version: latest
target: mistral-python-sdk-google-cloud
target: mistralai-gcp-sdk
secrets:
github_access_token: ${{ secrets.GITHUB_TOKEN }}
pypi_token: ${{ secrets.PYPI_TOKEN }}
Expand Down
22 changes: 12 additions & 10 deletions packages/mistralai_gcp/src/mistralai_gcp/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import json
import weakref
from typing import Any, Optional, cast
from typing import Any, Optional, Union, cast

import google.auth
import google.auth.credentials
Expand Down Expand Up @@ -67,30 +67,32 @@ def __init__(
:param timeout_ms: Optional request timeout applied to each operation in milliseconds
"""

credentials = None
if not access_token:
credentials, loaded_project_id = google.auth.default(
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
credentials.refresh(google.auth.transport.requests.Request())

if not isinstance(credentials, google.auth.credentials.Credentials):
raise models.SDKError(
"credentials must be an instance of google.auth.credentials.Credentials"
)
# default will already raise a google.auth.exceptions.DefaultCredentialsError if no credentials are found
assert isinstance(
credentials, google.auth.credentials.Credentials
), "credentials must be an instance of google.auth.credentials.Credentials"

credentials.refresh(google.auth.transport.requests.Request())
project_id = project_id or loaded_project_id

if project_id is None:
raise models.SDKError("project_id must be provided")
raise ValueError("project_id must be provided")

def auth_token() -> str:
if access_token:
return access_token

assert credentials is not None, "credentials must be initialized"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should normally never assert because of the logic above. That's why I used an assert

credentials.refresh(google.auth.transport.requests.Request())
token = credentials.token
if not token:
raise models.SDKError("Failed to get token from credentials")
raise Exception("Failed to get token from credentials")
return token

client_supplied = True
Expand Down Expand Up @@ -197,7 +199,7 @@ def __init__(self, region: str, project_id: str):

def before_request(
self, hook_ctx, request: httpx.Request
) -> httpx.Request | Exception:
) -> Union[httpx.Request, Exception]:
# The goal of this function is to template in the region, project and model into the URL path
# We do this here so that the API remains more user-friendly
model_id = None
Expand All @@ -210,7 +212,7 @@ def before_request(
new_content = json.dumps(parsed).encode("utf-8")

if model_id == "":
raise models.SDKError("model must be provided")
raise ValueError("model must be provided")

stream = "streamRawPredict" in request.url.path
specifier = "streamRawPredict" if stream else "rawPredict"
Expand Down