Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions configs/seko_talk/shot/rs2v/rs2v.json
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
{
"model_cls": "seko_talk",
"task": "rs2v",
"model_path":"/data/temp/SekoTalk-v2.7_beta1-bf16-step4",
"infer_steps": 4,
"model_path":"/data/temp/SekoTalk-v2.7_beta2-bf16-step4",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model_path 不属于 rs2v 的, 建议放置到它的上层,clip 的 item 的配置中。

"target_fps": 16,
"audio_sr": 16000,
"resize_mode": "adaptive",
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
Expand All @@ -14,5 +12,16 @@
"enable_cfg": false,
"use_31_block": true,
"target_video_length": 81,
"prev_frame_length": 0
"prev_frame_length": 0,

"default_input_info":
{
"infer_steps": 4,
"resize_mode": "adaptive",
"prompt": "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze.",
"negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "assets/inputs/audio/seko_input.png",
"audio_path": "assets/inputs/audio/seko_input.mp3",
"save_result_path": "save_results/output_seko_talk_shot_rs2v.mp4"
}
}
18 changes: 13 additions & 5 deletions configs/seko_talk/shot/stream/f2v.json
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
{
"model_cls": "seko_talk",
"task": "s2v",
"model_path":"Wan2.1-i2V1202-Audio-14B-720P/",
"infer_steps": 4,
"model_path":"/data/temp/Wan2.1-i2V1202-Audio-14B-720P/",
"target_fps": 16,
"audio_sr": 16000,
"resize_mode": "adaptive",
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
Expand All @@ -28,8 +26,18 @@
"audio_adapter_cpu_offload": false,
"lora_configs": [
{
"path": "lightx2v_I2V_14B_480p_cfg_step_distill_rank32_bf16.safetensors",
"path": "/data/temp/lightx2v_I2V_14B_480p_cfg_step_distill_rank32_bf16.safetensors",
"strength": 1.0
}
]
],
"default_input_info":
{
"infer_steps": 4,
"resize_mode": "adaptive",
"prompt": "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze.",
"negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "assets/inputs/audio/seko_input.png",
"audio_path": "assets/inputs/audio/seko_input.mp3",
"save_result_path": "save_results/output_seko_talk_shot_stream.mp4"
}
}
17 changes: 13 additions & 4 deletions configs/seko_talk/shot/stream/s2v.json
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
{
"model_cls": "seko_talk",
"task": "s2v",
"model_path":"Wan2.1-R2V721-Audio-14B-720P/",
"infer_steps": 4,
"model_path":"/data/temp/Wan2.1-R2V721-Audio-14B-720P/",
"target_fps": 16,
"audio_sr": 16000,
"resize_mode": "adaptive",
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
Expand All @@ -24,5 +22,16 @@
"offload_ratio": 1,
"use_tiling_vae": true,
"audio_encoder_cpu_offload": true,
"audio_adapter_cpu_offload": false
"audio_adapter_cpu_offload": false,

"default_input_info":
{
"infer_steps": 4,
"resize_mode": "adaptive",
"prompt": "The video features a male speaking to the camera with arms spread out, a slightly furrowed brow, and a focused gaze.",
"negative_prompt": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "assets/inputs/audio/seko_input.png",
"audio_path": "assets/inputs/audio/seko_input.mp3",
"save_result_path": "save_results/output_seko_talk_shot_stream.mp4"
}
}
5 changes: 4 additions & 1 deletion lightx2v/models/runners/wan/wan_audio_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,10 @@ def run_clip(self):

def run_clip_main(self):
self.scheduler.set_audio_adapter(self.audio_adapter)
self.model.scheduler.prepare(seed=self.input_info.seed, latent_shape=self.input_info.latent_shape, image_encoder_output=self.inputs["image_encoder_output"])
self.model.scheduler.prepare(
seed=self.input_info.seed, latent_shape=self.input_info.latent_shape, infer_steps=self.input_info.infer_steps, image_encoder_output=self.inputs["image_encoder_output"]
)

