Pytorch Dataset

Dataset & DataLoader

Dataset stores the samples and their corresponding labels, and DataLoader wraps an iterable around the Dataset to enable easy access to the samples.

Custom Dataset

必须实现三个方法_init_(), _len_()_getitem_()

自定义数据集 以散焦图像去模糊为例

Code is based on DRBNet


# list 
input_path = natsorted(glob(os.path.join("datasets/patch", 'train_c', 'source', '*.png')))

gt_path = natsorted(glob(os.path.join("datasets/patch", 'gt', 'source', '*.png')))

def MyDataset(Dataset):

    def __init_(self, input_path, gt_path):
        self.input_path = input_path
        self.gt_path = gt_path

    def __len__(self):
        return len(self.input_path)

    def __getitem__(self, idx):
        pt = read_image(self.input_c_file_path_list[index], 255.0)*2 - 1
        x = torch.FloatTensor(pt.transpose(0, 3, 1, 2).copy()).cuda()
        gt = read_image(self.gt_file_path_list[index], 255.0)*2 - 1
        y = torch.FloatTensor(gt.transpose(0, 3, 1, 2).copy()).cuda()
        return x, y


dataset = MyDatasetTwo(input_path, gt_path)

dataloader = data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

Pytorch Dataset
https://dreamerland.cn/2024/07/29/pytorch/Dataset/
作者
Silva31
发布于
2024年7月29日
许可协议