Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 52 additions & 23 deletions examples/aio/image_generation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import itertools
from typing import Sequence
from typing import Sequence, cast

from absl import app, flags

Expand All @@ -10,29 +10,66 @@

N = flags.DEFINE_integer("n", 1, "Number of images to generate.")
FORMAT = flags.DEFINE_enum("format", "base64", ["base64", "url"], "Image format used to return the result.")
MODEL = flags.DEFINE_string("model", "grok-imagine-image", "Image generation model to use.")
OUTPUT_DIR = flags.DEFINE_string("output-dir", None, "Directory to save the generated images.", required=True)


async def generate_single(client: xai_sdk.AsyncClient, image_format: ImageFormat):
"""Generate a single image from a prompt."""
for turn in itertools.count():
prompt = input("Prompt: ")
result = await client.image.sample(prompt, model="grok-2-image", image_format=image_format)
await _save_images(turn, [result])
async def generate_multi_turn(client: xai_sdk.AsyncClient, image_format: ImageFormat):
"""Multi-turn image generation that builds on the previous output image.

Turn 0 generates an initial image from your prompt. Each subsequent turn reuses the previous
image output as an input image (image-to-image) while you provide a new prompt to refine it.
"""
previous_image: str | None = None

async def generate_batch(client: xai_sdk.AsyncClient, image_format: ImageFormat):
"""Generate a batch of images from a prompt."""
for turn in itertools.count():
prompt = input("Prompt: ")
results = await client.image.sample_batch(prompt, n=N.value, model="grok-2-image", image_format=image_format)
await _save_images(turn, results)
if previous_image is None:
prompt = input("Prompt (blank to stop): ")
else:
prompt = input("Edit prompt (blank to stop): ")
if not prompt:
return

if N.value == 1:
response = await client.image.sample(
prompt,
model=MODEL.value,
image_format=image_format,
image_url=previous_image,
)
responses = [response]
else:
responses = await client.image.sample_batch(
prompt,
n=N.value,
model=MODEL.value,
image_format=image_format,
image_url=previous_image,
)

await _save_images(turn, responses)

selected = 0
if len(responses) > 1:
raw = input(f"Continue from which image? [0-{len(responses) - 1}] (default 0): ").strip()
if raw:
selected = int(raw)
if selected < 0 or selected >= len(responses):
raise ValueError(f"Invalid image index {selected}.")

chosen = responses[selected]
previous_image = chosen.url if image_format == "url" else chosen.base64

if len(responses) > 1:
if image_format == "url":
print(f"Continuing from image {selected}: {chosen.url}")
else:
print(f"Continuing from image {selected} (base64).")


async def _save_images(turn: int, responses: Sequence[ImageResponse]):
"""Save images to a file."""
for i, image in enumerate(responses):
print(image.prompt)
with open(f"{OUTPUT_DIR.value}/image_{turn}_{i}.jpg", "wb") as f:
f.write(await image.image)

Expand All @@ -42,16 +79,8 @@ async def main(argv: Sequence[str]) -> None:
raise app.UsageError("Unexpected command line arguments.")

client = xai_sdk.AsyncClient()

match (N.value, FORMAT.value):
case (1, "base64"):
await generate_single(client, "base64")
case (_, "base64"):
await generate_batch(client, "base64")
case (1, "url"):
await generate_single(client, "url")
case (_, "url"):
await generate_batch(client, "url")
image_format: ImageFormat = cast(ImageFormat, FORMAT.value)
await generate_multi_turn(client, image_format)


if __name__ == "__main__":
Expand Down
84 changes: 84 additions & 0 deletions examples/aio/video_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import asyncio
from datetime import timedelta
from typing import Sequence, cast

from absl import app, flags

import xai_sdk
from xai_sdk.video import VideoAspectRatio, VideoResolution

MODEL = flags.DEFINE_string("model", "grok-imagine-video", "Video generation model to use.")
IMAGE_URL = flags.DEFINE_string(
"image-url",
"",
"Optional input image (URL or base64 data URL) to use as the first frame (image-to-video).",
)
VIDEO_URL = flags.DEFINE_string(
"video-url",
"",
"Optional input video (URL or base64 data URL) to edit based on the prompt (video-to-video).",
)
DURATION = flags.DEFINE_integer("duration", 0, "Optional duration in seconds (1-15). Use 0 to omit.")
ASPECT_RATIO = flags.DEFINE_string(
"aspect-ratio",
"",
'Optional aspect ratio. One of: "1:1", "16:9", "9:16", "4:3", "3:4", "3:2", "2:3".',
)
RESOLUTION = flags.DEFINE_string("resolution", "", 'Optional resolution. One of: "480p", "720p".')
TIMEOUT = flags.DEFINE_integer("timeout", 600, "Timeout in seconds for polling.")
INTERVAL = flags.DEFINE_integer("interval", 1, "Polling interval in seconds.")


