Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 54 additions & 40 deletions train_test_split.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,66 @@
import os
import random
import shutil

imgList = os.listdir('images')


#shuffling images
random.shuffle(imgList)

split = 0.2

train_path = 'custom_dataset/train'
val_path = 'custom_dataset/val'

if os.path.isdir(train_path) == False:
os.makedirs(train_path)
if os.path.isdir(val_path) == False:
os.makedirs(val_path)

imgLen = len(imgList)
print("Images in total: ", imgLen)

train_images = imgList[: int(imgLen - (imgLen*split))]
val_images = imgList[int(imgLen - (imgLen*split)):]
print("Training images: ", len(train_images))
print("Validation images: ", len(val_images))

for imgName in train_images:
og_path = os.path.join('images', imgName)
target_path = os.path.join(train_path, imgName)

import argparse

# Argument Parsing
parser = argparse.ArgumentParser(description='A python script that splits the labeled data into train/test data in Yolov5 format')
parser.add_argument('-name', help='Name of custom dataset directory', default='custom_dataset')
parser.add_argument('-testsize', help='Test split size. Expects floating point number. Default test split size is 0.2', default=0.2)
parser.add_argument('-images', help='Path to images directory', default='images')
parser.add_argument('-labels', help='Path to bounding box txt files directory', default='bbox_txt')
args = vars(parser.parse_args())

img_list = os.listdir(args['images'])

# Shuffling images
random.shuffle(img_list)

split = args['testsize']
print('# Test split size:', split)

# Creating split directory
train_images_path = os.path.join(args['name'], 'train', 'images')
train_labels_path = os.path.join(args['name'], 'train', 'labels')
val_images_path = os.path.join(args['name'], 'val', 'images')
val_labels_path = os.path.join(args['name'], 'val', 'labels')
os.makedirs(train_images_path, exist_ok = True)
os.makedirs(train_labels_path, exist_ok = True)
os.makedirs(val_images_path, exist_ok = True)
os.makedirs(val_labels_path, exist_ok = True)

img_len = len(img_list)
print("# Images in total: ", img_len)

train_images = img_list[: int(img_len - (img_len*split))]
val_images = img_list[int(img_len - (img_len*split)):]
print("# Training images: ", len(train_images))
print("# Validation images: ", len(val_images))

for img_name in train_images:
base_name, ext = os.path.splitext(img_name)

# Copy image
og_path = os.path.join(args['images'], img_name)
target_path = os.path.join(train_images_path, img_name)
shutil.copyfile(og_path, target_path)

og_txt_path = os.path.join('bbox_txt', imgName.replace('.jpg', '.txt'))
target_txt_path = os.path.join(train_path, imgName.replace('.jpg', '.txt'))

# Copy bounding box txt file
og_txt_path = os.path.join(args['labels'], base_name + '.txt')
target_txt_path = os.path.join(train_labels_path, base_name + '.txt')
shutil.copyfile(og_txt_path, target_txt_path)

for imgName in val_images:
og_path = os.path.join('images', imgName)
target_path = os.path.join(val_path, imgName)
for img_name in val_images:
base_name, ext = os.path.splitext(img_name)

# Copy image
og_path = os.path.join(args['images'], img_name)
target_path = os.path.join(val_images_path, img_name)
shutil.copyfile(og_path, target_path)

og_txt_path = os.path.join('bbox_txt', imgName.replace('.jpg', '.txt'))
target_txt_path = os.path.join(val_path, imgName.replace('.jpg', '.txt'))

# Copy bounding box txt file
og_txt_path = os.path.join(args['labels'], base_name + '.txt')
target_txt_path = os.path.join(val_labels_path, base_name + '.txt')
shutil.copyfile(og_txt_path, target_txt_path)


print("Done! ")
print("# Done!")