Dataset类
Dataset
类是数据加载的核心组件之一。它是一个抽象类,用户需要通过继承这个类并实现其中的两个方法:__len__
和__getitem__
。
1. 数据集结构
-
数据分为训练集和测试集,训练集和测试集中分别有两个文件夹,文件夹名称为数据的类别,每个类别文件夹下有多个数据
2. 搭建框架
-
需要继承
torch.utils.data
中的Dataset
类,并重写两个魔法方法__getitem__
__len__
3.__init__
-
在初始化函数中完成对图像数据名称的获取,用于对后期数据的加载
4.__getitem__
-
getitem要根据给定的索引返回一个样本。通常会包含数据、标签,必要时还会应用数据变换
5.__len__
-
len方法用于返回加载的数据集中有多少个数据
完整代码
from torch.utils.data import Dataset
import os
from PIL import Image
class Mydata(Dataset):def __init__(self, root_dir, label_dir):self.root_dit = root_dir # 根目录 ./data/traimself.label_dir = label_dir # 类别 antsself.image_path = os.path.join(self.root_dit, self.label_dir) # './data/train/ants self.image_path_list = os.listdir(self.image_path) # 获取ants下的所有文件的名称def __getitem__(self, index):image = self.image_path_list[index] # 通过编号获取图像的名称image_item_path = os.path.join(self.image_path, image) # 拼接出图像的具体路径img = Image.open(image_item_path)label = self.label_dirreturn img, label # 返回图像数据、标签def __len__(self):return len(self.image_path_list)
两个Dataset实例求和的数据集
-
Dataset类的实例化支持求和操作,首先需要设置len方法,两个Dataset的实例的求和是将
__len__
方法中的返回的计算长度的列表作为数据集相加,从而得到新的数据集