diff --git a/repos/hifi-gan/meldataset.py b/repos/hifi-gan/meldataset.py index 87dcba7..61a7a59 100644 --- a/repos/hifi-gan/meldataset.py +++ b/repos/hifi-gan/meldataset.py @@ -15,18 +15,33 @@ except: pw = None +from tqdm import tqdm def check_files(sampling_rate, segment_size, training_files): len_training_files = len(training_files) training_files = [x for x in training_files if os.path.exists(x)] if (len_training_files - len(training_files)) > 0: print(len_training_files - len(training_files), "Files don't exist (and have been removed from training)") - - len_training_files = len(training_files) - training_files = [x for x in training_files if len(load_wav_to_torch(x, target_sr=sampling_rate, return_empty_on_exception=True)[0]) > segment_size] - if (len_training_files - len(training_files)) > 0: - print(len_training_files - len(training_files), "Files are too short (and have been removed from training)") + if not os.path.exists("./bad_files.txt"): + len_training_files = len(training_files) + for training_file in tqdm(training_files): + verify = len(load_wav_to_torch(training_file, target_sr=sampling_rate, return_empty_on_exception=True)[0]) + if verify < segment_size: + training_files.remove(training_file) + with open("./bad_files.txt", "a") as shit_file: + shit_file.write(training_file + "\n") + if (len_training_files - len(training_files)) > 0: + print(len_training_files - len(training_files), "Files are too short (and have been removed from training)") + else: + with open("./bad_files.txt") as f: + bad_files = f.read().splitlines() + len_training_files = len(training_files) + for training_file in tqdm(training_files): + if training_file in bad_files: + training_files.remove(training_file) + if (len_training_files - len(training_files)) > 0: + print(len_training_files - len(training_files), "Files are too short (and have been removed from training)") + return training_files - def get_dataset_filelist(a, segment_size, sampling_rate): if a.input_wavs_dir is None: