首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >[Tensorflow][转载]cifar10数据集模型搭建与训练

[Tensorflow][转载]cifar10数据集模型搭建与训练

作者头像
云未归来
发布2025-07-18 14:42:58
发布2025-07-18 14:42:58
800
举报

# To add a new cell, type '# %%'

# To add a new markdown cell, type '# %% [markdown]'

# %%

from IPython import get_ipython

# %%

get_ipython().run_line_magic('matplotlib', 'inline')

import matplotlib as mpl

import matplotlib.pyplot as plt

import numpy as np

import os

import pandas as pd

import sklearn

import sys

import tensorflow as tf

import time

from tensorflow import keras

print(tf.__version__)

print(sys.version_info)

for module in mpl, np, pd, sklearn, tf, keras:

    print(module.__name__, module.__version__)

# %%

class_names = [

    'airplane',

    'automobile',

    'bird',

    'cat',

    'deer',

    'dog',

    'frog',

    'horse',

    'ship',

    'truck',

]

train_lables_file = './cifar10/trainLabels.csv'

test_csv_file = './cifar10/sampleSubmission.csv'

train_folder = './cifar10/train/'

test_folder = './cifar10/test'

def parse_csv_file(filepath, folder):

    """Parses csv files into (filename(path), label) format"""

    results = []

    with open(filepath, 'r') as f:

        lines = f.readlines()[1:]

    for line in lines:

        image_id, label_str = line.strip('\n').split(',')

        image_full_path = os.path.join(folder, image_id + '.png')

        results.append((image_full_path, label_str))

    return results

train_labels_info = parse_csv_file(train_lables_file, train_folder)

test_csv_info = parse_csv_file(test_csv_file, test_folder)

import pprint

pprint.pprint(train_labels_info[0:5])

pprint.pprint(test_csv_info[0:5])

print(len(train_labels_info), len(test_csv_info))

# %%

# train_df = pd.DataFrame(train_labels_info)

train_df = pd.DataFrame(train_labels_info[0:45000])

valid_df = pd.DataFrame(train_labels_info[45000:])

test_df = pd.DataFrame(test_csv_info)

train_df.columns = ['filepath', 'class']

valid_df.columns = ['filepath', 'class']

test_df.columns = ['filepath', 'class']

print(train_df.head())

print(valid_df.head())

print(test_df.head())

# %%

height = 32

width = 32

channels = 3

batch_size = 32

num_classes = 10

train_datagen = keras.preprocessing.image.ImageDataGenerator(

    rescale = 1./255,

    rotation_range = 40,

    width_shift_range = 0.2,

    height_shift_range = 0.2,

    shear_range = 0.2,

    zoom_range = 0.2,

    horizontal_flip = True,

    fill_mode = 'nearest',

)

train_generator = train_datagen.flow_from_dataframe(

    train_df,

    directory = './',

    x_col = 'filepath',

    y_col = 'class',

    classes = class_names,

    target_size = (height, width),

    batch_size = batch_size,

    seed = 7,

    shuffle = True,

    class_mode = 'sparse',

)

valid_datagen = keras.preprocessing.image.ImageDataGenerator(

    rescale = 1./255)

valid_generator = valid_datagen.flow_from_dataframe(

    valid_df,

    directory = './',

    x_col = 'filepath',

    y_col = 'class',

    classes = class_names,

    target_size = (height, width),

    batch_size = batch_size,

    seed = 7,

    shuffle = False,

    class_mode = "sparse")

train_num = train_generator.samples

valid_num = valid_generator.samples

print(train_num, valid_num)

# %%

for i in range(2):

    x, y = train_generator.next()

    print(x.shape, y.shape)

    print(y)

# %%

model = keras.models.Sequential([

    keras.layers.Conv2D(filters=128, kernel_size=3, padding='same',

                        activation='relu', 

                        input_shape=[width, height, channels]),

    keras.layers.BatchNormalization(),

    keras.layers.Conv2D(filters=128, kernel_size=3, padding='same',

                        activation='relu'),

    keras.layers.BatchNormalization(),

    keras.layers.MaxPool2D(pool_size=2),

    keras.layers.Conv2D(filters=256, kernel_size=3, padding='same',

                        activation='relu'),

    keras.layers.BatchNormalization(),

    keras.layers.Conv2D(filters=256, kernel_size=3, padding='same',

                        activation='relu'),

    keras.layers.BatchNormalization(),

    keras.layers.MaxPool2D(pool_size=2),

    keras.layers.Conv2D(filters=512, kernel_size=3, padding='same',

                        activation='relu'),

    keras.layers.BatchNormalization(),

    keras.layers.Conv2D(filters=512, kernel_size=3, padding='same',

                        activation='relu'),

    keras.layers.BatchNormalization(),

    keras.layers.MaxPool2D(pool_size=2),

    keras.layers.Flatten(),

    keras.layers.Dense(512, activation='relu'),

    keras.layers.Dense(num_classes, activation='softmax'),

])

model.compile(loss="sparse_categorical_crossentropy",

              optimizer="adam", metrics=['accuracy'])

model.summary()

# %%

epochs = 20

history = model.fit_generator(train_generator,

                              steps_per_epoch = train_num // batch_size,

                              epochs = epochs,

                              validation_data = valid_generator,

                              validation_steps = valid_num // batch_size)

# %%

def plot_learning_curves(history, label, epcohs, min_value, max_value):

    data = {}

    data[label] = history.history[label]

    data['val_'+label] = history.history['val_'+label]

    pd.DataFrame(data).plot(figsize=(8, 5))

    plt.grid(True)

    plt.axis([0, epochs, min_value, max_value])

    plt.show()

plot_learning_curves(history, 'acc', epochs, 0, 1)

plot_learning_curves(history, 'loss', epochs, 0, 2)

# %%

test_datagen = keras.preprocessing.image.ImageDataGenerator(

    rescale = 1./255)

test_generator = valid_datagen.flow_from_dataframe(

    test_df,

    directory = './',

    x_col = 'filepath',

    y_col = 'class',

    classes = class_names,

    target_size = (height, width),

    batch_size = batch_size,

    seed = 7,

    shuffle = False,

    class_mode = "sparse")

test_num = test_generator.samples

print(test_num)

# %%

test_predict = model.predict_generator(test_generator,

                                       workers = 10,

                                       use_multiprocessing = True)

# %%

print(test_predict.shape)

# %%

print(test_predict[0:5])

# %%

test_predict_class_indices = np.argmax(test_predict, axis = 1)

# %%

print(test_predict_class_indices[0:5])

# %%

test_predict_class = [class_names[index] 

                      for index in test_predict_class_indices]

# %%

print(test_predict_class[0:5])

# %%

def generate_submissions(filename, predict_class):

    with open(filename, 'w') as f:

        f.write('id,label\n')

        for i in range(len(predict_class)):

            f.write('%d,%s\n' % (i+1, predict_class[i]))

output_file = "./cifar10/submission.csv"

generate_submissions(output_file, test_predict_class)

# %%

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2020-04-09,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档