前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >dataloader 源码_DataLoader

dataloader 源码_DataLoader

作者头像
全栈程序员站长
发布于 2022-06-30 10:31:00
发布于 2022-06-30 10:31:00
2400
举报

大家好,又见面了,我是你们的朋友全栈君。

import paddle.fluid as fluid

import numpy as np

BATCH_NUM = 10

BATCH_SIZE = 16

EPOCH_NUM = 4

CLASS_NUM = 10

ITERABLE = True # whether the created DataLoader object is iterable

USE_GPU = False # whether to use GPU

DATA_FORMAT = ‘batch_generator’ # data format of data source user provides

def simple_net(image, label):

fc_tmp = fluid.layers.fc(image, size=CLASS_NUM)

cross_entropy = fluid.layers.softmax_with_cross_entropy(image, label)

loss = fluid.layers.reduce_mean(cross_entropy)

sgd = fluid.optimizer.SGD(learning_rate=1e-3)

sgd.minimize(loss)

return loss

def get_random_images_and_labels(image_shape, label_shape):

image = np.random.random(size=image_shape).astype(‘float32’)

label = np.random.random(size=label_shape).astype(‘int64’)

return image, label

# If the data generator yields one sample each time,

# use DataLoader.set_sample_generator to set the data source.

def sample_generator_creator():

def __reader__():

for _ in range(BATCH_NUM * BATCH_SIZE):

image, label = get_random_images_and_labels([784], [1])

yield image, label

return __reader__

# If the data generator yield list of samples each time,

# use DataLoader.set_sample_list_generator to set the data source.

def sample_list_generator_creator():

def __reader__():

for _ in range(BATCH_NUM):

sample_list = []

for _ in range(BATCH_SIZE):

image, label = get_random_images_and_labels([784], [1])

sample_list.append([image, label])

yield sample_list

return __reader__

# If the data generator yields a batch each time,

# use DataLoader.set_batch_generator to set the data source.

def batch_generator_creator():

def __reader__():

for _ in range(BATCH_NUM):

batch_image, batch_label = get_random_images_and_labels([BATCH_SIZE, 784], [BATCH_SIZE, 1])

yield batch_image, batch_label

return __reader__

# If DataLoader is iterable, use for loop to train the network

def train_iterable(exe, prog, loss, loader):

for _ in range(EPOCH_NUM):

for data in loader():

exe.run(prog, feed=data, fetch_list=[loss])

# If DataLoader is not iterable, use start() and reset() method to control the process

def train_non_iterable(exe, prog, loss, loader):

for _ in range(EPOCH_NUM):

loader.start() # call DataLoader.start() before each epoch starts

try:

while True:

exe.run(prog, fetch_list=[loss])

except fluid.core.EOFException:

loader.reset() # call DataLoader.reset() after catching EOFException

def set_data_source(loader, places):

if DATA_FORMAT == ‘sample_generator’:

loader.set_sample_generator(sample_generator_creator(), batch_size=BATCH_SIZE, drop_last=True, places=places)

elif DATA_FORMAT == ‘sample_list_generator’:

loader.set_sample_list_generator(sample_list_generator_creator(), places=places)

elif DATA_FORMAT == ‘batch_generator’:

loader.set_batch_generator(batch_generator_creator(), places=places)

else:

raise ValueError(‘Unsupported data format’)

image = fluid.layers.data(name=’image’, shape=[784], dtype=’float32′)

label = fluid.layers.data(name=’label’, shape=[1], dtype=’int64′)

# Define DataLoader

loader = fluid.io.DataLoader.from_generator(feed_list=[image, label], capacity=16, iterable=ITERABLE)

# Define network

loss = simple_net(image, label)

# Set data source of DataLoader

#

# If DataLoader is iterable, places must be given and the number of places must be the same with device number.

# – If you are using GPU, call `fluid.cuda_places()` to get all GPU places.

# – If you are using CPU, call `fluid.cpu_places()` to get all CPU places.

#

# If DataLoader is not iterable, places can be None.

places = fluid.cuda_places() if USE_GPU else fluid.cpu_places()

set_data_source(loader, places)

exe = fluid.Executor(places[0])

exe.run(fluid.default_startup_program())

prog = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel(loss_name=loss.name)

if loader.iterable:

train_iterable(exe, prog, loss, loader)

else:

train_non_iterable(exe, prog, loss, loader)

”’

Users can use return_list = True in dygraph mode.

”’

with fluid.dygraph.guard(places[0]):

loader = fluid.io.DataLoader.from_generator(capacity=2, return_list=True)

set_data_source(loader, places[0])

for image, label in loader():

relu = fluid.layers.relu(image)

assert image.shape == [BATCH_SIZE, 784]

assert label.shape == [BATCH_SIZE, 1]

assert relu.shape == [BATCH_SIZE, 784]

发布者:全栈程序员栈长,转载请注明出处:https://javaforall.cn/132199.html原文链接:https://javaforall.cn

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档