Loading [MathJax]/jax/output/CommonHTML/config.js
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >将灰度图像应用于RGB图像的角点模型

将灰度图像应用于RGB图像的角点模型
EN

Data Science用户
提问于 2020-04-05 00:10:49
回答 5查看 6K关注 0票数 4

我使用时装MNIST数据集跟踪这个基本分类TensorFlow教程。训练集包含60000个28x28像素的灰度图像,分为10个等级(裤子、套衫、鞋子等)。本教程使用了一个简单的模型:

代码语言:javascript
运行
AI代码解释
复制
model = keras.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),
    keras.layers.Dense(128, activation='relu'),
    keras.layers.Dense(10)
])

经过10个历次,该模型的精度达到91%。

我现在正在使用另一个名为CIFAR-10的数据集,它包含50,000,32*32像素的RGB图像,也分为10个类(青蛙、马、船等)。

考虑到时尚MNIST和CIFAR-10数据集在图像数量和图像大小上非常相似,而且它们的类数相同,我天真地尝试了一个类似的模型,只需调整输入形状:

代码语言:javascript
运行
AI代码解释
复制
  model = keras.Sequential([
     keras.layers.Flatten(input_shape=(32, 32, 3)),
     keras.layers.Dense(128, activation='relu'),
     keras.layers.Dense(10)
  ])

唉,经过10个年代,该模型的精度达到了45%。我做错了什么?

我知道我在RGB图像中的样本是灰度图像中的三倍,所以我尝试增加时间的数目以及中间致密层的大小,但是没有结果。

以下是我的完整代码:

代码语言:javascript
运行
AI代码解释
复制
import tensorflow as tf
import IPython.display as display
from PIL import Image
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
import pdb
import pathlib
import os
from tensorflow.keras import layers #Needed to make the model
from tensorflow.keras import datasets, layers, models

(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()

IMG_HEIGHT = 32
IMG_WIDTH = 32

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']


train_images = train_images / 255.0
test_images = test_images / 255.0

def make_model():
      model = keras.Sequential([
         keras.layers.Flatten(input_shape=(IMG_HEIGHT, IMG_WIDTH, 3)),
         keras.layers.Dense(512, activation='relu'),
         keras.layers.Dense(10)
      ])
      model.compile(optimizer='adam',
                   loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                   metrics=['accuracy'])
      return model

model=make_model()
history = model.fit(train_images, train_labels, epochs=10)
EN

回答 5

Data Science用户

回答已采纳

发布于 2020-04-05 06:47:20

您的模型不够复杂,无法对CIFAR 10数据集进行适当的分类。CIFAR-10比时装MNIST数据集要复杂得多,因此您需要一个更复杂的model.You,可以为您的模型添加更多隐藏层来实现这一点。你也应该增加辍学层,以防止过度拟合。也许最简单的解决办法是使用转移学习。如果您想尝试转移学习,我建议您使用MobileNet CNN模式。这方面的文档可以找到这里。由于CIFAR-10有5万张样本图像,我不认为您需要数据增强。首先,尝试一个更复杂的模型,而不增加,看看你达到了什么样的准确性。如果不够,那么使用keras ImageData生成器来提供数据增强。这方面的文档是这里

票数 2
EN

Data Science用户

发布于 2020-04-05 06:56:56

我正在使用这个模型(基本上是基于乔莱特的工作)。它使用预训练模型(VGG16)来处理多类图像识别问题。

代码语言:javascript
运行
AI代码解释
复制
from keras.applications import VGG16
import os, datetime
import numpy as np
from keras.preprocessing.image import ImageDataGenerator
from keras.utils import to_categorical
from keras import models, layers, optimizers, regularizers
from keras.callbacks import EarlyStopping
from keras.callbacks import ReduceLROnPlateau
from keras.layers.core import Dense, Dropout, Activation
from keras.layers.normalization import BatchNormalization
from PIL import ImageFile
import statistics
ImageFile.LOAD_TRUNCATED_IMAGES = True

###############################################
# DIR with training images
base_dir = 'C:/pathtoimages'
# Number training images
ntrain = 2000
# Number validation images
nval  = 500
# Batch size
batch_size = 20 #20
# Epochs (fine tuning [100])
ep = 400 #400
# Epochs (first step [30])
ep_first = 30 
# Number of classes (for training, output layer)
nclasses = 30
###############################################
start = datetime.datetime.now()

conv_base = VGG16(weights='imagenet', include_top=False, input_shape=(150, 150, 3))
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'val')
#test_dir = os.path.join(base_dir, 'test')

datagen = ImageDataGenerator(rescale=1./255)

