From 4344d2d940a7fd3b6bfebbe2f30e9dce326cb449 Mon Sep 17 00:00:00 2001 From: mert <3117970+CatcherInThePy@users.noreply.github.com> Date: Mon, 4 Apr 2022 15:51:42 +0300 Subject: [PATCH] Updated train_test_split.py --- train_test_split.py | 94 ++++++++++++++++++++++++++------------------- 1 file changed, 54 insertions(+), 40 deletions(-) diff --git a/train_test_split.py b/train_test_split.py index 30d3a7a..7ec37f5 100644 --- a/train_test_split.py +++ b/train_test_split.py @@ -1,52 +1,66 @@ import os import random import shutil - -imgList = os.listdir('images') - - -#shuffling images -random.shuffle(imgList) - -split = 0.2 - -train_path = 'custom_dataset/train' -val_path = 'custom_dataset/val' - -if os.path.isdir(train_path) == False: - os.makedirs(train_path) -if os.path.isdir(val_path) == False: - os.makedirs(val_path) - -imgLen = len(imgList) -print("Images in total: ", imgLen) - -train_images = imgList[: int(imgLen - (imgLen*split))] -val_images = imgList[int(imgLen - (imgLen*split)):] -print("Training images: ", len(train_images)) -print("Validation images: ", len(val_images)) - -for imgName in train_images: - og_path = os.path.join('images', imgName) - target_path = os.path.join(train_path, imgName) - +import argparse + +# Argument Parsing +parser = argparse.ArgumentParser(description='A python script that splits the labeled data into train/test data in Yolov5 format') +parser.add_argument('-name', help='Name of custom dataset directory', default='custom_dataset') +parser.add_argument('-testsize', help='Test split size. Expects floating point number. Default test split size is 0.2', default=0.2) +parser.add_argument('-images', help='Path to images directory', default='images') +parser.add_argument('-labels', help='Path to bounding box txt files directory', default='bbox_txt') +args = vars(parser.parse_args()) + +img_list = os.listdir(args['images']) + +# Shuffling images +random.shuffle(img_list) + +split = args['testsize'] +print('# Test split size:', split) + +# Creating split directory +train_images_path = os.path.join(args['name'], 'train', 'images') +train_labels_path = os.path.join(args['name'], 'train', 'labels') +val_images_path = os.path.join(args['name'], 'val', 'images') +val_labels_path = os.path.join(args['name'], 'val', 'labels') +os.makedirs(train_images_path, exist_ok = True) +os.makedirs(train_labels_path, exist_ok = True) +os.makedirs(val_images_path, exist_ok = True) +os.makedirs(val_labels_path, exist_ok = True) + +img_len = len(img_list) +print("# Images in total: ", img_len) + +train_images = img_list[: int(img_len - (img_len*split))] +val_images = img_list[int(img_len - (img_len*split)):] +print("# Training images: ", len(train_images)) +print("# Validation images: ", len(val_images)) + +for img_name in train_images: + base_name, ext = os.path.splitext(img_name) + + # Copy image + og_path = os.path.join(args['images'], img_name) + target_path = os.path.join(train_images_path, img_name) shutil.copyfile(og_path, target_path) - og_txt_path = os.path.join('bbox_txt', imgName.replace('.jpg', '.txt')) - target_txt_path = os.path.join(train_path, imgName.replace('.jpg', '.txt')) - + # Copy bounding box txt file + og_txt_path = os.path.join(args['labels'], base_name + '.txt') + target_txt_path = os.path.join(train_labels_path, base_name + '.txt') shutil.copyfile(og_txt_path, target_txt_path) -for imgName in val_images: - og_path = os.path.join('images', imgName) - target_path = os.path.join(val_path, imgName) +for img_name in val_images: + base_name, ext = os.path.splitext(img_name) + # Copy image + og_path = os.path.join(args['images'], img_name) + target_path = os.path.join(val_images_path, img_name) shutil.copyfile(og_path, target_path) - og_txt_path = os.path.join('bbox_txt', imgName.replace('.jpg', '.txt')) - target_txt_path = os.path.join(val_path, imgName.replace('.jpg', '.txt')) - + # Copy bounding box txt file + og_txt_path = os.path.join(args['labels'], base_name + '.txt') + target_txt_path = os.path.join(val_labels_path, base_name + '.txt') shutil.copyfile(og_txt_path, target_txt_path) - -print("Done! ") +print("# Done!")