首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >[Tensorflow][转载]利用resnet迁移学习重新训练模型

[Tensorflow][转载]利用resnet迁移学习重新训练模型

作者头像
云未归来
发布2025-07-18 14:41:59
发布2025-07-18 14:41:59
1100
举报

# To add a new cell, type '# %%'

# To add a new markdown cell, type '# %% [markdown]'

# %%

from IPython import get_ipython

# %%

# This Python 3 environment comes with many helpful analytics libraries installed

# It is defined by the kaggle/python docker image: https://github.com/kaggle/docker-python

# For example, here's several helpful packages to load in 

import numpy as np # linear algebra

import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the "../input/" directory.

# For example, running this (by clicking run or pressing Shift+Enter) will list the files in the input directory

import os

print(os.listdir("../input"))

# Any results you write to the current directory are saved as output.

# %%

get_ipython().run_line_magic('matplotlib', 'inline')

import matplotlib as mpl

import matplotlib.pyplot as plt

import numpy as np

import os

import pandas as pd

import sklearn

import sys

import tensorflow as tf

import time

from tensorflow import keras

print(tf.__version__)

print(sys.version_info)

for module in mpl, np, pd, sklearn, tf, keras:

    print(module.__name__, module.__version__)

# %%

train_dir = "../input/training/training"

valid_dir = "../input/validation/validation"

label_file = "../input/monkey_labels.txt"

print(os.path.exists(train_dir))

print(os.path.exists(valid_dir))

print(os.path.exists(label_file))

print(os.listdir(train_dir))

print(os.listdir(valid_dir))

# %%

labels = pd.read_csv(label_file, header=0)

print(labels)

# %%

height = 224

width = 224

channels = 3

batch_size = 24

num_classes = 10

train_datagen = keras.preprocessing.image.ImageDataGenerator(

    preprocessing_function = keras.applications.resnet50.preprocess_input,

    rotation_range = 40,

    width_shift_range = 0.2,

    height_shift_range = 0.2,

    shear_range = 0.2,

    zoom_range = 0.2,

    horizontal_flip = True,

    fill_mode = 'nearest',

)

train_generator = train_datagen.flow_from_directory(train_dir,

                                                   target_size = (height, width),

                                                   batch_size = batch_size,

                                                   seed = 7,

                                                   shuffle = True,

                                                   class_mode = "categorical")

valid_datagen = keras.preprocessing.image.ImageDataGenerator(

    preprocessing_function = keras.applications.resnet50.preprocess_input)

valid_generator = valid_datagen.flow_from_directory(valid_dir,

                                                    target_size = (height, width),

                                                    batch_size = batch_size,

                                                    seed = 7,

                                                    shuffle = False,

                                                    class_mode = "categorical")

train_num = train_generator.samples

valid_num = valid_generator.samples

print(train_num, valid_num)

# %%

for i in range(2):

    x, y = train_generator.next()

    print(x.shape, y.shape)

    print(y)

# %%

resnet50_fine_tune = keras.models.Sequential()

resnet50_fine_tune.add(keras.applications.ResNet50(include_top = False,

                                                   pooling = 'avg',

                                                   weights = 'imagenet'))

resnet50_fine_tune.add(keras.layers.Dense(num_classes, activation = 'softmax'))

resnet50_fine_tune.layers[0].trainable = False

resnet50_fine_tune.compile(loss="categorical_crossentropy",

                           optimizer="sgd", metrics=['accuracy'])

resnet50_fine_tune.summary()

# %%

epochs = 10

history = resnet50_fine_tune.fit_generator(train_generator,

                                           steps_per_epoch = train_num // batch_size,

                                           epochs = epochs,

                                           validation_data = valid_generator,

                                           validation_steps = valid_num // batch_size)

# %%

def plot_learning_curves(history, label, epcohs, min_value, max_value):

    data = {}

    data[label] = history.history[label]

    data['val_'+label] = history.history['val_'+label]

    pd.DataFrame(data).plot(figsize=(8, 5))

    plt.grid(True)

    plt.axis([0, epochs, min_value, max_value])

    plt.show()

plot_learning_curves(history, 'acc', epochs, 0, 1)

plot_learning_curves(history, 'loss', epochs, 0, 2)

# %%

resnet50 = keras.applications.ResNet50(include_top = False,

                                       pooling = 'avg',

                                       weights = 'imagenet')

resnet50.summary()

# %%

for layer in resnet50.layers[0:-5]:

    layer.trainable = False

resnet50_new = keras.models.Sequential([

    resnet50,

    keras.layers.Dense(num_classes, activation = 'softmax'),

])

resnet50_new.compile(loss="categorical_crossentropy",

                     optimizer="sgd", metrics=['accuracy'])

resnet50_new.summary()

# %%

epochs = 10

history = resnet50_new.fit_generator(train_generator,

                                     steps_per_epoch = train_num // batch_size,

                                     epochs = epochs,

                                     validation_data = valid_generator,

                                     validation_steps = valid_num // batch_size)

# %%

plot_learning_curves(history, 'acc', epochs, 0, 1)

plot_learning_curves(history, 'loss', epochs, 0, 2)

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2020-04-08,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档