async def main(argv: Sequence[str]) -> None:
if len(argv) > 1:
raise app.UsageError("Unexpected command line arguments.")

client = xai_sdk.AsyncClient()

duration = DURATION.value or None
image_url = IMAGE_URL.value or None
video_url = VIDEO_URL.value or None
aspect_ratio = cast(VideoAspectRatio, ASPECT_RATIO.value) if ASPECT_RATIO.value else None
resolution = cast(VideoResolution, RESOLUTION.value) if RESOLUTION.value else None

previous_video_url: str | None = video_url
first_turn = True

while True:
prompt = input("Prompt (blank to stop): " if first_turn else "Edit prompt (blank to stop): ")
if not prompt:
return

try:
response = await client.video.generate(
prompt=prompt,
model=MODEL.value,
image_url=image_url if first_turn else None,
video_url=previous_video_url,
duration=duration,
aspect_ratio=aspect_ratio,
resolution=resolution,
timeout=timedelta(seconds=TIMEOUT.value),
interval=timedelta(seconds=INTERVAL.value),
)
print(f"Respects moderation: {response.respect_moderation}")
if response.respect_moderation:
print(f"Video URL: {response.url}")
else:
print("Video URL not returned due to moderation.")
print(f"Duration: {response.duration}s")

# Chain edits: use the returned URL as the next input video.
if response.respect_moderation:
previous_video_url = response.url
first_turn = False
except RuntimeError as e:
# request expired
print(e)
except ValueError as e:
# unknown deferred status
print(e)


if __name__ == "__main__":
app.run(lambda argv: asyncio.run(main(argv)))
75 changes: 52 additions & 23 deletions examples/sync/image_generation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import itertools
from typing import Sequence
from typing import Sequence, cast

from absl import app, flags

Expand All @@ -9,29 +9,66 @@

N = flags.DEFINE_integer("n", 1, "Number of images to generate.")
FORMAT = flags.DEFINE_enum("format", "base64", ["base64", "url"], "Image format used to return the result.")
MODEL = flags.DEFINE_string("model", "grok-imagine-image", "Image generation model to use.")
OUTPUT_DIR = flags.DEFINE_string("output-dir", None, "Directory to save the generated images.", required=True)


def generate_single(client: xai_sdk.Client, image_format: ImageFormat):
"""Generate a single image from a prompt."""
for turn in itertools.count():
prompt = input("Prompt: ")
result = client.image.sample(prompt, model="grok-2-image", image_format=image_format)
save_images(turn, [result])
def generate_multi_turn(client: xai_sdk.Client, image_format: ImageFormat):
"""Multi-turn image generation that builds on the previous output image.

Turn 0 generates an initial image from your prompt. Each subsequent turn reuses the previous
image output as an input image (image-to-image) while you provide a new prompt to refine it.
"""
previous_image: str | None = None

def generate_batch(client: xai_sdk.Client, image_format: ImageFormat):
"""Generate a batch of images from a prompt."""
for turn in itertools.count():
prompt = input("Prompt: ")
results = client.image.sample_batch(prompt, n=N.value, model="grok-2-image", image_format=image_format)
save_images(turn, results)
if previous_image is None:
prompt = input("Prompt (blank to stop): ")
else:
prompt = input("Edit prompt (blank to stop): ")
if not prompt:
return

if N.value == 1:
response = client.image.sample(
prompt,
model=MODEL.value,
image_format=image_format,
image_url=previous_image,
)
responses = [response]
else:
responses = client.image.sample_batch(
prompt,
n=N.value,
model=MODEL.value,
image_format=image_format,
image_url=previous_image,
)

save_images(turn, responses)

selected = 0
if len(responses) > 1:
raw = input(f"Continue from which image? [0-{len(responses) - 1}] (default 0): ").strip()
if raw:
selected = int(raw)
if selected < 0 or selected >= len(responses):
raise ValueError(f"Invalid image index {selected}.")

chosen = responses[selected]
previous_image = chosen.url if image_format == "url" else chosen.base64

