我是机器学习的新手,我目前正在努力创建一个能够预测品牌标识相似性的暹罗网络。我有一个210.000商标的数据集。用于暹罗网络的CNN看起来如下:
def build_cnn(inputShape, embeddingDim=48):
# specify the inputs for the feature extractor network
inputs = Input(shape=inputShape)
# define the first set of CONV => RELU => POOL => DROPOUT layers
x = Conv2D(64, (2, 2), padding="same", activation="relu")(inputs)
x = MaxPooling2D(pool_size=(5, 5))(x)
x = Dropout(0.3)(x)
# second set of CONV => RELU => POOL => DROPOUT layers
x = Conv2D(64, (2, 2), padding="same", activation="relu")(x)
x = MaxPooling2D(pool_size=2)(x)
x = Dropout(0.3)(x)
pooledOutput = GlobalAveragePooling2D()(x)
outputs = Dense(embeddingDim)(pooledOutput)
# build the model
model = Model(inputs, outputs)
model.summary()
plot_model(model, to_file=os.path.sep.join([config.BASE_OUTPUT,'model_cnn.png']))
# return the model to the calling function
return model

暹罗网络看起来是这样的(这里的模型是上面描述的cnn ):
imgA = Input(shape=config.IMG_SHAPE)
imgB = Input(shape=config.IMG_SHAPE)
featureExtractor = siamese_network.build_cnn(config.IMG_SHAPE)
featsA = featureExtractor(imgA)
featsB = featureExtractor(imgB)
distance = Lambda(euclidean_distance)([featsA, featsB])
outputs = Dense(1, activation="sigmoid")(distance)
model = Model(inputs=[imgA, imgB], outputs=outputs)

我的第一次测试是对800对阳性和800对阴性的测试,其准确性和丢失情况如下:

我的想法是,有一些过度拟合的发生,我的方法是创建更多的训练数据(2,000对正负对),并再次训练模型,但不幸的是,即使在20+时代之后,该模型也没有得到改进。

对于这两种情况,我使用以下方法来训练我的网络:
model.compile(loss="binary_crossentropy", optimizer="adam", metrics=["accuracy"])
print("[INFO] training model...")
history = model.fit(
[pairTrain[:, 0], pairTrain[:, 1]], labelTrain[:],
validation_data=([pairTest[:, 0], pairTest[:, 1]], labelTest[:]),
batch_size=10,
shuffle=True,
epochs=50)我不知道这里发生了什么,所以我真的很感激你的每一个帮助。我在这里的问题是,为什么连体网络学习(或者至少看起来像是在学习)的训练数据较少,但一旦我增加更多,准确性是不变的,根本没有改进?
根据艾伯托斯的评论,编辑,我试着用selu (仍然不起作用):

EDIT2和LeakyReLU看起来是这样的:

我最近的10k对训练结果如下:

发布于 2022-07-26 17:36:27
我也曾见过这样的事情发生,如果是这样的话,可以去看看,但是这个吉特布的问题实际上是同样的损失。
在这种情况下,我认为这更像是初始化问题,所以在这一点上,我认为您应该使用He-initialziation或非饱和激活函数(例如,尝试tf.keras.layers.LeakyReLU或tf.keras.activations.selu)。
https://stackoverflow.com/questions/73126888
复制相似问题