if self.config.get("model_cls") == "wan2.2" and self.config["task"] in ["i2v", "s2v", "rs2v"]:
self.inputs["image_encoder_output"]["vae_encoder_out"] = None

Expand Down
2 changes: 0 additions & 2 deletions lightx2v/models/schedulers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ def __init__(self, config):
self.config = config
self.latents = None
self.step_index = 0
self.infer_steps = config["infer_steps"]
self.caching_records = [True] * config["infer_steps"]
self.flag_df = False
self.transformer_infer = None
self.infer_condition = True # cfg status
Expand Down
25 changes: 23 additions & 2 deletions lightx2v/models/schedulers/wan/audio/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,27 @@

class EulerScheduler(WanScheduler):
def __init__(self, config):
super().__init__(config)
self.config = config
self.latents = None
self.step_index = 0
self.flag_df = False
self.transformer_infer = None
self.infer_condition = True # cfg status
self.keep_latents_dtype_in_scheduler = False
self.target_video_length = self.config["target_video_length"]
self.sample_shift = self.config["sample_shift"]
if self.config["seq_parallel"]:
self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p")
else:
self.seq_p_group = None
self.patch_size = (1, 2, 2)
self.shift = 1
self.num_train_timesteps = 1000
self.disable_corrector = []
self.solver_order = 2
self.noise_pred = None
self.sample_guide_scale = self.config["sample_guide_scale"]
self.head_size = self.config["dim"] // self.config["num_heads"]

if self.config["parallel"]:
self.sp_size = self.config["parallel"].get("seq_p_size", 1)
Expand Down Expand Up @@ -72,8 +92,9 @@ def prepare_latents(self, seed, latent_shape, dtype=torch.float32):
if self.prev_latents is not None:
self.latents = (1.0 - self.mask) * self.prev_latents + self.mask * self.latents

def prepare(self, seed, latent_shape, image_encoder_output=None):
def prepare(self, seed, latent_shape, infer_steps, image_encoder_output=None):
self.prepare_latents(seed, latent_shape, dtype=torch.float32)
self.infer_steps = infer_steps
timesteps = np.linspace(self.num_train_timesteps, 0, self.infer_steps + 1, dtype=np.float32)

self.timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32, device=AI_DEVICE)
Expand Down
64 changes: 29 additions & 35 deletions lightx2v/shot_runner/rs2v_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import torchaudio as ta
from loguru import logger

from lightx2v.shot_runner.shot_base import ShotConfig, ShotPipeline, load_clip_configs
from lightx2v.shot_runner.shot_base import ShotPipeline, load_clip_configs
from lightx2v.shot_runner.utils import RS2V_SlidingWindowReader, save_audio, save_to_video
from lightx2v.utils.input_info import init_input_info_from_args
from lightx2v.utils.profiler import *
from lightx2v.utils.utils import is_main_process, seed_all, vae_to_comfyui_image

Expand All @@ -22,24 +23,27 @@ def get_reference_state_sequence(frames_per_clip=17, target_fps=16):


class ShotRS2VPipeline(ShotPipeline): # type:ignore
def __init__(self, config):
super().__init__(config)
def __init__(self, clip_configs):
super().__init__(clip_configs)

@torch.no_grad()
def generate(self):
def generate(self, args):
rs2v = self.clip_generators["rs2v_clip"]
# 获取此clip模型的配置信息
target_video_length = rs2v.config.get("target_video_length", 81)
target_fps = rs2v.config.get("target_fps", 16)
audio_sr = rs2v.config.get("audio_sr", 16000)
video_duration = rs2v.config.get("video_duration", None)
audio_per_frame = audio_sr // target_fps

# 根据 pipe 最长 overlap_len 初始化 tail buffer
# 获取用户输入信息
clip_input_info = init_input_info_from_args(rs2v.config["task"], args, infer_steps=3, video_duration=20)
# 从默认配置中补全输入信息
clip_input_info = self.check_input_info(clip_input_info, rs2v.config)

