Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

import torch

from typing import Union
from PIL import Image

from ltx_core.components.diffusion_steps import EulerDiffusionStep
from ltx_core.components.guiders import MultiModalGuider, MultiModalGuiderParams
from ltx_core.components.noisers import GaussianNoiser
Expand Down Expand Up @@ -91,7 +94,7 @@ def __call__( # noqa: PLR0913
num_inference_steps: int,
video_guider_params: MultiModalGuiderParams,
audio_guider_params: MultiModalGuiderParams,
images: list[tuple[str, int, float]],
images: list[tuple[Union[str, Image.Image], int, float]],
tiling_config: TilingConfig | None = None,
enhance_prompt: bool = False,
) -> tuple[Iterator[torch.Tensor], torch.Tensor]:
Expand Down
6 changes: 4 additions & 2 deletions packages/ltx-pipelines/src/ltx_pipelines/utils/helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import gc
import logging
from dataclasses import replace
from typing import Union
from PIL import Image

import torch
from tqdm import tqdm
Expand Down Expand Up @@ -46,7 +48,7 @@ def cleanup_memory() -> None:


def image_conditionings_by_replacing_latent(
images: list[tuple[str, int, float]],
images: list[tuple[Union[str, Image.Image], int, float]],
height: int,
width: int,
video_encoder: VideoEncoder,
Expand Down Expand Up @@ -75,7 +77,7 @@ def image_conditionings_by_replacing_latent(


def image_conditionings_by_adding_guiding_latent(
images: list[tuple[str, int, float]],
images: list[tuple[Union[str, Image.Image], int, float]],
height: int,
width: int,
video_encoder: VideoEncoder,
Expand Down
11 changes: 9 additions & 2 deletions packages/ltx-pipelines/src/ltx_pipelines/utils/media_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections.abc import Generator, Iterator
from fractions import Fraction
from io import BytesIO
from typing import Union

import av
import numpy as np
Expand Down Expand Up @@ -79,13 +80,16 @@ def normalize_latent(latent: torch.Tensor, device: torch.device, dtype: torch.dt


def load_image_conditioning(
image_path: str, height: int, width: int, dtype: torch.dtype, device: torch.device
image_path: Union[str, Image.Image], height: int, width: int, dtype: torch.dtype, device: torch.device
) -> torch.Tensor:
"""
Loads an image from a path and preprocesses it for conditioning.
Note: The image is resized to the nearest multiple of 2 for compatibility with video codecs.
"""
image = decode_image(image_path=image_path)
if isinstance(image_path, Image.Image):
image = decode_image_pil(image_path=image_path)
else:
image = decode_image(image_path=image_path)
image = preprocess(image=image)
image = torch.tensor(image, dtype=torch.float32, device=device)
image = resize_and_center_crop(image, height, width)
Expand Down Expand Up @@ -114,6 +118,9 @@ def decode_image(image_path: str) -> np.ndarray:
np_array = np.array(image)[..., :3]
return np_array

def decode_image_pil(image: Image.Image) -> np.ndarray:
np_array = np.array(image)[..., :3]
return np_array

def _write_audio(
container: av.container.Container, audio_stream: av.audio.AudioStream, samples: torch.Tensor, audio_sample_rate: int
Expand Down