def extract_features(directory, sample_count):
    features = np.zeros(shape=(sample_count, 4, 4, 512))
    labels = np.zeros(shape=(sample_count))
    generator = datagen.flow_from_directory(
        directory,
        target_size=(150, 150),
        batch_size=batch_size,
        class_mode='binary')
    i = 0
    for inputs_batch, labels_batch in generator:
        features_batch = conv_base.predict(inputs_batch)
        features[i * batch_size : (i + 1) * batch_size] = features_batch
        labels[i * batch_size : (i + 1) * batch_size] = labels_batch
        i += 1
        if i * batch_size >= sample_count:
            break
    return features, labels

train_features, train_labels = extract_features(train_dir, ntrain)
validation_features, validation_labels = extract_features(validation_dir, nval)
#test_features, test_labels = extract_features(test_dir, 1000)

# Labels and features
train_labels = to_categorical(train_labels)
validation_labels = to_categorical(validation_labels)
#test_labels = to_categorical(test_labels)
train_features = np.reshape(train_features, (ntrain, 4 * 4 * 512))
validation_features = np.reshape(validation_features, (nval, 4 * 4 * 512))
#test_features = np.reshape(test_features, (1000, 4 * 4 * 512))

#######################################
# Model
model = models.Sequential()
model.add(conv_base)
model.add(layers.Flatten())
model.add(layers.Dense(4096, activation='relu',kernel_regularizer=regularizers.l2(0.003)))#0.002
model.add(BatchNormalization())

model.add(layers.Dense(2048, activation='relu',kernel_regularizer=regularizers.l2(0.003)))#0.002
model.add(layers.Dense(2048, activation='relu',kernel_regularizer=regularizers.l2(0.003)))#0.002
model.add(BatchNormalization())

model.add(layers.Dense(1024, activation='relu',kernel_regularizer=regularizers.l2(0.003)))#0.002
model.add(layers.Dense(1024, activation='relu',kernel_regularizer=regularizers.l2(0.003)))#0.002
model.add(BatchNormalization())

model.add(layers.Dense(512, activation='relu',kernel_regularizer=regularizers.l2(0.003)))#0.002
model.add(layers.Dense(512, activation='relu',kernel_regularizer=regularizers.l2(0.003)))#0.002
model.add(BatchNormalization())

model.add(layers.Dense(256, activation='relu',kernel_regularizer=regularizers.l2(0.003)))#0.002
model.add(layers.Dense(256, activation='relu',kernel_regularizer=regularizers.l2(0.003)))#0.002
model.add(BatchNormalization())

model.add(layers.Dense(128, activation='relu',kernel_regularizer=regularizers.l2(0.003)))#0.002
model.add(layers.Dense(128, activation='relu',kernel_regularizer=regularizers.l2(0.003)))#0.002
model.add(layers.Dense(128, activation='relu',kernel_regularizer=regularizers.l2(0.003)))#0.002
model.add(layers.Dense(128, activation='relu',kernel_regularizer=regularizers.l2(0.003)))#0.002

model.add(layers.Dense(nclasses, activation='softmax'))
conv_base.trainable = False

#######################################
# Data generators
train_datagen = ImageDataGenerator(
      rescale=1./255,
      rotation_range=40,
      width_shift_range=0.2,
      height_shift_range=0.2,
      shear_range=0.2,
      zoom_range=0.2,
      horizontal_flip=True,
      fill_mode='nearest')

# Note that the validation data should not be augmented!
test_datagen = ImageDataGenerator(rescale=1./255)

train_generator = train_datagen.flow_from_directory(
        # This is the target directory
        train_dir,
        # All images will be resized to 150x150
        target_size=(150, 150),
        batch_size=batch_size,
        # Since we use categorical_crossentropy loss, we need binary labels
        class_mode='categorical')

validation_generator = test_datagen.flow_from_directory(
        validation_dir,
        target_size=(150, 150),
        batch_size=batch_size,
        class_mode='categorical')

# Model compile / fit
model.compile(loss='categorical_crossentropy',
              optimizer=optimizers.RMSprop(lr=2e-5),
              metrics=['acc'])

# early stopping: https://keras.io/callbacks/#earlystopping
es = EarlyStopping(monitor='val_loss', mode='min', min_delta=0.001, verbose=1, patience=40, restore_best_weights=True)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', mode='min', factor=0.9, patience=15, min_lr=1e-20, verbose=1, cooldown=3)

history = model.fit_generator(
      train_generator,
      steps_per_epoch=round((ntrain+nval)/batch_size,0),
      epochs=ep_first,
      validation_data=validation_generator,
      validation_steps=20, #50
      verbose=2,
      callbacks=[es, reduce_lr])

