1 year ago
#371020
Kyda
How to change my targets in the dataset after random_split?
I have a dataset for train and test as follows,
dataset['train'], dataset['test'] = torch.utils.data.random_split(dataset_all, [num_train,num_test],
generator=torch.Generator().manual_seed(random_seed))
Is there any good way to change targets and retrieve certain dataset by providing index? Right now, I am using list this way only to get dataset for labels==0
dataloader['train'] = torch.utils.data.DataLoader(dataset['train'], batch_size=len(dataset['train']), num_workers=4)
inputs, labels = next(iter(dataloader['train']))
x_train = inputs[np.where(labels==0)]
y_train = labels[np.where(labels==0)]
data_train = My_Dataset(x_train, y_train, transform=None)
This way takes lots of time and memory when the size of dataset is too large.
pytorch
dataset
dataloader
0 Answers
Your Answer