if len(responses) > 1:
if image_format == "url":
print(f"Continuing from image {selected}: {chosen.url}")
else:
print(f"Continuing from image {selected} (base64).")


def save_images(turn: int, responses: Sequence[ImageResponse]):
"""Save images to a file."""
for i, image in enumerate(responses):
print(image.prompt)
with open(f"{OUTPUT_DIR.value}/image_{turn}_{i}.jpg", "wb") as f:
f.write(image.image)

Expand All @@ -41,16 +78,8 @@ def main(argv: Sequence[str]) -> None:
raise app.UsageError("Unexpected command line arguments.")

client = xai_sdk.Client()

match (N.value, FORMAT.value):
case (1, "base64"):
generate_single(client, "base64")
case (1, "url"):
generate_single(client, "url")
case (_, "base64"):
generate_batch(client, "base64")
case (_, "url"):
generate_batch(client, "url")
image_format: ImageFormat = cast(ImageFormat, FORMAT.value)
generate_multi_turn(client, image_format)


if __name__ == "__main__":
Expand Down
83 changes: 83 additions & 0 deletions examples/sync/video_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from datetime import timedelta
from typing import Sequence, cast

from absl import app, flags

import xai_sdk
from xai_sdk.video import VideoAspectRatio, VideoResolution

MODEL = flags.DEFINE_string("model", "grok-imagine-video", "Video generation model to use.")
IMAGE_URL = flags.DEFINE_string(
"image-url",
"",
"Optional input image (URL or base64 data URL) to use as the first frame (image-to-video).",
)
VIDEO_URL = flags.DEFINE_string(
"video-url",
"",
"Optional input video (URL or base64 data URL) to edit based on the prompt (video-to-video).",
)
DURATION = flags.DEFINE_integer("duration", 0, "Optional duration in seconds (1-15). Use 0 to omit.")
ASPECT_RATIO = flags.DEFINE_string(
"aspect-ratio",
"",
'Optional aspect ratio. One of: "1:1", "16:9", "9:16", "4:3", "3:4", "3:2", "2:3".',
)
RESOLUTION = flags.DEFINE_string("resolution", "", 'Optional resolution. One of: "480p", "720p".')
TIMEOUT = flags.DEFINE_integer("timeout", 600, "Timeout in seconds for polling.")
INTERVAL = flags.DEFINE_integer("interval", 1, "Polling interval in seconds.")


def main(argv: Sequence[str]) -> None:
if len(argv) > 1:
raise app.UsageError("Unexpected command line arguments.")

client = xai_sdk.Client()

duration = DURATION.value or None
image_url = IMAGE_URL.value or None
video_url = VIDEO_URL.value or None
aspect_ratio = cast(VideoAspectRatio, ASPECT_RATIO.value) if ASPECT_RATIO.value else None
resolution = cast(VideoResolution, RESOLUTION.value) if RESOLUTION.value else None

previous_video_url: str | None = video_url
first_turn = True

while True:
prompt = input("Prompt (blank to stop): " if first_turn else "Edit prompt (blank to stop): ")
if not prompt:
return

try:
response = client.video.generate(
prompt=prompt,
model=MODEL.value,
image_url=image_url if first_turn else None,
video_url=previous_video_url,
duration=duration,
aspect_ratio=aspect_ratio,
resolution=resolution,
timeout=timedelta(seconds=TIMEOUT.value),
interval=timedelta(seconds=INTERVAL.value),
)
print(f"Respects moderation: {response.respect_moderation}")
if response.respect_moderation:
print(f"Video URL: {response.url}")
else:
print("Video URL not returned due to moderation.")
print(f"Duration: {response.duration}s")

# Chain edits: use the returned URL as the next input video.
if response.respect_moderation:
previous_video_url = response.url
first_turn = False
except RuntimeError as e:
# request expired
print(e)
except ValueError as e:
# unknown deferred status
print(e)


if __name__ == "__main__":
app.run(main)
15 changes: 13 additions & 2 deletions src/xai_sdk/aio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
from . import auth, batch, chat, client, collections, image, models, tokenizer
from . import auth, batch, chat, client, collections, files, image, models, tokenizer, video

__all__ = ["auth", "batch", "chat", "client", "collections", "image", "models", "tokenizer"]
__all__ = [
"auth",
"batch",
"chat",
"client",
"collections",
"files",
"image",
"models",
"tokenizer",
"video",
]
Loading