#######################################
# Fine tuning
conv_base.trainable = True

set_trainable = False
for layer in conv_base.layers:
    if layer.name == 'block5_conv1':
        set_trainable = True
    if set_trainable:
        layer.trainable = True
    else:
        layer.trainable = False

model.compile(loss='categorical_crossentropy',
              optimizer=optimizers.RMSprop(lr=0.00001), #1e-5
              metrics=['acc'])

history = model.fit_generator(
      train_generator,
      steps_per_epoch=round((ntrain+nval)/batch_size,0),
      epochs=ep,
      validation_data=validation_generator,
      validation_steps=20,
      callbacks=[es, reduce_lr])

#######################################
# Save model
model.save('C:/yourpath/yourmodel.hdf5')
end = datetime.datetime.now()
delta = str(end-start)

# Metrics
acc = history.history['acc']
acc = acc[-5:]
val_acc = history.history['val_acc']
val_acc = val_acc[-5:]
loss = history.history['loss']
loss = loss[-5:]
val_loss = history.history['val_loss']
val_loss = val_loss[-5:]

# End statement
print("============================================")
print("Time taken (h/m/s): %s" %delta[:7])
print("============================================")
print("Metrics (average last five steps)")
print("--------------------------------------------")
print("Loss       %.3f" %statistics.mean(loss))
print("Val. Loss  %.3f" %statistics.mean(val_loss))
print("--------------------------------------------")
print("Acc.       %.3f" %statistics.mean(acc))
print("Val. Acc.  %.3f" %statistics.mean(val_acc))
print("============================================")
print("Epochs:    %s / %s" %(ep,ep_first))
票数 2
EN

Data Science用户

发布于 2020-04-05 00:33:48

我想到了两件事:

您可以添加数据生成器。这将通过引入一组小变化(即随机旋转、缩放、剪切、水平/垂直移动.)从当前图像生成新图像,从而迫使模型学习不同类别图像之间的重要区别特征。

您也可以添加辍学层,以对抗过度适应。

下面是一个很好的例子:https://keras.io/examples/cifar10_cnn/

票数 1
EN
页面原文内容由Data Science提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://datascience.stackexchange.com/questions/71751

