欢迎访问 生活随笔!

生活随笔

当前位置: 首页 > 编程资源 > 编程问答 >内容正文

编程问答

Pytorch 怎么构建自己的数据集。怎么重写官方数据集。

发布时间:2025/4/16 编程问答 14 豆豆
生活随笔 收集整理的这篇文章主要介绍了 Pytorch 怎么构建自己的数据集。怎么重写官方数据集。 小编觉得挺不错的,现在分享给大家,帮大家做个参考.

 

小白记录,大神勿扰

 

小白入门的时候,发现,现有的基本都是直接类似这样的:

trainset = datasets.MNIST('../MNIST', download=True,train=True, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])) train_loader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4)

这个download=True直接解决了一切问题,却不理解发生肾么事了。

而且由于网不好等原因,常常无法自动下载。

这个网上有一些方法,提前自己把数据下载好,放在download的那个目录下。

或者改源代码的下载链接为本地目录。

例如:https://zhuanlan.zhihu.com/p/129081723

 

有时候,大多数时候想用自己数据集,如下这样类似的写法:

class MyDataset(Dataset):def __init__(self, image_path, label_path, setid_path, train=True, transform=None):setid = scipy.io.loadmat(setid_path)labels = scipy.io.loadmat(label_path)['labels'][0]if train:trnid = setid['tstid'][0]self.labels = [labels[i - 1] - 1 for i in trnid]self.images = ['%s/image_%05d.jpg' % (image_path, i) for i in trnid]else:tstid = np.append(setid['valid'][0], setid['trnid'][0])self.labels = [labels[i - 1] - 1 for i in tstid]self.images = ['%s/image_%05d.jpg' % (image_path, i) for i in tstid]self.transform = transformdef __getitem__(self, index):label = self.labels[index]image = self.images[index]if self.transform is not None:image = self.transform(Image.open(image))return image, labeldef __len__(self):return len(self.labels)

init初始化,一般就包括加载数据啊,然后整体数据的一些基本处理之类的。数据可以来自自己定义放好的本地文件夹,也可以是自己在code之前就完成加载的numpy格式或者其他格式的数据(这时候init中就不需要加本地路径了)。

getitem,每次调用数据,其实就是调用它,后面index不要丢。内部一般就写 init之后,数据被加载之前 还需要进行的一些处理。这里,比如你要加载不一样的图像,这里return不同的就可以了。

len,就返回一个数据长度即可。

 

然后调用自己定义的数据集,MyDataset,再放到loader中,再从loader中直接拿数据就ok了,这时候拿到的数据就是一个batch一个batch的。

train_dataset = MyDataset(image_path, label_path, setid_path,train=True, transform=transforms.Compose([transforms.RandomHorizontalFlip(),transforms.RandomRotation(30),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]))train_loader = DataLoader(train_dataset, batch_size=BATCH, shuffle=True).......for batch_idx, (image, label) in enumerate(train_loader):...

好,然后怎么重载官网数据集,比如说,你载loader中,希望每次加载这样的数据,image1,image2,label

又是基于现有数据集,比如minist,那么就可以重写这个官网的数据集。本质上和完全自己定义是一回事。

示例代码如下:

class CIFAR10_(datasets.CIFAR10):"""CIFAR10 Dataset."""def __getitem__(self, index):img, target = self.data[index], self.targets[index]img = Image.fromarray(img)if self.target_transform is not None:target = self.target_transform(target)if self.transform is not None:img1 = self.transform(img)if self.train:img2 = self.transform(img)if self.train:return img1, img2, target, index

 然后改怎么使用,就怎么使用,可以自己下载好:

trainset = datasets.CIFAR10_(root='./data', train=True, download=True, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=True) testset = datasets.CIFAR10_(root='./data', train=True, download=False, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=4, drop_last=True)

这时候train loader中出来的 就是 这个样子的:

for batch_idx, (inputs1, inputs2, target, indexes) in enumerate(trainloader):...

ok

 

 

 

 

 

 

 

 

《新程序员》:云原生和全面数字化实践50位技术专家共同创作,文字、视频、音频交互阅读

总结

以上是生活随笔为你收集整理的Pytorch 怎么构建自己的数据集。怎么重写官方数据集。的全部内容,希望文章能够帮你解决所遇到的问题。

如果觉得生活随笔网站内容还不错,欢迎将生活随笔推荐给好友。