From 8ad1671654ea3af715183a8eb8cd655a35a80ad5 Mon Sep 17 00:00:00 2001 From: F4k3r22 Date: Sun, 22 Feb 2026 10:22:44 -0600 Subject: [PATCH] support PIL Image objects in image conditioning functions --- .../src/ltx_pipelines/ti2vid_two_stages.py | 5 ++++- .../ltx-pipelines/src/ltx_pipelines/utils/helpers.py | 6 ++++-- .../ltx-pipelines/src/ltx_pipelines/utils/media_io.py | 11 +++++++++-- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py b/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py index 1369454e..941e9c61 100644 --- a/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py +++ b/packages/ltx-pipelines/src/ltx_pipelines/ti2vid_two_stages.py @@ -3,6 +3,9 @@ import torch +from typing import Union +from PIL import Image + from ltx_core.components.diffusion_steps import EulerDiffusionStep from ltx_core.components.guiders import MultiModalGuider, MultiModalGuiderParams from ltx_core.components.noisers import GaussianNoiser @@ -91,7 +94,7 @@ def __call__( # noqa: PLR0913 num_inference_steps: int, video_guider_params: MultiModalGuiderParams, audio_guider_params: MultiModalGuiderParams, - images: list[tuple[str, int, float]], + images: list[tuple[Union[str, Image.Image], int, float]], tiling_config: TilingConfig | None = None, enhance_prompt: bool = False, ) -> tuple[Iterator[torch.Tensor], torch.Tensor]: diff --git a/packages/ltx-pipelines/src/ltx_pipelines/utils/helpers.py b/packages/ltx-pipelines/src/ltx_pipelines/utils/helpers.py index 98577d45..55b5aab1 100644 --- a/packages/ltx-pipelines/src/ltx_pipelines/utils/helpers.py +++ b/packages/ltx-pipelines/src/ltx_pipelines/utils/helpers.py @@ -1,6 +1,8 @@ import gc import logging from dataclasses import replace +from typing import Union +from PIL import Image import torch from tqdm import tqdm @@ -46,7 +48,7 @@ def cleanup_memory() -> None: def image_conditionings_by_replacing_latent( - images: list[tuple[str, int, float]], + images: list[tuple[Union[str, Image.Image], int, float]], height: int, width: int, video_encoder: VideoEncoder, @@ -75,7 +77,7 @@ def image_conditionings_by_replacing_latent( def image_conditionings_by_adding_guiding_latent( - images: list[tuple[str, int, float]], + images: list[tuple[Union[str, Image.Image], int, float]], height: int, width: int, video_encoder: VideoEncoder, diff --git a/packages/ltx-pipelines/src/ltx_pipelines/utils/media_io.py b/packages/ltx-pipelines/src/ltx_pipelines/utils/media_io.py index 349499b1..5e612b79 100644 --- a/packages/ltx-pipelines/src/ltx_pipelines/utils/media_io.py +++ b/packages/ltx-pipelines/src/ltx_pipelines/utils/media_io.py @@ -3,6 +3,7 @@ from collections.abc import Generator, Iterator from fractions import Fraction from io import BytesIO +from typing import Union import av import numpy as np @@ -79,13 +80,16 @@ def normalize_latent(latent: torch.Tensor, device: torch.device, dtype: torch.dt def load_image_conditioning( - image_path: str, height: int, width: int, dtype: torch.dtype, device: torch.device + image_path: Union[str, Image.Image], height: int, width: int, dtype: torch.dtype, device: torch.device ) -> torch.Tensor: """ Loads an image from a path and preprocesses it for conditioning. Note: The image is resized to the nearest multiple of 2 for compatibility with video codecs. """ - image = decode_image(image_path=image_path) + if isinstance(image_path, Image.Image): + image = decode_image_pil(image_path=image_path) + else: + image = decode_image(image_path=image_path) image = preprocess(image=image) image = torch.tensor(image, dtype=torch.float32, device=device) image = resize_and_center_crop(image, height, width) @@ -114,6 +118,9 @@ def decode_image(image_path: str) -> np.ndarray: np_array = np.array(image)[..., :3] return np_array +def decode_image_pil(image: Image.Image) -> np.ndarray: + np_array = np.array(image)[..., :3] + return np_array def _write_audio( container: av.container.Container, audio_stream: av.audio.AudioStream, samples: torch.Tensor, audio_sample_rate: int