复制
相关文章
使枚举类型的选项在VS的属性窗里显示为中文
我们自己做的组件,一般希望它的属性在设计时能够在属性窗里显示为中文,可以在属性上添加System.ComponentModel.DisplayNameAttribute标注达到这个目的。但是,枚举的选项如何以中文的形式显示在属性窗里呢?
明年我18
2019/09/18
1.2K0
使枚举类型的选项在VS的属性窗里显示为中文
javascript表单提交的内容显示在表格中
实现三个文本域的内容提交之后显示在表格中,代码直接用文本文件运行,记得后缀改为.html 运行结果
别团等shy哥发育
2023/02/27
8.1K0
javascript表单提交的内容显示在表格中
R语言提取PDF文件中的文本内容
综上步骤,我们便可以随便获取任意章节的任意内容。那么接下来就是对这些文字的应用,各位集思广益吧。
一粒沙
2019/07/31
10K1
如何使特定的数据高亮显示?
当表格里数据比较多时,很多时候我们为了便于观察数据,会特意把符合某些特征的数据行高亮显示出来。这不,公司的HR小姐姐就有这个需求,说她手头上有一份招聘数据,她想把“薪水”超过20000的行突出显示出来,应该怎么操作呢?
猴子聊数据分析
2020/02/26
5.9K0
R中的向量化运算
1、R中的向量化运算-seq seq(1, 10, by=1) seq(1, 10, by=0.1) seq(1.9, 10, by=0.1) #注意,不能这样子递减 seq(10, 1, by=0.1) #注意,你可以这样子递减 seq(10, 1, by=-0.1) #除了设置步长,还可以设置均分的步数 seq(10, 1, length.out=10) seq(10, 1, length.out=100) seq(10, 1, length.out=91) #数清楚里面的个数 2、R中
Erin
2018/01/09
2.1K0
"0.1"在PL/SQL Developer和sqlplus中如何不显示为".1"?
微信群有朋友问,PL/SQL Developer显示0.1的时候自动将0删除,即".1",因此有什么方法,可以显示小数点之前的0?
bisal
2019/01/30
2.1K0
在 Linux 中如何按名称和 Grep 内容查找文件?
如果您使用该find命令递归搜索某些文件,然后将结果通过管道传递给该grep命令,那么您实际上将解析文件路径/名称,而不是它们的内容。
网络技术联盟站
2022/05/11
6.9K0
在 Linux 中如何按名称和 Grep 内容查找文件?
cat命令 – 在终端设备上显示文件内容
Linux系统中有很多个用于查看文件内容的命令,每个命令又都有自己的特点,比如这个cat命令就是用于查看内容较少的纯文本文件的。cat这个命令也很好记,因为cat在英语中是“猫”的意思,小猫咪是不是给您一种娇小、可爱的感觉呢?
用户4988085
2021/07/24
1.7K0
MATLAB中向量_向量法表示字符串
matlab中的向量是只有一行元素的数组,向量中的单个项通常称为元素。Matlab中的向量索引值从1开始,而不是从0开始。
全栈程序员站长
2022/11/17
2.4K0
MATLAB中向量_向量法表示字符串
linux 查看文件内容 显示行号
linux 系统中文件内容显示行号分为临时显示和永久显示两种,本文对两种方式进行介绍
全栈程序员站长
2022/06/25
15.3K0
linux 查看文件内容 显示行号
Python把PDF文件中每页内容分离为独立图片文件
封面图片:《Python程序设计实验指导书》(ISBN:9787302525790),董付国,清华大学出版社
Python小屋屋主
2019/07/23
1.5K0
Python把PDF文件中每页内容分离为独立图片文件
每日一题--4--在两个文件中取交集,显示指定的内容
把这个两个文件都存在的用户的密码输出出来 [root@sentinel student]# head file1 file2 ==> file1 <== oldboy 1234 alex 4567 lidao 9999 ==> file2 <== 001 lidao 002 alex 003 oldboy 004 oldgirl 提示:需要用到如何判断这两个文件不是一个文件。 解题思路 awk 'FNR==NR{h[$1]=$2}FNR!=NR{print h[$2]}' file1 fi
张琳兮
2019/03/14
1.4K0
在DragonOS中,使蜂鸣器发声
很简单,代码如下: void beep(uint64_t times) { io_out8(0x43, 182&0xff); io_out8(0x42, 2280&0xff); io_out8(0x42, (2280>>8)&0xff); uint32_t x = io_in8(0x61)&0xff; x |= 3; io_out8(0x61, x&0xff); times *= 10000; for(uint64_t i=0;i<times
灯珑LoGin
2022/10/31
4310
转义字符'\r'在Python内置函数print()中的妙用
在Python 3.x中,内置函数print()用来实现格式化输出,各参数含义请参考本文末尾的相关阅读。本文重点介绍print()函数的end参数以及转义字符'\r'的妙用。 本文末尾的相关阅读中已经
Python小屋屋主
2018/04/16
4.3K0
转义字符'\r'在Python内置函数print()中的妙用
如何将文件内容转成String字符串
以上两种方式从编码简洁度来讲,肯定是第二种好很多,但其实性能是差不多的,一个是牺牲了读的性能,另一个是牺牲了写的性能。
Java深度编程
2020/06/10
3.6K0
R沟通|​在Rstudio中运行tex文件
这期主要介绍下如何在Rstudio中运行和使用.tex文件,并给大家安利一个非常nice的模板和根据该模板制作的案例。
庄闪闪
2021/04/09
4K0
如何使用EvilTree在文件中搜索正则或关键字匹配的内容
 关于EvilTree  EvilTree是一款功能强大的文件内容搜索工具,该工具基于经典的“tree”命令实现其功能,本质上来说它就是“tree”命令的一个独立Python 3重制版。但EvilTree还增加了在文件中搜索用户提供的关键字或正则表达式的额外功能,而且还支持突出高亮显示包含匹配项的关键字/内容。  工具特性  1、当在嵌套目录结构的文件中搜索敏感信息时,能够可视化哪些文件包含用户提供的关键字/正则表达式模式以及这些文件在文件夹层次结构中的位置,这是EvilTree的一个非常显著的优势;
FB客服
2023/03/29
4.2K0
如何使用EvilTree在文件中搜索正则或关键字匹配的内容
获取类路径某个json文件中的内容字符串
实际项目中可能会有需要读取类路径下面的配置文件中的内容的需求,由于springboot项目打包的是jar包,通过文件读取获取流的方式开发的时候没有问题,但是上到linux服务器上就有问题了,对于这个问题记录一下处理的方式
在水一方
2022/09/16
2.8K0
点击加载更多

相似问题

EventKit框架中的Snooze方法?

20

Google错误“多个命名为'initWithArray:‘的方法”

30

eventKit教程

21

EventKit权限

121

EventKit提醒

11
添加站长 进交流群

领取专属 10元无门槛券

AI混元助手 在线答疑

扫码加入开发者社群
关注 腾讯云开发者公众号

洞察 腾讯核心技术

剖析业界实践案例

扫码关注腾讯云开发者公众号
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档