diff --git a/configs/seko_talk/shot/rs2v/rs2v.json b/configs/seko_talk/shot/rs2v/rs2v.json index c071ffae..8c766f0b 100644 --- a/configs/seko_talk/shot/rs2v/rs2v.json +++ b/configs/seko_talk/shot/rs2v/rs2v.json @@ -1,11 +1,9 @@ { "model_cls": "seko_talk", "task": "rs2v", - "model_path":"/data/temp/SekoTalk-v2.7_beta1-bf16-step4", - "infer_steps": 4, + "model_path":"/data/temp/SekoTalk-v2.7_beta2-bf16-step4", "target_fps": 16, "audio_sr": 16000, - "resize_mode": "adaptive", "self_attn_1_type": "flash_attn3", "cross_attn_1_type": "flash_attn3", "cross_attn_2_type": "flash_attn3", @@ -14,5 +12,16 @@ "enable_cfg": false, "use_31_block": true, "target_video_length": 81, - "prev_frame_length": 0 + "prev_frame_length": 0, + + "default_input_info": + { + "infer_steps": 4, + "resize_mode": "adaptive", + "prompt": "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze.", + "negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + "image_path": "assets/inputs/audio/seko_input.png", + "audio_path": "assets/inputs/audio/seko_input.mp3", + "save_result_path": "save_results/output_seko_talk_shot_rs2v.mp4" + } } diff --git a/configs/seko_talk/shot/stream/f2v.json b/configs/seko_talk/shot/stream/f2v.json index 73e5f548..ce6558f7 100644 --- a/configs/seko_talk/shot/stream/f2v.json +++ b/configs/seko_talk/shot/stream/f2v.json @@ -1,11 +1,9 @@ { "model_cls": "seko_talk", "task": "s2v", - "model_path":"Wan2.1-i2V1202-Audio-14B-720P/", - "infer_steps": 4, + "model_path":"/data/temp/Wan2.1-i2V1202-Audio-14B-720P/", "target_fps": 16, "audio_sr": 16000, - "resize_mode": "adaptive", "self_attn_1_type": "flash_attn3", "cross_attn_1_type": "flash_attn3", "cross_attn_2_type": "flash_attn3", @@ -28,8 +26,18 @@ "audio_adapter_cpu_offload": false, "lora_configs": [ { - "path": "lightx2v_I2V_14B_480p_cfg_step_distill_rank32_bf16.safetensors", + "path": "/data/temp/lightx2v_I2V_14B_480p_cfg_step_distill_rank32_bf16.safetensors", "strength": 1.0 } - ] + ], + "default_input_info": + { + "infer_steps": 4, + "resize_mode": "adaptive", + "prompt": "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze.", + "negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + "image_path": "assets/inputs/audio/seko_input.png", + "audio_path": "assets/inputs/audio/seko_input.mp3", + "save_result_path": "save_results/output_seko_talk_shot_stream.mp4" + } } diff --git a/configs/seko_talk/shot/stream/s2v.json b/configs/seko_talk/shot/stream/s2v.json index bf64bcdd..627dbc34 100644 --- a/configs/seko_talk/shot/stream/s2v.json +++ b/configs/seko_talk/shot/stream/s2v.json @@ -1,11 +1,9 @@ { "model_cls": "seko_talk", "task": "s2v", - "model_path":"Wan2.1-R2V721-Audio-14B-720P/", - "infer_steps": 4, + "model_path":"/data/temp/Wan2.1-R2V721-Audio-14B-720P/", "target_fps": 16, "audio_sr": 16000, - "resize_mode": "adaptive", "self_attn_1_type": "flash_attn3", "cross_attn_1_type": "flash_attn3", "cross_attn_2_type": "flash_attn3", @@ -24,5 +22,16 @@ "offload_ratio": 1, "use_tiling_vae": true, "audio_encoder_cpu_offload": true, - "audio_adapter_cpu_offload": false + "audio_adapter_cpu_offload": false, + + "default_input_info": + { + "infer_steps": 4, + "resize_mode": "adaptive", + "prompt": "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze.", + "negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", + "image_path": "assets/inputs/audio/seko_input.png", + "audio_path": "assets/inputs/audio/seko_input.mp3", + "save_result_path": "save_results/output_seko_talk_shot_stream.mp4" + } } diff --git a/lightx2v/models/runners/wan/wan_audio_runner.py b/lightx2v/models/runners/wan/wan_audio_runner.py index 25425715..1f2af628 100755 --- a/lightx2v/models/runners/wan/wan_audio_runner.py +++ b/lightx2v/models/runners/wan/wan_audio_runner.py @@ -817,7 +817,10 @@ def run_clip(self): def run_clip_main(self): self.scheduler.set_audio_adapter(self.audio_adapter) - self.model.scheduler.prepare(seed=self.input_info.seed, latent_shape=self.input_info.latent_shape, image_encoder_output=self.inputs["image_encoder_output"]) + self.model.scheduler.prepare( + seed=self.input_info.seed, latent_shape=self.input_info.latent_shape, infer_steps=self.input_info.infer_steps, image_encoder_output=self.inputs["image_encoder_output"] + ) + if self.config.get("model_cls") == "wan2.2" and self.config["task"] in ["i2v", "s2v", "rs2v"]: self.inputs["image_encoder_output"]["vae_encoder_out"] = None diff --git a/lightx2v/models/schedulers/scheduler.py b/lightx2v/models/schedulers/scheduler.py index a14403d0..b29786f4 100755 --- a/lightx2v/models/schedulers/scheduler.py +++ b/lightx2v/models/schedulers/scheduler.py @@ -6,8 +6,6 @@ def __init__(self, config): self.config = config self.latents = None self.step_index = 0 - self.infer_steps = config["infer_steps"] - self.caching_records = [True] * config["infer_steps"] self.flag_df = False self.transformer_infer = None self.infer_condition = True # cfg status diff --git a/lightx2v/models/schedulers/wan/audio/scheduler.py b/lightx2v/models/schedulers/wan/audio/scheduler.py index 4fcb584e..2c527722 100755 --- a/lightx2v/models/schedulers/wan/audio/scheduler.py +++ b/lightx2v/models/schedulers/wan/audio/scheduler.py @@ -12,7 +12,27 @@ class EulerScheduler(WanScheduler): def __init__(self, config): - super().__init__(config) + self.config = config + self.latents = None + self.step_index = 0 + self.flag_df = False + self.transformer_infer = None + self.infer_condition = True # cfg status + self.keep_latents_dtype_in_scheduler = False + self.target_video_length = self.config["target_video_length"] + self.sample_shift = self.config["sample_shift"] + if self.config["seq_parallel"]: + self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p") + else: + self.seq_p_group = None + self.patch_size = (1, 2, 2) + self.shift = 1 + self.num_train_timesteps = 1000 + self.disable_corrector = [] + self.solver_order = 2 + self.noise_pred = None + self.sample_guide_scale = self.config["sample_guide_scale"] + self.head_size = self.config["dim"] // self.config["num_heads"] if self.config["parallel"]: self.sp_size = self.config["parallel"].get("seq_p_size", 1) @@ -72,8 +92,9 @@ def prepare_latents(self, seed, latent_shape, dtype=torch.float32): if self.prev_latents is not None: self.latents = (1.0 - self.mask) * self.prev_latents + self.mask * self.latents - def prepare(self, seed, latent_shape, image_encoder_output=None): + def prepare(self, seed, latent_shape, infer_steps, image_encoder_output=None): self.prepare_latents(seed, latent_shape, dtype=torch.float32) + self.infer_steps = infer_steps timesteps = np.linspace(self.num_train_timesteps, 0, self.infer_steps + 1, dtype=np.float32) self.timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=AI_DEVICE) diff --git a/lightx2v/shot_runner/rs2v_infer.py b/lightx2v/shot_runner/rs2v_infer.py index ed35f9e3..dc465468 100755 --- a/lightx2v/shot_runner/rs2v_infer.py +++ b/lightx2v/shot_runner/rs2v_infer.py @@ -6,8 +6,9 @@ import torchaudio as ta from loguru import logger -from lightx2v.shot_runner.shot_base import ShotConfig, ShotPipeline, load_clip_configs +from lightx2v.shot_runner.shot_base import ShotPipeline, load_clip_configs from lightx2v.shot_runner.utils import RS2V_SlidingWindowReader, save_audio, save_to_video +from lightx2v.utils.input_info import init_input_info_from_args from lightx2v.utils.profiler import * from lightx2v.utils.utils import is_main_process, seed_all, vae_to_comfyui_image @@ -22,24 +23,27 @@ def get_reference_state_sequence(frames_per_clip=17, target_fps=16): class ShotRS2VPipeline(ShotPipeline): # type:ignore - def __init__(self, config): - super().__init__(config) + def __init__(self, clip_configs): + super().__init__(clip_configs) @torch.no_grad() - def generate(self): + def generate(self, args): rs2v = self.clip_generators["rs2v_clip"] + # 获取此clip模型的配置信息 target_video_length = rs2v.config.get("target_video_length", 81) target_fps = rs2v.config.get("target_fps", 16) audio_sr = rs2v.config.get("audio_sr", 16000) - video_duration = rs2v.config.get("video_duration", None) audio_per_frame = audio_sr // target_fps - # 根据 pipe 最长 overlap_len 初始化 tail buffer + # 获取用户输入信息 + clip_input_info = init_input_info_from_args(rs2v.config["task"], args, infer_steps=3, video_duration=20) + # 从默认配置中补全输入信息 + clip_input_info = self.check_input_info(clip_input_info, rs2v.config) gen_video_list = [] cut_audio_list = [] - - audio_array, ori_sr = ta.load(self.shot_cfg.audio_path) + video_duration = clip_input_info.video_duration + audio_array, ori_sr = ta.load(clip_input_info.audio_path) audio_array = audio_array.mean(0) if ori_sr != audio_sr: audio_array = ta.functional.resample(audio_array, ori_sr, audio_sr) @@ -68,42 +72,42 @@ def generate(self): is_last = True if pad_len > 0 else False pipe = rs2v - inputs = self.clip_inputs["rs2v_clip"] - inputs.is_first = is_first - inputs.is_last = is_last - inputs.ref_state = ref_state_sq[idx % len(ref_state_sq)] - inputs.seed = inputs.seed + idx - inputs.audio_clip = audio_clip + + clip_input_info.is_first = is_first + clip_input_info.is_last = is_last + clip_input_info.ref_state = ref_state_sq[idx % len(ref_state_sq)] + clip_input_info.seed = clip_input_info.seed + idx + clip_input_info.audio_clip = audio_clip idx = idx + 1 if self.progress_callback: self.progress_callback(idx, total_clips) - gen_clip_video, audio_clip, gen_latents = pipe.run_clip_pipeline(inputs) + gen_clip_video, audio_clip, gen_latents = pipe.run_clip_pipeline(clip_input_info) logger.info(f"Generated rs2v clip {idx}, pad_len {pad_len}, gen_clip_video shape: {gen_clip_video.shape}, audio_clip shape: {audio_clip.shape} gen_latents shape: {gen_latents.shape}") video_pad_len = pad_len // audio_per_frame gen_video_list.append(gen_clip_video[:, :, : gen_clip_video.shape[2] - video_pad_len].clone()) cut_audio_list.append(audio_clip[: audio_clip.shape[0] - pad_len]) - inputs.overlap_latent = gen_latents[:, -1:] + clip_input_info.overlap_latent = gen_latents[:, -1:] gen_lvideo = torch.cat(gen_video_list, dim=2).float() gen_lvideo = torch.clamp(gen_lvideo, -1, 1) merge_audio = np.concatenate(cut_audio_list, axis=0).astype(np.float32) - if is_main_process() and self.shot_cfg.save_result_path: + if is_main_process() and clip_input_info.save_result_path: out_path = os.path.join("./", "video_merge.mp4") audio_file = os.path.join("./", "audio_merge.wav") save_to_video(gen_lvideo, out_path, 16) - save_audio(merge_audio, audio_file, out_path, output_path=self.shot_cfg.save_result_path) + save_audio(merge_audio, audio_file, out_path, output_path=clip_input_info.save_result_path) os.remove(out_path) os.remove(audio_file) return gen_lvideo, merge_audio, audio_sr def run_pipeline(self, input_info): - self.update_input_info(input_info) - gen_lvideo, merge_audio, audio_sr = self.generate() + # input_info = self.update_input_info(input_info) + gen_lvideo, merge_audio, audio_sr = self.generate(input_info) if isinstance(input_info, dict): return_result_tensor = input_info.get("return_result_tensor", False) else: @@ -130,23 +134,13 @@ def main(): args = parser.parse_args() seed_all(args.seed) - clip_configs = load_clip_configs(args.config_json) - shot_cfg = ShotConfig( - seed=args.seed, - image_path=args.image_path, - audio_path=args.audio_path, - prompt=args.prompt, - negative_prompt=args.negative_prompt, - save_result_path=args.save_result_path, - clip_configs=clip_configs, - target_shape=args.target_shape, - ) - - with ProfilingContext4DebugL1("Total Cost"): - shot_stream_pipe = ShotRS2VPipeline(shot_cfg) - shot_stream_pipe.generate() + with ProfilingContext4DebugL1("Init Pipeline Cost Time"): + shot_rs2v_pipe = ShotRS2VPipeline(clip_configs) + + with ProfilingContext4DebugL1("Generate Cost Time"): + shot_rs2v_pipe.generate(args) # Clean up distributed process group if dist.is_initialized(): diff --git a/lightx2v/shot_runner/shot_base.py b/lightx2v/shot_runner/shot_base.py index 5a477ed6..eee41ea7 100755 --- a/lightx2v/shot_runner/shot_base.py +++ b/lightx2v/shot_runner/shot_base.py @@ -7,29 +7,17 @@ import torch from loguru import logger -from lightx2v.utils.input_info import init_empty_input_info, update_input_info_from_dict +from lightx2v.utils.input_info import fill_input_info_from_defaults from lightx2v.utils.profiler import * from lightx2v.utils.registry_factory import RUNNER_REGISTER -from lightx2v.utils.set_config import auto_calc_config, get_default_config, print_config, set_parallel_config +from lightx2v.utils.set_config import print_config, set_config, set_parallel_config from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER @dataclass class ClipConfig: name: str - config_json: str | dict[str, Any] - - -@dataclass -class ShotConfig: - seed: int - image_path: str - audio_path: str - prompt: str - negative_prompt: str - save_result_path: str - clip_configs: list[ClipConfig] - target_shape: list[int] + config_json: dict[str, Any] def get_config_json(config_json): @@ -58,13 +46,11 @@ def load_clip_configs(main_json_path: str): clip_configs = [] for item in clip_configs_raw: if "config" in item: - config_json = item["config"] + config = item["config"] else: config_json = str(Path(lightx2v_path) / item["path"]) - - default_config = get_default_config() - default_config.update(get_config_json(config_json)) - config = auto_calc_config(default_config) + config_json = {"config_json": config_json} + config = set_config(Namespace(**config_json)) if "parallel" in cfg: # Add parallel config to clip json config["parallel"] = cfg["parallel"] @@ -75,34 +61,21 @@ def load_clip_configs(main_json_path: str): class ShotPipeline: - def __init__(self, shot_cfg: ShotConfig): - self.shot_cfg = shot_cfg + def __init__(self, clip_configs: list[ClipConfig]): + # self.clip_configs = clip_configs self.clip_generators = {} self.clip_inputs = {} - self.overlap_frame = None - self.overlap_latent = None self.progress_callback = None - for clip_config in shot_cfg.clip_configs: + for clip_config in clip_configs: name = clip_config.name self.clip_generators[name] = self.create_clip_generator(clip_config) - args = Namespace( - seed=self.shot_cfg.seed, - prompt=self.shot_cfg.prompt, - negative_prompt=self.shot_cfg.negative_prompt, - image_path=self.shot_cfg.image_path, - audio_path=self.shot_cfg.audio_path, - save_result_path=self.shot_cfg.save_result_path, - task=self.clip_generators[name].task, - return_result_tensor=True, - overlap_frame=self.overlap_frame, - overlap_latent=self.overlap_latent, - target_shape=self.shot_cfg.target_shape, - ) - input_info = init_empty_input_info(self.clip_generators[name].task) - update_input_info_from_dict(input_info, vars(args)) - self.clip_inputs[name] = input_info + def check_input_info(self, user_input_info, clip_config): + default_input_info = clip_config.get("default_input_info", None) + if default_input_info is not None: + fill_input_info_from_defaults(user_input_info, default_input_info) + return user_input_info.normalize_unset_to_none() def _input_data_to_dict(self, input_data): if isinstance(input_data, dict): @@ -145,9 +118,11 @@ def set_progress_callback(self, callback): self.progress_callback = callback def create_clip_generator(self, clip_config: ClipConfig): + logger.info(f"Clip {clip_config.name} initializing ... ") + print_config(clip_config.config_json) runner = self._init_runner(clip_config.config_json) logger.info(f"Clip {clip_config.name} initialized successfully!") - print_config(clip_config.config_json) + return runner @torch.no_grad() diff --git a/lightx2v/shot_runner/stream_infer.py b/lightx2v/shot_runner/stream_infer.py index 834bc0c5..1e9cd4d1 100755 --- a/lightx2v/shot_runner/stream_infer.py +++ b/lightx2v/shot_runner/stream_infer.py @@ -7,8 +7,9 @@ import torchaudio as ta from loguru import logger -from lightx2v.shot_runner.shot_base import ShotConfig, ShotPipeline, load_clip_configs +from lightx2v.shot_runner.shot_base import ShotPipeline, load_clip_configs from lightx2v.shot_runner.utils import SlidingWindowReader, save_audio, save_to_video +from lightx2v.utils.input_info import init_input_info_from_args from lightx2v.utils.profiler import * from lightx2v.utils.utils import seed_all @@ -18,20 +19,32 @@ def __init__(self, config): super().__init__(config) @torch.no_grad() - def generate(self): + def generate(self, args): s2v = self.clip_generators["s2v_clip"] # s2v一致性强,动态相应差 f2v = self.clip_generators["f2v_clip"] # f2v一致性差,动态响应强 # 根据 pipe 最长 overlap_len 初始化 tail buffer - self.max_tail_len = max(s2v.prev_frame_length, f2v.prev_frame_length) + self.max_tail_len = max(s2v.config.get("prev_frame_length", None), f2v.config.get("prev_frame_length", None)) + model_fps = s2v.config.get("target_fps", 16) + model_sr = s2v.config.get("audio_sr", 16000) + + # 获取用户输入信息 + s2v_input_info = init_input_info_from_args(s2v.config["task"], args, infer_steps=3) + f2v_input_info = init_input_info_from_args(f2v.config["task"], args) + # 从默认配置中补全输入信息 + s2v_input_info = self.check_input_info(s2v_input_info, s2v.config) + f2v_input_info = self.check_input_info(f2v_input_info, f2v.config) + + assert s2v_input_info.audio_path == f2v_input_info.audio_path, "s2v and f2v must use the same audio input" + self.global_tail_video = None gen_video_list = [] cut_audio_list = [] - audio_array, ori_sr = ta.load(self.shot_cfg.audio_path) + audio_array, ori_sr = ta.load(args.audio_path) audio_array = audio_array.mean(0) - if ori_sr != 16000: - audio_array = ta.functional.resample(audio_array, ori_sr, 16000) + if ori_sr != model_sr: + audio_array = ta.functional.resample(audio_array, ori_sr, model_sr) audio_reader = SlidingWindowReader(audio_array, frame_len=33) # Demo 交替生成 clip @@ -44,21 +57,22 @@ def generate(self): if i % 2 == 0: pipe = s2v - inputs = self.clip_inputs["s2v_clip"] + inputs = s2v_input_info else: pipe = f2v - inputs = self.clip_inputs["f2v_clip"] + inputs = f2v_input_info inputs.prompt = "A man speaks to the camera with a slightly furrowed brow and focused gaze. He raises both hands upward in powerful, emphatic gestures. " # 添加动作提示 inputs.seed = inputs.seed + i # 不同 clip 使用不同随机种子 inputs.audio_clip = audio_clip i = i + 1 + # if i % 4 == 0: + # inputs.infer_steps = 2#s2v 一半时间用2步推理 + if self.global_tail_video is not None: # 根据当前 pipe 需要多少 overlap_len 来裁剪 tail inputs.overlap_frame = self.global_tail_video[:, :, -pipe.prev_frame_length :] - gen_clip_video, audio_clip, _ = pipe.run_clip_pipeline(inputs) - aligned_len = gen_clip_video.shape[2] - overlap gen_video_list.append(gen_clip_video[:, :, :aligned_len]) cut_audio_list.append(audio_clip[: aligned_len * audio_reader.audio_per_frame]) @@ -72,8 +86,8 @@ def generate(self): out_path = os.path.join("./", "video_merge.mp4") audio_file = os.path.join("./", "audio_merge.wav") - save_to_video(gen_lvideo, out_path, 16) - save_audio(merge_audio, audio_file, out_path, output_path=self.shot_cfg.save_result_path) + save_to_video(gen_lvideo, out_path, model_fps) + save_audio(merge_audio, audio_file, out_path, output_path=args.save_result_path) os.remove(out_path) os.remove(audio_file) @@ -92,23 +106,13 @@ def main(): args = parser.parse_args() seed_all(args.seed) - clip_configs = load_clip_configs(args.config_json) - shot_cfg = ShotConfig( - seed=args.seed, - image_path=args.image_path, - audio_path=args.audio_path, - prompt=args.prompt, - negative_prompt=args.negative_prompt, - save_result_path=args.save_result_path, - clip_configs=clip_configs, - target_shape=args.target_shape, - ) - - with ProfilingContext4DebugL1("Total Cost"): - shot_stream_pipe = ShotStreamPipeline(shot_cfg) - shot_stream_pipe.generate() + with ProfilingContext4DebugL1("Init Pipeline Cost Time"): + shot_stream_pipe = ShotStreamPipeline(clip_configs) + + with ProfilingContext4DebugL1("Generate Cost Time"): + shot_stream_pipe.generate(args) # Clean up distributed process group if dist.is_initialized(): diff --git a/lightx2v/utils/input_info.py b/lightx2v/utils/input_info.py index 3ea8c336..3d227845 100755 --- a/lightx2v/utils/input_info.py +++ b/lightx2v/utils/input_info.py @@ -1,9 +1,18 @@ import inspect -from dataclasses import dataclass, field +from dataclasses import dataclass, field, fields +from typing import Any import torch +class _UnsetType: + def __repr__(self): + return "UNSET" + + +UNSET = _UnsetType() + + @dataclass class T2VInputInfo: seed: int = field(default_factory=int) @@ -305,6 +314,73 @@ def init_empty_input_info(task): raise ValueError(f"Unsupported task: {task}") +@dataclass +class SekoTalkInputs: + infer_steps: int | Any = UNSET + seed: int | Any = UNSET + prompt: str | Any = UNSET + prompt_enhanced: str | Any = UNSET + negative_prompt: str | Any = UNSET + image_path: str | Any = UNSET + audio_path: str | Any = UNSET + audio_num: int | Any = UNSET + video_duration: float | Any = UNSET + with_mask: bool | Any = UNSET + save_result_path: str | Any = UNSET + return_result_tensor: bool | Any = UNSET + stream_config: dict | Any = UNSET + + resize_mode: str | Any = UNSET + target_shape: list | Any = UNSET + + # prev info + overlap_frame: torch.Tensor | Any = UNSET + overlap_latent: torch.Tensor | Any = UNSET + # input preprocess audio + audio_clip: torch.Tensor | Any = UNSET + + # input reference state + ref_state: int | Any = UNSET + # flags for first and last clip + is_first: bool | Any = UNSET + is_last: bool | Any = UNSET + + @classmethod + def from_args(cls, args, **overrides): + """ + Build InputInfo from argparse.Namespace (or any object with __dict__) + Priority: + args < overrides + """ + field_names = {f.name for f in fields(cls)} + data = {k: v for k, v in vars(args).items() if k in field_names} + data.update(overrides) + return cls(**data) + + def normalize_unset_to_none(self): + """ + Replace all UNSET fields with None. + Call this right before running / inference. + """ + for f in fields(self): + if getattr(self, f.name) is UNSET: + setattr(self, f.name, None) + return self + + +def init_input_info_from_args(task, args, **overrides): + if task in ["s2v", "rs2v"]: + return SekoTalkInputs.from_args(args, **overrides) + else: + raise ValueError(f"Unsupported task: {task}") + + +def fill_input_info_from_defaults(input_info, defaults): + for key in input_info.__dataclass_fields__: + if key in defaults and getattr(input_info, key) is UNSET: + setattr(input_info, key, defaults[key]) + + def update_input_info_from_dict(input_info, data): for key in input_info.__dataclass_fields__: if key in data: diff --git a/lightx2v/utils/set_config.py b/lightx2v/utils/set_config.py index db1d7816..3017c9f6 100755 --- a/lightx2v/utils/set_config.py +++ b/lightx2v/utils/set_config.py @@ -48,6 +48,8 @@ def auto_calc_config(config): config_json = json.load(f) config.update(config_json) + assert os.path.exists(config["model_path"]), f"Model path not found: {config['model_path']}" + if config["model_cls"] in ["hunyuan_video_1.5", "hunyuan_video_1.5_distill"]: # Special config for hunyuan video 1.5 model folder structure config["transformer_model_path"] = os.path.join(config["model_path"], "transformer", config["transformer_model_name"]) # transformer_model_name: [480p_t2v, 480p_i2v, 720p_t2v, 720p_i2v] if os.path.exists(os.path.join(config["transformer_model_path"], "config.json")): @@ -89,7 +91,6 @@ def auto_calc_config(config): elif os.path.exists(os.path.join(config["model_path"], "transformer", "config.json")): with open(os.path.join(config["model_path"], "transformer", "config.json"), "r") as f: model_config = json.load(f) - if config["model_cls"] == "z_image": # https://huggingface.co/Tongyi-MAI/Z-Image-Turbo/blob/main/transformer/config.json z_image_patch_size = model_config.pop("all_patch_size", [2])