diff --git a/src/python/detectors/pytorch_data_loader_with_multiple_workers/pytorch_data_loader_with_multiple_workers.py b/src/python/detectors/pytorch_data_loader_with_multiple_workers/pytorch_data_loader_with_multiple_workers.py index ed8c5e3..359842c 100644 --- a/src/python/detectors/pytorch_data_loader_with_multiple_workers/pytorch_data_loader_with_multiple_workers.py +++ b/src/python/detectors/pytorch_data_loader_with_multiple_workers/pytorch_data_loader_with_multiple_workers.py @@ -37,25 +37,32 @@ def pytorch_data_loader_with_multiple_workers_noncompliant(): # {fact rule=pytorch-data-loader-with-multiple-workers@v1.0 defects=0} def pytorch_data_loader_with_multiple_workers_compliant(args): - import torch.optim - import torchvision.datasets as datasets - # Data loading code - traindir = os.path.join(args.data, 'train') - valdir = os.path.join(args.data, 'val') - normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225]) - - train_dataset = datasets.ImageFolder(traindir, imagenet_transforms) - train_sampler = torch.utils.data.distributed\ - .DistributedSampler(train_dataset) - - # Compliant: args.workers value is assigned to num_workers, - # but native python 'list/dict' is not used here to store the dataset. - train_loader = torch.utils.data.DataLoader(train_dataset, - batch_size=args.batch_size, - shuffle=(train_sampler is None), - num_workers=args.workers, - pin_memory=True, - sampler=train_sampler) + from torch.utils.data import Dataset, DataLoader + import numpy as np + import torch + + class DataIter(Dataset): + def __init__(self): + self.data_np = np.array([x for x in range(24000000)]) + + def __len__(self): + return len(self.data_np) + + def __getitem__(self, idx): + data = self.data_np[idx] + data = np.array([data], dtype=np.int64) + return torch.tensor(data) + + train_data = DataIter() + # Compliant: native python `list/dict` is not used to store the dataset + # for non zero `num_workers`. + train_loader = DataLoader(train_data, batch_size=300, + shuffle=True, + drop_last=True, + pin_memory=False, + num_workers=8) + for i, item in enumerate(train_loader): + if i % 1000 == 0: + print(i) # {/fact}