当前位置:网站首页>Pytorch——数据加载和处理
Pytorch——数据加载和处理
2022-04-23 05:43:00 【左小田^O^】
Pytorch中进行数据处理的工具
scikit-image:用于图像的IO和变换
pandas:用于更容易地进行csv解析
从此处下载数据集。
数据存于“data / faces /”的目录中。这个数据集实际上是imagenet数据集标
注为face的图片当中在 dlib 面部检测 (dlib’s pose estimation) 表现良好的图片。我们要处理的是一个面部姿态的数据集。如下图
数据集类
torch.utils.data.Dataset
表示数据集的抽象类。
自定义数据集应继承Dataset并覆盖以下方法 * __len__
实现 len(dataset) 返还数据集的尺寸。 * __getitem__
用来获取一些索引数据,例如 dataset[i] 中的(i)。
建立数据集类
为面部数据集创建一个数据集类。我们将在 __init__
中读取csv的文件内容,在 __getitem__
中读取图片。
我们的数据样本将按这样一个字典 {‘image’: image, ‘landmarks’: landmarks} 组织。
我们的数据集类将添加一个可选参数 transform
以方便对样本进行预处理。 init 方法如下图所示:
class FaceLandmarksDataset(Dataset):
"""面部标记数据集."""
def __init__(self, csv_file, root_dir, transform=None):
""" csv_file(string):带注释的csv文件的路径。 root_dir(string):包含所有图像的目录。 transform(callable, optional):一个样本上的可用的可选变换 """
self.landmarks_frame = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.landmarks_frame)
def __getitem__(self, idx):
img_name = os.path.join(self.root_dir,self.landmarks_frame.iloc[idx, 0])
image = io.imread(img_name)
landmarks = self.landmarks_frame.iloc[idx, 1:]
landmarks = np.array([landmarks])
landmarks = landmarks.astype('float').reshape(-1, 2)
sample = {
'image': image, 'landmarks': landmarks}
if self.transform:
sample = self.transform(sample)
return sample
数据可视化
face_dataset = FaceLandmarksDataset(csv_file='data/faces/face_landmarks.csv',root_dir='data/faces/')
fig = plt.figure()
for i in range(len(face_dataset)):
sample = face_dataset[i]
print(i, sample['image'].shape, sample['landmarks'].shape)
ax = plt.subplot(1, 4, i + 1)
plt.tight_layout()
ax.set_title('Sample #{}'.format(i))
ax.axis('off')
show_landmarks(**sample)
if i == 3:
plt.show()
break
数据变换
通过上面的例子我们会发现图片并不是同样的尺寸。
绝大多数神经网络都假定图片的尺寸相同。
因此我们需要做一些预处理。让我们创建三个转换:
* Rescale
:缩放图片
* RandomCrop
:对图片进行随机裁剪。
* ToTensor
:这是一种数据增强操作,把numpy格式图片转为torch格式图片 (我们需要交换坐标轴).
我们会把它们写成可调用的类的形式而不是简单的函数,这样就不需要每次调用时传递一遍参数。我们只需要实现 call 方法,必 要的时候实现 init 方法。我们可以这样调用这些转换:
版权声明
本文为[左小田^O^]所创,转载请带上原文链接,感谢
https://blog.csdn.net/qq_45802081/article/details/115459764
边栏推荐
- io. lettuce. core. RedisCommandExecutionException: ERR wrong number of arguments for ‘auth‘ command
- Anaconda
- Map对象 map.get(key)
- DWSurvey是一个开源的调查问卷系统。解决无法运行问题,修改bug。
- Pytorch学习记录(十二):学习率衰减+正则化
- interviewter:介绍一下MySQL日期函数
- Navicate连接oracle(11g)时ORA:28547 Connection to server failed probable Oeacle Net admin error
- 第36期《AtCoder Beginner Contest 248 打比赛总结》
- Pytorch学习记录(十一):数据增强、torchvision.transforms各函数讲解
- Software architecture design - software architecture style
猜你喜欢
随机推荐
Fletter next generation graphics renderer impaller
软件架构设计——软件架构风格
JSP语法及JSTL标签
2 - software design principles
No.1.#_6 Navicat快捷键
MySQL triggers, stored procedures, stored functions
Map object map get(key)
Dva中在effects中获取state的值
RedHat6之smb服务访问速度慢解决办法记录
freemark中插入图片
Solid contract DoS attack
umi官网yarn create @umijs/umi-app 报错:文件名、目录名或卷标语法不正确
框架解析1.系统架构简介
Anaconda
类的加载与ClassLoader的理解
Duplicate key update in MySQL
尚硅谷 p290 多态性练习
What is JSON? First acquaintance with JSON
The list attribute in the entity is empty or null, and is set to an empty array
C3P0数据库连接池使用