diff --git a/examples/aio/image_generation.py b/examples/aio/image_generation.py index 796d712..eaaac18 100644 --- a/examples/aio/image_generation.py +++ b/examples/aio/image_generation.py @@ -1,6 +1,6 @@ import asyncio import itertools -from typing import Sequence +from typing import Sequence, cast from absl import app, flags @@ -10,29 +10,66 @@ N = flags.DEFINE_integer("n", 1, "Number of images to generate.") FORMAT = flags.DEFINE_enum("format", "base64", ["base64", "url"], "Image format used to return the result.") +MODEL = flags.DEFINE_string("model", "grok-imagine-image", "Image generation model to use.") OUTPUT_DIR = flags.DEFINE_string("output-dir", None, "Directory to save the generated images.", required=True) -async def generate_single(client: xai_sdk.AsyncClient, image_format: ImageFormat): - """Generate a single image from a prompt.""" - for turn in itertools.count(): - prompt = input("Prompt: ") - result = await client.image.sample(prompt, model="grok-2-image", image_format=image_format) - await _save_images(turn, [result]) +async def generate_multi_turn(client: xai_sdk.AsyncClient, image_format: ImageFormat): + """Multi-turn image generation that builds on the previous output image. + Turn 0 generates an initial image from your prompt. Each subsequent turn reuses the previous + image output as an input image (image-to-image) while you provide a new prompt to refine it. + """ + previous_image: str | None = None -async def generate_batch(client: xai_sdk.AsyncClient, image_format: ImageFormat): - """Generate a batch of images from a prompt.""" for turn in itertools.count(): - prompt = input("Prompt: ") - results = await client.image.sample_batch(prompt, n=N.value, model="grok-2-image", image_format=image_format) - await _save_images(turn, results) + if previous_image is None: + prompt = input("Prompt (blank to stop): ") + else: + prompt = input("Edit prompt (blank to stop): ") + if not prompt: + return + + if N.value == 1: + response = await client.image.sample( + prompt, + model=MODEL.value, + image_format=image_format, + image_url=previous_image, + ) + responses = [response] + else: + responses = await client.image.sample_batch( + prompt, + n=N.value, + model=MODEL.value, + image_format=image_format, + image_url=previous_image, + ) + + await _save_images(turn, responses) + + selected = 0 + if len(responses) > 1: + raw = input(f"Continue from which image? [0-{len(responses) - 1}] (default 0): ").strip() + if raw: + selected = int(raw) + if selected < 0 or selected >= len(responses): + raise ValueError(f"Invalid image index {selected}.") + + chosen = responses[selected] + previous_image = chosen.url if image_format == "url" else chosen.base64 + + if len(responses) > 1: + if image_format == "url": + print(f"Continuing from image {selected}: {chosen.url}") + else: + print(f"Continuing from image {selected} (base64).") async def _save_images(turn: int, responses: Sequence[ImageResponse]): """Save images to a file.""" for i, image in enumerate(responses): - print(image.prompt) with open(f"{OUTPUT_DIR.value}/image_{turn}_{i}.jpg", "wb") as f: f.write(await image.image) @@ -42,16 +79,8 @@ async def main(argv: Sequence[str]) -> None: raise app.UsageError("Unexpected command line arguments.") client = xai_sdk.AsyncClient() - - match (N.value, FORMAT.value): - case (1, "base64"): - await generate_single(client, "base64") - case (_, "base64"): - await generate_batch(client, "base64") - case (1, "url"): - await generate_single(client, "url") - case (_, "url"): - await generate_batch(client, "url") + image_format: ImageFormat = cast(ImageFormat, FORMAT.value) + await generate_multi_turn(client, image_format) if __name__ == "__main__": diff --git a/examples/aio/video_generation.py b/examples/aio/video_generation.py new file mode 100644 index 0000000..6a5b4d8 --- /dev/null +++ b/examples/aio/video_generation.py @@ -0,0 +1,84 @@ +import asyncio +from datetime import timedelta +from typing import Sequence, cast + +from absl import app, flags + +import xai_sdk +from xai_sdk.video import VideoAspectRatio, VideoResolution + +MODEL = flags.DEFINE_string("model", "grok-imagine-video", "Video generation model to use.") +IMAGE_URL = flags.DEFINE_string( + "image-url", + "", + "Optional input image (URL or base64 data URL) to use as the first frame (image-to-video).", +) +VIDEO_URL = flags.DEFINE_string( + "video-url", + "", + "Optional input video (URL or base64 data URL) to edit based on the prompt (video-to-video).", +) +DURATION = flags.DEFINE_integer("duration", 0, "Optional duration in seconds (1-15). Use 0 to omit.") +ASPECT_RATIO = flags.DEFINE_string( + "aspect-ratio", + "", + 'Optional aspect ratio. One of: "1:1", "16:9", "9:16", "4:3", "3:4", "3:2", "2:3".', +) +RESOLUTION = flags.DEFINE_string("resolution", "", 'Optional resolution. One of: "480p", "720p".') +TIMEOUT = flags.DEFINE_integer("timeout", 600, "Timeout in seconds for polling.") +INTERVAL = flags.DEFINE_integer("interval", 1, "Polling interval in seconds.") + + +async def main(argv: Sequence[str]) -> None: + if len(argv) > 1: + raise app.UsageError("Unexpected command line arguments.") + + client = xai_sdk.AsyncClient() + + duration = DURATION.value or None + image_url = IMAGE_URL.value or None + video_url = VIDEO_URL.value or None + aspect_ratio = cast(VideoAspectRatio, ASPECT_RATIO.value) if ASPECT_RATIO.value else None + resolution = cast(VideoResolution, RESOLUTION.value) if RESOLUTION.value else None + + previous_video_url: str | None = video_url + first_turn = True + + while True: + prompt = input("Prompt (blank to stop): " if first_turn else "Edit prompt (blank to stop): ") + if not prompt: + return + + try: + response = await client.video.generate( + prompt=prompt, + model=MODEL.value, + image_url=image_url if first_turn else None, + video_url=previous_video_url, + duration=duration, + aspect_ratio=aspect_ratio, + resolution=resolution, + timeout=timedelta(seconds=TIMEOUT.value), + interval=timedelta(seconds=INTERVAL.value), + ) + print(f"Respects moderation: {response.respect_moderation}") + if response.respect_moderation: + print(f"Video URL: {response.url}") + else: + print("Video URL not returned due to moderation.") + print(f"Duration: {response.duration}s") + + # Chain edits: use the returned URL as the next input video. + if response.respect_moderation: + previous_video_url = response.url + first_turn = False + except RuntimeError as e: + # request expired + print(e) + except ValueError as e: + # unknown deferred status + print(e) + + +if __name__ == "__main__": + app.run(lambda argv: asyncio.run(main(argv))) diff --git a/examples/sync/image_generation.py b/examples/sync/image_generation.py index 8091b10..0f15903 100644 --- a/examples/sync/image_generation.py +++ b/examples/sync/image_generation.py @@ -1,5 +1,5 @@ import itertools -from typing import Sequence +from typing import Sequence, cast from absl import app, flags @@ -9,29 +9,66 @@ N = flags.DEFINE_integer("n", 1, "Number of images to generate.") FORMAT = flags.DEFINE_enum("format", "base64", ["base64", "url"], "Image format used to return the result.") +MODEL = flags.DEFINE_string("model", "grok-imagine-image", "Image generation model to use.") OUTPUT_DIR = flags.DEFINE_string("output-dir", None, "Directory to save the generated images.", required=True) -def generate_single(client: xai_sdk.Client, image_format: ImageFormat): - """Generate a single image from a prompt.""" - for turn in itertools.count(): - prompt = input("Prompt: ") - result = client.image.sample(prompt, model="grok-2-image", image_format=image_format) - save_images(turn, [result]) +def generate_multi_turn(client: xai_sdk.Client, image_format: ImageFormat): + """Multi-turn image generation that builds on the previous output image. + Turn 0 generates an initial image from your prompt. Each subsequent turn reuses the previous + image output as an input image (image-to-image) while you provide a new prompt to refine it. + """ + previous_image: str | None = None -def generate_batch(client: xai_sdk.Client, image_format: ImageFormat): - """Generate a batch of images from a prompt.""" for turn in itertools.count(): - prompt = input("Prompt: ") - results = client.image.sample_batch(prompt, n=N.value, model="grok-2-image", image_format=image_format) - save_images(turn, results) + if previous_image is None: + prompt = input("Prompt (blank to stop): ") + else: + prompt = input("Edit prompt (blank to stop): ") + if not prompt: + return + + if N.value == 1: + response = client.image.sample( + prompt, + model=MODEL.value, + image_format=image_format, + image_url=previous_image, + ) + responses = [response] + else: + responses = client.image.sample_batch( + prompt, + n=N.value, + model=MODEL.value, + image_format=image_format, + image_url=previous_image, + ) + + save_images(turn, responses) + + selected = 0 + if len(responses) > 1: + raw = input(f"Continue from which image? [0-{len(responses) - 1}] (default 0): ").strip() + if raw: + selected = int(raw) + if selected < 0 or selected >= len(responses): + raise ValueError(f"Invalid image index {selected}.") + + chosen = responses[selected] + previous_image = chosen.url if image_format == "url" else chosen.base64 + + if len(responses) > 1: + if image_format == "url": + print(f"Continuing from image {selected}: {chosen.url}") + else: + print(f"Continuing from image {selected} (base64).") def save_images(turn: int, responses: Sequence[ImageResponse]): """Save images to a file.""" for i, image in enumerate(responses): - print(image.prompt) with open(f"{OUTPUT_DIR.value}/image_{turn}_{i}.jpg", "wb") as f: f.write(image.image) @@ -41,16 +78,8 @@ def main(argv: Sequence[str]) -> None: raise app.UsageError("Unexpected command line arguments.") client = xai_sdk.Client() - - match (N.value, FORMAT.value): - case (1, "base64"): - generate_single(client, "base64") - case (1, "url"): - generate_single(client, "url") - case (_, "base64"): - generate_batch(client, "base64") - case (_, "url"): - generate_batch(client, "url") + image_format: ImageFormat = cast(ImageFormat, FORMAT.value) + generate_multi_turn(client, image_format) if __name__ == "__main__": diff --git a/examples/sync/video_generation.py b/examples/sync/video_generation.py new file mode 100644 index 0000000..1e2a7a7 --- /dev/null +++ b/examples/sync/video_generation.py @@ -0,0 +1,83 @@ +from datetime import timedelta +from typing import Sequence, cast + +from absl import app, flags + +import xai_sdk +from xai_sdk.video import VideoAspectRatio, VideoResolution + +MODEL = flags.DEFINE_string("model", "grok-imagine-video", "Video generation model to use.") +IMAGE_URL = flags.DEFINE_string( + "image-url", + "", + "Optional input image (URL or base64 data URL) to use as the first frame (image-to-video).", +) +VIDEO_URL = flags.DEFINE_string( + "video-url", + "", + "Optional input video (URL or base64 data URL) to edit based on the prompt (video-to-video).", +) +DURATION = flags.DEFINE_integer("duration", 0, "Optional duration in seconds (1-15). Use 0 to omit.") +ASPECT_RATIO = flags.DEFINE_string( + "aspect-ratio", + "", + 'Optional aspect ratio. One of: "1:1", "16:9", "9:16", "4:3", "3:4", "3:2", "2:3".', +) +RESOLUTION = flags.DEFINE_string("resolution", "", 'Optional resolution. One of: "480p", "720p".') +TIMEOUT = flags.DEFINE_integer("timeout", 600, "Timeout in seconds for polling.") +INTERVAL = flags.DEFINE_integer("interval", 1, "Polling interval in seconds.") + + +def main(argv: Sequence[str]) -> None: + if len(argv) > 1: + raise app.UsageError("Unexpected command line arguments.") + + client = xai_sdk.Client() + + duration = DURATION.value or None + image_url = IMAGE_URL.value or None + video_url = VIDEO_URL.value or None + aspect_ratio = cast(VideoAspectRatio, ASPECT_RATIO.value) if ASPECT_RATIO.value else None + resolution = cast(VideoResolution, RESOLUTION.value) if RESOLUTION.value else None + + previous_video_url: str | None = video_url + first_turn = True + + while True: + prompt = input("Prompt (blank to stop): " if first_turn else "Edit prompt (blank to stop): ") + if not prompt: + return + + try: + response = client.video.generate( + prompt=prompt, + model=MODEL.value, + image_url=image_url if first_turn else None, + video_url=previous_video_url, + duration=duration, + aspect_ratio=aspect_ratio, + resolution=resolution, + timeout=timedelta(seconds=TIMEOUT.value), + interval=timedelta(seconds=INTERVAL.value), + ) + print(f"Respects moderation: {response.respect_moderation}") + if response.respect_moderation: + print(f"Video URL: {response.url}") + else: + print("Video URL not returned due to moderation.") + print(f"Duration: {response.duration}s") + + # Chain edits: use the returned URL as the next input video. + if response.respect_moderation: + previous_video_url = response.url + first_turn = False + except RuntimeError as e: + # request expired + print(e) + except ValueError as e: + # unknown deferred status + print(e) + + +if __name__ == "__main__": + app.run(main) diff --git a/src/xai_sdk/aio/__init__.py b/src/xai_sdk/aio/__init__.py index a366f8a..6bf0c19 100644 --- a/src/xai_sdk/aio/__init__.py +++ b/src/xai_sdk/aio/__init__.py @@ -1,3 +1,14 @@ -from . import auth, batch, chat, client, collections, image, models, tokenizer +from . import auth, batch, chat, client, collections, files, image, models, tokenizer, video -__all__ = ["auth", "batch", "chat", "client", "collections", "image", "models", "tokenizer"] +__all__ = [ + "auth", + "batch", + "chat", + "client", + "collections", + "files", + "image", + "models", + "tokenizer", + "video", +] diff --git a/src/xai_sdk/aio/client.py b/src/xai_sdk/aio/client.py index 97bbec0..a163b52 100644 --- a/src/xai_sdk/aio/client.py +++ b/src/xai_sdk/aio/client.py @@ -13,7 +13,7 @@ UnaryUnaryAuthAioInterceptor, UnaryUnaryTimeoutAioInterceptor, ) -from . import auth, batch, chat, collections, files, image, models, tokenizer +from . import auth, batch, chat, collections, files, image, models, tokenizer, video class Client(BaseClient): @@ -27,6 +27,7 @@ class Client(BaseClient): image: "image.Client" models: "models.Client" tokenize: "tokenizer.Client" + video: "video.Client" def _init( self, @@ -73,6 +74,7 @@ def _init( self.image = image.Client(self._api_channel) self.models = models.Client(self._api_channel) self.tokenize = tokenizer.Client(self._api_channel) + self.video = video.Client(self._api_channel) def _make_grpc_channel( self, diff --git a/src/xai_sdk/aio/image.py b/src/xai_sdk/aio/image.py index 221b09b..c1558ce 100644 --- a/src/xai_sdk/aio/image.py +++ b/src/xai_sdk/aio/image.py @@ -7,10 +7,14 @@ from ..image import ( BaseClient, BaseImageResponse, + ImageAspectRatio, ImageFormat, + ImageResolution, _make_span_request_attributes, _make_span_response_attributes, + convert_image_aspect_ratio_to_pb, convert_image_format_to_pb, + convert_image_resolution_to_pb, ) from ..proto import image_pb2 from ..telemetry import get_tracer @@ -26,19 +30,42 @@ async def sample( prompt: str, model: str, *, + image_url: Optional[str] = None, user: Optional[str] = None, image_format: Optional[ImageFormat] = None, + aspect_ratio: Optional[ImageAspectRatio] = None, + resolution: Optional[ImageResolution] = None, ) -> "ImageResponse": """Samples a single image asynchronously based on the provided prompt. Args: prompt: The prompt to generate an image from. model: The model to use for image generation. + image_url: The URL or base64-encoded string of an input image to use as a starting point for generation. + Only supported for grok-imagine models. user: A unique identifier representing your end-user, which can help xAI to monitor and detect abuse. image_format: The format of the image to return. One of: - `"url"`: The image is returned as a URL. - `"base64"`: The image is returned as a base64-encoded string. defaults to `"url"` if not specified. + aspect_ratio: The aspect ratio of the image to generate. One of: + - `"1:1"` + - `"16:9"` + - `"9:16"` + - `"4:3"` + - `"3:4"` + - `"3:2"` + - `"2:3"` + - `"2:1"` + - `"1:2"` + - `"20:9"` + - `"9:20"` + - `"19.5:9"` + - `"9:19.5"` + Only supported for grok-imagine models. + resolution: The image resolution to generate. One of: + - `"1k"`: ~1 megapixel total. Dimensions vary by aspect ratio. + Only supported for grok-imagine models. Returns: An `ImageResponse` object allowing access to the generated image. @@ -51,6 +78,17 @@ async def sample( n=1, format=convert_image_format_to_pb(image_format), ) + if image_url is not None: + request.image.CopyFrom( + image_pb2.ImageUrlContent( + image_url=image_url, + detail=image_pb2.ImageDetail.DETAIL_AUTO, + ) + ) + if aspect_ratio is not None: + request.aspect_ratio = convert_image_aspect_ratio_to_pb(aspect_ratio) + if resolution is not None: + request.resolution = convert_image_resolution_to_pb(resolution) with tracer.start_as_current_span( name=f"image.sample {model}", @@ -68,8 +106,11 @@ async def sample_batch( model: str, n: int, *, + image_url: Optional[str] = None, user: Optional[str] = None, image_format: Optional[ImageFormat] = None, + aspect_ratio: Optional[ImageAspectRatio] = None, + resolution: Optional[ImageResolution] = None, ) -> Sequence["ImageResponse"]: """Samples a batch of images asynchronously based on the provided prompt. @@ -77,11 +118,31 @@ async def sample_batch( prompt: The prompt to generate an image from. model: The model to use for image generation. n: The number of images to generate. + image_url: The URL or base64-encoded string of an input image to use as a starting point for generation. + Only supported for grok-imagine models. user: A unique identifier representing your end-user, which can help xAI to monitor and detect abuse. image_format: The format of the image to return. One of: - `"url"`: The image is returned as a URL. - `"base64"`: The image is returned as a base64-encoded string. defaults to `"url"` if not specified. + aspect_ratio: The aspect ratio of the image to generate. One of: + - `"1:1"` + - `"16:9"` + - `"9:16"` + - `"4:3"` + - `"3:4"` + - `"3:2"` + - `"2:3"` + - `"2:1"` + - `"1:2"` + - `"20:9"` + - `"9:20"` + - `"19.5:9"` + - `"9:19.5"` + Only supported for grok-imagine models. + resolution: The image resolution to generate. One of: + - `"1k"`: ~1 megapixel total. Dimensions vary by aspect ratio. + Only supported for grok-imagine models. Returns: A sequence of `ImageResponse` objects, one for each image generated. @@ -94,6 +155,17 @@ async def sample_batch( n=n, format=convert_image_format_to_pb(image_format), ) + if image_url is not None: + request.image.CopyFrom( + image_pb2.ImageUrlContent( + image_url=image_url, + detail=image_pb2.ImageDetail.DETAIL_AUTO, + ) + ) + if aspect_ratio is not None: + request.aspect_ratio = convert_image_aspect_ratio_to_pb(aspect_ratio) + if resolution is not None: + request.resolution = convert_image_resolution_to_pb(resolution) with tracer.start_as_current_span( name=f"image.sample_batch {model}", diff --git a/src/xai_sdk/aio/video.py b/src/xai_sdk/aio/video.py new file mode 100644 index 0000000..3d66338 --- /dev/null +++ b/src/xai_sdk/aio/video.py @@ -0,0 +1,112 @@ +import asyncio +import datetime +from typing import Optional + +from opentelemetry.trace import SpanKind + +from ..poll_timer import PollTimer +from ..proto import deferred_pb2, video_pb2 +from ..telemetry import get_tracer +from ..video import ( + BaseClient, + VideoAspectRatio, + VideoResolution, + VideoResponse, + _make_generate_request, + _make_span_request_attributes, + _make_span_response_attributes, +) + +tracer = get_tracer(__name__) + + +class Client(BaseClient): + """Asynchronous client for interacting with the `Video` API.""" + + async def start( + self, + prompt: str, + model: str, + *, + image_url: Optional[str] = None, + video_url: Optional[str] = None, + duration: Optional[int] = None, + aspect_ratio: Optional[VideoAspectRatio] = None, + resolution: Optional[VideoResolution] = None, + ) -> deferred_pb2.StartDeferredResponse: + """Starts a video generation request and returns a request_id for polling.""" + request = _make_generate_request( + prompt, + model, + image_url=image_url, + video_url=video_url, + duration=duration, + aspect_ratio=aspect_ratio, + resolution=resolution, + ) + + with tracer.start_as_current_span( + name=f"video.start {model}", + kind=SpanKind.CLIENT, + attributes=_make_span_request_attributes(request), + ): + return await self._stub.GenerateVideo(request) + + async def get(self, request_id: str) -> video_pb2.GetDeferredVideoResponse: + """Gets the current status (and optional result) for a deferred video request.""" + request = video_pb2.GetDeferredVideoRequest(request_id=request_id) + return await self._stub.GetDeferredVideo(request) + + async def generate( + self, + prompt: str, + model: str, + *, + image_url: Optional[str] = None, + video_url: Optional[str] = None, + duration: Optional[int] = None, + aspect_ratio: Optional[VideoAspectRatio] = None, + resolution: Optional[VideoResolution] = None, + timeout: Optional[datetime.timedelta] = None, + interval: Optional[datetime.timedelta] = None, + ) -> VideoResponse: + """Generates a video using polling and returns the completed response. + + This wraps `GenerateVideo` + repeated `GetDeferredVideo` calls until the request is complete. + """ + timer = PollTimer(timeout, interval) + request_pb = _make_generate_request( + prompt, + model, + image_url=image_url, + video_url=video_url, + duration=duration, + aspect_ratio=aspect_ratio, + resolution=resolution, + ) + + with tracer.start_as_current_span( + name=f"video.generate {model}", + kind=SpanKind.CLIENT, + attributes=_make_span_request_attributes(request_pb), + ) as span: + start = await self._stub.GenerateVideo(request_pb) + + while True: + get_req = video_pb2.GetDeferredVideoRequest(request_id=start.request_id) + + r = await self._stub.GetDeferredVideo(get_req) + match r.status: + case deferred_pb2.DeferredStatus.DONE: + if not r.HasField("response"): + raise RuntimeError("Deferred request completed but no response was returned.") + response = VideoResponse(r.response) + span.set_attributes(_make_span_response_attributes(request_pb, response)) + return response + case deferred_pb2.DeferredStatus.EXPIRED: + raise RuntimeError("Deferred request expired.") + case deferred_pb2.DeferredStatus.PENDING: + await asyncio.sleep(timer.sleep_interval_or_raise()) + continue + case unknown_status: + raise ValueError(f"Unknown deferred status: {unknown_status}") diff --git a/src/xai_sdk/image.py b/src/xai_sdk/image.py index f38b0e2..0396bab 100644 --- a/src/xai_sdk/image.py +++ b/src/xai_sdk/image.py @@ -1,13 +1,28 @@ import base64 -from typing import Any, Literal, Sequence, Union +from typing import Any, Sequence, Union import grpc from .meta import ProtoDecorator -from .proto import image_pb2, image_pb2_grpc +from .proto import image_pb2, image_pb2_grpc, usage_pb2 from .telemetry import should_disable_sensitive_attributes - -ImageFormat = Literal["base64", "url"] +from .types.image import ImageAspectRatio, ImageFormat, ImageResolution + +_IMAGE_ASPECT_RATIO_MAP: dict[ImageAspectRatio, image_pb2.ImageAspectRatio] = { + "1:1": image_pb2.ImageAspectRatio.IMG_ASPECT_RATIO_1_1, + "3:4": image_pb2.ImageAspectRatio.IMG_ASPECT_RATIO_3_4, + "4:3": image_pb2.ImageAspectRatio.IMG_ASPECT_RATIO_4_3, + "9:16": image_pb2.ImageAspectRatio.IMG_ASPECT_RATIO_9_16, + "16:9": image_pb2.ImageAspectRatio.IMG_ASPECT_RATIO_16_9, + "2:3": image_pb2.ImageAspectRatio.IMG_ASPECT_RATIO_2_3, + "3:2": image_pb2.ImageAspectRatio.IMG_ASPECT_RATIO_3_2, + "9:19.5": image_pb2.ImageAspectRatio.IMG_ASPECT_RATIO_9_19_5, + "19.5:9": image_pb2.ImageAspectRatio.IMG_ASPECT_RATIO_19_5_9, + "9:20": image_pb2.ImageAspectRatio.IMG_ASPECT_RATIO_9_20, + "20:9": image_pb2.ImageAspectRatio.IMG_ASPECT_RATIO_20_9, + "1:2": image_pb2.ImageAspectRatio.IMG_ASPECT_RATIO_1_2, + "2:1": image_pb2.ImageAspectRatio.IMG_ASPECT_RATIO_2_1, +} class BaseClient: @@ -35,6 +50,16 @@ def __init__(self, proto: image_pb2.ImageResponse, index: int) -> None: super().__init__(proto) self._image = proto.images[index] + @property + def model(self) -> str: + """The model used to generate the image (ignoring aliases).""" + return self._proto.model + + @property + def usage(self) -> usage_pb2.SamplingUsage: + """Token and tool usage for this request.""" + return self._proto.usage + @property def prompt(self) -> str: """The actual prompt used to generate the image. @@ -44,11 +69,18 @@ def prompt(self) -> str: """ return self._image.up_sampled_prompt + @property + def respect_moderation(self) -> bool: + """Whether the image respects moderation rules.""" + return self._image.respect_moderation + @property def url(self) -> str: """Returns the URL under which the image is stored or raises an error.""" url = self._image.url if not url: + if not self.respect_moderation: + raise ValueError("Image did not respect moderation rules; URL is not available.") raise ValueError("Image was not returned via URL and cannot be fetched.") return url @@ -57,6 +89,8 @@ def base64(self) -> str: """Returns the image as base64-encoded string or raises an error.""" value = self._image.base64 if not value: + if not self.respect_moderation: + raise ValueError("Image did not respect moderation rules; base64 is not available.") raise ValueError("Image was not returned via base64.") return value @@ -73,7 +107,8 @@ def _make_span_request_attributes(request: image_pb2.GenerateImageRequest) -> di attributes: dict[str, str | int] = { "gen_ai.operation.name": "generate_image", "gen_ai.request.model": request.model, - "gen_ai.system": "xai", + "gen_ai.provider.name": "xai", + "gen_ai.output.type": "image", } if should_disable_sensitive_attributes(): @@ -86,6 +121,12 @@ def _make_span_request_attributes(request: image_pb2.GenerateImageRequest) -> di if request.HasField("n"): attributes["gen_ai.request.image.count"] = request.n + if request.HasField("aspect_ratio"): + attributes["gen_ai.request.image.aspect_ratio"] = _format_image_aspect_ratio(request.aspect_ratio) + if request.HasField("resolution"): + attributes["gen_ai.request.image.resolution"] = ( + image_pb2.ImageResolution.Name(request.resolution).removeprefix("IMG_RESOLUTION_").lower() + ) if request.user: attributes["user_id"] = request.user @@ -103,15 +144,29 @@ def _make_span_response_attributes( if should_disable_sensitive_attributes(): return attributes + # All of these attributes are the same for all images in this response. + if responses: + usage = responses[0].usage + attributes["gen_ai.usage.input_tokens"] = usage.prompt_tokens + attributes["gen_ai.usage.output_tokens"] = usage.completion_tokens + attributes["gen_ai.usage.total_tokens"] = usage.total_tokens + attributes["gen_ai.usage.reasoning_tokens"] = usage.reasoning_tokens + attributes["gen_ai.usage.cached_prompt_text_tokens"] = usage.cached_prompt_text_tokens + attributes["gen_ai.usage.prompt_text_tokens"] = usage.prompt_text_tokens + attributes["gen_ai.usage.prompt_image_tokens"] = usage.prompt_image_tokens + attributes["gen_ai.response.image.format"] = ( image_pb2.ImageFormat.Name(request.format).removeprefix("IMG_FORMAT_").lower() ) for index, response in enumerate(responses): attributes[f"gen_ai.response.{index}.image.up_sampled_prompt"] = response.prompt + attributes[f"gen_ai.response.{index}.image.respect_moderation"] = response.respect_moderation if request.format == image_pb2.ImageFormat.IMG_FORMAT_URL: - attributes[f"gen_ai.response.{index}.image.url"] = response.url + if response._image.url: + attributes[f"gen_ai.response.{index}.image.url"] = response._image.url elif request.format == image_pb2.ImageFormat.IMG_FORMAT_BASE64: - attributes[f"gen_ai.response.{index}.image.base64"] = response.base64 + if response._image.base64: + attributes[f"gen_ai.response.{index}.image.base64"] = response._image.base64 return attributes @@ -125,3 +180,29 @@ def convert_image_format_to_pb(image_format: ImageFormat) -> image_pb2.ImageForm return image_pb2.ImageFormat.IMG_FORMAT_URL case _: raise ValueError(f"Invalid image format {image_format}.") + + +def convert_image_aspect_ratio_to_pb(aspect_ratio: ImageAspectRatio) -> image_pb2.ImageAspectRatio: + """Converts a string literal representation of an image aspect ratio to its protobuf enum variant.""" + try: + return _IMAGE_ASPECT_RATIO_MAP[aspect_ratio] + except KeyError as exc: + raise ValueError(f"Invalid image aspect ratio {aspect_ratio}.") from exc + + +def _format_image_aspect_ratio(aspect_ratio: image_pb2.ImageAspectRatio) -> str: + """Formats the protobuf enum into the public string form (e.g. '9:19.5').""" + name = image_pb2.ImageAspectRatio.Name(aspect_ratio).removeprefix("IMG_ASPECT_RATIO_") + if name == "AUTO": + return "auto" + # Protobuf encodes the "19.5" ratio portion as "19_5". + return name.replace("19_5", "19.5").replace("_", ":") + + +def convert_image_resolution_to_pb(resolution: ImageResolution) -> image_pb2.ImageResolution: + """Converts a string literal representation of an image resolution to its protobuf enum variant.""" + match resolution: + case "1k": + return image_pb2.ImageResolution.IMG_RESOLUTION_1K + case _: + raise ValueError(f"Invalid image resolution {resolution}.") diff --git a/src/xai_sdk/proto/__init__.py b/src/xai_sdk/proto/__init__.py index 308b2cd..e604b3d 100644 --- a/src/xai_sdk/proto/__init__.py +++ b/src/xai_sdk/proto/__init__.py @@ -31,6 +31,8 @@ tokenize_pb2_grpc, types_pb2, usage_pb2, + video_pb2, + video_pb2_grpc, ) elif version.parse(google.protobuf.__version__).major == 6: from .v6 import ( @@ -61,6 +63,8 @@ tokenize_pb2_grpc, types_pb2, usage_pb2, + video_pb2, + video_pb2_grpc, ) else: raise ValueError(f"Unsupported protobuf version: {google.protobuf.__version__}") \ No newline at end of file diff --git a/src/xai_sdk/proto/v5/image_pb2.py b/src/xai_sdk/proto/v5/image_pb2.py index eaeeea8..b89b90e 100644 --- a/src/xai_sdk/proto/v5/image_pb2.py +++ b/src/xai_sdk/proto/v5/image_pb2.py @@ -22,9 +22,10 @@ _sym_db = _symbol_database.Default() +from . import usage_pb2 as xai_dot_api_dot_v1_dot_usage__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16xai/api/v1/image.proto\x12\x07xai_api\"\xd5\x01\n\x14GenerateImageRequest\x12\x16\n\x06prompt\x18\x01 \x01(\tR\x06prompt\x12.\n\x05image\x18\x05 \x01(\x0b\x32\x18.xai_api.ImageUrlContentR\x05image\x12\x14\n\x05model\x18\x02 \x01(\tR\x05model\x12\x11\n\x01n\x18\x03 \x01(\x05H\x00R\x01n\x88\x01\x01\x12\x12\n\x04user\x18\x04 \x01(\tR\x04user\x12,\n\x06\x66ormat\x18\x0b \x01(\x0e\x32\x14.xai_api.ImageFormatR\x06\x66ormatB\x04\n\x02_nJ\x04\x08\r\x10\x0e\"V\n\rImageResponse\x12/\n\x06images\x18\x01 \x03(\x0b\x32\x17.xai_api.GeneratedImageR\x06images\x12\x14\n\x05model\x18\x02 \x01(\tR\x05model\"\xa2\x01\n\x0eGeneratedImage\x12\x18\n\x06\x62\x61se64\x18\x01 \x01(\tH\x00R\x06\x62\x61se64\x12\x12\n\x03url\x18\x03 \x01(\tH\x00R\x03url\x12*\n\x11up_sampled_prompt\x18\x02 \x01(\tR\x0fupSampledPrompt\x12-\n\x12respect_moderation\x18\x04 \x01(\x08R\x11respectModerationB\x07\n\x05image\"\\\n\x0fImageUrlContent\x12\x1b\n\timage_url\x18\x01 \x01(\tR\x08imageUrl\x12,\n\x06\x64\x65tail\x18\x02 \x01(\x0e\x32\x14.xai_api.ImageDetailR\x06\x64\x65tail*S\n\x0bImageDetail\x12\x12\n\x0e\x44\x45TAIL_INVALID\x10\x00\x12\x0f\n\x0b\x44\x45TAIL_AUTO\x10\x01\x12\x0e\n\nDETAIL_LOW\x10\x02\x12\x0f\n\x0b\x44\x45TAIL_HIGH\x10\x03*P\n\x0bImageFormat\x12\x16\n\x12IMG_FORMAT_INVALID\x10\x00\x12\x15\n\x11IMG_FORMAT_BASE64\x10\x01\x12\x12\n\x0eIMG_FORMAT_URL\x10\x02\x32Q\n\x05Image\x12H\n\rGenerateImage\x12\x1d.xai_api.GenerateImageRequest\x1a\x16.xai_api.ImageResponse\"\x00\x42Q\n\x0b\x63om.xai_apiB\nImageProtoP\x01\xa2\x02\x03XXX\xaa\x02\x06XaiApi\xca\x02\x06XaiApi\xe2\x02\x12XaiApi\\GPBMetadata\xea\x02\x06XaiApib\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16xai/api/v1/image.proto\x12\x07xai_api\x1a\x16xai/api/v1/usage.proto\"\xf7\x02\n\x14GenerateImageRequest\x12\x16\n\x06prompt\x18\x01 \x01(\tR\x06prompt\x12.\n\x05image\x18\x05 \x01(\x0b\x32\x18.xai_api.ImageUrlContentR\x05image\x12\x14\n\x05model\x18\x02 \x01(\tR\x05model\x12\x11\n\x01n\x18\x03 \x01(\x05H\x00R\x01n\x88\x01\x01\x12\x12\n\x04user\x18\x04 \x01(\tR\x04user\x12,\n\x06\x66ormat\x18\x0b \x01(\x0e\x32\x14.xai_api.ImageFormatR\x06\x66ormat\x12\x41\n\x0c\x61spect_ratio\x18\x0e \x01(\x0e\x32\x19.xai_api.ImageAspectRatioH\x01R\x0b\x61spectRatio\x88\x01\x01\x12=\n\nresolution\x18\x0f \x01(\x0e\x32\x18.xai_api.ImageResolutionH\x02R\nresolution\x88\x01\x01\x42\x04\n\x02_nB\x0f\n\r_aspect_ratioB\r\n\x0b_resolutionJ\x04\x08\r\x10\x0e\"\x84\x01\n\rImageResponse\x12/\n\x06images\x18\x01 \x03(\x0b\x32\x17.xai_api.GeneratedImageR\x06images\x12\x14\n\x05model\x18\x02 \x01(\tR\x05model\x12,\n\x05usage\x18\x03 \x01(\x0b\x32\x16.xai_api.SamplingUsageR\x05usage\"\xa2\x01\n\x0eGeneratedImage\x12\x18\n\x06\x62\x61se64\x18\x01 \x01(\tH\x00R\x06\x62\x61se64\x12\x12\n\x03url\x18\x03 \x01(\tH\x00R\x03url\x12*\n\x11up_sampled_prompt\x18\x02 \x01(\tR\x0fupSampledPrompt\x12-\n\x12respect_moderation\x18\x04 \x01(\x08R\x11respectModerationB\x07\n\x05image\"\\\n\x0fImageUrlContent\x12\x1b\n\timage_url\x18\x01 \x01(\tR\x08imageUrl\x12,\n\x06\x64\x65tail\x18\x02 \x01(\x0e\x32\x14.xai_api.ImageDetailR\x06\x64\x65tail*S\n\x0bImageDetail\x12\x12\n\x0e\x44\x45TAIL_INVALID\x10\x00\x12\x0f\n\x0b\x44\x45TAIL_AUTO\x10\x01\x12\x0e\n\nDETAIL_LOW\x10\x02\x12\x0f\n\x0b\x44\x45TAIL_HIGH\x10\x03*P\n\x0bImageFormat\x12\x16\n\x12IMG_FORMAT_INVALID\x10\x00\x12\x15\n\x11IMG_FORMAT_BASE64\x10\x01\x12\x12\n\x0eIMG_FORMAT_URL\x10\x02*j\n\x0cImageQuality\x12\x17\n\x13IMG_QUALITY_INVALID\x10\x00\x12\x13\n\x0fIMG_QUALITY_LOW\x10\x01\x12\x16\n\x12IMG_QUALITY_MEDIUM\x10\x02\x12\x14\n\x10IMG_QUALITY_HIGH\x10\x03*\xa7\x03\n\x10ImageAspectRatio\x12\x1c\n\x18IMG_ASPECT_RATIO_INVALID\x10\x00\x12\x18\n\x14IMG_ASPECT_RATIO_1_1\x10\x01\x12\x18\n\x14IMG_ASPECT_RATIO_3_4\x10\x02\x12\x18\n\x14IMG_ASPECT_RATIO_4_3\x10\x03\x12\x19\n\x15IMG_ASPECT_RATIO_9_16\x10\x04\x12\x19\n\x15IMG_ASPECT_RATIO_16_9\x10\x05\x12\x18\n\x14IMG_ASPECT_RATIO_2_3\x10\x06\x12\x18\n\x14IMG_ASPECT_RATIO_3_2\x10\x07\x12\x19\n\x15IMG_ASPECT_RATIO_AUTO\x10\x08\x12\x1b\n\x17IMG_ASPECT_RATIO_9_19_5\x10\t\x12\x1b\n\x17IMG_ASPECT_RATIO_19_5_9\x10\n\x12\x19\n\x15IMG_ASPECT_RATIO_9_20\x10\x0b\x12\x19\n\x15IMG_ASPECT_RATIO_20_9\x10\x0c\x12\x18\n\x14IMG_ASPECT_RATIO_1_2\x10\r\x12\x18\n\x14IMG_ASPECT_RATIO_2_1\x10\x0e*D\n\x0fImageResolution\x12\x1a\n\x16IMG_RESOLUTION_INVALID\x10\x00\x12\x15\n\x11IMG_RESOLUTION_1K\x10\x01\x32Q\n\x05Image\x12H\n\rGenerateImage\x12\x1d.xai_api.GenerateImageRequest\x1a\x16.xai_api.ImageResponse\"\x00\x42Q\n\x0b\x63om.xai_apiB\nImageProtoP\x01\xa2\x02\x03XXX\xaa\x02\x06XaiApi\xca\x02\x06XaiApi\xe2\x02\x12XaiApi\\GPBMetadata\xea\x02\x06XaiApib\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -32,18 +33,24 @@ if not _descriptor._USE_C_DESCRIPTORS: _globals['DESCRIPTOR']._loaded_options = None _globals['DESCRIPTOR']._serialized_options = b'\n\013com.xai_apiB\nImageProtoP\001\242\002\003XXX\252\002\006XaiApi\312\002\006XaiApi\342\002\022XaiApi\\GPBMetadata\352\002\006XaiApi' - _globals['_IMAGEDETAIL']._serialized_start=598 - _globals['_IMAGEDETAIL']._serialized_end=681 - _globals['_IMAGEFORMAT']._serialized_start=683 - _globals['_IMAGEFORMAT']._serialized_end=763 - _globals['_GENERATEIMAGEREQUEST']._serialized_start=36 - _globals['_GENERATEIMAGEREQUEST']._serialized_end=249 - _globals['_IMAGERESPONSE']._serialized_start=251 - _globals['_IMAGERESPONSE']._serialized_end=337 - _globals['_GENERATEDIMAGE']._serialized_start=340 - _globals['_GENERATEDIMAGE']._serialized_end=502 - _globals['_IMAGEURLCONTENT']._serialized_start=504 - _globals['_IMAGEURLCONTENT']._serialized_end=596 - _globals['_IMAGE']._serialized_start=765 - _globals['_IMAGE']._serialized_end=846 + _globals['_IMAGEDETAIL']._serialized_start=831 + _globals['_IMAGEDETAIL']._serialized_end=914 + _globals['_IMAGEFORMAT']._serialized_start=916 + _globals['_IMAGEFORMAT']._serialized_end=996 + _globals['_IMAGEQUALITY']._serialized_start=998 + _globals['_IMAGEQUALITY']._serialized_end=1104 + _globals['_IMAGEASPECTRATIO']._serialized_start=1107 + _globals['_IMAGEASPECTRATIO']._serialized_end=1530 + _globals['_IMAGERESOLUTION']._serialized_start=1532 + _globals['_IMAGERESOLUTION']._serialized_end=1600 + _globals['_GENERATEIMAGEREQUEST']._serialized_start=60 + _globals['_GENERATEIMAGEREQUEST']._serialized_end=435 + _globals['_IMAGERESPONSE']._serialized_start=438 + _globals['_IMAGERESPONSE']._serialized_end=570 + _globals['_GENERATEDIMAGE']._serialized_start=573 + _globals['_GENERATEDIMAGE']._serialized_end=735 + _globals['_IMAGEURLCONTENT']._serialized_start=737 + _globals['_IMAGEURLCONTENT']._serialized_end=829 + _globals['_IMAGE']._serialized_start=1602 + _globals['_IMAGE']._serialized_end=1683 # @@protoc_insertion_point(module_scope) diff --git a/src/xai_sdk/proto/v5/image_pb2.pyi b/src/xai_sdk/proto/v5/image_pb2.pyi index ac1eecd..3c9a97c 100644 --- a/src/xai_sdk/proto/v5/image_pb2.pyi +++ b/src/xai_sdk/proto/v5/image_pb2.pyi @@ -1,3 +1,4 @@ +from . import usage_pb2 as _usage_pb2 from google.protobuf.internal import containers as _containers from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper from google.protobuf import descriptor as _descriptor @@ -18,6 +19,36 @@ class ImageFormat(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): IMG_FORMAT_INVALID: _ClassVar[ImageFormat] IMG_FORMAT_BASE64: _ClassVar[ImageFormat] IMG_FORMAT_URL: _ClassVar[ImageFormat] + +class ImageQuality(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + IMG_QUALITY_INVALID: _ClassVar[ImageQuality] + IMG_QUALITY_LOW: _ClassVar[ImageQuality] + IMG_QUALITY_MEDIUM: _ClassVar[ImageQuality] + IMG_QUALITY_HIGH: _ClassVar[ImageQuality] + +class ImageAspectRatio(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + IMG_ASPECT_RATIO_INVALID: _ClassVar[ImageAspectRatio] + IMG_ASPECT_RATIO_1_1: _ClassVar[ImageAspectRatio] + IMG_ASPECT_RATIO_3_4: _ClassVar[ImageAspectRatio] + IMG_ASPECT_RATIO_4_3: _ClassVar[ImageAspectRatio] + IMG_ASPECT_RATIO_9_16: _ClassVar[ImageAspectRatio] + IMG_ASPECT_RATIO_16_9: _ClassVar[ImageAspectRatio] + IMG_ASPECT_RATIO_2_3: _ClassVar[ImageAspectRatio] + IMG_ASPECT_RATIO_3_2: _ClassVar[ImageAspectRatio] + IMG_ASPECT_RATIO_AUTO: _ClassVar[ImageAspectRatio] + IMG_ASPECT_RATIO_9_19_5: _ClassVar[ImageAspectRatio] + IMG_ASPECT_RATIO_19_5_9: _ClassVar[ImageAspectRatio] + IMG_ASPECT_RATIO_9_20: _ClassVar[ImageAspectRatio] + IMG_ASPECT_RATIO_20_9: _ClassVar[ImageAspectRatio] + IMG_ASPECT_RATIO_1_2: _ClassVar[ImageAspectRatio] + IMG_ASPECT_RATIO_2_1: _ClassVar[ImageAspectRatio] + +class ImageResolution(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + IMG_RESOLUTION_INVALID: _ClassVar[ImageResolution] + IMG_RESOLUTION_1K: _ClassVar[ImageResolution] DETAIL_INVALID: ImageDetail DETAIL_AUTO: ImageDetail DETAIL_LOW: ImageDetail @@ -25,30 +56,57 @@ DETAIL_HIGH: ImageDetail IMG_FORMAT_INVALID: ImageFormat IMG_FORMAT_BASE64: ImageFormat IMG_FORMAT_URL: ImageFormat +IMG_QUALITY_INVALID: ImageQuality +IMG_QUALITY_LOW: ImageQuality +IMG_QUALITY_MEDIUM: ImageQuality +IMG_QUALITY_HIGH: ImageQuality +IMG_ASPECT_RATIO_INVALID: ImageAspectRatio +IMG_ASPECT_RATIO_1_1: ImageAspectRatio +IMG_ASPECT_RATIO_3_4: ImageAspectRatio +IMG_ASPECT_RATIO_4_3: ImageAspectRatio +IMG_ASPECT_RATIO_9_16: ImageAspectRatio +IMG_ASPECT_RATIO_16_9: ImageAspectRatio +IMG_ASPECT_RATIO_2_3: ImageAspectRatio +IMG_ASPECT_RATIO_3_2: ImageAspectRatio +IMG_ASPECT_RATIO_AUTO: ImageAspectRatio +IMG_ASPECT_RATIO_9_19_5: ImageAspectRatio +IMG_ASPECT_RATIO_19_5_9: ImageAspectRatio +IMG_ASPECT_RATIO_9_20: ImageAspectRatio +IMG_ASPECT_RATIO_20_9: ImageAspectRatio +IMG_ASPECT_RATIO_1_2: ImageAspectRatio +IMG_ASPECT_RATIO_2_1: ImageAspectRatio +IMG_RESOLUTION_INVALID: ImageResolution +IMG_RESOLUTION_1K: ImageResolution class GenerateImageRequest(_message.Message): - __slots__ = ("prompt", "image", "model", "n", "user", "format") + __slots__ = ("prompt", "image", "model", "n", "user", "format", "aspect_ratio", "resolution") PROMPT_FIELD_NUMBER: _ClassVar[int] IMAGE_FIELD_NUMBER: _ClassVar[int] MODEL_FIELD_NUMBER: _ClassVar[int] N_FIELD_NUMBER: _ClassVar[int] USER_FIELD_NUMBER: _ClassVar[int] FORMAT_FIELD_NUMBER: _ClassVar[int] + ASPECT_RATIO_FIELD_NUMBER: _ClassVar[int] + RESOLUTION_FIELD_NUMBER: _ClassVar[int] prompt: str image: ImageUrlContent model: str n: int user: str format: ImageFormat - def __init__(self, prompt: _Optional[str] = ..., image: _Optional[_Union[ImageUrlContent, _Mapping]] = ..., model: _Optional[str] = ..., n: _Optional[int] = ..., user: _Optional[str] = ..., format: _Optional[_Union[ImageFormat, str]] = ...) -> None: ... + aspect_ratio: ImageAspectRatio + resolution: ImageResolution + def __init__(self, prompt: _Optional[str] = ..., image: _Optional[_Union[ImageUrlContent, _Mapping]] = ..., model: _Optional[str] = ..., n: _Optional[int] = ..., user: _Optional[str] = ..., format: _Optional[_Union[ImageFormat, str]] = ..., aspect_ratio: _Optional[_Union[ImageAspectRatio, str]] = ..., resolution: _Optional[_Union[ImageResolution, str]] = ...) -> None: ... class ImageResponse(_message.Message): - __slots__ = ("images", "model") + __slots__ = ("images", "model", "usage") IMAGES_FIELD_NUMBER: _ClassVar[int] MODEL_FIELD_NUMBER: _ClassVar[int] + USAGE_FIELD_NUMBER: _ClassVar[int] images: _containers.RepeatedCompositeFieldContainer[GeneratedImage] model: str - def __init__(self, images: _Optional[_Iterable[_Union[GeneratedImage, _Mapping]]] = ..., model: _Optional[str] = ...) -> None: ... + usage: _usage_pb2.SamplingUsage + def __init__(self, images: _Optional[_Iterable[_Union[GeneratedImage, _Mapping]]] = ..., model: _Optional[str] = ..., usage: _Optional[_Union[_usage_pb2.SamplingUsage, _Mapping]] = ...) -> None: ... class GeneratedImage(_message.Message): __slots__ = ("base64", "url", "up_sampled_prompt", "respect_moderation") diff --git a/src/xai_sdk/proto/v5/video_pb2.py b/src/xai_sdk/proto/v5/video_pb2.py new file mode 100644 index 0000000..9f5b200 --- /dev/null +++ b/src/xai_sdk/proto/v5/video_pb2.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: xai/api/v1/video.proto +# Protobuf Python Version: 5.29.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 5, + 29, + 1, + '', + 'xai/api/v1/video.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from . import image_pb2 as xai_dot_api_dot_v1_dot_image__pb2 +from . import deferred_pb2 as xai_dot_api_dot_v1_dot_deferred__pb2 +from . import usage_pb2 as xai_dot_api_dot_v1_dot_usage__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16xai/api/v1/video.proto\x12\x07xai_api\x1a\x16xai/api/v1/image.proto\x1a\x19xai/api/v1/deferred.proto\x1a\x16xai/api/v1/usage.proto\"#\n\x0fVideoUrlContent\x12\x10\n\x03url\x18\x01 \x01(\tR\x03url\",\n\x0bVideoOutput\x12\x1d\n\nupload_url\x18\x01 \x01(\tR\tuploadUrl\"\xf4\x02\n\x14GenerateVideoRequest\x12\x16\n\x06prompt\x18\x01 \x01(\tR\x06prompt\x12.\n\x05image\x18\x02 \x01(\x0b\x32\x18.xai_api.ImageUrlContentR\x05image\x12\x14\n\x05model\x18\x03 \x01(\tR\x05model\x12\x1f\n\x08\x64uration\x18\x04 \x01(\x05H\x00R\x08\x64uration\x88\x01\x01\x12.\n\x05video\x18\x06 \x01(\x0b\x32\x18.xai_api.VideoUrlContentR\x05video\x12\x41\n\x0c\x61spect_ratio\x18\x07 \x01(\x0e\x32\x19.xai_api.VideoAspectRatioH\x01R\x0b\x61spectRatio\x88\x01\x01\x12=\n\nresolution\x18\x08 \x01(\x0e\x32\x18.xai_api.VideoResolutionH\x02R\nresolution\x88\x01\x01\x42\x0b\n\t_durationB\x0f\n\r_aspect_ratioB\r\n\x0b_resolution\"8\n\x17GetDeferredVideoRequest\x12\x1d\n\nrequest_id\x18\x01 \x01(\tR\trequestId\"\x82\x01\n\rVideoResponse\x12-\n\x05video\x18\x01 \x01(\x0b\x32\x17.xai_api.GeneratedVideoR\x05video\x12\x14\n\x05model\x18\x02 \x01(\tR\x05model\x12,\n\x05usage\x18\x03 \x01(\x0b\x32\x16.xai_api.SamplingUsageR\x05usage\"m\n\x0eGeneratedVideo\x12\x10\n\x03url\x18\x01 \x01(\tR\x03url\x12\x1a\n\x08\x64uration\x18\x04 \x01(\x05R\x08\x64uration\x12-\n\x12respect_moderation\x18\x05 \x01(\x08R\x11respectModeration\"\x91\x01\n\x18GetDeferredVideoResponse\x12/\n\x06status\x18\x01 \x01(\x0e\x32\x17.xai_api.DeferredStatusR\x06status\x12\x37\n\x08response\x18\x02 \x01(\x0b\x32\x16.xai_api.VideoResponseH\x00R\x08response\x88\x01\x01\x42\x0b\n\t_response*\xfc\x01\n\x10VideoAspectRatio\x12\"\n\x1eVIDEO_ASPECT_RATIO_UNSPECIFIED\x10\x00\x12\x1a\n\x16VIDEO_ASPECT_RATIO_1_1\x10\x01\x12\x1b\n\x17VIDEO_ASPECT_RATIO_16_9\x10\x02\x12\x1b\n\x17VIDEO_ASPECT_RATIO_9_16\x10\x03\x12\x1a\n\x16VIDEO_ASPECT_RATIO_4_3\x10\x04\x12\x1a\n\x16VIDEO_ASPECT_RATIO_3_4\x10\x05\x12\x1a\n\x16VIDEO_ASPECT_RATIO_3_2\x10\x06\x12\x1a\n\x16VIDEO_ASPECT_RATIO_2_3\x10\x07*i\n\x0fVideoResolution\x12 \n\x1cVIDEO_RESOLUTION_UNSPECIFIED\x10\x00\x12\x19\n\x15VIDEO_RESOLUTION_480P\x10\x01\x12\x19\n\x15VIDEO_RESOLUTION_720P\x10\x02\x32\xb4\x01\n\x05Video\x12P\n\rGenerateVideo\x12\x1d.xai_api.GenerateVideoRequest\x1a\x1e.xai_api.StartDeferredResponse\"\x00\x12Y\n\x10GetDeferredVideo\x12 .xai_api.GetDeferredVideoRequest\x1a!.xai_api.GetDeferredVideoResponse\"\x00\x42Q\n\x0b\x63om.xai_apiB\nVideoProtoP\x01\xa2\x02\x03XXX\xaa\x02\x06XaiApi\xca\x02\x06XaiApi\xe2\x02\x12XaiApi\\GPBMetadata\xea\x02\x06XaiApib\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'xai.api.v1.video_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + _globals['DESCRIPTOR']._loaded_options = None + _globals['DESCRIPTOR']._serialized_options = b'\n\013com.xai_apiB\nVideoProtoP\001\242\002\003XXX\252\002\006XaiApi\312\002\006XaiApi\342\002\022XaiApi\\GPBMetadata\352\002\006XaiApi' + _globals['_VIDEOASPECTRATIO']._serialized_start=1019 + _globals['_VIDEOASPECTRATIO']._serialized_end=1271 + _globals['_VIDEORESOLUTION']._serialized_start=1273 + _globals['_VIDEORESOLUTION']._serialized_end=1378 + _globals['_VIDEOURLCONTENT']._serialized_start=110 + _globals['_VIDEOURLCONTENT']._serialized_end=145 + _globals['_VIDEOOUTPUT']._serialized_start=147 + _globals['_VIDEOOUTPUT']._serialized_end=191 + _globals['_GENERATEVIDEOREQUEST']._serialized_start=194 + _globals['_GENERATEVIDEOREQUEST']._serialized_end=566 + _globals['_GETDEFERREDVIDEOREQUEST']._serialized_start=568 + _globals['_GETDEFERREDVIDEOREQUEST']._serialized_end=624 + _globals['_VIDEORESPONSE']._serialized_start=627 + _globals['_VIDEORESPONSE']._serialized_end=757 + _globals['_GENERATEDVIDEO']._serialized_start=759 + _globals['_GENERATEDVIDEO']._serialized_end=868 + _globals['_GETDEFERREDVIDEORESPONSE']._serialized_start=871 + _globals['_GETDEFERREDVIDEORESPONSE']._serialized_end=1016 + _globals['_VIDEO']._serialized_start=1381 + _globals['_VIDEO']._serialized_end=1561 +# @@protoc_insertion_point(module_scope) diff --git a/src/xai_sdk/proto/v5/video_pb2.pyi b/src/xai_sdk/proto/v5/video_pb2.pyi new file mode 100644 index 0000000..71c8ea6 --- /dev/null +++ b/src/xai_sdk/proto/v5/video_pb2.pyi @@ -0,0 +1,101 @@ +from . import image_pb2 as _image_pb2 +from . import deferred_pb2 as _deferred_pb2 +from . import usage_pb2 as _usage_pb2 +from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Mapping as _Mapping, Optional as _Optional, Union as _Union + +DESCRIPTOR: _descriptor.FileDescriptor + +class VideoAspectRatio(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + VIDEO_ASPECT_RATIO_UNSPECIFIED: _ClassVar[VideoAspectRatio] + VIDEO_ASPECT_RATIO_1_1: _ClassVar[VideoAspectRatio] + VIDEO_ASPECT_RATIO_16_9: _ClassVar[VideoAspectRatio] + VIDEO_ASPECT_RATIO_9_16: _ClassVar[VideoAspectRatio] + VIDEO_ASPECT_RATIO_4_3: _ClassVar[VideoAspectRatio] + VIDEO_ASPECT_RATIO_3_4: _ClassVar[VideoAspectRatio] + VIDEO_ASPECT_RATIO_3_2: _ClassVar[VideoAspectRatio] + VIDEO_ASPECT_RATIO_2_3: _ClassVar[VideoAspectRatio] + +class VideoResolution(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + VIDEO_RESOLUTION_UNSPECIFIED: _ClassVar[VideoResolution] + VIDEO_RESOLUTION_480P: _ClassVar[VideoResolution] + VIDEO_RESOLUTION_720P: _ClassVar[VideoResolution] +VIDEO_ASPECT_RATIO_UNSPECIFIED: VideoAspectRatio +VIDEO_ASPECT_RATIO_1_1: VideoAspectRatio +VIDEO_ASPECT_RATIO_16_9: VideoAspectRatio +VIDEO_ASPECT_RATIO_9_16: VideoAspectRatio +VIDEO_ASPECT_RATIO_4_3: VideoAspectRatio +VIDEO_ASPECT_RATIO_3_4: VideoAspectRatio +VIDEO_ASPECT_RATIO_3_2: VideoAspectRatio +VIDEO_ASPECT_RATIO_2_3: VideoAspectRatio +VIDEO_RESOLUTION_UNSPECIFIED: VideoResolution +VIDEO_RESOLUTION_480P: VideoResolution +VIDEO_RESOLUTION_720P: VideoResolution + +class VideoUrlContent(_message.Message): + __slots__ = ("url",) + URL_FIELD_NUMBER: _ClassVar[int] + url: str + def __init__(self, url: _Optional[str] = ...) -> None: ... + +class VideoOutput(_message.Message): + __slots__ = ("upload_url",) + UPLOAD_URL_FIELD_NUMBER: _ClassVar[int] + upload_url: str + def __init__(self, upload_url: _Optional[str] = ...) -> None: ... + +class GenerateVideoRequest(_message.Message): + __slots__ = ("prompt", "image", "model", "duration", "video", "aspect_ratio", "resolution") + PROMPT_FIELD_NUMBER: _ClassVar[int] + IMAGE_FIELD_NUMBER: _ClassVar[int] + MODEL_FIELD_NUMBER: _ClassVar[int] + DURATION_FIELD_NUMBER: _ClassVar[int] + VIDEO_FIELD_NUMBER: _ClassVar[int] + ASPECT_RATIO_FIELD_NUMBER: _ClassVar[int] + RESOLUTION_FIELD_NUMBER: _ClassVar[int] + prompt: str + image: _image_pb2.ImageUrlContent + model: str + duration: int + video: VideoUrlContent + aspect_ratio: VideoAspectRatio + resolution: VideoResolution + def __init__(self, prompt: _Optional[str] = ..., image: _Optional[_Union[_image_pb2.ImageUrlContent, _Mapping]] = ..., model: _Optional[str] = ..., duration: _Optional[int] = ..., video: _Optional[_Union[VideoUrlContent, _Mapping]] = ..., aspect_ratio: _Optional[_Union[VideoAspectRatio, str]] = ..., resolution: _Optional[_Union[VideoResolution, str]] = ...) -> None: ... + +class GetDeferredVideoRequest(_message.Message): + __slots__ = ("request_id",) + REQUEST_ID_FIELD_NUMBER: _ClassVar[int] + request_id: str + def __init__(self, request_id: _Optional[str] = ...) -> None: ... + +class VideoResponse(_message.Message): + __slots__ = ("video", "model", "usage") + VIDEO_FIELD_NUMBER: _ClassVar[int] + MODEL_FIELD_NUMBER: _ClassVar[int] + USAGE_FIELD_NUMBER: _ClassVar[int] + video: GeneratedVideo + model: str + usage: _usage_pb2.SamplingUsage + def __init__(self, video: _Optional[_Union[GeneratedVideo, _Mapping]] = ..., model: _Optional[str] = ..., usage: _Optional[_Union[_usage_pb2.SamplingUsage, _Mapping]] = ...) -> None: ... + +class GeneratedVideo(_message.Message): + __slots__ = ("url", "duration", "respect_moderation") + URL_FIELD_NUMBER: _ClassVar[int] + DURATION_FIELD_NUMBER: _ClassVar[int] + RESPECT_MODERATION_FIELD_NUMBER: _ClassVar[int] + url: str + duration: int + respect_moderation: bool + def __init__(self, url: _Optional[str] = ..., duration: _Optional[int] = ..., respect_moderation: bool = ...) -> None: ... + +class GetDeferredVideoResponse(_message.Message): + __slots__ = ("status", "response") + STATUS_FIELD_NUMBER: _ClassVar[int] + RESPONSE_FIELD_NUMBER: _ClassVar[int] + status: _deferred_pb2.DeferredStatus + response: VideoResponse + def __init__(self, status: _Optional[_Union[_deferred_pb2.DeferredStatus, str]] = ..., response: _Optional[_Union[VideoResponse, _Mapping]] = ...) -> None: ... diff --git a/src/xai_sdk/proto/v5/video_pb2_grpc.py b/src/xai_sdk/proto/v5/video_pb2_grpc.py new file mode 100644 index 0000000..c123b72 --- /dev/null +++ b/src/xai_sdk/proto/v5/video_pb2_grpc.py @@ -0,0 +1,131 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +from . import deferred_pb2 as xai_dot_api_dot_v1_dot_deferred__pb2 +from . import video_pb2 as xai_dot_api_dot_v1_dot_video__pb2 + + +class VideoStub(object): + """An API service for interaction with video generation models. + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.GenerateVideo = channel.unary_unary( + '/xai_api.Video/GenerateVideo', + request_serializer=xai_dot_api_dot_v1_dot_video__pb2.GenerateVideoRequest.SerializeToString, + response_deserializer=xai_dot_api_dot_v1_dot_deferred__pb2.StartDeferredResponse.FromString, + _registered_method=True) + self.GetDeferredVideo = channel.unary_unary( + '/xai_api.Video/GetDeferredVideo', + request_serializer=xai_dot_api_dot_v1_dot_video__pb2.GetDeferredVideoRequest.SerializeToString, + response_deserializer=xai_dot_api_dot_v1_dot_video__pb2.GetDeferredVideoResponse.FromString, + _registered_method=True) + + +class VideoServicer(object): + """An API service for interaction with video generation models. + """ + + def GenerateVideo(self, request, context): + """Create a video based on a text prompt and optionally an image. + If an image is provided, generates video with the image as the first frame (image-to-video). + If no image is provided, generates video from text only (text-to-video). + + This is an asynchronous operation. The method returns immediately with a request_id + that can be used to poll for the result using GetDeferredVideo. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetDeferredVideo(self, request, context): + """Gets the result of a video generation started by calling `GenerateVideo`. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_VideoServicer_to_server(servicer, server): + rpc_method_handlers = { + 'GenerateVideo': grpc.unary_unary_rpc_method_handler( + servicer.GenerateVideo, + request_deserializer=xai_dot_api_dot_v1_dot_video__pb2.GenerateVideoRequest.FromString, + response_serializer=xai_dot_api_dot_v1_dot_deferred__pb2.StartDeferredResponse.SerializeToString, + ), + 'GetDeferredVideo': grpc.unary_unary_rpc_method_handler( + servicer.GetDeferredVideo, + request_deserializer=xai_dot_api_dot_v1_dot_video__pb2.GetDeferredVideoRequest.FromString, + response_serializer=xai_dot_api_dot_v1_dot_video__pb2.GetDeferredVideoResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'xai_api.Video', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('xai_api.Video', rpc_method_handlers) + + + # This class is part of an EXPERIMENTAL API. +class Video(object): + """An API service for interaction with video generation models. + """ + + @staticmethod + def GenerateVideo(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/xai_api.Video/GenerateVideo', + xai_dot_api_dot_v1_dot_video__pb2.GenerateVideoRequest.SerializeToString, + xai_dot_api_dot_v1_dot_deferred__pb2.StartDeferredResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def GetDeferredVideo(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/xai_api.Video/GetDeferredVideo', + xai_dot_api_dot_v1_dot_video__pb2.GetDeferredVideoRequest.SerializeToString, + xai_dot_api_dot_v1_dot_video__pb2.GetDeferredVideoResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/src/xai_sdk/proto/v6/image_pb2.py b/src/xai_sdk/proto/v6/image_pb2.py index 0f1e481..5a8d14e 100644 --- a/src/xai_sdk/proto/v6/image_pb2.py +++ b/src/xai_sdk/proto/v6/image_pb2.py @@ -22,9 +22,10 @@ _sym_db = _symbol_database.Default() +from . import usage_pb2 as xai_dot_api_dot_v1_dot_usage__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16xai/api/v1/image.proto\x12\x07xai_api\"\xd5\x01\n\x14GenerateImageRequest\x12\x16\n\x06prompt\x18\x01 \x01(\tR\x06prompt\x12.\n\x05image\x18\x05 \x01(\x0b\x32\x18.xai_api.ImageUrlContentR\x05image\x12\x14\n\x05model\x18\x02 \x01(\tR\x05model\x12\x11\n\x01n\x18\x03 \x01(\x05H\x00R\x01n\x88\x01\x01\x12\x12\n\x04user\x18\x04 \x01(\tR\x04user\x12,\n\x06\x66ormat\x18\x0b \x01(\x0e\x32\x14.xai_api.ImageFormatR\x06\x66ormatB\x04\n\x02_nJ\x04\x08\r\x10\x0e\"V\n\rImageResponse\x12/\n\x06images\x18\x01 \x03(\x0b\x32\x17.xai_api.GeneratedImageR\x06images\x12\x14\n\x05model\x18\x02 \x01(\tR\x05model\"\xa2\x01\n\x0eGeneratedImage\x12\x18\n\x06\x62\x61se64\x18\x01 \x01(\tH\x00R\x06\x62\x61se64\x12\x12\n\x03url\x18\x03 \x01(\tH\x00R\x03url\x12*\n\x11up_sampled_prompt\x18\x02 \x01(\tR\x0fupSampledPrompt\x12-\n\x12respect_moderation\x18\x04 \x01(\x08R\x11respectModerationB\x07\n\x05image\"\\\n\x0fImageUrlContent\x12\x1b\n\timage_url\x18\x01 \x01(\tR\x08imageUrl\x12,\n\x06\x64\x65tail\x18\x02 \x01(\x0e\x32\x14.xai_api.ImageDetailR\x06\x64\x65tail*S\n\x0bImageDetail\x12\x12\n\x0e\x44\x45TAIL_INVALID\x10\x00\x12\x0f\n\x0b\x44\x45TAIL_AUTO\x10\x01\x12\x0e\n\nDETAIL_LOW\x10\x02\x12\x0f\n\x0b\x44\x45TAIL_HIGH\x10\x03*P\n\x0bImageFormat\x12\x16\n\x12IMG_FORMAT_INVALID\x10\x00\x12\x15\n\x11IMG_FORMAT_BASE64\x10\x01\x12\x12\n\x0eIMG_FORMAT_URL\x10\x02\x32Q\n\x05Image\x12H\n\rGenerateImage\x12\x1d.xai_api.GenerateImageRequest\x1a\x16.xai_api.ImageResponse\"\x00\x42Q\n\x0b\x63om.xai_apiB\nImageProtoP\x01\xa2\x02\x03XXX\xaa\x02\x06XaiApi\xca\x02\x06XaiApi\xe2\x02\x12XaiApi\\GPBMetadata\xea\x02\x06XaiApib\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16xai/api/v1/image.proto\x12\x07xai_api\x1a\x16xai/api/v1/usage.proto\"\xf7\x02\n\x14GenerateImageRequest\x12\x16\n\x06prompt\x18\x01 \x01(\tR\x06prompt\x12.\n\x05image\x18\x05 \x01(\x0b\x32\x18.xai_api.ImageUrlContentR\x05image\x12\x14\n\x05model\x18\x02 \x01(\tR\x05model\x12\x11\n\x01n\x18\x03 \x01(\x05H\x00R\x01n\x88\x01\x01\x12\x12\n\x04user\x18\x04 \x01(\tR\x04user\x12,\n\x06\x66ormat\x18\x0b \x01(\x0e\x32\x14.xai_api.ImageFormatR\x06\x66ormat\x12\x41\n\x0c\x61spect_ratio\x18\x0e \x01(\x0e\x32\x19.xai_api.ImageAspectRatioH\x01R\x0b\x61spectRatio\x88\x01\x01\x12=\n\nresolution\x18\x0f \x01(\x0e\x32\x18.xai_api.ImageResolutionH\x02R\nresolution\x88\x01\x01\x42\x04\n\x02_nB\x0f\n\r_aspect_ratioB\r\n\x0b_resolutionJ\x04\x08\r\x10\x0e\"\x84\x01\n\rImageResponse\x12/\n\x06images\x18\x01 \x03(\x0b\x32\x17.xai_api.GeneratedImageR\x06images\x12\x14\n\x05model\x18\x02 \x01(\tR\x05model\x12,\n\x05usage\x18\x03 \x01(\x0b\x32\x16.xai_api.SamplingUsageR\x05usage\"\xa2\x01\n\x0eGeneratedImage\x12\x18\n\x06\x62\x61se64\x18\x01 \x01(\tH\x00R\x06\x62\x61se64\x12\x12\n\x03url\x18\x03 \x01(\tH\x00R\x03url\x12*\n\x11up_sampled_prompt\x18\x02 \x01(\tR\x0fupSampledPrompt\x12-\n\x12respect_moderation\x18\x04 \x01(\x08R\x11respectModerationB\x07\n\x05image\"\\\n\x0fImageUrlContent\x12\x1b\n\timage_url\x18\x01 \x01(\tR\x08imageUrl\x12,\n\x06\x64\x65tail\x18\x02 \x01(\x0e\x32\x14.xai_api.ImageDetailR\x06\x64\x65tail*S\n\x0bImageDetail\x12\x12\n\x0e\x44\x45TAIL_INVALID\x10\x00\x12\x0f\n\x0b\x44\x45TAIL_AUTO\x10\x01\x12\x0e\n\nDETAIL_LOW\x10\x02\x12\x0f\n\x0b\x44\x45TAIL_HIGH\x10\x03*P\n\x0bImageFormat\x12\x16\n\x12IMG_FORMAT_INVALID\x10\x00\x12\x15\n\x11IMG_FORMAT_BASE64\x10\x01\x12\x12\n\x0eIMG_FORMAT_URL\x10\x02*j\n\x0cImageQuality\x12\x17\n\x13IMG_QUALITY_INVALID\x10\x00\x12\x13\n\x0fIMG_QUALITY_LOW\x10\x01\x12\x16\n\x12IMG_QUALITY_MEDIUM\x10\x02\x12\x14\n\x10IMG_QUALITY_HIGH\x10\x03*\xa7\x03\n\x10ImageAspectRatio\x12\x1c\n\x18IMG_ASPECT_RATIO_INVALID\x10\x00\x12\x18\n\x14IMG_ASPECT_RATIO_1_1\x10\x01\x12\x18\n\x14IMG_ASPECT_RATIO_3_4\x10\x02\x12\x18\n\x14IMG_ASPECT_RATIO_4_3\x10\x03\x12\x19\n\x15IMG_ASPECT_RATIO_9_16\x10\x04\x12\x19\n\x15IMG_ASPECT_RATIO_16_9\x10\x05\x12\x18\n\x14IMG_ASPECT_RATIO_2_3\x10\x06\x12\x18\n\x14IMG_ASPECT_RATIO_3_2\x10\x07\x12\x19\n\x15IMG_ASPECT_RATIO_AUTO\x10\x08\x12\x1b\n\x17IMG_ASPECT_RATIO_9_19_5\x10\t\x12\x1b\n\x17IMG_ASPECT_RATIO_19_5_9\x10\n\x12\x19\n\x15IMG_ASPECT_RATIO_9_20\x10\x0b\x12\x19\n\x15IMG_ASPECT_RATIO_20_9\x10\x0c\x12\x18\n\x14IMG_ASPECT_RATIO_1_2\x10\r\x12\x18\n\x14IMG_ASPECT_RATIO_2_1\x10\x0e*D\n\x0fImageResolution\x12\x1a\n\x16IMG_RESOLUTION_INVALID\x10\x00\x12\x15\n\x11IMG_RESOLUTION_1K\x10\x01\x32Q\n\x05Image\x12H\n\rGenerateImage\x12\x1d.xai_api.GenerateImageRequest\x1a\x16.xai_api.ImageResponse\"\x00\x42Q\n\x0b\x63om.xai_apiB\nImageProtoP\x01\xa2\x02\x03XXX\xaa\x02\x06XaiApi\xca\x02\x06XaiApi\xe2\x02\x12XaiApi\\GPBMetadata\xea\x02\x06XaiApib\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -32,18 +33,24 @@ if not _descriptor._USE_C_DESCRIPTORS: _globals['DESCRIPTOR']._loaded_options = None _globals['DESCRIPTOR']._serialized_options = b'\n\013com.xai_apiB\nImageProtoP\001\242\002\003XXX\252\002\006XaiApi\312\002\006XaiApi\342\002\022XaiApi\\GPBMetadata\352\002\006XaiApi' - _globals['_IMAGEDETAIL']._serialized_start=598 - _globals['_IMAGEDETAIL']._serialized_end=681 - _globals['_IMAGEFORMAT']._serialized_start=683 - _globals['_IMAGEFORMAT']._serialized_end=763 - _globals['_GENERATEIMAGEREQUEST']._serialized_start=36 - _globals['_GENERATEIMAGEREQUEST']._serialized_end=249 - _globals['_IMAGERESPONSE']._serialized_start=251 - _globals['_IMAGERESPONSE']._serialized_end=337 - _globals['_GENERATEDIMAGE']._serialized_start=340 - _globals['_GENERATEDIMAGE']._serialized_end=502 - _globals['_IMAGEURLCONTENT']._serialized_start=504 - _globals['_IMAGEURLCONTENT']._serialized_end=596 - _globals['_IMAGE']._serialized_start=765 - _globals['_IMAGE']._serialized_end=846 + _globals['_IMAGEDETAIL']._serialized_start=831 + _globals['_IMAGEDETAIL']._serialized_end=914 + _globals['_IMAGEFORMAT']._serialized_start=916 + _globals['_IMAGEFORMAT']._serialized_end=996 + _globals['_IMAGEQUALITY']._serialized_start=998 + _globals['_IMAGEQUALITY']._serialized_end=1104 + _globals['_IMAGEASPECTRATIO']._serialized_start=1107 + _globals['_IMAGEASPECTRATIO']._serialized_end=1530 + _globals['_IMAGERESOLUTION']._serialized_start=1532 + _globals['_IMAGERESOLUTION']._serialized_end=1600 + _globals['_GENERATEIMAGEREQUEST']._serialized_start=60 + _globals['_GENERATEIMAGEREQUEST']._serialized_end=435 + _globals['_IMAGERESPONSE']._serialized_start=438 + _globals['_IMAGERESPONSE']._serialized_end=570 + _globals['_GENERATEDIMAGE']._serialized_start=573 + _globals['_GENERATEDIMAGE']._serialized_end=735 + _globals['_IMAGEURLCONTENT']._serialized_start=737 + _globals['_IMAGEURLCONTENT']._serialized_end=829 + _globals['_IMAGE']._serialized_start=1602 + _globals['_IMAGE']._serialized_end=1683 # @@protoc_insertion_point(module_scope) diff --git a/src/xai_sdk/proto/v6/image_pb2.pyi b/src/xai_sdk/proto/v6/image_pb2.pyi index 52d1605..7a0fa59 100644 --- a/src/xai_sdk/proto/v6/image_pb2.pyi +++ b/src/xai_sdk/proto/v6/image_pb2.pyi @@ -1,3 +1,4 @@ +from . import usage_pb2 as _usage_pb2 from google.protobuf.internal import containers as _containers from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper from google.protobuf import descriptor as _descriptor @@ -19,6 +20,36 @@ class ImageFormat(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): IMG_FORMAT_INVALID: _ClassVar[ImageFormat] IMG_FORMAT_BASE64: _ClassVar[ImageFormat] IMG_FORMAT_URL: _ClassVar[ImageFormat] + +class ImageQuality(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + IMG_QUALITY_INVALID: _ClassVar[ImageQuality] + IMG_QUALITY_LOW: _ClassVar[ImageQuality] + IMG_QUALITY_MEDIUM: _ClassVar[ImageQuality] + IMG_QUALITY_HIGH: _ClassVar[ImageQuality] + +class ImageAspectRatio(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + IMG_ASPECT_RATIO_INVALID: _ClassVar[ImageAspectRatio] + IMG_ASPECT_RATIO_1_1: _ClassVar[ImageAspectRatio] + IMG_ASPECT_RATIO_3_4: _ClassVar[ImageAspectRatio] + IMG_ASPECT_RATIO_4_3: _ClassVar[ImageAspectRatio] + IMG_ASPECT_RATIO_9_16: _ClassVar[ImageAspectRatio] + IMG_ASPECT_RATIO_16_9: _ClassVar[ImageAspectRatio] + IMG_ASPECT_RATIO_2_3: _ClassVar[ImageAspectRatio] + IMG_ASPECT_RATIO_3_2: _ClassVar[ImageAspectRatio] + IMG_ASPECT_RATIO_AUTO: _ClassVar[ImageAspectRatio] + IMG_ASPECT_RATIO_9_19_5: _ClassVar[ImageAspectRatio] + IMG_ASPECT_RATIO_19_5_9: _ClassVar[ImageAspectRatio] + IMG_ASPECT_RATIO_9_20: _ClassVar[ImageAspectRatio] + IMG_ASPECT_RATIO_20_9: _ClassVar[ImageAspectRatio] + IMG_ASPECT_RATIO_1_2: _ClassVar[ImageAspectRatio] + IMG_ASPECT_RATIO_2_1: _ClassVar[ImageAspectRatio] + +class ImageResolution(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + IMG_RESOLUTION_INVALID: _ClassVar[ImageResolution] + IMG_RESOLUTION_1K: _ClassVar[ImageResolution] DETAIL_INVALID: ImageDetail DETAIL_AUTO: ImageDetail DETAIL_LOW: ImageDetail @@ -26,30 +57,57 @@ DETAIL_HIGH: ImageDetail IMG_FORMAT_INVALID: ImageFormat IMG_FORMAT_BASE64: ImageFormat IMG_FORMAT_URL: ImageFormat +IMG_QUALITY_INVALID: ImageQuality +IMG_QUALITY_LOW: ImageQuality +IMG_QUALITY_MEDIUM: ImageQuality +IMG_QUALITY_HIGH: ImageQuality +IMG_ASPECT_RATIO_INVALID: ImageAspectRatio +IMG_ASPECT_RATIO_1_1: ImageAspectRatio +IMG_ASPECT_RATIO_3_4: ImageAspectRatio +IMG_ASPECT_RATIO_4_3: ImageAspectRatio +IMG_ASPECT_RATIO_9_16: ImageAspectRatio +IMG_ASPECT_RATIO_16_9: ImageAspectRatio +IMG_ASPECT_RATIO_2_3: ImageAspectRatio +IMG_ASPECT_RATIO_3_2: ImageAspectRatio +IMG_ASPECT_RATIO_AUTO: ImageAspectRatio +IMG_ASPECT_RATIO_9_19_5: ImageAspectRatio +IMG_ASPECT_RATIO_19_5_9: ImageAspectRatio +IMG_ASPECT_RATIO_9_20: ImageAspectRatio +IMG_ASPECT_RATIO_20_9: ImageAspectRatio +IMG_ASPECT_RATIO_1_2: ImageAspectRatio +IMG_ASPECT_RATIO_2_1: ImageAspectRatio +IMG_RESOLUTION_INVALID: ImageResolution +IMG_RESOLUTION_1K: ImageResolution class GenerateImageRequest(_message.Message): - __slots__ = ("prompt", "image", "model", "n", "user", "format") + __slots__ = ("prompt", "image", "model", "n", "user", "format", "aspect_ratio", "resolution") PROMPT_FIELD_NUMBER: _ClassVar[int] IMAGE_FIELD_NUMBER: _ClassVar[int] MODEL_FIELD_NUMBER: _ClassVar[int] N_FIELD_NUMBER: _ClassVar[int] USER_FIELD_NUMBER: _ClassVar[int] FORMAT_FIELD_NUMBER: _ClassVar[int] + ASPECT_RATIO_FIELD_NUMBER: _ClassVar[int] + RESOLUTION_FIELD_NUMBER: _ClassVar[int] prompt: str image: ImageUrlContent model: str n: int user: str format: ImageFormat - def __init__(self, prompt: _Optional[str] = ..., image: _Optional[_Union[ImageUrlContent, _Mapping]] = ..., model: _Optional[str] = ..., n: _Optional[int] = ..., user: _Optional[str] = ..., format: _Optional[_Union[ImageFormat, str]] = ...) -> None: ... + aspect_ratio: ImageAspectRatio + resolution: ImageResolution + def __init__(self, prompt: _Optional[str] = ..., image: _Optional[_Union[ImageUrlContent, _Mapping]] = ..., model: _Optional[str] = ..., n: _Optional[int] = ..., user: _Optional[str] = ..., format: _Optional[_Union[ImageFormat, str]] = ..., aspect_ratio: _Optional[_Union[ImageAspectRatio, str]] = ..., resolution: _Optional[_Union[ImageResolution, str]] = ...) -> None: ... class ImageResponse(_message.Message): - __slots__ = ("images", "model") + __slots__ = ("images", "model", "usage") IMAGES_FIELD_NUMBER: _ClassVar[int] MODEL_FIELD_NUMBER: _ClassVar[int] + USAGE_FIELD_NUMBER: _ClassVar[int] images: _containers.RepeatedCompositeFieldContainer[GeneratedImage] model: str - def __init__(self, images: _Optional[_Iterable[_Union[GeneratedImage, _Mapping]]] = ..., model: _Optional[str] = ...) -> None: ... + usage: _usage_pb2.SamplingUsage + def __init__(self, images: _Optional[_Iterable[_Union[GeneratedImage, _Mapping]]] = ..., model: _Optional[str] = ..., usage: _Optional[_Union[_usage_pb2.SamplingUsage, _Mapping]] = ...) -> None: ... class GeneratedImage(_message.Message): __slots__ = ("base64", "url", "up_sampled_prompt", "respect_moderation") diff --git a/src/xai_sdk/proto/v6/video_pb2.py b/src/xai_sdk/proto/v6/video_pb2.py new file mode 100644 index 0000000..a97dba8 --- /dev/null +++ b/src/xai_sdk/proto/v6/video_pb2.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: xai/api/v1/video.proto +# Protobuf Python Version: 6.30.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 6, + 30, + 0, + '', + 'xai/api/v1/video.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from . import image_pb2 as xai_dot_api_dot_v1_dot_image__pb2 +from . import deferred_pb2 as xai_dot_api_dot_v1_dot_deferred__pb2 +from . import usage_pb2 as xai_dot_api_dot_v1_dot_usage__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16xai/api/v1/video.proto\x12\x07xai_api\x1a\x16xai/api/v1/image.proto\x1a\x19xai/api/v1/deferred.proto\x1a\x16xai/api/v1/usage.proto\"#\n\x0fVideoUrlContent\x12\x10\n\x03url\x18\x01 \x01(\tR\x03url\",\n\x0bVideoOutput\x12\x1d\n\nupload_url\x18\x01 \x01(\tR\tuploadUrl\"\xf4\x02\n\x14GenerateVideoRequest\x12\x16\n\x06prompt\x18\x01 \x01(\tR\x06prompt\x12.\n\x05image\x18\x02 \x01(\x0b\x32\x18.xai_api.ImageUrlContentR\x05image\x12\x14\n\x05model\x18\x03 \x01(\tR\x05model\x12\x1f\n\x08\x64uration\x18\x04 \x01(\x05H\x00R\x08\x64uration\x88\x01\x01\x12.\n\x05video\x18\x06 \x01(\x0b\x32\x18.xai_api.VideoUrlContentR\x05video\x12\x41\n\x0c\x61spect_ratio\x18\x07 \x01(\x0e\x32\x19.xai_api.VideoAspectRatioH\x01R\x0b\x61spectRatio\x88\x01\x01\x12=\n\nresolution\x18\x08 \x01(\x0e\x32\x18.xai_api.VideoResolutionH\x02R\nresolution\x88\x01\x01\x42\x0b\n\t_durationB\x0f\n\r_aspect_ratioB\r\n\x0b_resolution\"8\n\x17GetDeferredVideoRequest\x12\x1d\n\nrequest_id\x18\x01 \x01(\tR\trequestId\"\x82\x01\n\rVideoResponse\x12-\n\x05video\x18\x01 \x01(\x0b\x32\x17.xai_api.GeneratedVideoR\x05video\x12\x14\n\x05model\x18\x02 \x01(\tR\x05model\x12,\n\x05usage\x18\x03 \x01(\x0b\x32\x16.xai_api.SamplingUsageR\x05usage\"m\n\x0eGeneratedVideo\x12\x10\n\x03url\x18\x01 \x01(\tR\x03url\x12\x1a\n\x08\x64uration\x18\x04 \x01(\x05R\x08\x64uration\x12-\n\x12respect_moderation\x18\x05 \x01(\x08R\x11respectModeration\"\x91\x01\n\x18GetDeferredVideoResponse\x12/\n\x06status\x18\x01 \x01(\x0e\x32\x17.xai_api.DeferredStatusR\x06status\x12\x37\n\x08response\x18\x02 \x01(\x0b\x32\x16.xai_api.VideoResponseH\x00R\x08response\x88\x01\x01\x42\x0b\n\t_response*\xfc\x01\n\x10VideoAspectRatio\x12\"\n\x1eVIDEO_ASPECT_RATIO_UNSPECIFIED\x10\x00\x12\x1a\n\x16VIDEO_ASPECT_RATIO_1_1\x10\x01\x12\x1b\n\x17VIDEO_ASPECT_RATIO_16_9\x10\x02\x12\x1b\n\x17VIDEO_ASPECT_RATIO_9_16\x10\x03\x12\x1a\n\x16VIDEO_ASPECT_RATIO_4_3\x10\x04\x12\x1a\n\x16VIDEO_ASPECT_RATIO_3_4\x10\x05\x12\x1a\n\x16VIDEO_ASPECT_RATIO_3_2\x10\x06\x12\x1a\n\x16VIDEO_ASPECT_RATIO_2_3\x10\x07*i\n\x0fVideoResolution\x12 \n\x1cVIDEO_RESOLUTION_UNSPECIFIED\x10\x00\x12\x19\n\x15VIDEO_RESOLUTION_480P\x10\x01\x12\x19\n\x15VIDEO_RESOLUTION_720P\x10\x02\x32\xb4\x01\n\x05Video\x12P\n\rGenerateVideo\x12\x1d.xai_api.GenerateVideoRequest\x1a\x1e.xai_api.StartDeferredResponse\"\x00\x12Y\n\x10GetDeferredVideo\x12 .xai_api.GetDeferredVideoRequest\x1a!.xai_api.GetDeferredVideoResponse\"\x00\x42Q\n\x0b\x63om.xai_apiB\nVideoProtoP\x01\xa2\x02\x03XXX\xaa\x02\x06XaiApi\xca\x02\x06XaiApi\xe2\x02\x12XaiApi\\GPBMetadata\xea\x02\x06XaiApib\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'xai.api.v1.video_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + _globals['DESCRIPTOR']._loaded_options = None + _globals['DESCRIPTOR']._serialized_options = b'\n\013com.xai_apiB\nVideoProtoP\001\242\002\003XXX\252\002\006XaiApi\312\002\006XaiApi\342\002\022XaiApi\\GPBMetadata\352\002\006XaiApi' + _globals['_VIDEOASPECTRATIO']._serialized_start=1019 + _globals['_VIDEOASPECTRATIO']._serialized_end=1271 + _globals['_VIDEORESOLUTION']._serialized_start=1273 + _globals['_VIDEORESOLUTION']._serialized_end=1378 + _globals['_VIDEOURLCONTENT']._serialized_start=110 + _globals['_VIDEOURLCONTENT']._serialized_end=145 + _globals['_VIDEOOUTPUT']._serialized_start=147 + _globals['_VIDEOOUTPUT']._serialized_end=191 + _globals['_GENERATEVIDEOREQUEST']._serialized_start=194 + _globals['_GENERATEVIDEOREQUEST']._serialized_end=566 + _globals['_GETDEFERREDVIDEOREQUEST']._serialized_start=568 + _globals['_GETDEFERREDVIDEOREQUEST']._serialized_end=624 + _globals['_VIDEORESPONSE']._serialized_start=627 + _globals['_VIDEORESPONSE']._serialized_end=757 + _globals['_GENERATEDVIDEO']._serialized_start=759 + _globals['_GENERATEDVIDEO']._serialized_end=868 + _globals['_GETDEFERREDVIDEORESPONSE']._serialized_start=871 + _globals['_GETDEFERREDVIDEORESPONSE']._serialized_end=1016 + _globals['_VIDEO']._serialized_start=1381 + _globals['_VIDEO']._serialized_end=1561 +# @@protoc_insertion_point(module_scope) diff --git a/src/xai_sdk/proto/v6/video_pb2.pyi b/src/xai_sdk/proto/v6/video_pb2.pyi new file mode 100644 index 0000000..8d803f2 --- /dev/null +++ b/src/xai_sdk/proto/v6/video_pb2.pyi @@ -0,0 +1,102 @@ +from . import image_pb2 as _image_pb2 +from . import deferred_pb2 as _deferred_pb2 +from . import usage_pb2 as _usage_pb2 +from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from collections.abc import Mapping as _Mapping +from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union + +DESCRIPTOR: _descriptor.FileDescriptor + +class VideoAspectRatio(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + VIDEO_ASPECT_RATIO_UNSPECIFIED: _ClassVar[VideoAspectRatio] + VIDEO_ASPECT_RATIO_1_1: _ClassVar[VideoAspectRatio] + VIDEO_ASPECT_RATIO_16_9: _ClassVar[VideoAspectRatio] + VIDEO_ASPECT_RATIO_9_16: _ClassVar[VideoAspectRatio] + VIDEO_ASPECT_RATIO_4_3: _ClassVar[VideoAspectRatio] + VIDEO_ASPECT_RATIO_3_4: _ClassVar[VideoAspectRatio] + VIDEO_ASPECT_RATIO_3_2: _ClassVar[VideoAspectRatio] + VIDEO_ASPECT_RATIO_2_3: _ClassVar[VideoAspectRatio] + +class VideoResolution(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = () + VIDEO_RESOLUTION_UNSPECIFIED: _ClassVar[VideoResolution] + VIDEO_RESOLUTION_480P: _ClassVar[VideoResolution] + VIDEO_RESOLUTION_720P: _ClassVar[VideoResolution] +VIDEO_ASPECT_RATIO_UNSPECIFIED: VideoAspectRatio +VIDEO_ASPECT_RATIO_1_1: VideoAspectRatio +VIDEO_ASPECT_RATIO_16_9: VideoAspectRatio +VIDEO_ASPECT_RATIO_9_16: VideoAspectRatio +VIDEO_ASPECT_RATIO_4_3: VideoAspectRatio +VIDEO_ASPECT_RATIO_3_4: VideoAspectRatio +VIDEO_ASPECT_RATIO_3_2: VideoAspectRatio +VIDEO_ASPECT_RATIO_2_3: VideoAspectRatio +VIDEO_RESOLUTION_UNSPECIFIED: VideoResolution +VIDEO_RESOLUTION_480P: VideoResolution +VIDEO_RESOLUTION_720P: VideoResolution + +class VideoUrlContent(_message.Message): + __slots__ = ("url",) + URL_FIELD_NUMBER: _ClassVar[int] + url: str + def __init__(self, url: _Optional[str] = ...) -> None: ... + +class VideoOutput(_message.Message): + __slots__ = ("upload_url",) + UPLOAD_URL_FIELD_NUMBER: _ClassVar[int] + upload_url: str + def __init__(self, upload_url: _Optional[str] = ...) -> None: ... + +class GenerateVideoRequest(_message.Message): + __slots__ = ("prompt", "image", "model", "duration", "video", "aspect_ratio", "resolution") + PROMPT_FIELD_NUMBER: _ClassVar[int] + IMAGE_FIELD_NUMBER: _ClassVar[int] + MODEL_FIELD_NUMBER: _ClassVar[int] + DURATION_FIELD_NUMBER: _ClassVar[int] + VIDEO_FIELD_NUMBER: _ClassVar[int] + ASPECT_RATIO_FIELD_NUMBER: _ClassVar[int] + RESOLUTION_FIELD_NUMBER: _ClassVar[int] + prompt: str + image: _image_pb2.ImageUrlContent + model: str + duration: int + video: VideoUrlContent + aspect_ratio: VideoAspectRatio + resolution: VideoResolution + def __init__(self, prompt: _Optional[str] = ..., image: _Optional[_Union[_image_pb2.ImageUrlContent, _Mapping]] = ..., model: _Optional[str] = ..., duration: _Optional[int] = ..., video: _Optional[_Union[VideoUrlContent, _Mapping]] = ..., aspect_ratio: _Optional[_Union[VideoAspectRatio, str]] = ..., resolution: _Optional[_Union[VideoResolution, str]] = ...) -> None: ... + +class GetDeferredVideoRequest(_message.Message): + __slots__ = ("request_id",) + REQUEST_ID_FIELD_NUMBER: _ClassVar[int] + request_id: str + def __init__(self, request_id: _Optional[str] = ...) -> None: ... + +class VideoResponse(_message.Message): + __slots__ = ("video", "model", "usage") + VIDEO_FIELD_NUMBER: _ClassVar[int] + MODEL_FIELD_NUMBER: _ClassVar[int] + USAGE_FIELD_NUMBER: _ClassVar[int] + video: GeneratedVideo + model: str + usage: _usage_pb2.SamplingUsage + def __init__(self, video: _Optional[_Union[GeneratedVideo, _Mapping]] = ..., model: _Optional[str] = ..., usage: _Optional[_Union[_usage_pb2.SamplingUsage, _Mapping]] = ...) -> None: ... + +class GeneratedVideo(_message.Message): + __slots__ = ("url", "duration", "respect_moderation") + URL_FIELD_NUMBER: _ClassVar[int] + DURATION_FIELD_NUMBER: _ClassVar[int] + RESPECT_MODERATION_FIELD_NUMBER: _ClassVar[int] + url: str + duration: int + respect_moderation: bool + def __init__(self, url: _Optional[str] = ..., duration: _Optional[int] = ..., respect_moderation: bool = ...) -> None: ... + +class GetDeferredVideoResponse(_message.Message): + __slots__ = ("status", "response") + STATUS_FIELD_NUMBER: _ClassVar[int] + RESPONSE_FIELD_NUMBER: _ClassVar[int] + status: _deferred_pb2.DeferredStatus + response: VideoResponse + def __init__(self, status: _Optional[_Union[_deferred_pb2.DeferredStatus, str]] = ..., response: _Optional[_Union[VideoResponse, _Mapping]] = ...) -> None: ... diff --git a/src/xai_sdk/proto/v6/video_pb2_grpc.py b/src/xai_sdk/proto/v6/video_pb2_grpc.py new file mode 100644 index 0000000..c123b72 --- /dev/null +++ b/src/xai_sdk/proto/v6/video_pb2_grpc.py @@ -0,0 +1,131 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc + +from . import deferred_pb2 as xai_dot_api_dot_v1_dot_deferred__pb2 +from . import video_pb2 as xai_dot_api_dot_v1_dot_video__pb2 + + +class VideoStub(object): + """An API service for interaction with video generation models. + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.GenerateVideo = channel.unary_unary( + '/xai_api.Video/GenerateVideo', + request_serializer=xai_dot_api_dot_v1_dot_video__pb2.GenerateVideoRequest.SerializeToString, + response_deserializer=xai_dot_api_dot_v1_dot_deferred__pb2.StartDeferredResponse.FromString, + _registered_method=True) + self.GetDeferredVideo = channel.unary_unary( + '/xai_api.Video/GetDeferredVideo', + request_serializer=xai_dot_api_dot_v1_dot_video__pb2.GetDeferredVideoRequest.SerializeToString, + response_deserializer=xai_dot_api_dot_v1_dot_video__pb2.GetDeferredVideoResponse.FromString, + _registered_method=True) + + +class VideoServicer(object): + """An API service for interaction with video generation models. + """ + + def GenerateVideo(self, request, context): + """Create a video based on a text prompt and optionally an image. + If an image is provided, generates video with the image as the first frame (image-to-video). + If no image is provided, generates video from text only (text-to-video). + + This is an asynchronous operation. The method returns immediately with a request_id + that can be used to poll for the result using GetDeferredVideo. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def GetDeferredVideo(self, request, context): + """Gets the result of a video generation started by calling `GenerateVideo`. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_VideoServicer_to_server(servicer, server): + rpc_method_handlers = { + 'GenerateVideo': grpc.unary_unary_rpc_method_handler( + servicer.GenerateVideo, + request_deserializer=xai_dot_api_dot_v1_dot_video__pb2.GenerateVideoRequest.FromString, + response_serializer=xai_dot_api_dot_v1_dot_deferred__pb2.StartDeferredResponse.SerializeToString, + ), + 'GetDeferredVideo': grpc.unary_unary_rpc_method_handler( + servicer.GetDeferredVideo, + request_deserializer=xai_dot_api_dot_v1_dot_video__pb2.GetDeferredVideoRequest.FromString, + response_serializer=xai_dot_api_dot_v1_dot_video__pb2.GetDeferredVideoResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'xai_api.Video', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('xai_api.Video', rpc_method_handlers) + + + # This class is part of an EXPERIMENTAL API. +class Video(object): + """An API service for interaction with video generation models. + """ + + @staticmethod + def GenerateVideo(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/xai_api.Video/GenerateVideo', + xai_dot_api_dot_v1_dot_video__pb2.GenerateVideoRequest.SerializeToString, + xai_dot_api_dot_v1_dot_deferred__pb2.StartDeferredResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def GetDeferredVideo(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/xai_api.Video/GetDeferredVideo', + xai_dot_api_dot_v1_dot_video__pb2.GetDeferredVideoRequest.SerializeToString, + xai_dot_api_dot_v1_dot_video__pb2.GetDeferredVideoResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/src/xai_sdk/sync/__init__.py b/src/xai_sdk/sync/__init__.py index a366f8a..6bf0c19 100644 --- a/src/xai_sdk/sync/__init__.py +++ b/src/xai_sdk/sync/__init__.py @@ -1,3 +1,14 @@ -from . import auth, batch, chat, client, collections, image, models, tokenizer +from . import auth, batch, chat, client, collections, files, image, models, tokenizer, video -__all__ = ["auth", "batch", "chat", "client", "collections", "image", "models", "tokenizer"] +__all__ = [ + "auth", + "batch", + "chat", + "client", + "collections", + "files", + "image", + "models", + "tokenizer", + "video", +] diff --git a/src/xai_sdk/sync/client.py b/src/xai_sdk/sync/client.py index 96d706a..e8f393e 100644 --- a/src/xai_sdk/sync/client.py +++ b/src/xai_sdk/sync/client.py @@ -8,7 +8,7 @@ create_channel_credentials, ) from ..interceptors import AuthInterceptor, TimeoutInterceptor -from . import auth, batch, chat, collections, files, image, models, tokenizer +from . import auth, batch, chat, collections, files, image, models, tokenizer, video class Client(BaseClient): @@ -22,6 +22,7 @@ class Client(BaseClient): image: "image.Client" models: "models.Client" tokenize: "tokenizer.Client" + video: "video.Client" def _init( self, @@ -72,6 +73,7 @@ def _init( self.image = image.Client(self._api_channel) self.models = models.Client(self._api_channel) self.tokenize = tokenizer.Client(self._api_channel) + self.video = video.Client(self._api_channel) def _make_grpc_channel( self, diff --git a/src/xai_sdk/sync/image.py b/src/xai_sdk/sync/image.py index 39b4cf9..b9ba05a 100644 --- a/src/xai_sdk/sync/image.py +++ b/src/xai_sdk/sync/image.py @@ -7,10 +7,14 @@ from ..image import ( BaseClient, BaseImageResponse, + ImageAspectRatio, ImageFormat, + ImageResolution, _make_span_request_attributes, _make_span_response_attributes, + convert_image_aspect_ratio_to_pb, convert_image_format_to_pb, + convert_image_resolution_to_pb, ) from ..proto import image_pb2 from ..telemetry import get_tracer @@ -26,19 +30,42 @@ def sample( prompt: str, model: str, *, + image_url: Optional[str] = None, user: Optional[str] = None, image_format: Optional[ImageFormat] = None, + aspect_ratio: Optional[ImageAspectRatio] = None, + resolution: Optional[ImageResolution] = None, ) -> "ImageResponse": """Samples a single image based on the provided prompt. Args: prompt: The prompt to generate an image from. model: The model to use for image generation. + image_url: The URL or base64-encoded string of an input image to use as a starting point for generation. + Only supported for grok-imagine models. user: A unique identifier representing your end-user, which can help xAI to monitor and detect abuse. image_format: The format of the image to return. One of: - `"url"`: The image is returned as a URL. - `"base64"`: The image is returned as a base64-encoded string. defaults to `"url"` if not specified. + aspect_ratio: The aspect ratio of the image to generate. One of: + - `"1:1"` + - `"16:9"` + - `"9:16"` + - `"4:3"` + - `"3:4"` + - `"3:2"` + - `"2:3"` + - `"2:1"` + - `"1:2"` + - `"20:9"` + - `"9:20"` + - `"19.5:9"` + - `"9:19.5"` + Only supported for grok-imagine models. + resolution: The image resolution to generate. One of: + - `"1k"`: ~1 megapixel total. Dimensions vary by aspect ratio. + Only supported for grok-imagine models. Returns: An `ImageResponse` object allowing access to the generated image. @@ -51,6 +78,17 @@ def sample( n=1, format=convert_image_format_to_pb(image_format), ) + if image_url is not None: + request.image.CopyFrom( + image_pb2.ImageUrlContent( + image_url=image_url, + detail=image_pb2.ImageDetail.DETAIL_AUTO, + ) + ) + if aspect_ratio is not None: + request.aspect_ratio = convert_image_aspect_ratio_to_pb(aspect_ratio) + if resolution is not None: + request.resolution = convert_image_resolution_to_pb(resolution) with tracer.start_as_current_span( name=f"image.sample {model}", @@ -68,8 +106,11 @@ def sample_batch( model: str, n: int, *, + image_url: Optional[str] = None, user: Optional[str] = None, image_format: Optional[ImageFormat] = None, + aspect_ratio: Optional[ImageAspectRatio] = None, + resolution: Optional[ImageResolution] = None, ) -> Sequence["ImageResponse"]: """Samples a batch of images based on the provided prompt. @@ -77,11 +118,31 @@ def sample_batch( prompt: The prompt to generate an image from. model: The model to use for image generation. n: The number of images to generate. + image_url: The URL or base64-encoded string of an input image to use as a starting point for generation. + Only supported for grok-imagine models. user: A unique identifier representing your end-user, which can help xAI to monitor and detect abuse. image_format: The format of the image to return. One of: - `"url"`: The image is returned as a URL. - `"base64"`: The image is returned as a base64-encoded string. defaults to `"url"` if not specified. + aspect_ratio: The aspect ratio of the image to generate. One of: + - `"1:1"` + - `"16:9"` + - `"9:16"` + - `"4:3"` + - `"3:4"` + - `"3:2"` + - `"2:3"` + - `"2:1"` + - `"1:2"` + - `"20:9"` + - `"9:20"` + - `"19.5:9"` + - `"9:19.5"` + Only supported for grok-imagine models. + resolution: The image resolution to generate. One of: + - `"1k"`: ~1 megapixel total. Dimensions vary by aspect ratio. + Only supported for grok-imagine models. Returns: A sequence of `ImageResponse` objects, one for each image generated. @@ -94,6 +155,17 @@ def sample_batch( n=n, format=convert_image_format_to_pb(image_format), ) + if image_url is not None: + request.image.CopyFrom( + image_pb2.ImageUrlContent( + image_url=image_url, + detail=image_pb2.ImageDetail.DETAIL_AUTO, + ) + ) + if aspect_ratio is not None: + request.aspect_ratio = convert_image_aspect_ratio_to_pb(aspect_ratio) + if resolution is not None: + request.resolution = convert_image_resolution_to_pb(resolution) with tracer.start_as_current_span( name=f"image.sample_batch {model}", diff --git a/src/xai_sdk/sync/video.py b/src/xai_sdk/sync/video.py new file mode 100644 index 0000000..4999311 --- /dev/null +++ b/src/xai_sdk/sync/video.py @@ -0,0 +1,112 @@ +import datetime +import time +from typing import Optional + +from opentelemetry.trace import SpanKind + +from ..poll_timer import PollTimer +from ..proto import deferred_pb2, video_pb2 +from ..telemetry import get_tracer +from ..video import ( + BaseClient, + VideoAspectRatio, + VideoResolution, + VideoResponse, + _make_generate_request, + _make_span_request_attributes, + _make_span_response_attributes, +) + +tracer = get_tracer(__name__) + + +class Client(BaseClient): + """Synchronous client for interacting with the `Video` API.""" + + def start( + self, + prompt: str, + model: str, + *, + image_url: Optional[str] = None, + video_url: Optional[str] = None, + duration: Optional[int] = None, + aspect_ratio: Optional[VideoAspectRatio] = None, + resolution: Optional[VideoResolution] = None, + ) -> deferred_pb2.StartDeferredResponse: + """Starts a video generation request and returns a request_id for polling.""" + request = _make_generate_request( + prompt, + model, + image_url=image_url, + video_url=video_url, + duration=duration, + aspect_ratio=aspect_ratio, + resolution=resolution, + ) + + with tracer.start_as_current_span( + name=f"video.start {model}", + kind=SpanKind.CLIENT, + attributes=_make_span_request_attributes(request), + ): + return self._stub.GenerateVideo(request) + + def get(self, request_id: str) -> video_pb2.GetDeferredVideoResponse: + """Gets the current status (and optional result) for a deferred video request.""" + request = video_pb2.GetDeferredVideoRequest(request_id=request_id) + return self._stub.GetDeferredVideo(request) + + def generate( + self, + prompt: str, + model: str, + *, + image_url: Optional[str] = None, + video_url: Optional[str] = None, + duration: Optional[int] = None, + aspect_ratio: Optional[VideoAspectRatio] = None, + resolution: Optional[VideoResolution] = None, + timeout: Optional[datetime.timedelta] = None, + interval: Optional[datetime.timedelta] = None, + ) -> VideoResponse: + """Generates a video using polling and returns the completed response. + + This wraps `GenerateVideo` + repeated `GetDeferredVideo` calls until the request is complete. + """ + timer = PollTimer(timeout, interval) + request_pb = _make_generate_request( + prompt, + model, + image_url=image_url, + video_url=video_url, + duration=duration, + aspect_ratio=aspect_ratio, + resolution=resolution, + ) + + with tracer.start_as_current_span( + name=f"video.generate {model}", + kind=SpanKind.CLIENT, + attributes=_make_span_request_attributes(request_pb), + ) as span: + start = self._stub.GenerateVideo(request_pb) + + while True: + get_req = video_pb2.GetDeferredVideoRequest(request_id=start.request_id) + + r = self._stub.GetDeferredVideo(get_req) + match r.status: + case deferred_pb2.DeferredStatus.DONE: + if not r.HasField("response"): + raise RuntimeError("Deferred request completed but no response was returned.") + response = VideoResponse(r.response) + span.set_attributes(_make_span_response_attributes(request_pb, response)) + return response + case deferred_pb2.DeferredStatus.EXPIRED: + raise RuntimeError("Deferred request expired.") + case deferred_pb2.DeferredStatus.PENDING: + time.sleep(timer.sleep_interval_or_raise()) + continue + case unknown_status: + raise ValueError(f"Unknown deferred status: {unknown_status}") diff --git a/src/xai_sdk/types/__init__.py b/src/xai_sdk/types/__init__.py index 26dcbc1..da1ace2 100644 --- a/src/xai_sdk/types/__init__.py +++ b/src/xai_sdk/types/__init__.py @@ -7,17 +7,25 @@ ResponseFormat, ToolMode, ) -from .model import AllModels, ChatModel, ImageGenerationModel +from .image import ImageAspectRatio, ImageFormat, ImageResolution +from .model import AllModels, ChatModel, ImageGenerationModel, VideoGenerationModel +from .video import VideoAspectRatio, VideoResolution __all__ = [ "AllModels", "ChatModel", "Content", + "ImageAspectRatio", "ImageDetail", + "ImageFormat", "ImageGenerationModel", + "ImageResolution", "IncludeOption", "IncludeOptionMap", "ReasoningEffort", "ResponseFormat", "ToolMode", + "VideoAspectRatio", + "VideoGenerationModel", + "VideoResolution", ] diff --git a/src/xai_sdk/types/image.py b/src/xai_sdk/types/image.py new file mode 100644 index 0000000..588bc08 --- /dev/null +++ b/src/xai_sdk/types/image.py @@ -0,0 +1,21 @@ +from typing import Literal, TypeAlias + +__all__ = ["ImageAspectRatio", "ImageFormat", "ImageResolution"] + +ImageFormat: TypeAlias = Literal["base64", "url"] +ImageAspectRatio: TypeAlias = Literal[ + "1:1", + "3:4", + "4:3", + "9:16", + "16:9", + "2:3", + "3:2", + "9:19.5", + "19.5:9", + "9:20", + "20:9", + "1:2", + "2:1", +] +ImageResolution: TypeAlias = Literal["1k"] diff --git a/src/xai_sdk/types/model.py b/src/xai_sdk/types/model.py index a2937e8..a2abe99 100644 --- a/src/xai_sdk/types/model.py +++ b/src/xai_sdk/types/model.py @@ -1,6 +1,6 @@ from typing import Literal, TypeAlias, Union -__all__ = ["AllModels", "ChatModel", "ImageGenerationModel"] +__all__ = ["AllModels", "ChatModel", "ImageGenerationModel", "VideoGenerationModel"] ChatModel: TypeAlias = Literal[ "grok-4", @@ -30,10 +30,14 @@ "grok-2-image", "grok-2-image-1212", "grok-2-image-latest", + "grok-imagine-image", ] +VideoGenerationModel: TypeAlias = Literal["grok-imagine-video"] + AllModels: TypeAlias = Union[ ChatModel, ImageGenerationModel, + VideoGenerationModel, str, ] diff --git a/src/xai_sdk/types/video.py b/src/xai_sdk/types/video.py new file mode 100644 index 0000000..3c9067b --- /dev/null +++ b/src/xai_sdk/types/video.py @@ -0,0 +1,39 @@ +from typing import Literal, TypeAlias + +from ..proto import video_pb2 + +__all__ = [ + "VideoAspectRatio", + "VideoAspectRatioMap", + "VideoResolution", + "VideoResolutionMap", +] + +# Aspect ratio for video generation. +VideoAspectRatio: TypeAlias = Literal[ + "1:1", + "16:9", + "9:16", + "4:3", + "3:4", + "3:2", + "2:3", +] + +# Resolution for video generation. +VideoResolution: TypeAlias = Literal["480p", "720p"] + +VideoAspectRatioMap: dict[VideoAspectRatio, "video_pb2.VideoAspectRatio"] = { + "1:1": video_pb2.VideoAspectRatio.VIDEO_ASPECT_RATIO_1_1, + "16:9": video_pb2.VideoAspectRatio.VIDEO_ASPECT_RATIO_16_9, + "9:16": video_pb2.VideoAspectRatio.VIDEO_ASPECT_RATIO_9_16, + "4:3": video_pb2.VideoAspectRatio.VIDEO_ASPECT_RATIO_4_3, + "3:4": video_pb2.VideoAspectRatio.VIDEO_ASPECT_RATIO_3_4, + "3:2": video_pb2.VideoAspectRatio.VIDEO_ASPECT_RATIO_3_2, + "2:3": video_pb2.VideoAspectRatio.VIDEO_ASPECT_RATIO_2_3, +} + +VideoResolutionMap: dict[VideoResolution, "video_pb2.VideoResolution"] = { + "480p": video_pb2.VideoResolution.VIDEO_RESOLUTION_480P, + "720p": video_pb2.VideoResolution.VIDEO_RESOLUTION_720P, +} diff --git a/src/xai_sdk/video.py b/src/xai_sdk/video.py new file mode 100644 index 0000000..2bc3085 --- /dev/null +++ b/src/xai_sdk/video.py @@ -0,0 +1,164 @@ +from typing import Any, Optional, Union + +import grpc + +from .meta import ProtoDecorator +from .proto import image_pb2, usage_pb2, video_pb2, video_pb2_grpc +from .telemetry import should_disable_sensitive_attributes +from .types.video import VideoAspectRatio, VideoAspectRatioMap, VideoResolution, VideoResolutionMap + + +class BaseClient: + """Base Client for interacting with the `Video` API.""" + + _stub: video_pb2_grpc.VideoStub + + def __init__(self, channel: Union[grpc.Channel, grpc.aio.Channel]): + """Creates a new client based on a gRPC channel.""" + self._stub = video_pb2_grpc.VideoStub(channel) + + +class VideoResponse(ProtoDecorator[video_pb2.VideoResponse]): + """Adds auxiliary functions for handling the video response proto.""" + + _video: video_pb2.GeneratedVideo + + def __init__(self, proto: video_pb2.VideoResponse) -> None: + """Initializes a new instance of the `VideoResponse` wrapper class.""" + super().__init__(proto) + self._video = proto.video + + @property + def model(self) -> str: + """The model used to generate the video (ignoring aliases).""" + return self._proto.model + + @property + def usage(self) -> usage_pb2.SamplingUsage: + """Token and tool usage for this request.""" + return self._proto.usage + + @property + def respect_moderation(self) -> bool: + """Whether the generated video respects moderation rules.""" + return getattr(self._video, "respect_moderation", True) + + @property + def url(self) -> str: + """The URL under which the video is stored or raises an error. + + Note: The returned URL is valid for 24 hours. + """ + url = self._video.url + if not url: + if not self.respect_moderation: + raise ValueError("Video did not respect moderation rules; URL is not available.") + raise ValueError("Video URL missing from response.") + return url + + @property + def duration(self) -> int: + """Duration of the generated video in seconds.""" + return self._video.duration + + +def _make_generate_request( + prompt: str, + model: str, + *, + image_url: Optional[str], + video_url: Optional[str], + duration: Optional[int], + aspect_ratio: Optional[VideoAspectRatio], + resolution: Optional[VideoResolution], +) -> video_pb2.GenerateVideoRequest: + request = video_pb2.GenerateVideoRequest(prompt=prompt, model=model) + + if image_url is not None: + request.image.CopyFrom( + image_pb2.ImageUrlContent( + image_url=image_url, + detail=image_pb2.ImageDetail.DETAIL_AUTO, + ) + ) + if video_url is not None: + request.video.CopyFrom(video_pb2.VideoUrlContent(url=video_url)) + if duration is not None: + request.duration = duration + if aspect_ratio is not None: + request.aspect_ratio = convert_video_aspect_ratio_to_pb(aspect_ratio) + if resolution is not None: + request.resolution = convert_video_resolution_to_pb(resolution) + + return request + + +def _make_span_request_attributes(request: video_pb2.GenerateVideoRequest) -> dict[str, Any]: + """Creates the video generation span request attributes.""" + attributes: dict[str, Any] = { + "gen_ai.operation.name": "generate_video", + "gen_ai.request.model": request.model, + "gen_ai.provider.name": "xai", + "server.address": "api.x.ai", + "gen_ai.output.type": "video", + } + + if should_disable_sensitive_attributes(): + return attributes + + attributes["gen_ai.prompt"] = request.prompt + + if request.HasField("duration"): + attributes["gen_ai.request.video.duration"] = request.duration + if request.HasField("aspect_ratio"): + attributes["gen_ai.request.video.aspect_ratio"] = ( + video_pb2.VideoAspectRatio.Name(request.aspect_ratio).removeprefix("VIDEO_ASPECT_RATIO_").replace("_", ":") + ) + if request.HasField("resolution"): + attributes["gen_ai.request.video.resolution"] = ( + video_pb2.VideoResolution.Name(request.resolution).removeprefix("VIDEO_RESOLUTION_").lower() + ) + + return attributes + + +def _make_span_response_attributes(request: video_pb2.GenerateVideoRequest, response: VideoResponse) -> dict[str, Any]: + """Creates the video generation span response attributes.""" + attributes: dict[str, Any] = { + "gen_ai.response.model": request.model, + } + + if should_disable_sensitive_attributes(): + return attributes + + usage = response.usage + attributes["gen_ai.usage.input_tokens"] = usage.prompt_tokens + attributes["gen_ai.usage.output_tokens"] = usage.completion_tokens + attributes["gen_ai.usage.total_tokens"] = usage.total_tokens + attributes["gen_ai.usage.reasoning_tokens"] = usage.reasoning_tokens + attributes["gen_ai.usage.cached_prompt_text_tokens"] = usage.cached_prompt_text_tokens + attributes["gen_ai.usage.prompt_text_tokens"] = usage.prompt_text_tokens + attributes["gen_ai.usage.prompt_image_tokens"] = usage.prompt_image_tokens + + attributes["gen_ai.response.0.video.respect_moderation"] = response.respect_moderation + if response._video.url: + attributes["gen_ai.response.0.video.url"] = response._video.url + attributes["gen_ai.response.0.video.duration"] = response.duration + + return attributes + + +def convert_video_aspect_ratio_to_pb(aspect_ratio: VideoAspectRatio) -> video_pb2.VideoAspectRatio: + """Converts a string literal representation of a video aspect ratio to its protobuf enum variant.""" + try: + return VideoAspectRatioMap[aspect_ratio] + except KeyError as exc: + raise ValueError(f"Invalid video aspect ratio {aspect_ratio}.") from exc + + +def convert_video_resolution_to_pb(resolution: VideoResolution) -> video_pb2.VideoResolution: + """Converts a string literal representation of a video resolution to its protobuf enum variant.""" + try: + return VideoResolutionMap[resolution] + except KeyError as exc: + raise ValueError(f"Invalid video resolution {resolution}.") from exc diff --git a/tests/aio/image_test.py b/tests/aio/image_test.py index cb2154b..cc391d9 100644 --- a/tests/aio/image_test.py +++ b/tests/aio/image_test.py @@ -6,6 +6,7 @@ from xai_sdk import AsyncClient from xai_sdk.image import ImageFormat +from xai_sdk.proto import image_pb2 from .. import server @@ -48,6 +49,59 @@ async def test_batch(client: AsyncClient, image_asset: bytes): assert image_asset == await r.image +@pytest.mark.asyncio(loop_scope="session") +async def test_sample_passes_aspect_ratio_and_resolution(client: AsyncClient): + server.clear_last_image_request() + + await client.image.sample( + prompt="foo", + model="grok-2-image", + aspect_ratio="1:1", + resolution="1k", + ) + + request = server.get_last_image_request() + assert request is not None + assert request.HasField("aspect_ratio") + assert request.aspect_ratio == image_pb2.ImageAspectRatio.IMG_ASPECT_RATIO_1_1 + assert request.HasField("resolution") + assert request.resolution == image_pb2.ImageResolution.IMG_RESOLUTION_1K + + +@pytest.mark.asyncio(loop_scope="session") +async def test_sample_batch_passes_aspect_ratio_and_resolution(client: AsyncClient): + server.clear_last_image_request() + + await client.image.sample_batch( + prompt="foo", + model="grok-2-image", + n=2, + aspect_ratio="16:9", + resolution="1k", + ) + + request = server.get_last_image_request() + assert request is not None + assert request.HasField("aspect_ratio") + assert request.aspect_ratio == image_pb2.ImageAspectRatio.IMG_ASPECT_RATIO_16_9 + assert request.HasField("resolution") + assert request.resolution == image_pb2.ImageResolution.IMG_RESOLUTION_1K + + +@pytest.mark.asyncio(loop_scope="session") +async def test_sample_passes_image_url(client: AsyncClient): + server.clear_last_image_request() + + input_image_url = "https://example.com/image.jpg" + await client.image.sample(prompt="foo", model="grok-imagine-image", image_url=input_image_url) + + request = server.get_last_image_request() + assert request is not None + assert request.HasField("image") + assert request.image.image_url == input_image_url + assert request.image.detail == image_pb2.ImageDetail.DETAIL_AUTO + + @mock.patch("xai_sdk.aio.image.tracer") @pytest.mark.asyncio(loop_scope="session") @pytest.mark.parametrize("image_format", ["url", "base64"]) @@ -65,7 +119,8 @@ async def test_sample_creates_span_with_correct_attributes( expected_request_attributes = { "gen_ai.prompt": "A beautiful sunset", "gen_ai.operation.name": "generate_image", - "gen_ai.system": "xai", + "gen_ai.provider.name": "xai", + "gen_ai.output.type": "image", "gen_ai.request.model": "grok-2-image", "gen_ai.request.image.format": image_format, "gen_ai.request.image.count": 1, @@ -81,7 +136,15 @@ async def test_sample_creates_span_with_correct_attributes( expected_response_attributes = { "gen_ai.response.model": "grok-2-image", "gen_ai.response.image.format": image_format, + "gen_ai.usage.input_tokens": response.usage.prompt_tokens, + "gen_ai.usage.output_tokens": response.usage.completion_tokens, + "gen_ai.usage.total_tokens": response.usage.total_tokens, + "gen_ai.usage.reasoning_tokens": response.usage.reasoning_tokens, + "gen_ai.usage.cached_prompt_text_tokens": response.usage.cached_prompt_text_tokens, + "gen_ai.usage.prompt_text_tokens": response.usage.prompt_text_tokens, + "gen_ai.usage.prompt_image_tokens": response.usage.prompt_image_tokens, "gen_ai.response.0.image.up_sampled_prompt": response.prompt, + "gen_ai.response.0.image.respect_moderation": response.respect_moderation, } if image_format == "url": @@ -110,7 +173,8 @@ async def test_sample_creates_span_without_sensitive_attributes_when_disabled( expected_request_attributes = { "gen_ai.operation.name": "generate_image", - "gen_ai.system": "xai", + "gen_ai.provider.name": "xai", + "gen_ai.output.type": "image", "gen_ai.request.model": "grok-2-image", } @@ -146,7 +210,8 @@ async def test_sample_batch_creates_span_with_correct_attributes( expected_request_attributes = { "gen_ai.prompt": "A beautiful sunset", "gen_ai.operation.name": "generate_image", - "gen_ai.system": "xai", + "gen_ai.provider.name": "xai", + "gen_ai.output.type": "image", "gen_ai.request.model": "grok-2-image", "gen_ai.request.image.format": image_format, "gen_ai.request.image.count": 3, @@ -162,9 +227,19 @@ async def test_sample_batch_creates_span_with_correct_attributes( expected_response_attributes = { "gen_ai.response.model": "grok-2-image", "gen_ai.response.image.format": image_format, + "gen_ai.usage.input_tokens": responses[0].usage.prompt_tokens, + "gen_ai.usage.output_tokens": responses[0].usage.completion_tokens, + "gen_ai.usage.total_tokens": responses[0].usage.total_tokens, + "gen_ai.usage.reasoning_tokens": responses[0].usage.reasoning_tokens, + "gen_ai.usage.cached_prompt_text_tokens": responses[0].usage.cached_prompt_text_tokens, + "gen_ai.usage.prompt_text_tokens": responses[0].usage.prompt_text_tokens, + "gen_ai.usage.prompt_image_tokens": responses[0].usage.prompt_image_tokens, "gen_ai.response.0.image.up_sampled_prompt": responses[0].prompt, "gen_ai.response.1.image.up_sampled_prompt": responses[1].prompt, "gen_ai.response.2.image.up_sampled_prompt": responses[2].prompt, + "gen_ai.response.0.image.respect_moderation": responses[0].respect_moderation, + "gen_ai.response.1.image.respect_moderation": responses[1].respect_moderation, + "gen_ai.response.2.image.respect_moderation": responses[2].respect_moderation, } if image_format == "url": diff --git a/tests/aio/video_test.py b/tests/aio/video_test.py new file mode 100644 index 0000000..002f853 --- /dev/null +++ b/tests/aio/video_test.py @@ -0,0 +1,149 @@ +from unittest import mock + +import pytest +import pytest_asyncio +from opentelemetry.trace import SpanKind + +from xai_sdk import AsyncClient +from xai_sdk.proto import image_pb2, video_pb2 + +from .. import server + + +@pytest_asyncio.fixture(scope="session") +async def client(): + with server.run_test_server() as port: + yield AsyncClient(api_key=server.API_KEY, api_host=f"localhost:{port}") + + +@pytest.mark.asyncio(loop_scope="session") +async def test_generate_returns_video_url_and_optional_prompt(client: AsyncClient): + response = await client.video.generate(prompt="foo", model="grok-imagine-video") + + assert response.model == "grok-imagine-video" + assert response.url + assert response.duration > 0 + + +@pytest.mark.asyncio(loop_scope="session") +async def test_generate_passes_duration_aspect_ratio_and_resolution(client: AsyncClient): + server.clear_last_video_request() + + response = await client.video.generate( + prompt="foo", + model="grok-imagine-video", + duration=3, + aspect_ratio="16:9", + resolution="480p", + ) + + assert response.duration == 3 + + request = server.get_last_video_request() + assert request is not None + assert request.HasField("duration") + assert request.duration == 3 + assert request.HasField("aspect_ratio") + assert request.aspect_ratio == video_pb2.VideoAspectRatio.VIDEO_ASPECT_RATIO_16_9 + assert request.HasField("resolution") + assert request.resolution == video_pb2.VideoResolution.VIDEO_RESOLUTION_480P + + +@pytest.mark.asyncio(loop_scope="session") +async def test_generate_passes_image_url(client: AsyncClient): + server.clear_last_video_request() + + input_image_url = "https://example.com/image.jpg" + await client.video.generate(prompt="foo", model="grok-imagine-video", image_url=input_image_url) + + request = server.get_last_video_request() + assert request is not None + assert request.HasField("image") + assert request.image.image_url == input_image_url + assert request.image.detail == image_pb2.ImageDetail.DETAIL_AUTO + + +@pytest.mark.asyncio(loop_scope="session") +async def test_generate_passes_video_url(client: AsyncClient): + server.clear_last_video_request() + + input_video_url = "https://example.com/video.mp4" + await client.video.generate(prompt="foo", model="grok-imagine-video", video_url=input_video_url) + + request = server.get_last_video_request() + assert request is not None + assert request.HasField("video") + assert request.video.url == input_video_url + + +@mock.patch("xai_sdk.aio.video.tracer") +@pytest.mark.asyncio(loop_scope="session") +async def test_generate_creates_span_with_correct_attributes(mock_tracer: mock.MagicMock, client: AsyncClient): + mock_span = mock.MagicMock() + mock_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span + + response = await client.video.generate(prompt="A beautiful sunset", model="grok-imagine-video") + + expected_request_attributes = { + "gen_ai.prompt": "A beautiful sunset", + "gen_ai.operation.name": "generate_video", + "gen_ai.provider.name": "xai", + "gen_ai.output.type": "video", + "server.address": "api.x.ai", + "gen_ai.request.model": "grok-imagine-video", + } + + mock_tracer.start_as_current_span.assert_called_once_with( + name="video.generate grok-imagine-video", + kind=SpanKind.CLIENT, + attributes=expected_request_attributes, + ) + + expected_response_attributes = { + "gen_ai.response.model": "grok-imagine-video", + "gen_ai.usage.input_tokens": response.usage.prompt_tokens, + "gen_ai.usage.output_tokens": response.usage.completion_tokens, + "gen_ai.usage.total_tokens": response.usage.total_tokens, + "gen_ai.usage.reasoning_tokens": response.usage.reasoning_tokens, + "gen_ai.usage.cached_prompt_text_tokens": response.usage.cached_prompt_text_tokens, + "gen_ai.usage.prompt_text_tokens": response.usage.prompt_text_tokens, + "gen_ai.usage.prompt_image_tokens": response.usage.prompt_image_tokens, + "gen_ai.response.0.video.respect_moderation": response.respect_moderation, + "gen_ai.response.0.video.url": response.url, + "gen_ai.response.0.video.duration": response.duration, + } + + mock_span.set_attributes.assert_called_once_with(expected_response_attributes) + + +@mock.patch("xai_sdk.aio.video.tracer") +@pytest.mark.asyncio(loop_scope="session") +async def test_generate_creates_span_without_sensitive_attributes_when_disabled( + mock_tracer: mock.MagicMock, client: AsyncClient +): + """Test that sensitive attributes are not included when XAI_SDK_DISABLE_SENSITIVE_TELEMETRY_ATTRIBUTES is set.""" + mock_span = mock.MagicMock() + mock_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span + + with mock.patch.dict("os.environ", {"XAI_SDK_DISABLE_SENSITIVE_TELEMETRY_ATTRIBUTES": "1"}): + await client.video.generate(prompt="A beautiful sunset", model="grok-imagine-video") + + expected_request_attributes = { + "gen_ai.operation.name": "generate_video", + "gen_ai.provider.name": "xai", + "server.address": "api.x.ai", + "gen_ai.request.model": "grok-imagine-video", + "gen_ai.output.type": "video", + } + + mock_tracer.start_as_current_span.assert_called_once_with( + name="video.generate grok-imagine-video", + kind=SpanKind.CLIENT, + attributes=expected_request_attributes, + ) + + expected_response_attributes = { + "gen_ai.response.model": "grok-imagine-video", + } + + mock_span.set_attributes.assert_called_once_with(expected_response_attributes) diff --git a/tests/server.py b/tests/server.py index ead29b0..cbc4c6a 100644 --- a/tests/server.py +++ b/tests/server.py @@ -40,6 +40,8 @@ tokenize_pb2, tokenize_pb2_grpc, usage_pb2, + video_pb2, + video_pb2_grpc, ) # All valid requests should use this API key. @@ -47,6 +49,61 @@ MANAGEMENT_API_KEY = "456" IMAGE_PATH = "test.jpg" +_last_image_request_lock = threading.Lock() +_last_video_request_lock = threading.Lock() + + +class _LastImageRequestState: + def __init__(self) -> None: + self.value: Optional[image_pb2.GenerateImageRequest] = None + + +_last_image_request_state = _LastImageRequestState() + + +class _LastVideoRequestState: + def __init__(self) -> None: + self.value: Optional[video_pb2.GenerateVideoRequest] = None + + +_last_video_request_state = _LastVideoRequestState() + + +def clear_last_image_request() -> None: + with _last_image_request_lock: + _last_image_request_state.value = None + + +def get_last_image_request() -> Optional[image_pb2.GenerateImageRequest]: + with _last_image_request_lock: + if _last_image_request_state.value is None: + return None + # Return a defensive copy so tests can't mutate shared state. + return image_pb2.GenerateImageRequest.FromString(_last_image_request_state.value.SerializeToString()) + + +def _record_last_image_request(request: image_pb2.GenerateImageRequest) -> None: + with _last_image_request_lock: + _last_image_request_state.value = image_pb2.GenerateImageRequest.FromString(request.SerializeToString()) + + +def clear_last_video_request() -> None: + with _last_video_request_lock: + _last_video_request_state.value = None + + +def get_last_video_request() -> Optional[video_pb2.GenerateVideoRequest]: + with _last_video_request_lock: + if _last_video_request_state.value is None: + return None + # Return a defensive copy so tests can't mutate shared state. + return video_pb2.GenerateVideoRequest.FromString(_last_video_request_state.value.SerializeToString()) + + +def _record_last_video_request(request: video_pb2.GenerateVideoRequest) -> None: + with _last_video_request_lock: + _last_video_request_state.value = video_pb2.GenerateVideoRequest.FromString(request.SerializeToString()) + def read_image() -> bytes: path = os.path.join(os.path.dirname(__file__), IMAGE_PATH) @@ -585,6 +642,7 @@ def __init__(self, url): def GenerateImage(self, request: image_pb2.GenerateImageRequest, context: grpc.ServicerContext): _check_auth(context) + _record_last_image_request(request) if request.format == image_pb2.ImageFormat.IMG_FORMAT_URL: return image_pb2.ImageResponse( @@ -606,6 +664,52 @@ def GenerateImage(self, request: image_pb2.GenerateImageRequest, context: grpc.S ) +class VideoServicer(video_pb2_grpc.VideoServicer): + """Minimal Video service used by tests.""" + + def __init__(self, url: str): + self._url = url + self._deferred_requests: dict[str, tuple[video_pb2.GenerateVideoRequest, int]] = {} + + def GenerateVideo(self, request: video_pb2.GenerateVideoRequest, context: grpc.ServicerContext): + _check_auth(context) + _record_last_video_request(request) + + key = f"video-{len(self._deferred_requests)}" + # Store a defensive copy + poll count. + self._deferred_requests[key] = ( + video_pb2.GenerateVideoRequest.FromString(request.SerializeToString()), + 0, + ) + return deferred_pb2.StartDeferredResponse(request_id=key) + + def GetDeferredVideo(self, request: video_pb2.GetDeferredVideoRequest, context: grpc.ServicerContext): + _check_auth(context) + + if request.request_id not in self._deferred_requests: + context.abort(grpc.StatusCode.NOT_FOUND, "Invalid request ID") + + stored_request, polls = self._deferred_requests[request.request_id] + + # Every request needs to be polled three times. + if polls < 2: + self._deferred_requests[request.request_id] = (stored_request, polls + 1) + return video_pb2.GetDeferredVideoResponse(status=deferred_pb2.DeferredStatus.PENDING) + + duration = stored_request.duration if stored_request.HasField("duration") else 5 + + return video_pb2.GetDeferredVideoResponse( + status=deferred_pb2.DeferredStatus.DONE, + response=video_pb2.VideoResponse( + model=stored_request.model, + video=video_pb2.GeneratedVideo( + url=self._url, + duration=duration, + ), + ), + ) + + class ImageHandler(http.server.SimpleHTTPRequestHandler): def do_GET(self): if self.path == "/foo.jpg": @@ -1405,6 +1509,7 @@ def __init__( image_pb2_grpc.add_ImageServicer_to_server( ImageServicer(f"http://localhost:{self._image_port}/foo.jpg"), self._server ) + video_pb2_grpc.add_VideoServicer_to_server(VideoServicer("https://example.com/foo.mp4"), self._server) documents_pb2_grpc.add_DocumentsServicer_to_server(DocumentServicer(), self._server) files_pb2_grpc.add_FilesServicer_to_server(FilesServicer(self._store), self._server) diff --git a/tests/sync/image_test.py b/tests/sync/image_test.py index 30c00db..6260208 100644 --- a/tests/sync/image_test.py +++ b/tests/sync/image_test.py @@ -5,6 +5,7 @@ from xai_sdk import Client from xai_sdk.image import ImageFormat +from xai_sdk.proto import image_pb2 from .. import server @@ -44,6 +45,56 @@ def test_batch(client: Client, image_asset: bytes): assert image_asset == r.image +def test_sample_passes_aspect_ratio_and_resolution(client: Client): + server.clear_last_image_request() + + client.image.sample( + prompt="foo", + model="grok-2-image", + aspect_ratio="1:1", + resolution="1k", + ) + + request = server.get_last_image_request() + assert request is not None + assert request.HasField("aspect_ratio") + assert request.aspect_ratio == image_pb2.ImageAspectRatio.IMG_ASPECT_RATIO_1_1 + assert request.HasField("resolution") + assert request.resolution == image_pb2.ImageResolution.IMG_RESOLUTION_1K + + +def test_sample_batch_passes_aspect_ratio_and_resolution(client: Client): + server.clear_last_image_request() + + client.image.sample_batch( + prompt="foo", + model="grok-2-image", + n=2, + aspect_ratio="16:9", + resolution="1k", + ) + + request = server.get_last_image_request() + assert request is not None + assert request.HasField("aspect_ratio") + assert request.aspect_ratio == image_pb2.ImageAspectRatio.IMG_ASPECT_RATIO_16_9 + assert request.HasField("resolution") + assert request.resolution == image_pb2.ImageResolution.IMG_RESOLUTION_1K + + +def test_sample_passes_image_url(client: Client): + server.clear_last_image_request() + + input_image_url = "https://example.com/image.jpg" + client.image.sample(prompt="foo", model="grok-imagine-image", image_url=input_image_url) + + request = server.get_last_image_request() + assert request is not None + assert request.HasField("image") + assert request.image.image_url == input_image_url + assert request.image.detail == image_pb2.ImageDetail.DETAIL_AUTO + + @mock.patch("xai_sdk.sync.image.tracer") @pytest.mark.parametrize("image_format", ["url", "base64"]) def test_sample_creates_span_with_correct_attributes( @@ -60,7 +111,8 @@ def test_sample_creates_span_with_correct_attributes( expected_request_attributes = { "gen_ai.prompt": "A beautiful sunset", "gen_ai.operation.name": "generate_image", - "gen_ai.system": "xai", + "gen_ai.provider.name": "xai", + "gen_ai.output.type": "image", "gen_ai.request.model": "grok-2-image", "gen_ai.request.image.format": image_format, "gen_ai.request.image.count": 1, @@ -76,7 +128,15 @@ def test_sample_creates_span_with_correct_attributes( expected_response_attributes = { "gen_ai.response.model": "grok-2-image", "gen_ai.response.image.format": image_format, + "gen_ai.usage.input_tokens": response.usage.prompt_tokens, + "gen_ai.usage.output_tokens": response.usage.completion_tokens, + "gen_ai.usage.total_tokens": response.usage.total_tokens, + "gen_ai.usage.reasoning_tokens": response.usage.reasoning_tokens, + "gen_ai.usage.cached_prompt_text_tokens": response.usage.cached_prompt_text_tokens, + "gen_ai.usage.prompt_text_tokens": response.usage.prompt_text_tokens, + "gen_ai.usage.prompt_image_tokens": response.usage.prompt_image_tokens, "gen_ai.response.0.image.up_sampled_prompt": response.prompt, + "gen_ai.response.0.image.respect_moderation": response.respect_moderation, } if image_format == "url": @@ -102,7 +162,8 @@ def test_sample_creates_span_without_sensitive_attributes_when_disabled( expected_request_attributes = { "gen_ai.operation.name": "generate_image", - "gen_ai.system": "xai", + "gen_ai.provider.name": "xai", + "gen_ai.output.type": "image", "gen_ai.request.model": "grok-2-image", } @@ -137,7 +198,8 @@ def test_sample_batch_creates_span_with_correct_attributes( expected_request_attributes = { "gen_ai.prompt": "A beautiful sunset", "gen_ai.operation.name": "generate_image", - "gen_ai.system": "xai", + "gen_ai.provider.name": "xai", + "gen_ai.output.type": "image", "gen_ai.request.model": "grok-2-image", "gen_ai.request.image.format": image_format, "gen_ai.request.image.count": 3, @@ -153,9 +215,19 @@ def test_sample_batch_creates_span_with_correct_attributes( expected_response_attributes = { "gen_ai.response.model": "grok-2-image", "gen_ai.response.image.format": image_format, + "gen_ai.usage.input_tokens": responses[0].usage.prompt_tokens, + "gen_ai.usage.output_tokens": responses[0].usage.completion_tokens, + "gen_ai.usage.total_tokens": responses[0].usage.total_tokens, + "gen_ai.usage.reasoning_tokens": responses[0].usage.reasoning_tokens, + "gen_ai.usage.cached_prompt_text_tokens": responses[0].usage.cached_prompt_text_tokens, + "gen_ai.usage.prompt_text_tokens": responses[0].usage.prompt_text_tokens, + "gen_ai.usage.prompt_image_tokens": responses[0].usage.prompt_image_tokens, "gen_ai.response.0.image.up_sampled_prompt": responses[0].prompt, "gen_ai.response.1.image.up_sampled_prompt": responses[1].prompt, "gen_ai.response.2.image.up_sampled_prompt": responses[2].prompt, + "gen_ai.response.0.image.respect_moderation": responses[0].respect_moderation, + "gen_ai.response.1.image.respect_moderation": responses[1].respect_moderation, + "gen_ai.response.2.image.respect_moderation": responses[2].respect_moderation, } if image_format == "url": diff --git a/tests/sync/video_test.py b/tests/sync/video_test.py new file mode 100644 index 0000000..673172c --- /dev/null +++ b/tests/sync/video_test.py @@ -0,0 +1,139 @@ +from unittest import mock + +import pytest +from opentelemetry.trace import SpanKind + +from xai_sdk import Client +from xai_sdk.proto import image_pb2, video_pb2 + +from .. import server + + +@pytest.fixture(scope="session") +def client(): + with server.run_test_server() as port: + yield Client(api_key=server.API_KEY, api_host=f"localhost:{port}") + + +def test_generate_returns_video_url_and_optional_prompt(client: Client): + response = client.video.generate(prompt="foo", model="grok-imagine-video") + + assert response.model == "grok-imagine-video" + assert response.url + assert response.duration > 0 + + +def test_generate_passes_duration_aspect_ratio_and_resolution(client: Client): + server.clear_last_video_request() + + response = client.video.generate( + prompt="foo", + model="grok-imagine-video", + duration=3, + aspect_ratio="16:9", + resolution="480p", + ) + + assert response.duration == 3 + + request = server.get_last_video_request() + assert request is not None + assert request.HasField("duration") + assert request.duration == 3 + assert request.HasField("aspect_ratio") + assert request.aspect_ratio == video_pb2.VideoAspectRatio.VIDEO_ASPECT_RATIO_16_9 + assert request.HasField("resolution") + assert request.resolution == video_pb2.VideoResolution.VIDEO_RESOLUTION_480P + + +def test_generate_passes_image_url(client: Client): + server.clear_last_video_request() + + input_image_url = "https://example.com/image.jpg" + client.video.generate(prompt="foo", model="grok-imagine-video", image_url=input_image_url) + + request = server.get_last_video_request() + assert request is not None + assert request.HasField("image") + assert request.image.image_url == input_image_url + assert request.image.detail == image_pb2.ImageDetail.DETAIL_AUTO + + +def test_generate_passes_video_url(client: Client): + server.clear_last_video_request() + + input_video_url = "https://example.com/video.mp4" + client.video.generate(prompt="foo", model="grok-imagine-video", video_url=input_video_url) + + request = server.get_last_video_request() + assert request is not None + assert request.HasField("video") + assert request.video.url == input_video_url + + +@mock.patch("xai_sdk.sync.video.tracer") +def test_generate_creates_span_with_correct_attributes(mock_tracer: mock.MagicMock, client: Client): + mock_span = mock.MagicMock() + mock_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span + + response = client.video.generate(prompt="A beautiful sunset", model="grok-imagine-video") + + expected_request_attributes = { + "gen_ai.prompt": "A beautiful sunset", + "gen_ai.operation.name": "generate_video", + "gen_ai.provider.name": "xai", + "server.address": "api.x.ai", + "gen_ai.request.model": "grok-imagine-video", + "gen_ai.output.type": "video", + } + + mock_tracer.start_as_current_span.assert_called_once_with( + name="video.generate grok-imagine-video", + kind=SpanKind.CLIENT, + attributes=expected_request_attributes, + ) + + expected_response_attributes = { + "gen_ai.response.model": "grok-imagine-video", + "gen_ai.usage.input_tokens": response.usage.prompt_tokens, + "gen_ai.usage.output_tokens": response.usage.completion_tokens, + "gen_ai.usage.total_tokens": response.usage.total_tokens, + "gen_ai.usage.reasoning_tokens": response.usage.reasoning_tokens, + "gen_ai.usage.cached_prompt_text_tokens": response.usage.cached_prompt_text_tokens, + "gen_ai.usage.prompt_text_tokens": response.usage.prompt_text_tokens, + "gen_ai.usage.prompt_image_tokens": response.usage.prompt_image_tokens, + "gen_ai.response.0.video.respect_moderation": response.respect_moderation, + "gen_ai.response.0.video.url": response.url, + "gen_ai.response.0.video.duration": response.duration, + } + mock_span.set_attributes.assert_called_once_with(expected_response_attributes) + + +@mock.patch("xai_sdk.sync.video.tracer") +def test_generate_creates_span_without_sensitive_attributes_when_disabled(mock_tracer: mock.MagicMock, client: Client): + """Test that sensitive attributes are not included when XAI_SDK_DISABLE_SENSITIVE_TELEMETRY_ATTRIBUTES is set.""" + mock_span = mock.MagicMock() + mock_tracer.start_as_current_span.return_value.__enter__.return_value = mock_span + + with mock.patch.dict("os.environ", {"XAI_SDK_DISABLE_SENSITIVE_TELEMETRY_ATTRIBUTES": "1"}): + client.video.generate(prompt="A beautiful sunset", model="grok-imagine-video") + + expected_request_attributes = { + "gen_ai.operation.name": "generate_video", + "gen_ai.provider.name": "xai", + "server.address": "api.x.ai", + "gen_ai.request.model": "grok-imagine-video", + "gen_ai.output.type": "video", + } + + mock_tracer.start_as_current_span.assert_called_once_with( + name="video.generate grok-imagine-video", + kind=SpanKind.CLIENT, + attributes=expected_request_attributes, + ) + + expected_response_attributes = { + "gen_ai.response.model": "grok-imagine-video", + } + + mock_span.set_attributes.assert_called_once_with(expected_response_attributes)