gen_video_list = []
cut_audio_list = []

audio_array, ori_sr = ta.load(self.shot_cfg.audio_path)
video_duration = clip_input_info.video_duration
audio_array, ori_sr = ta.load(clip_input_info.audio_path)
audio_array = audio_array.mean(0)
if ori_sr != audio_sr:
audio_array = ta.functional.resample(audio_array, ori_sr, audio_sr)
Expand Down Expand Up @@ -68,42 +72,42 @@ def generate(self):
is_last = True if pad_len > 0 else False

pipe = rs2v
inputs = self.clip_inputs["rs2v_clip"]
inputs.is_first = is_first
inputs.is_last = is_last
inputs.ref_state = ref_state_sq[idx % len(ref_state_sq)]
inputs.seed = inputs.seed + idx
inputs.audio_clip = audio_clip

clip_input_info.is_first = is_first
clip_input_info.is_last = is_last
clip_input_info.ref_state = ref_state_sq[idx % len(ref_state_sq)]
clip_input_info.seed = clip_input_info.seed + idx
clip_input_info.audio_clip = audio_clip
idx = idx + 1
if self.progress_callback:
self.progress_callback(idx, total_clips)

gen_clip_video, audio_clip, gen_latents = pipe.run_clip_pipeline(inputs)
gen_clip_video, audio_clip, gen_latents = pipe.run_clip_pipeline(clip_input_info)
logger.info(f"Generated rs2v clip {idx}, pad_len {pad_len}, gen_clip_video shape: {gen_clip_video.shape}, audio_clip shape: {audio_clip.shape} gen_latents shape: {gen_latents.shape}")

video_pad_len = pad_len // audio_per_frame
gen_video_list.append(gen_clip_video[:, :, : gen_clip_video.shape[2] - video_pad_len].clone())
cut_audio_list.append(audio_clip[: audio_clip.shape[0] - pad_len])
inputs.overlap_latent = gen_latents[:, -1:]
clip_input_info.overlap_latent = gen_latents[:, -1:]

gen_lvideo = torch.cat(gen_video_list, dim=2).float()
gen_lvideo = torch.clamp(gen_lvideo, -1, 1)
merge_audio = np.concatenate(cut_audio_list, axis=0).astype(np.float32)

if is_main_process() and self.shot_cfg.save_result_path:
if is_main_process() and clip_input_info.save_result_path:
out_path = os.path.join("./", "video_merge.mp4")
audio_file = os.path.join("./", "audio_merge.wav")

save_to_video(gen_lvideo, out_path, 16)
save_audio(merge_audio, audio_file, out_path, output_path=self.shot_cfg.save_result_path)
save_audio(merge_audio, audio_file, out_path, output_path=clip_input_info.save_result_path)
os.remove(out_path)
os.remove(audio_file)

return gen_lvideo, merge_audio, audio_sr

def run_pipeline(self, input_info):
self.update_input_info(input_info)
gen_lvideo, merge_audio, audio_sr = self.generate()
# input_info = self.update_input_info(input_info)
gen_lvideo, merge_audio, audio_sr = self.generate(input_info)
if isinstance(input_info, dict):
return_result_tensor = input_info.get("return_result_tensor", False)
else:
Expand All @@ -130,23 +134,13 @@ def main():
args = parser.parse_args()

seed_all(args.seed)

clip_configs = load_clip_configs(args.config_json)

shot_cfg = ShotConfig(
seed=args.seed,
image_path=args.image_path,
audio_path=args.audio_path,
prompt=args.prompt,
negative_prompt=args.negative_prompt,
save_result_path=args.save_result_path,
clip_configs=clip_configs,
target_shape=args.target_shape,
)

with ProfilingContext4DebugL1("Total Cost"):
shot_stream_pipe = ShotRS2VPipeline(shot_cfg)
shot_stream_pipe.generate()
with ProfilingContext4DebugL1("Init Pipeline Cost Time"):
shot_rs2v_pipe = ShotRS2VPipeline(clip_configs)

