From 8aa805099a4c5e0e2f74fd413196552aa73e17e3 Mon Sep 17 00:00:00 2001 From: Mikkel Garcia Date: Sat, 15 Jul 2023 14:00:18 -0600 Subject: [PATCH] add --format for mp4 support --- README.md | 4 ++++ scripts/animate.py | 8 +++++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 09b6353d..bfaa4edc 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,10 @@ conda env create -f environment.yaml conda activate animatediff ``` +if using --format mp4 + +`pip install "imageio[ffmpeg]"` + ### Download Base T2I & Motion Module Checkpoints We provide two versions of our Motion Module, which are trained on stable-diffusion-v1-4 and finetuned on v1-5 seperately. It's recommanded to try both of them for best results. diff --git a/scripts/animate.py b/scripts/animate.py index 8bb5dd74..db928b7d 100644 --- a/scripts/animate.py +++ b/scripts/animate.py @@ -33,6 +33,7 @@ def main(args): time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") savedir = f"samples/{Path(args.config).stem}-{time_str}" + extension = args.format os.makedirs(savedir) inference_config = OmegaConf.load(args.inference_config) @@ -134,13 +135,13 @@ def main(args): samples.append(sample) prompt = "-".join((prompt.replace("/", "").split(" ")[:10])) - save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{prompt}.gif") - print(f"save to {savedir}/sample/{prompt}.gif") + save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{prompt}.{extension}") + print(f"save to {savedir}/sample/{prompt}.{extension}") sample_idx += 1 samples = torch.concat(samples) - save_videos_grid(samples, f"{savedir}/sample.gif", n_rows=4) + save_videos_grid(samples, f"{savedir}/sample.{extension}", n_rows=4) OmegaConf.save(config, f"{savedir}/config.yaml") @@ -150,6 +151,7 @@ def main(args): parser.add_argument("--pretrained_model_path", type=str, default="models/StableDiffusion/stable-diffusion-v1-5",) parser.add_argument("--inference_config", type=str, default="configs/inference/inference.yaml") parser.add_argument("--config", type=str, required=True) + parser.add_argument("--format", type=str, default="gif") parser.add_argument("--L", type=int, default=16 ) parser.add_argument("--W", type=int, default=512)