Loading [MathJax]/jax/output/CommonHTML/config.js
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
社区首页 >问答首页 >我们的训练/验证损失曲线很好,但测试性能受到影响

我们的训练/验证损失曲线很好,但测试性能受到影响
EN

Stack Overflow用户
提问于 2021-07-04 15:59:34
回答 2查看 43关注 0票数 0

我们目前正在研究一种从胸部x光图像中检测肺结核的图像分类任务。你可以在下面看到我们的代码。我们将0.7用于训练集,0.2用于验证集,0.1用于测试集。我们的训练和验证损失在这里

但是当我们在我们的测试数据集上尝试它时,我们得到的是:

我们的代码有什么问题吗?提前谢谢你。

代码语言:javascript
运行
AI代码解释
复制
from tensorflow import keras
from keras.applications.mobilenet_v2 import MobileNetV2
from keras.applications.mobilenet_v2 import preprocess_input
from keras.layers import Dense, Flatten
from keras.models import Sequential
from keras.losses import BinaryCrossentropy
from tensorflow.keras.optimizers import Adam
from keras.preprocessing.image import ImageDataGenerator
from datetime import datetime, date
from keras.callbacks import ModelCheckpoint
from sklearn.metrics import confusion_matrix, classification_report
import numpy as np

#Loading a pre-trained model
image_size = 224

base_model = MobileNetV2(input_shape=(image_size,image_size,3), weights='imagenet', include_top=False)

for layer in base_model.layers:
    layer.trainable = False

model = Sequential()
model.add(base_model)
model.add(Flatten())
model.add(Dense(1000, activation='relu'))
model.add(Dense(2, activation="sigmoid"))

loss_func = BinaryCrossentropy()
opt = Adam(learning_rate=0.001)

model.compile(loss=loss_func,
              optimizer=opt,  
              metrics=['accuracy'])

#Training
test_path = '...' 
val_path = '...'

datagen = ImageDataGenerator(rescale = 1./255,horizontal_flip = True, shear_range = 0.2, zoom_range=0.2)

batch_size=32
validation_size=8

train_set = datagen.flow_from_directory(test_path, 
                                        target_size = (image_size, image_size),
                                        batch_size=batch_size,
                                        class_mode = 'categorical')

validation_set = datagen.flow_from_directory(val_path, 
                                             target_size = (image_size, image_size),
                                             batch_size=validation_size,
                                             class_mode = 'categorical')

#Fitting the data to the model
model_name = 'MobileNetV2'
date_today= date.today().strftime('%m_%d_%Y')

checkpoint = ModelCheckpoint(filepath=f'Models/{model_name}_{date_today}.h5',
                             monitor='val_loss',
                             mode='min',
                             verbose=1, 
                             save_best_only=True)

model_history = model.fit(train_set, 
                          validation_data=validation_set,
                          epochs=100,
                          steps_per_epoch=len(train_set)//batch_size,
                          validation_steps=len(validation_set)//validation_size,
                          callbacks=[checkpoint],
                          verbose=1)


#Testing the model on the test set
test_path = '...'

test_datagen = ImageDataGenerator()

test_set = test_datagen.flow_from_directory(test_path,
                                            target_size = (image_size, image_size),
                                            class_mode = 'categorical')

predictions = model.predict(test_set, verbose=1)

y_pred = np.argmax(predictions, axis=1)
class_labels = list(test_set.class_indices.keys())  

print('Classification Report')
clsf = classification_report(test_set.classes, y_pred, target_names=class_labels)
print(clsf)
print('\n')
print('Confusion Matrix')
cfm = confusion_matrix(test_set.classes, y_pred)
print(cfm)
EN

回答 2

Stack Overflow用户

发布于 2021-07-04 18:04:48

代码是正确的,但是,我发现了一个小错误,那就是你在sigmoid输出层分配了2个单元。这是不正确的;应该有一个单元,因为这是一个二进制分类问题。如下所示:

代码语言:javascript
运行
AI代码解释
复制
model.add(Dense(1, activation="sigmoid"))
票数 0
EN

Stack Overflow用户

发布于 2021-07-05 01:35:29

肺结核是一个复杂的物体,具有复杂的特征。因此,测试集可能会产生意外的结果。要避免这一点,您必须修改您的网络并合并其他训练图像。您可以尝试转移学习,但如果要从中转移参数的网络是在与结核病完全无关的对象上进行训练的,那么它可能并不合适。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68246532

复制
相关文章

相似问题

添加站长 进交流群

领取专属 10元无门槛券

AI混元助手 在线答疑

扫码加入开发者社群
关注 腾讯云开发者公众号

洞察 腾讯核心技术

剖析业界实践案例

扫码关注腾讯云开发者公众号
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档