Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions repos/hifi-gan/meldataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,33 @@
except:
pw = None

from tqdm import tqdm
def check_files(sampling_rate, segment_size, training_files):
len_training_files = len(training_files)
training_files = [x for x in training_files if os.path.exists(x)]
if (len_training_files - len(training_files)) > 0:
print(len_training_files - len(training_files), "Files don't exist (and have been removed from training)")

len_training_files = len(training_files)
training_files = [x for x in training_files if len(load_wav_to_torch(x, target_sr=sampling_rate, return_empty_on_exception=True)[0]) > segment_size]
if (len_training_files - len(training_files)) > 0:
print(len_training_files - len(training_files), "Files are too short (and have been removed from training)")
if not os.path.exists("./bad_files.txt"):
len_training_files = len(training_files)
for training_file in tqdm(training_files):
verify = len(load_wav_to_torch(training_file, target_sr=sampling_rate, return_empty_on_exception=True)[0])
if verify < segment_size:
training_files.remove(training_file)
with open("./bad_files.txt", "a") as shit_file:
shit_file.write(training_file + "\n")
if (len_training_files - len(training_files)) > 0:
print(len_training_files - len(training_files), "Files are too short (and have been removed from training)")
else:
with open("./bad_files.txt") as f:
bad_files = f.read().splitlines()
len_training_files = len(training_files)
for training_file in tqdm(training_files):
if training_file in bad_files:
training_files.remove(training_file)
if (len_training_files - len(training_files)) > 0:
print(len_training_files - len(training_files), "Files are too short (and have been removed from training)")

return training_files


def get_dataset_filelist(a, segment_size, sampling_rate):
if a.input_wavs_dir is None:
Expand Down