From 01ea09e00684b5520e1a937078ff720f0b93d818 Mon Sep 17 00:00:00 2001 From: wangshankun Date: Fri, 6 Feb 2026 06:52:46 +0000 Subject: [PATCH 1/4] v2.7 using format inputinfo --- configs/seko_talk/shot/rs2v/rs2v.json | 17 ++- lightx2v/models/runners/default_runner.py | 4 +- .../models/runners/wan/wan_audio_runner.py | 5 +- lightx2v/models/schedulers/scheduler.py | 2 - .../models/schedulers/wan/audio/scheduler.py | 3 +- .../wan/feature_caching/scheduler.py | 34 ++++- lightx2v/models/schedulers/wan/scheduler.py | 7 +- lightx2v/shot_runner/rs2v_infer.py | 64 +++++----- lightx2v/shot_runner/shot_base.py | 78 ++++++------ lightx2v/utils/input_info.py | 119 ++++++++++-------- lightx2v/utils/set_config.py | 2 - 11 files changed, 193 insertions(+), 142 deletions(-) diff --git a/configs/seko_talk/shot/rs2v/rs2v.json b/configs/seko_talk/shot/rs2v/rs2v.json index c071ffae8..8c766f0b6 100644 --- a/configs/seko_talk/shot/rs2v/rs2v.json +++ b/configs/seko_talk/shot/rs2v/rs2v.json @@ -1,11 +1,9 @@ { "model_cls": "seko_talk", "task": "rs2v", - "model_path":"/data/temp/SekoTalk-v2.7_beta1-bf16-step4", - "infer_steps": 4, + "model_path":"/data/temp/SekoTalk-v2.7_beta2-bf16-step4", "target_fps": 16, "audio_sr": 16000, - "resize_mode": "adaptive", "self_attn_1_type": "flash_attn3", "cross_attn_1_type": "flash_attn3", "cross_attn_2_type": "flash_attn3", @@ -14,5 +12,16 @@ "enable_cfg": false, "use_31_block": true, "target_video_length": 81, - "prev_frame_length": 0 + "prev_frame_length": 0, + + "default_input_info": + { + "infer_steps": 4, + "resize_mode": "adaptive", + "prompt": "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze.", + "negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + "image_path": "assets/inputs/audio/seko_input.png", + "audio_path": "assets/inputs/audio/seko_input.mp3", + "save_result_path": "save_results/output_seko_talk_shot_rs2v.mp4" + } } diff --git a/lightx2v/models/runners/default_runner.py b/lightx2v/models/runners/default_runner.py index d62963368..7003cd358 100755 --- a/lightx2v/models/runners/default_runner.py +++ b/lightx2v/models/runners/default_runner.py @@ -346,7 +346,9 @@ def init_run(self): self.model = self.load_transformer() self.model.set_scheduler(self.scheduler) - self.model.scheduler.prepare(seed=self.input_info.seed, latent_shape=self.input_info.latent_shape, image_encoder_output=self.inputs["image_encoder_output"]) + self.model.scheduler.prepare( + seed=self.input_info.seed, latent_shape=self.input_info.latent_shape, infer_steps=self.input_info.infer_steps, image_encoder_output=self.inputs["image_encoder_output"] + ) if self.config.get("model_cls") == "wan2.2" and self.config["task"] in ["i2v", "s2v", "rs2v"]: self.inputs["image_encoder_output"]["vae_encoder_out"] = None diff --git a/lightx2v/models/runners/wan/wan_audio_runner.py b/lightx2v/models/runners/wan/wan_audio_runner.py index 254257154..1f2af6282 100755 --- a/lightx2v/models/runners/wan/wan_audio_runner.py +++ b/lightx2v/models/runners/wan/wan_audio_runner.py @@ -817,7 +817,10 @@ def run_clip(self): def run_clip_main(self): self.scheduler.set_audio_adapter(self.audio_adapter) - self.model.scheduler.prepare(seed=self.input_info.seed, latent_shape=self.input_info.latent_shape, image_encoder_output=self.inputs["image_encoder_output"]) + self.model.scheduler.prepare( + seed=self.input_info.seed, latent_shape=self.input_info.latent_shape, infer_steps=self.input_info.infer_steps, image_encoder_output=self.inputs["image_encoder_output"] + ) + if self.config.get("model_cls") == "wan2.2" and self.config["task"] in ["i2v", "s2v", "rs2v"]: self.inputs["image_encoder_output"]["vae_encoder_out"] = None diff --git a/lightx2v/models/schedulers/scheduler.py b/lightx2v/models/schedulers/scheduler.py index a14403d0d..b29786f47 100755 --- a/lightx2v/models/schedulers/scheduler.py +++ b/lightx2v/models/schedulers/scheduler.py @@ -6,8 +6,6 @@ def __init__(self, config): self.config = config self.latents = None self.step_index = 0 - self.infer_steps = config["infer_steps"] - self.caching_records = [True] * config["infer_steps"] self.flag_df = False self.transformer_infer = None self.infer_condition = True # cfg status diff --git a/lightx2v/models/schedulers/wan/audio/scheduler.py b/lightx2v/models/schedulers/wan/audio/scheduler.py index 4fcb584ec..4844dfa6c 100755 --- a/lightx2v/models/schedulers/wan/audio/scheduler.py +++ b/lightx2v/models/schedulers/wan/audio/scheduler.py @@ -72,8 +72,9 @@ def prepare_latents(self, seed, latent_shape, dtype=torch.float32): if self.prev_latents is not None: self.latents = (1.0 - self.mask) * self.prev_latents + self.mask * self.latents - def prepare(self, seed, latent_shape, image_encoder_output=None): + def prepare(self, seed, latent_shape, infer_steps, image_encoder_output=None): self.prepare_latents(seed, latent_shape, dtype=torch.float32) + self.infer_steps = infer_steps timesteps = np.linspace(self.num_train_timesteps, 0, self.infer_steps + 1, dtype=np.float32) self.timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=AI_DEVICE) diff --git a/lightx2v/models/schedulers/wan/feature_caching/scheduler.py b/lightx2v/models/schedulers/wan/feature_caching/scheduler.py index c306ab87e..5d6e31683 100755 --- a/lightx2v/models/schedulers/wan/feature_caching/scheduler.py +++ b/lightx2v/models/schedulers/wan/feature_caching/scheduler.py @@ -1,4 +1,8 @@ +import numpy as np +import torch + from lightx2v.models.schedulers.wan.scheduler import WanScheduler +from lightx2v_platform.base.global_var import AI_DEVICE class WanSchedulerCaching(WanScheduler): @@ -13,6 +17,32 @@ class WanSchedulerTaylorCaching(WanSchedulerCaching): def __init__(self, config): super().__init__(config) + def prepare(self, seed, latent_shape, infer_steps, image_encoder_output=None): + self.infer_steps = infer_steps pattern = [True, False, False, False] - self.caching_records = (pattern * ((config.infer_steps + 3) // 4))[: config.infer_steps] - self.caching_records_2 = (pattern * ((config.infer_steps + 3) // 4))[: config.infer_steps] + self.caching_records = (pattern * ((self.config.infer_steps + 3) // 4))[: self.config.infer_steps] + self.caching_records_2 = (pattern * ((self.config.infer_steps + 3) // 4))[: self.config.infer_steps] + + if self.config["model_cls"] == "wan2.2" and self.config["task"] in ["i2v", "s2v", "rs2v"]: + self.vae_encoder_out = image_encoder_output["vae_encoder_out"] + + self.prepare_latents(seed, latent_shape, dtype=torch.float32) + + alphas = np.linspace(1, 1 / self.num_train_timesteps, self.num_train_timesteps)[::-1].copy() + sigmas = 1.0 - alphas + sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) + + sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) + + self.sigmas = sigmas + self.timesteps = sigmas * self.num_train_timesteps + + self.model_outputs = [None] * self.solver_order + self.timestep_list = [None] * self.solver_order + self.last_sample = None + + self.sigmas = self.sigmas.to("cpu") + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() + + self.set_timesteps(self.infer_steps, device=AI_DEVICE, shift=self.sample_shift) diff --git a/lightx2v/models/schedulers/wan/scheduler.py b/lightx2v/models/schedulers/wan/scheduler.py index 19cc3f844..edf425281 100755 --- a/lightx2v/models/schedulers/wan/scheduler.py +++ b/lightx2v/models/schedulers/wan/scheduler.py @@ -11,7 +11,6 @@ class WanScheduler(BaseScheduler): def __init__(self, config): super().__init__(config) - self.infer_steps = self.config["infer_steps"] self.target_video_length = self.config["target_video_length"] self.sample_shift = self.config["sample_shift"] if self.config["seq_parallel"]: @@ -25,10 +24,12 @@ def __init__(self, config): self.solver_order = 2 self.noise_pred = None self.sample_guide_scale = self.config["sample_guide_scale"] - self.caching_records_2 = [True] * self.config["infer_steps"] self.head_size = self.config["dim"] // self.config["num_heads"] - def prepare(self, seed, latent_shape, image_encoder_output=None): + def prepare(self, seed, latent_shape, infer_steps, image_encoder_output=None): + self.infer_steps = infer_steps + self.caching_records_2 = [True] * infer_steps + if self.config["model_cls"] == "wan2.2" and self.config["task"] in ["i2v", "s2v", "rs2v"]: self.vae_encoder_out = image_encoder_output["vae_encoder_out"] diff --git a/lightx2v/shot_runner/rs2v_infer.py b/lightx2v/shot_runner/rs2v_infer.py index ed35f9e39..a50d22224 100755 --- a/lightx2v/shot_runner/rs2v_infer.py +++ b/lightx2v/shot_runner/rs2v_infer.py @@ -6,8 +6,9 @@ import torchaudio as ta from loguru import logger -from lightx2v.shot_runner.shot_base import ShotConfig, ShotPipeline, load_clip_configs +from lightx2v.shot_runner.shot_base import ShotPipeline, load_clip_configs from lightx2v.shot_runner.utils import RS2V_SlidingWindowReader, save_audio, save_to_video +from lightx2v.utils.input_info import init_input_info_from_args from lightx2v.utils.profiler import * from lightx2v.utils.utils import is_main_process, seed_all, vae_to_comfyui_image @@ -22,24 +23,27 @@ def get_reference_state_sequence(frames_per_clip=17, target_fps=16): class ShotRS2VPipeline(ShotPipeline): # type:ignore - def __init__(self, config): - super().__init__(config) + def __init__(self, clip_configs): + super().__init__(clip_configs) @torch.no_grad() - def generate(self): + def generate(self, args): rs2v = self.clip_generators["rs2v_clip"] + # 获取clip模型配置信息 target_video_length = rs2v.config.get("target_video_length", 81) target_fps = rs2v.config.get("target_fps", 16) audio_sr = rs2v.config.get("audio_sr", 16000) - video_duration = rs2v.config.get("video_duration", None) audio_per_frame = audio_sr // target_fps - # 根据 pipe 最长 overlap_len 初始化 tail buffer + # 获取用户输入信息 + clip_input_info = init_input_info_from_args(rs2v.config["task"], args, infer_steps=3, video_duration=20) + # 从默认配置中补全输入信息 + clip_input_info = self.check_input_info(clip_input_info, rs2v.config) gen_video_list = [] cut_audio_list = [] - - audio_array, ori_sr = ta.load(self.shot_cfg.audio_path) + video_duration = clip_input_info.video_duration + audio_array, ori_sr = ta.load(clip_input_info.audio_path) audio_array = audio_array.mean(0) if ori_sr != audio_sr: audio_array = ta.functional.resample(audio_array, ori_sr, audio_sr) @@ -68,42 +72,42 @@ def generate(self): is_last = True if pad_len > 0 else False pipe = rs2v - inputs = self.clip_inputs["rs2v_clip"] - inputs.is_first = is_first - inputs.is_last = is_last - inputs.ref_state = ref_state_sq[idx % len(ref_state_sq)] - inputs.seed = inputs.seed + idx - inputs.audio_clip = audio_clip + + clip_input_info.is_first = is_first + clip_input_info.is_last = is_last + clip_input_info.ref_state = ref_state_sq[idx % len(ref_state_sq)] + clip_input_info.seed = clip_input_info.seed + idx + clip_input_info.audio_clip = audio_clip idx = idx + 1 if self.progress_callback: self.progress_callback(idx, total_clips) - gen_clip_video, audio_clip, gen_latents = pipe.run_clip_pipeline(inputs) + gen_clip_video, audio_clip, gen_latents = pipe.run_clip_pipeline(clip_input_info) logger.info(f"Generated rs2v clip {idx}, pad_len {pad_len}, gen_clip_video shape: {gen_clip_video.shape}, audio_clip shape: {audio_clip.shape} gen_latents shape: {gen_latents.shape}") video_pad_len = pad_len // audio_per_frame gen_video_list.append(gen_clip_video[:, :, : gen_clip_video.shape[2] - video_pad_len].clone()) cut_audio_list.append(audio_clip[: audio_clip.shape[0] - pad_len]) - inputs.overlap_latent = gen_latents[:, -1:] + clip_input_info.overlap_latent = gen_latents[:, -1:] gen_lvideo = torch.cat(gen_video_list, dim=2).float() gen_lvideo = torch.clamp(gen_lvideo, -1, 1) merge_audio = np.concatenate(cut_audio_list, axis=0).astype(np.float32) - if is_main_process() and self.shot_cfg.save_result_path: + if is_main_process() and clip_input_info.save_result_path: out_path = os.path.join("./", "video_merge.mp4") audio_file = os.path.join("./", "audio_merge.wav") save_to_video(gen_lvideo, out_path, 16) - save_audio(merge_audio, audio_file, out_path, output_path=self.shot_cfg.save_result_path) + save_audio(merge_audio, audio_file, out_path, output_path=clip_input_info.save_result_path) os.remove(out_path) os.remove(audio_file) return gen_lvideo, merge_audio, audio_sr def run_pipeline(self, input_info): - self.update_input_info(input_info) - gen_lvideo, merge_audio, audio_sr = self.generate() + # input_info = self.update_input_info(input_info) + gen_lvideo, merge_audio, audio_sr = self.generate(input_info) if isinstance(input_info, dict): return_result_tensor = input_info.get("return_result_tensor", False) else: @@ -130,23 +134,13 @@ def main(): args = parser.parse_args() seed_all(args.seed) - clip_configs = load_clip_configs(args.config_json) - shot_cfg = ShotConfig( - seed=args.seed, - image_path=args.image_path, - audio_path=args.audio_path, - prompt=args.prompt, - negative_prompt=args.negative_prompt, - save_result_path=args.save_result_path, - clip_configs=clip_configs, - target_shape=args.target_shape, - ) - - with ProfilingContext4DebugL1("Total Cost"): - shot_stream_pipe = ShotRS2VPipeline(shot_cfg) - shot_stream_pipe.generate() + with ProfilingContext4DebugL1("Init Pipeline Cost Time"): + shot_rs2v_pipe = ShotRS2VPipeline(clip_configs) + + with ProfilingContext4DebugL1("Generate Cost Time"): + shot_rs2v_pipe.generate(args) # Clean up distributed process group if dist.is_initialized(): diff --git a/lightx2v/shot_runner/shot_base.py b/lightx2v/shot_runner/shot_base.py index 5a477ed65..38d0eb4a6 100755 --- a/lightx2v/shot_runner/shot_base.py +++ b/lightx2v/shot_runner/shot_base.py @@ -7,29 +7,17 @@ import torch from loguru import logger -from lightx2v.utils.input_info import init_empty_input_info, update_input_info_from_dict +from lightx2v.utils.input_info import fill_input_info_from_defaults, init_empty_input_info from lightx2v.utils.profiler import * from lightx2v.utils.registry_factory import RUNNER_REGISTER -from lightx2v.utils.set_config import auto_calc_config, get_default_config, print_config, set_parallel_config +from lightx2v.utils.set_config import print_config, set_config, set_parallel_config from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER @dataclass class ClipConfig: name: str - config_json: str | dict[str, Any] - - -@dataclass -class ShotConfig: - seed: int - image_path: str - audio_path: str - prompt: str - negative_prompt: str - save_result_path: str - clip_configs: list[ClipConfig] - target_shape: list[int] + config_json: dict[str, Any] def get_config_json(config_json): @@ -58,13 +46,11 @@ def load_clip_configs(main_json_path: str): clip_configs = [] for item in clip_configs_raw: if "config" in item: - config_json = item["config"] + config = item["config"] else: config_json = str(Path(lightx2v_path) / item["path"]) - - default_config = get_default_config() - default_config.update(get_config_json(config_json)) - config = auto_calc_config(default_config) + config_json = {"config_json": config_json} + config = set_config(Namespace(**config_json)) if "parallel" in cfg: # Add parallel config to clip json config["parallel"] = cfg["parallel"] @@ -75,34 +61,21 @@ def load_clip_configs(main_json_path: str): class ShotPipeline: - def __init__(self, shot_cfg: ShotConfig): - self.shot_cfg = shot_cfg + def __init__(self, clip_configs: list[ClipConfig]): + # self.clip_configs = clip_configs self.clip_generators = {} self.clip_inputs = {} - self.overlap_frame = None - self.overlap_latent = None self.progress_callback = None - for clip_config in shot_cfg.clip_configs: + for clip_config in clip_configs: name = clip_config.name self.clip_generators[name] = self.create_clip_generator(clip_config) - args = Namespace( - seed=self.shot_cfg.seed, - prompt=self.shot_cfg.prompt, - negative_prompt=self.shot_cfg.negative_prompt, - image_path=self.shot_cfg.image_path, - audio_path=self.shot_cfg.audio_path, - save_result_path=self.shot_cfg.save_result_path, - task=self.clip_generators[name].task, - return_result_tensor=True, - overlap_frame=self.overlap_frame, - overlap_latent=self.overlap_latent, - target_shape=self.shot_cfg.target_shape, - ) - input_info = init_empty_input_info(self.clip_generators[name].task) - update_input_info_from_dict(input_info, vars(args)) - self.clip_inputs[name] = input_info + def check_input_info(self, user_input_info, clip_config): + default_input_info = clip_config.get("default_input_info", None) + if default_input_info is not None: + fill_input_info_from_defaults(user_input_info, default_input_info) + return user_input_info.normalize_unset_to_none() def _input_data_to_dict(self, input_data): if isinstance(input_data, dict): @@ -145,15 +118,36 @@ def set_progress_callback(self, callback): self.progress_callback = callback def create_clip_generator(self, clip_config: ClipConfig): + print_config(clip_config.config_json) runner = self._init_runner(clip_config.config_json) logger.info(f"Clip {clip_config.name} initialized successfully!") - print_config(clip_config.config_json) + return runner @torch.no_grad() def generate(self): pass + def set_inputs(self, args): + args = Namespace( + seed=self.shot_cfg.seed, + prompt=self.shot_cfg.prompt, + negative_prompt=self.shot_cfg.negative_prompt, + image_path=self.shot_cfg.image_path, + audio_path=self.shot_cfg.audio_path, + save_result_path=self.shot_cfg.save_result_path, + task=self.clip_generators[name].task, + return_result_tensor=True, + overlap_frame=self.overlap_frame, + overlap_latent=self.overlap_latent, + target_shape=self.shot_cfg.target_shape, + ) + input_info = init_empty_input_info(self.clip_generators[name].task) + update_input_info_from_dict(input_info, vars(args)) + self.clip_inputs[name] = input_info + + return input_info + def run_pipeline(self, input_info): self.update_input_info(input_info) return self.generate() diff --git a/lightx2v/utils/input_info.py b/lightx2v/utils/input_info.py index 3ea8c3367..d75ff2e62 100755 --- a/lightx2v/utils/input_info.py +++ b/lightx2v/utils/input_info.py @@ -1,9 +1,18 @@ import inspect -from dataclasses import dataclass, field +from dataclasses import dataclass, field, fields +from typing import Any import torch +class _UnsetType: + def __repr__(self): + return "UNSET" + + +UNSET = _UnsetType() + + @dataclass class T2VInputInfo: seed: int = field(default_factory=int) @@ -77,61 +86,58 @@ class VaceInputInfo: @dataclass class S2VInputInfo: - seed: int = field(default_factory=int) - prompt: str = field(default_factory=str) - prompt_enhanced: str = field(default_factory=str) - negative_prompt: str = field(default_factory=str) - image_path: str = field(default_factory=str) - audio_path: str = field(default_factory=str) - audio_num: int = field(default_factory=int) - with_mask: bool = field(default_factory=lambda: False) - save_result_path: str = field(default_factory=str) - return_result_tensor: bool = field(default_factory=lambda: False) - stream_config: dict = field(default_factory=dict) - # shape related - resize_mode: str = field(default_factory=str) - original_shape: list = field(default_factory=list) - resized_shape: list = field(default_factory=list) - latent_shape: list = field(default_factory=list) - target_shape: list = field(default_factory=list) - + infer_steps: int | Any = UNSET + seed: int | Any = UNSET + prompt: str | Any = UNSET + prompt_enhanced: str | Any = UNSET + negative_prompt: str | Any = UNSET + image_path: str | Any = UNSET + audio_path: str | Any = UNSET + audio_num: int | Any = UNSET + video_duration: float | Any = UNSET + with_mask: bool | Any = UNSET + return_result_tensor: bool | Any = UNSET + save_result_path: str | Any = UNSET + return_result_tensor: bool | Any = UNSET + stream_config: dict | Any = UNSET + resize_mode: str | Any = UNSET + target_shape: list | Any = UNSET # prev info - overlap_frame: torch.Tensor = field(default_factory=lambda: None) - overlap_latent: torch.Tensor = field(default_factory=lambda: None) + overlap_frame: torch.Tensor | Any = UNSET + overlap_latent: torch.Tensor | Any = UNSET # input preprocess audio - audio_clip: torch.Tensor = field(default_factory=lambda: None) + audio_clip: torch.Tensor | Any = UNSET + + @classmethod + def from_args(cls, args, **overrides): + """ + Build InputInfo from argparse.Namespace (or any object with __dict__) + Priority: + args < overrides + """ + field_names = {f.name for f in fields(cls)} + data = {k: v for k, v in vars(args).items() if k in field_names} + data.update(overrides) + return cls(**data) + + def normalize_unset_to_none(self): + """ + Replace all UNSET fields with None. + Call this right before running / inference. + """ + for f in fields(self): + if getattr(self, f.name) is UNSET: + setattr(self, f.name, None) + return self @dataclass -class RS2VInputInfo: - seed: int = field(default_factory=int) - prompt: str = field(default_factory=str) - prompt_enhanced: str = field(default_factory=str) - negative_prompt: str = field(default_factory=str) - image_path: str = field(default_factory=str) - audio_path: str = field(default_factory=str) - audio_num: int = field(default_factory=int) - with_mask: bool = field(default_factory=lambda: False) - save_result_path: str = field(default_factory=str) - return_result_tensor: bool = field(default_factory=lambda: False) - stream_config: dict = field(default_factory=dict) - # shape related - resize_mode: str = field(default_factory=str) - original_shape: list = field(default_factory=list) - resized_shape: list = field(default_factory=list) - latent_shape: list = field(default_factory=list) - target_shape: list = field(default_factory=list) - - # prev info - overlap_frame: torch.Tensor = field(default_factory=lambda: None) - overlap_latent: torch.Tensor = field(default_factory=lambda: None) - # input preprocess audio - audio_clip: torch.Tensor = field(default_factory=lambda: None) +class RS2VInputInfo(S2VInputInfo): # input reference state - ref_state: int = field(default_factory=int) + ref_state: int | Any = UNSET # flags for first and last clip - is_first: bool = field(default_factory=lambda: False) - is_last: bool = field(default_factory=lambda: False) + is_first: bool | Any = UNSET + is_last: bool | Any = UNSET # Need Check @@ -274,6 +280,15 @@ class WorldPlayT2VInputInfo: action: torch.Tensor = field(default_factory=lambda: None) +def init_input_info_from_args(task, args, **overrides): + if task == "s2v": + return S2VInputInfo.from_args(args, **overrides) + elif task == "rs2v": + return RS2VInputInfo.from_args(args, **overrides) + else: + raise ValueError(f"Unsupported task: {task}") + + def init_empty_input_info(task): if task == "t2v": return T2VInputInfo() @@ -305,6 +320,12 @@ def init_empty_input_info(task): raise ValueError(f"Unsupported task: {task}") +def fill_input_info_from_defaults(input_info, defaults): + for key in input_info.__dataclass_fields__: + if key in defaults and getattr(input_info, key) is UNSET: + setattr(input_info, key, defaults[key]) + + def update_input_info_from_dict(input_info, data): for key in input_info.__dataclass_fields__: if key in data: diff --git a/lightx2v/utils/set_config.py b/lightx2v/utils/set_config.py index db1d78168..e9ef2b9ac 100755 --- a/lightx2v/utils/set_config.py +++ b/lightx2v/utils/set_config.py @@ -47,7 +47,6 @@ def auto_calc_config(config): with open(config["config_json"], "r") as f: config_json = json.load(f) config.update(config_json) - if config["model_cls"] in ["hunyuan_video_1.5", "hunyuan_video_1.5_distill"]: # Special config for hunyuan video 1.5 model folder structure config["transformer_model_path"] = os.path.join(config["model_path"], "transformer", config["transformer_model_name"]) # transformer_model_name: [480p_t2v, 480p_i2v, 720p_t2v, 720p_i2v] if os.path.exists(os.path.join(config["transformer_model_path"], "config.json")): @@ -89,7 +88,6 @@ def auto_calc_config(config): elif os.path.exists(os.path.join(config["model_path"], "transformer", "config.json")): with open(os.path.join(config["model_path"], "transformer", "config.json"), "r") as f: model_config = json.load(f) - if config["model_cls"] == "z_image": # https://huggingface.co/Tongyi-MAI/Z-Image-Turbo/blob/main/transformer/config.json z_image_patch_size = model_config.pop("all_patch_size", [2]) From 709d6a084cfc4598736fa6319e69b8414c7b2360 Mon Sep 17 00:00:00 2001 From: wangshankun Date: Fri, 6 Feb 2026 10:01:38 +0000 Subject: [PATCH 2/4] shot stream infer format inputinfo --- configs/seko_talk/shot/stream/f2v.json | 18 +++++--- configs/seko_talk/shot/stream/s2v.json | 17 ++++++-- lightx2v/shot_runner/rs2v_infer.py | 2 +- lightx2v/shot_runner/shot_base.py | 1 + lightx2v/shot_runner/stream_infer.py | 58 ++++++++++++++------------ lightx2v/utils/set_config.py | 3 ++ 6 files changed, 62 insertions(+), 37 deletions(-) diff --git a/configs/seko_talk/shot/stream/f2v.json b/configs/seko_talk/shot/stream/f2v.json index 73e5f5482..ce6558f7f 100644 --- a/configs/seko_talk/shot/stream/f2v.json +++ b/configs/seko_talk/shot/stream/f2v.json @@ -1,11 +1,9 @@ { "model_cls": "seko_talk", "task": "s2v", - "model_path":"Wan2.1-i2V1202-Audio-14B-720P/", - "infer_steps": 4, + "model_path":"/data/temp/Wan2.1-i2V1202-Audio-14B-720P/", "target_fps": 16, "audio_sr": 16000, - "resize_mode": "adaptive", "self_attn_1_type": "flash_attn3", "cross_attn_1_type": "flash_attn3", "cross_attn_2_type": "flash_attn3", @@ -28,8 +26,18 @@ "audio_adapter_cpu_offload": false, "lora_configs": [ { - "path": "lightx2v_I2V_14B_480p_cfg_step_distill_rank32_bf16.safetensors", + "path": "/data/temp/lightx2v_I2V_14B_480p_cfg_step_distill_rank32_bf16.safetensors", "strength": 1.0 } - ] + ], + "default_input_info": + { + "infer_steps": 4, + "resize_mode": "adaptive", + "prompt": "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze.", + "negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + "image_path": "assets/inputs/audio/seko_input.png", + "audio_path": "assets/inputs/audio/seko_input.mp3", + "save_result_path": "save_results/output_seko_talk_shot_stream.mp4" + } } diff --git a/configs/seko_talk/shot/stream/s2v.json b/configs/seko_talk/shot/stream/s2v.json index bf64bcdd0..627dbc34d 100644 --- a/configs/seko_talk/shot/stream/s2v.json +++ b/configs/seko_talk/shot/stream/s2v.json @@ -1,11 +1,9 @@ { "model_cls": "seko_talk", "task": "s2v", - "model_path":"Wan2.1-R2V721-Audio-14B-720P/", - "infer_steps": 4, + "model_path":"/data/temp/Wan2.1-R2V721-Audio-14B-720P/", "target_fps": 16, "audio_sr": 16000, - "resize_mode": "adaptive", "self_attn_1_type": "flash_attn3", "cross_attn_1_type": "flash_attn3", "cross_attn_2_type": "flash_attn3", @@ -24,5 +22,16 @@ "offload_ratio": 1, "use_tiling_vae": true, "audio_encoder_cpu_offload": true, - "audio_adapter_cpu_offload": false + "audio_adapter_cpu_offload": false, + + "default_input_info": + { + "infer_steps": 4, + "resize_mode": "adaptive", + "prompt": "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze.", + "negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + "image_path": "assets/inputs/audio/seko_input.png", + "audio_path": "assets/inputs/audio/seko_input.mp3", + "save_result_path": "save_results/output_seko_talk_shot_stream.mp4" + } } diff --git a/lightx2v/shot_runner/rs2v_infer.py b/lightx2v/shot_runner/rs2v_infer.py index a50d22224..dc4654686 100755 --- a/lightx2v/shot_runner/rs2v_infer.py +++ b/lightx2v/shot_runner/rs2v_infer.py @@ -29,7 +29,7 @@ def __init__(self, clip_configs): @torch.no_grad() def generate(self, args): rs2v = self.clip_generators["rs2v_clip"] - # 获取clip模型配置信息 + # 获取此clip模型的配置信息 target_video_length = rs2v.config.get("target_video_length", 81) target_fps = rs2v.config.get("target_fps", 16) audio_sr = rs2v.config.get("audio_sr", 16000) diff --git a/lightx2v/shot_runner/shot_base.py b/lightx2v/shot_runner/shot_base.py index 38d0eb4a6..0b568dd0d 100755 --- a/lightx2v/shot_runner/shot_base.py +++ b/lightx2v/shot_runner/shot_base.py @@ -118,6 +118,7 @@ def set_progress_callback(self, callback): self.progress_callback = callback def create_clip_generator(self, clip_config: ClipConfig): + logger.info(f"Clip {clip_config.name} initializing...") print_config(clip_config.config_json) runner = self._init_runner(clip_config.config_json) logger.info(f"Clip {clip_config.name} initialized successfully!") diff --git a/lightx2v/shot_runner/stream_infer.py b/lightx2v/shot_runner/stream_infer.py index 834bc0c5f..1e9cd4d16 100755 --- a/lightx2v/shot_runner/stream_infer.py +++ b/lightx2v/shot_runner/stream_infer.py @@ -7,8 +7,9 @@ import torchaudio as ta from loguru import logger -from lightx2v.shot_runner.shot_base import ShotConfig, ShotPipeline, load_clip_configs +from lightx2v.shot_runner.shot_base import ShotPipeline, load_clip_configs from lightx2v.shot_runner.utils import SlidingWindowReader, save_audio, save_to_video +from lightx2v.utils.input_info import init_input_info_from_args from lightx2v.utils.profiler import * from lightx2v.utils.utils import seed_all @@ -18,20 +19,32 @@ def __init__(self, config): super().__init__(config) @torch.no_grad() - def generate(self): + def generate(self, args): s2v = self.clip_generators["s2v_clip"] # s2v一致性强,动态相应差 f2v = self.clip_generators["f2v_clip"] # f2v一致性差,动态响应强 # 根据 pipe 最长 overlap_len 初始化 tail buffer - self.max_tail_len = max(s2v.prev_frame_length, f2v.prev_frame_length) + self.max_tail_len = max(s2v.config.get("prev_frame_length", None), f2v.config.get("prev_frame_length", None)) + model_fps = s2v.config.get("target_fps", 16) + model_sr = s2v.config.get("audio_sr", 16000) + + # 获取用户输入信息 + s2v_input_info = init_input_info_from_args(s2v.config["task"], args, infer_steps=3) + f2v_input_info = init_input_info_from_args(f2v.config["task"], args) + # 从默认配置中补全输入信息 + s2v_input_info = self.check_input_info(s2v_input_info, s2v.config) + f2v_input_info = self.check_input_info(f2v_input_info, f2v.config) + + assert s2v_input_info.audio_path == f2v_input_info.audio_path, "s2v and f2v must use the same audio input" + self.global_tail_video = None gen_video_list = [] cut_audio_list = [] - audio_array, ori_sr = ta.load(self.shot_cfg.audio_path) + audio_array, ori_sr = ta.load(args.audio_path) audio_array = audio_array.mean(0) - if ori_sr != 16000: - audio_array = ta.functional.resample(audio_array, ori_sr, 16000) + if ori_sr != model_sr: + audio_array = ta.functional.resample(audio_array, ori_sr, model_sr) audio_reader = SlidingWindowReader(audio_array, frame_len=33) # Demo 交替生成 clip @@ -44,21 +57,22 @@ def generate(self): if i % 2 == 0: pipe = s2v - inputs = self.clip_inputs["s2v_clip"] + inputs = s2v_input_info else: pipe = f2v - inputs = self.clip_inputs["f2v_clip"] + inputs = f2v_input_info inputs.prompt = "A man speaks to the camera with a slightly furrowed brow and focused gaze. He raises both hands upward in powerful, emphatic gestures. " # 添加动作提示 inputs.seed = inputs.seed + i # 不同 clip 使用不同随机种子 inputs.audio_clip = audio_clip i = i + 1 + # if i % 4 == 0: + # inputs.infer_steps = 2#s2v 一半时间用2步推理 + if self.global_tail_video is not None: # 根据当前 pipe 需要多少 overlap_len 来裁剪 tail inputs.overlap_frame = self.global_tail_video[:, :, -pipe.prev_frame_length :] - gen_clip_video, audio_clip, _ = pipe.run_clip_pipeline(inputs) - aligned_len = gen_clip_video.shape[2] - overlap gen_video_list.append(gen_clip_video[:, :, :aligned_len]) cut_audio_list.append(audio_clip[: aligned_len * audio_reader.audio_per_frame]) @@ -72,8 +86,8 @@ def generate(self): out_path = os.path.join("./", "video_merge.mp4") audio_file = os.path.join("./", "audio_merge.wav") - save_to_video(gen_lvideo, out_path, 16) - save_audio(merge_audio, audio_file, out_path, output_path=self.shot_cfg.save_result_path) + save_to_video(gen_lvideo, out_path, model_fps) + save_audio(merge_audio, audio_file, out_path, output_path=args.save_result_path) os.remove(out_path) os.remove(audio_file) @@ -92,23 +106,13 @@ def main(): args = parser.parse_args() seed_all(args.seed) - clip_configs = load_clip_configs(args.config_json) - shot_cfg = ShotConfig( - seed=args.seed, - image_path=args.image_path, - audio_path=args.audio_path, - prompt=args.prompt, - negative_prompt=args.negative_prompt, - save_result_path=args.save_result_path, - clip_configs=clip_configs, - target_shape=args.target_shape, - ) - - with ProfilingContext4DebugL1("Total Cost"): - shot_stream_pipe = ShotStreamPipeline(shot_cfg) - shot_stream_pipe.generate() + with ProfilingContext4DebugL1("Init Pipeline Cost Time"): + shot_stream_pipe = ShotStreamPipeline(clip_configs) + + with ProfilingContext4DebugL1("Generate Cost Time"): + shot_stream_pipe.generate(args) # Clean up distributed process group if dist.is_initialized(): diff --git a/lightx2v/utils/set_config.py b/lightx2v/utils/set_config.py index e9ef2b9ac..3017c9f67 100755 --- a/lightx2v/utils/set_config.py +++ b/lightx2v/utils/set_config.py @@ -47,6 +47,9 @@ def auto_calc_config(config): with open(config["config_json"], "r") as f: config_json = json.load(f) config.update(config_json) + + assert os.path.exists(config["model_path"]), f"Model path not found: {config['model_path']}" + if config["model_cls"] in ["hunyuan_video_1.5", "hunyuan_video_1.5_distill"]: # Special config for hunyuan video 1.5 model folder structure config["transformer_model_path"] = os.path.join(config["model_path"], "transformer", config["transformer_model_name"]) # transformer_model_name: [480p_t2v, 480p_i2v, 720p_t2v, 720p_i2v] if os.path.exists(os.path.join(config["transformer_model_path"], "config.json")): From f4fa027e993f72dd43274a5524ee7e81b2f86cee Mon Sep 17 00:00:00 2001 From: wangshankun Date: Fri, 6 Feb 2026 10:50:47 +0000 Subject: [PATCH 3/4] Add SekoTalkInputs --- lightx2v/models/runners/default_runner.py | 4 +- .../models/schedulers/wan/audio/scheduler.py | 22 ++- .../wan/feature_caching/scheduler.py | 34 +--- lightx2v/models/schedulers/wan/scheduler.py | 7 +- lightx2v/shot_runner/shot_base.py | 24 +-- lightx2v/utils/input_info.py | 164 ++++++++++++------ 6 files changed, 139 insertions(+), 116 deletions(-) diff --git a/lightx2v/models/runners/default_runner.py b/lightx2v/models/runners/default_runner.py index 7003cd358..d62963368 100755 --- a/lightx2v/models/runners/default_runner.py +++ b/lightx2v/models/runners/default_runner.py @@ -346,9 +346,7 @@ def init_run(self): self.model = self.load_transformer() self.model.set_scheduler(self.scheduler) - self.model.scheduler.prepare( - seed=self.input_info.seed, latent_shape=self.input_info.latent_shape, infer_steps=self.input_info.infer_steps, image_encoder_output=self.inputs["image_encoder_output"] - ) + self.model.scheduler.prepare(seed=self.input_info.seed, latent_shape=self.input_info.latent_shape, image_encoder_output=self.inputs["image_encoder_output"]) if self.config.get("model_cls") == "wan2.2" and self.config["task"] in ["i2v", "s2v", "rs2v"]: self.inputs["image_encoder_output"]["vae_encoder_out"] = None diff --git a/lightx2v/models/schedulers/wan/audio/scheduler.py b/lightx2v/models/schedulers/wan/audio/scheduler.py index 4844dfa6c..2c5277224 100755 --- a/lightx2v/models/schedulers/wan/audio/scheduler.py +++ b/lightx2v/models/schedulers/wan/audio/scheduler.py @@ -12,7 +12,27 @@ class EulerScheduler(WanScheduler): def __init__(self, config): - super().__init__(config) + self.config = config + self.latents = None + self.step_index = 0 + self.flag_df = False + self.transformer_infer = None + self.infer_condition = True # cfg status + self.keep_latents_dtype_in_scheduler = False + self.target_video_length = self.config["target_video_length"] + self.sample_shift = self.config["sample_shift"] + if self.config["seq_parallel"]: + self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p") + else: + self.seq_p_group = None + self.patch_size = (1, 2, 2) + self.shift = 1 + self.num_train_timesteps = 1000 + self.disable_corrector = [] + self.solver_order = 2 + self.noise_pred = None + self.sample_guide_scale = self.config["sample_guide_scale"] + self.head_size = self.config["dim"] // self.config["num_heads"] if self.config["parallel"]: self.sp_size = self.config["parallel"].get("seq_p_size", 1) diff --git a/lightx2v/models/schedulers/wan/feature_caching/scheduler.py b/lightx2v/models/schedulers/wan/feature_caching/scheduler.py index 5d6e31683..c306ab87e 100755 --- a/lightx2v/models/schedulers/wan/feature_caching/scheduler.py +++ b/lightx2v/models/schedulers/wan/feature_caching/scheduler.py @@ -1,8 +1,4 @@ -import numpy as np -import torch - from lightx2v.models.schedulers.wan.scheduler import WanScheduler -from lightx2v_platform.base.global_var import AI_DEVICE class WanSchedulerCaching(WanScheduler): @@ -17,32 +13,6 @@ class WanSchedulerTaylorCaching(WanSchedulerCaching): def __init__(self, config): super().__init__(config) - def prepare(self, seed, latent_shape, infer_steps, image_encoder_output=None): - self.infer_steps = infer_steps pattern = [True, False, False, False] - self.caching_records = (pattern * ((self.config.infer_steps + 3) // 4))[: self.config.infer_steps] - self.caching_records_2 = (pattern * ((self.config.infer_steps + 3) // 4))[: self.config.infer_steps] - - if self.config["model_cls"] == "wan2.2" and self.config["task"] in ["i2v", "s2v", "rs2v"]: - self.vae_encoder_out = image_encoder_output["vae_encoder_out"] - - self.prepare_latents(seed, latent_shape, dtype=torch.float32) - - alphas = np.linspace(1, 1 / self.num_train_timesteps, self.num_train_timesteps)[::-1].copy() - sigmas = 1.0 - alphas - sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32) - - sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas) - - self.sigmas = sigmas - self.timesteps = sigmas * self.num_train_timesteps - - self.model_outputs = [None] * self.solver_order - self.timestep_list = [None] * self.solver_order - self.last_sample = None - - self.sigmas = self.sigmas.to("cpu") - self.sigma_min = self.sigmas[-1].item() - self.sigma_max = self.sigmas[0].item() - - self.set_timesteps(self.infer_steps, device=AI_DEVICE, shift=self.sample_shift) + self.caching_records = (pattern * ((config.infer_steps + 3) // 4))[: config.infer_steps] + self.caching_records_2 = (pattern * ((config.infer_steps + 3) // 4))[: config.infer_steps] diff --git a/lightx2v/models/schedulers/wan/scheduler.py b/lightx2v/models/schedulers/wan/scheduler.py index edf425281..19cc3f844 100755 --- a/lightx2v/models/schedulers/wan/scheduler.py +++ b/lightx2v/models/schedulers/wan/scheduler.py @@ -11,6 +11,7 @@ class WanScheduler(BaseScheduler): def __init__(self, config): super().__init__(config) + self.infer_steps = self.config["infer_steps"] self.target_video_length = self.config["target_video_length"] self.sample_shift = self.config["sample_shift"] if self.config["seq_parallel"]: @@ -24,12 +25,10 @@ def __init__(self, config): self.solver_order = 2 self.noise_pred = None self.sample_guide_scale = self.config["sample_guide_scale"] + self.caching_records_2 = [True] * self.config["infer_steps"] self.head_size = self.config["dim"] // self.config["num_heads"] - def prepare(self, seed, latent_shape, infer_steps, image_encoder_output=None): - self.infer_steps = infer_steps - self.caching_records_2 = [True] * infer_steps - + def prepare(self, seed, latent_shape, image_encoder_output=None): if self.config["model_cls"] == "wan2.2" and self.config["task"] in ["i2v", "s2v", "rs2v"]: self.vae_encoder_out = image_encoder_output["vae_encoder_out"] diff --git a/lightx2v/shot_runner/shot_base.py b/lightx2v/shot_runner/shot_base.py index 0b568dd0d..eee41ea76 100755 --- a/lightx2v/shot_runner/shot_base.py +++ b/lightx2v/shot_runner/shot_base.py @@ -7,7 +7,7 @@ import torch from loguru import logger -from lightx2v.utils.input_info import fill_input_info_from_defaults, init_empty_input_info +from lightx2v.utils.input_info import fill_input_info_from_defaults from lightx2v.utils.profiler import * from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.set_config import print_config, set_config, set_parallel_config @@ -118,7 +118,7 @@ def set_progress_callback(self, callback): self.progress_callback = callback def create_clip_generator(self, clip_config: ClipConfig): - logger.info(f"Clip {clip_config.name} initializing...") + logger.info(f"Clip {clip_config.name} initializing ... ") print_config(clip_config.config_json) runner = self._init_runner(clip_config.config_json) logger.info(f"Clip {clip_config.name} initialized successfully!") @@ -129,26 +129,6 @@ def create_clip_generator(self, clip_config: ClipConfig): def generate(self): pass - def set_inputs(self, args): - args = Namespace( - seed=self.shot_cfg.seed, - prompt=self.shot_cfg.prompt, - negative_prompt=self.shot_cfg.negative_prompt, - image_path=self.shot_cfg.image_path, - audio_path=self.shot_cfg.audio_path, - save_result_path=self.shot_cfg.save_result_path, - task=self.clip_generators[name].task, - return_result_tensor=True, - overlap_frame=self.overlap_frame, - overlap_latent=self.overlap_latent, - target_shape=self.shot_cfg.target_shape, - ) - input_info = init_empty_input_info(self.clip_generators[name].task) - update_input_info_from_dict(input_info, vars(args)) - self.clip_inputs[name] = input_info - - return input_info - def run_pipeline(self, input_info): self.update_input_info(input_info) return self.generate() diff --git a/lightx2v/utils/input_info.py b/lightx2v/utils/input_info.py index d75ff2e62..c0ae288c4 100755 --- a/lightx2v/utils/input_info.py +++ b/lightx2v/utils/input_info.py @@ -86,58 +86,61 @@ class VaceInputInfo: @dataclass class S2VInputInfo: - infer_steps: int | Any = UNSET - seed: int | Any = UNSET - prompt: str | Any = UNSET - prompt_enhanced: str | Any = UNSET - negative_prompt: str | Any = UNSET - image_path: str | Any = UNSET - audio_path: str | Any = UNSET - audio_num: int | Any = UNSET - video_duration: float | Any = UNSET - with_mask: bool | Any = UNSET - return_result_tensor: bool | Any = UNSET - save_result_path: str | Any = UNSET - return_result_tensor: bool | Any = UNSET - stream_config: dict | Any = UNSET - resize_mode: str | Any = UNSET - target_shape: list | Any = UNSET + seed: int = field(default_factory=int) + prompt: str = field(default_factory=str) + prompt_enhanced: str = field(default_factory=str) + negative_prompt: str = field(default_factory=str) + image_path: str = field(default_factory=str) + audio_path: str = field(default_factory=str) + audio_num: int = field(default_factory=int) + with_mask: bool = field(default_factory=lambda: False) + save_result_path: str = field(default_factory=str) + return_result_tensor: bool = field(default_factory=lambda: False) + stream_config: dict = field(default_factory=dict) + # shape related + resize_mode: str = field(default_factory=str) + original_shape: list = field(default_factory=list) + resized_shape: list = field(default_factory=list) + latent_shape: list = field(default_factory=list) + target_shape: list = field(default_factory=list) + # prev info - overlap_frame: torch.Tensor | Any = UNSET - overlap_latent: torch.Tensor | Any = UNSET + overlap_frame: torch.Tensor = field(default_factory=lambda: None) + overlap_latent: torch.Tensor = field(default_factory=lambda: None) # input preprocess audio - audio_clip: torch.Tensor | Any = UNSET - - @classmethod - def from_args(cls, args, **overrides): - """ - Build InputInfo from argparse.Namespace (or any object with __dict__) - Priority: - args < overrides - """ - field_names = {f.name for f in fields(cls)} - data = {k: v for k, v in vars(args).items() if k in field_names} - data.update(overrides) - return cls(**data) - - def normalize_unset_to_none(self): - """ - Replace all UNSET fields with None. - Call this right before running / inference. - """ - for f in fields(self): - if getattr(self, f.name) is UNSET: - setattr(self, f.name, None) - return self + audio_clip: torch.Tensor = field(default_factory=lambda: None) @dataclass -class RS2VInputInfo(S2VInputInfo): +class RS2VInputInfo: + seed: int = field(default_factory=int) + prompt: str = field(default_factory=str) + prompt_enhanced: str = field(default_factory=str) + negative_prompt: str = field(default_factory=str) + image_path: str = field(default_factory=str) + audio_path: str = field(default_factory=str) + audio_num: int = field(default_factory=int) + with_mask: bool = field(default_factory=lambda: False) + save_result_path: str = field(default_factory=str) + return_result_tensor: bool = field(default_factory=lambda: False) + stream_config: dict = field(default_factory=dict) + # shape related + resize_mode: str = field(default_factory=str) + original_shape: list = field(default_factory=list) + resized_shape: list = field(default_factory=list) + latent_shape: list = field(default_factory=list) + target_shape: list = field(default_factory=list) + + # prev info + overlap_frame: torch.Tensor = field(default_factory=lambda: None) + overlap_latent: torch.Tensor = field(default_factory=lambda: None) + # input preprocess audio + audio_clip: torch.Tensor = field(default_factory=lambda: None) # input reference state - ref_state: int | Any = UNSET + ref_state: int = field(default_factory=int) # flags for first and last clip - is_first: bool | Any = UNSET - is_last: bool | Any = UNSET + is_first: bool = field(default_factory=lambda: False) + is_last: bool = field(default_factory=lambda: False) # Need Check @@ -280,15 +283,6 @@ class WorldPlayT2VInputInfo: action: torch.Tensor = field(default_factory=lambda: None) -def init_input_info_from_args(task, args, **overrides): - if task == "s2v": - return S2VInputInfo.from_args(args, **overrides) - elif task == "rs2v": - return RS2VInputInfo.from_args(args, **overrides) - else: - raise ValueError(f"Unsupported task: {task}") - - def init_empty_input_info(task): if task == "t2v": return T2VInputInfo() @@ -320,6 +314,68 @@ def init_empty_input_info(task): raise ValueError(f"Unsupported task: {task}") +@dataclass +class SekoTalkInputs: + infer_steps: int | Any = UNSET + seed: int | Any = UNSET + prompt: str | Any = UNSET + prompt_enhanced: str | Any = UNSET + negative_prompt: str | Any = UNSET + image_path: str | Any = UNSET + audio_path: str | Any = UNSET + audio_num: int | Any = UNSET + video_duration: float | Any = UNSET + with_mask: bool | Any = UNSET + return_result_tensor: bool | Any = UNSET + save_result_path: str | Any = UNSET + return_result_tensor: bool | Any = UNSET + stream_config: dict | Any = UNSET + + resize_mode: str | Any = UNSET + target_shape: list | Any = UNSET + + # prev info + overlap_frame: torch.Tensor | Any = UNSET + overlap_latent: torch.Tensor | Any = UNSET + # input preprocess audio + audio_clip: torch.Tensor | Any = UNSET + + # input reference state + ref_state: int | Any = UNSET + # flags for first and last clip + is_first: bool | Any = UNSET + is_last: bool | Any = UNSET + + @classmethod + def from_args(cls, args, **overrides): + """ + Build InputInfo from argparse.Namespace (or any object with __dict__) + Priority: + args < overrides + """ + field_names = {f.name for f in fields(cls)} + data = {k: v for k, v in vars(args).items() if k in field_names} + data.update(overrides) + return cls(**data) + + def normalize_unset_to_none(self): + """ + Replace all UNSET fields with None. + Call this right before running / inference. + """ + for f in fields(self): + if getattr(self, f.name) is UNSET: + setattr(self, f.name, None) + return self + + +def init_input_info_from_args(task, args, **overrides): + if task in ["s2v", "rs2v"]: + return SekoTalkInputs.from_args(args, **overrides) + else: + raise ValueError(f"Unsupported task: {task}") + + def fill_input_info_from_defaults(input_info, defaults): for key in input_info.__dataclass_fields__: if key in defaults and getattr(input_info, key) is UNSET: From 7728a546df95caeea43bf6f57210e1502b1b39ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Shankun=20Wang=20=28=E7=8E=8B=E5=96=84=E6=98=86=29?= Date: Tue, 10 Feb 2026 18:35:54 +0800 Subject: [PATCH 4/4] Remove duplicate return_result_tensor declaration --- lightx2v/utils/input_info.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightx2v/utils/input_info.py b/lightx2v/utils/input_info.py index c0ae288c4..3d227845b 100755 --- a/lightx2v/utils/input_info.py +++ b/lightx2v/utils/input_info.py @@ -326,7 +326,6 @@ class SekoTalkInputs: audio_num: int | Any = UNSET video_duration: float | Any = UNSET with_mask: bool | Any = UNSET - return_result_tensor: bool | Any = UNSET save_result_path: str | Any = UNSET return_result_tensor: bool | Any = UNSET stream_config: dict | Any = UNSET