首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >paddle深度学习8 自定义数据集

paddle深度学习8 自定义数据集

原创
作者头像
用户11104668
发布2025-01-14 10:17:16
发布2025-01-14 10:17:16
3760
举报
文章被收录于专栏:paddle深度学习paddle深度学习

除了Paddle中一些已经包含的常用数据集,在实际的深度学习项目中,经常需要使用自定义的数据集(以便灵活地使用一些其它地外部数据集)进行训练和测试。PaddlePaddle 提供了灵活的工具来加载和处理自定义数据集。下面我们将详细介绍如何使用 PaddlePaddle 加载和使用一个简单的二维空间点的二分类数据集。

【准备自定义数据集】

假设要完成一个二维空间点的二分类任务,数据集的结构如下:

l 每个样本由两个浮点数 (x1, x2) 组成,表示二维空间中的一个点。

l 标签 label 为 0 或 1,表示该点属于哪一类别。

可以用 Python 生成一个简单的数据集:

import numpy as np

np.random.seed(123)

num_samples = 1000

# 生成两类数据点

class_0 = np.random.normal(loc=[-1, -1], scale=0.5, size=(num_samples // 2, 2))

class_1 = np.random.normal(loc=[1, 1], scale=0.5, size=(num_samples // 2, 2))

data = np.vstack([class_0, class_1])

labels = np.hstack([np.zeros(num_samples // 2), np.ones(num_samples // 2)])

# 打乱数据

indices = np.arange(num_samples)

np.random.shuffle(indices)

data = data[indices]

labels = labels[indices]

# 划分训练集和验证集

train_ratio = 0.8

train_size = int(num_samples * train_ratio)

train_data, val_data = data[:train_size], data[train_size:]

train_labels, val_labels = labels[:train_size], labels[train_size:]

print(len(train_data))

print(len(val_data))

可以查看一下这些数据的分布情况

import matplotlib.pyplot as plt

plt.figure(figsize=(8, 4))

plt.subplot(1, 2, 1)

plt.scatter(train_data[:, 0], train_data[:, 1], c=train_labels, cmap='viridis')

plt.title('Train Data Distribution')

plt.subplot(1, 2, 2)

plt.scatter(val_data[:, 0], val_data[:, 1], c=val_labels, cmap='viridis')

plt.title('Validation Data Distribution')

plt.show()

结果是随机的

【加载自定义数据集】

PaddlePaddle 提供了 paddle.io.Dataset 类,我们可以通过继承这个类来定义自己的数据集

import paddle

class MyDataset(paddle.io.Dataset):

def __init__(self, data, labels):

super(MyDataset, self).__init__()

self.data = data

self.labels = labels

def __getitem__(self, idx):

# 返回单个样本和标签

sample = self.data[idx]

label = self.labels[idx]

return sample.astype('float32'), label.astype('int64') #转换为Paddle支持的类型

def __len__(self):

# 返回数据集长度

return len(self.labels)

# 创建训练集和验证集

train_dataset = MyDataset(train_data, train_labels)

val_dataset = MyDataset(val_data, val_labels)

# 打印 train_dataset 的前5项

for i in range(5):

sample, label = train_dataset[i]

print(f"Sample {i+1}: {sample}, Label: {label}")

【使用 paddle.io.DataLoader 加载数据】

定义好数据集后,惯用的做法是使用 paddle.io.DataLoader 来加载数据,需要把数据集转换为DataLoader类型

# 创建 DataLoader

train_loader = paddle.io.DataLoader(train_dataset, batch_size=4, shuffle=True)

val_loader = paddle.io.DataLoader(val_dataset, batch_size=4, shuffle=False)

# 使用 DataLoader 进行迭代

for i, data in enumerate(train_loader):

samples, labels = data

print(f"Batch {i}, Samples shape: {samples.shape}, Labels shape: {labels.shape}")

if i==4:

print('提前结束')

break

shuffle参数表示是否打乱数据

dataloader会重新对数据进行分批,每次读入的数据将不再是单个,而是多个,batch_size表示每次读入的数据个数

因此[4,2]表示每批数据集的数据为4个包含(x,y)的数据点,而[4]表示每批数据有4个标签值

对dalaloader的迭代会直到数据集的最后一个数据为止,为了防止输出过长,我们这里使用break提前结束迭代

【用tqdm显示进度条】

在实际项目中,为了运行过程更直观,会使用tqdm工具显示数据集的加载进度

from tqdm.notebook import tqdm

import warnings

import time

warnings.filterwarnings('ignore',message="This function will be removed in tqdm")

for i, data in enumerate(tqdm(train_loader)):

samples, labels = data

time.sleep(0.1)

注意这里的time.sleep并不是必须的,只是为了视觉效果增加了一些运行时间

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

原创声明:本文系作者授权腾讯云开发者社区发表,未经许可,不得转载。

如有侵权,请联系 cloudcommunity@tencent.com 删除。

评论
作者已关闭评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档