with ProfilingContext4DebugL1("Generate Cost Time"):
shot_rs2v_pipe.generate(args)

# Clean up distributed process group
if dist.is_initialized():
Expand Down
59 changes: 17 additions & 42 deletions lightx2v/shot_runner/shot_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,17 @@
import torch
from loguru import logger

from lightx2v.utils.input_info import init_empty_input_info, update_input_info_from_dict
from lightx2v.utils.input_info import fill_input_info_from_defaults
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.set_config import auto_calc_config, get_default_config, print_config, set_parallel_config
from lightx2v.utils.set_config import print_config, set_config, set_parallel_config
from lightx2v_platform.registry_factory import PLATFORM_DEVICE_REGISTER


@dataclass
class ClipConfig:
name: str
config_json: str | dict[str, Any]


@dataclass
class ShotConfig:
seed: int
image_path: str
audio_path: str
prompt: str
negative_prompt: str
save_result_path: str
clip_configs: list[ClipConfig]
target_shape: list[int]
config_json: dict[str, Any]


def get_config_json(config_json):
Expand Down Expand Up @@ -58,13 +46,11 @@ def load_clip_configs(main_json_path: str):
clip_configs = []
for item in clip_configs_raw:
if "config" in item:
config_json = item["config"]
config = item["config"]
else:
config_json = str(Path(lightx2v_path) / item["path"])

default_config = get_default_config()
default_config.update(get_config_json(config_json))
config = auto_calc_config(default_config)
config_json = {"config_json": config_json}
config = set_config(Namespace(**config_json))

if "parallel" in cfg: # Add parallel config to clip json
config["parallel"] = cfg["parallel"]
Expand All @@ -75,34 +61,21 @@ def load_clip_configs(main_json_path: str):


class ShotPipeline:
def __init__(self, shot_cfg: ShotConfig):
self.shot_cfg = shot_cfg
def __init__(self, clip_configs: list[ClipConfig]):
# self.clip_configs = clip_configs
self.clip_generators = {}
self.clip_inputs = {}
self.overlap_frame = None
self.overlap_latent = None
self.progress_callback = None

for clip_config in shot_cfg.clip_configs:
for clip_config in clip_configs:
name = clip_config.name
self.clip_generators[name] = self.create_clip_generator(clip_config)

args = Namespace(
seed=self.shot_cfg.seed,
prompt=self.shot_cfg.prompt,
negative_prompt=self.shot_cfg.negative_prompt,
image_path=self.shot_cfg.image_path,
audio_path=self.shot_cfg.audio_path,
save_result_path=self.shot_cfg.save_result_path,
task=self.clip_generators[name].task,
return_result_tensor=True,
overlap_frame=self.overlap_frame,
overlap_latent=self.overlap_latent,
target_shape=self.shot_cfg.target_shape,
)
input_info = init_empty_input_info(self.clip_generators[name].task)
update_input_info_from_dict(input_info, vars(args))
self.clip_inputs[name] = input_info
def check_input_info(self, user_input_info, clip_config):
default_input_info = clip_config.get("default_input_info", None)
if default_input_info is not None:
fill_input_info_from_defaults(user_input_info, default_input_info)
return user_input_info.normalize_unset_to_none()

def _input_data_to_dict(self, input_data):
if isinstance(input_data, dict):
Expand Down Expand Up @@ -145,9 +118,11 @@ def set_progress_callback(self, callback):
self.progress_callback = callback

def create_clip_generator(self, clip_config: ClipConfig):
logger.info(f"Clip {clip_config.name} initializing ... ")
print_config(clip_config.config_json)
runner = self._init_runner(clip_config.config_json)
logger.info(f"Clip {clip_config.name} initialized successfully!")
print_config(clip_config.config_json)

return runner

@torch.no_grad()
Expand Down
Loading