前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >专栏 >TensorFlow域对抗训练DANN神经网络分析MNIST与Blobs数据集梯度反转层提升目标域适应能力可视化

TensorFlow域对抗训练DANN神经网络分析MNIST与Blobs数据集梯度反转层提升目标域适应能力可视化

作者头像
拓端
发布于 2025-02-08 06:14:23
发布于 2025-02-08 06:14:23
18300
代码可运行
举报
文章被收录于专栏:拓端tecdat拓端tecdat
运行总次数:0
代码可运行

本文围绕基于TensorFlow实现的神经网络对抗训练域适应方法展开研究。详细介绍了梯度反转层的原理与实现,通过MNIST和Blobs等数据集进行实验,对比了不同训练方式(仅源域训练、域对抗训练等)下的分类性能。结果表明,域对抗训练能够有效提升模型在目标域上的适应能力,为解决无监督域适应问题提供了一种有效的途径点击文末“阅读原文”获取完整代码、数据、远程指导)。

机器学习深度学习领域,域适应是一个重要的研究方向。不同数据源(即不同域)之间往往存在分布差异,这使得在一个域上训练的模型在另一个域上的性能显著下降。“Unsupervised Domain Adaptation by Backpropagation” 论文提出了一种简单有效的方法,通过随机梯度下降(SGD)和梯度反转层来实现域适应。后续的 “Domain - Adversarial Training of Neural Networks” 对该工作进行了详细阐述和扩展。

梯度反转层

梯度反转层是实现域对抗训练的关键。

代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
# 反转 x 关于 y 的梯度,并按 l 进行缩放(默认为 1.0)
y = flip_gradient(x, l)
MNIST
构建MNIST - M数据集
实验结果对比

以下是大致的结果:

Blobs - DANN
Blob数据集
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
# 绘制数据集
plt.scatter(Xs\[:, 0\], Xs\[:, 1\], c=ys, cmap='coolwarm', alpha=0.4)
plt.scatter(Xt\[:, 0\], Xt\[:, 1\], c=yt, cmap='cool', alpha=0.4)
plt.show()
Blob数据集可视化
Blob数据集可视化

构建模型

不同训练方式的实验
  • 域分类:设置 grad_scale=-1.0 可以有效关闭梯度反转。仅训练域分类器会创建使类别合并的表示。
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
 train\_loss = sess.graph.get\_tensor\_by\_name(train\_loss\_name + ':0')
 train\_op = sess.graph.get\_operation\_by\_name(train\_op\_name)
 sess.run(tf.global\_variables\_initializer())
 for i in range(num_batches):
 if grad_scale is None:
不同训练方式的实验
  • 域分类
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
 F = sess.graph.get\_tensor\_by\_name(feat\_tensor_name + ':0')
 emb\_s = sess.run(F, feed\_dict={'X:0': Xs})
 emb\_t = sess.run(F, feed\_dict={'X:0': Xt})
 emb\_all = np.vstack(\[emb\_s, emb_t\])
 pca = PCA(n_components=2)
 pca\_emb = pca.fit\_transform(emb_all)
 num = pca_emb.shape\[0\] // 2
 plt.scatter(pca\_emb\[:num, 0\], pca\_emb\[:num, 1\], c=ys, cmap='coolwarm', alpha=0.4)
 plt.scatter(pca\_emb\[num:, 0\], pca\_emb\[num:, 1\], c=yt, cmap='cool', alpha=0.4)
 plt.show()
train\_and\_evaluate(sess, 'domain\_train\_op', 'domain\_loss', grad\_scale=-1.0, verbose=False)
extract\_and\_plot\_pca\_feats(sess)

运行结果如下:

域分类PCA特征可视化
域分类PCA特征可视化

从结果可以看出,仅训练域分类器时,模型能够很好地区分源域和目标域,但对类别的区分能力较差,这表明这种训练方式创建的表示使类别合并了。

  • 标签分类

运行结果如下:

标签分类PCA特征可视化
标签分类PCA特征可视化

在源域上进行标签预测训练时,模型在源域上能够很好地区分不同类别,但在目标域上的类别区分能力较差,说明这种训练方式对目标域的适应能力不足。

  • 域适应

运行结果如下:

域适应PCA特征可视化
域适应PCA特征可视化

使用域对抗损失进行训练时,模型在源域和目标域上的类别分类准确率都较高,说明域对抗训练能够有效提升模型在目标域上的适应能力。

  • 更深的域分类器的域适应

运行结果如下:

更深域分类器的域适应PCA特征可视化
更深域分类器的域适应PCA特征可视化

使用更深的域分类器进行域适应训练时,在多次实验中似乎更能可靠地合并域,同时保持较高的类别分类准确率。

MNIST - DANN

数据处理

在数据处理阶段,我们对MNIST和MNIST - M数据集进行了预处理。对于MNIST数据,将其转换为适合卷积神经网络输入的格式,并扩展为三通道图像。MNIST - M数据则直接从之前生成的 pkl 文件中加载。通过计算像素均值,我们对数据进行归一化处理,这有助于提高模型的训练效果。最后,创建了一个混合数据集用于后续的TSNE可视化,方便我们直观地观察模型在不同域上的特征分布情况。

数据可视化
MNIST训练数据可视化
MNIST训练数据可视化
MNIST - M训练数据可视化
MNIST - M训练数据可视化

通过 函数对MNIST和MNIST - M的训练数据进行可视化展示,我们可以直观地看到两个数据集之间的差异,这也体现了域适应问题的挑战性,即不同域之间的数据分布存在明显差异。

构建模型
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
 # 特征提取器 - CNN模型

 b\_conv1 = bias\_variable(\[48\])
 h\_conv1 = tf.nn.relu(conv2d(h\_pool0, W\_conv1) + b\_conv1)
 h\_pool1 = max\_pool\_2x2(h\_conv1)
 self.feature = tf.reshape(h_pool1, \[-1, 7 * 7 * 48\])
 # 标签预测器 - MLP模型
 with tf.variable\_scope('label\_predictor'):
 
 W\_fc2 = weight\_variable(\[100, 10\])
 b\_fc2 = bias\_variable(\[10\])
 logits = tf.matmul(h\_fc1, W\_fc2) + b_fc2
 self.pred = tf.nn.softmax(logits)
 self.pred\_loss = tf.nn.softmax\_cross\_entropy\_with\_logits(logits=logits, labels=self.classify\_labels)
 # 域预测器 -MLP模型,带有对抗损失

 d\_b\_fc1 = bias_variable(\[2\])
 d\_logits = tf.matmul(d\_h\_fc0, d\_W\_fc1) + d\_b_fc1
 self.domain\_pred = tf.nn.softmax(d\_logits)
 self.domain\_loss = tf.nn.softmax\_cross\_entropy\_with\_logits(logits=d\_logits, labels=self.domain)

该模型主要由三个部分组成:特征提取器、标签预测器和域预测器。特征提取器使用卷积神经网络(CNN)从输入图像中提取特征;标签预测器是一个多层感知机(MLP),用于对图像的类别进行预测;域预测器同样是一个MLP,用于判断输入数据来自源域还是目标域。在域预测器中,使用了梯度反转层 flip_gradient 来实现对抗训练,使得特征提取器学习到的特征能够在不同域之间具有不变性。

模型训练与评估

上述代码实现了两种训练模式:仅在源域上训练(source)和使用域对抗训练(dann)。在训练过程中,根据论文中的方法动态调整适应参数 l 和学习率 lr。 运行结果如下:

从结果可以看出,仅在源域上训练时,模型在源域(MNIST)上有较高的准确率,但在目标域(MNIST - M)上的准确率较低,说明模型对目标域的适应能力较差。而使用域对抗训练后,虽然源域的准确率略有下降,但目标域的准确率有了显著提升,表明域对抗训练有效地提高了模型在不同域之间的泛化能力。

特征可视化
代码语言:javascript
代码运行次数:0
运行
AI代码解释
复制
plot\_embedding(dann\_tsne

通过t - 分布随机邻域嵌入(t - SNE)方法将高维特征映射到二维空间进行可视化。从可视化结果可以直观地看到,仅在源域上训练时,源域和目标域的数据在特征空间中分离明显,说明模型没有学习到域不变的特征。而使用域对抗训练后,源域和目标域的数据在特征空间中更加接近,表明模型学习到了更具泛化性的特征,能够更好地适应不同的域。

结论

本文详细介绍了基于TensorFlow实现的神经网络对抗训练域适应方法。通过梯度反转层和域对抗训练,模型能够学习到域不变的特征,从而提高在目标域上的分类性能。在MNIST和Blobs数据集上的实验结果表明,域对抗训练相比于仅在源域上训练,能够显著提升模型在目标域上的准确率。同时,通过特征可视化可以直观地观察到域对抗训练对特征分布的影响,进一步验证了该方法的有效性。未来的研究可以考虑在更复杂的数据集和任务上应用该方法,以及探索如何进一步优化域对抗训练的效果。

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2025-02-07,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 拓端数据部落 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
暂无评论
推荐阅读
编辑精选文章
换一批
Tensorflow之 CNN卷积神经网络的MNIST手写数字识别
前言 tensorflow中文社区对官方文档进行了完整翻译。鉴于官方更新不少内容,而现有的翻译基本上都已过时。故本人对更新后文档进行翻译工作,纰漏之处请大家指正。(如需了解其他方面知识,可参阅以下Tensorflow系列文章)。 深入MNIST TensorFlow是一个非常强大的用来做大规模数值计算的库。其所擅长的任务之一就是实现以及训练深度神经网络。在本教程中,通过为MNIST构建一个深度卷积神经网络的分类器,我们将学到构建一个TensorFlow模型的基本步骤。 这个教程假设你已经熟悉神经网络和MNI
用户1332428
2018/03/08
1.6K0
Tensorflow之 CNN卷积神经网络的MNIST手写数字识别
【深度域自适应】一、DANN与梯度反转层(GRL)详解
在当前人工智能的如火如荼在各行各业得到广泛应用,尤其是人工智能也因此从各个方面影响当前人们的衣食住行等日常生活。这背后的原因都是因为如CNN、RNN、LSTM和GAN等各种深度神经网络的强大性能,在各个应用场景中解决了各种难题。
AI那点小事
2022/01/21
3.5K0
【深度域自适应】一、DANN与梯度反转层(GRL)详解
【深度域适配】一、DANN与梯度反转层(GRL)详解
CSDN博客原文链接:https://blog.csdn.net/qq_30091945/article/details/104478550
AI那点小事
2020/04/15
6K0
【深度域适配】一、DANN与梯度反转层(GRL)详解
【深度域自适应】二、利用DANN实现MNIST和MNIST-M数据集迁移训练
在前一篇文章【深度域自适应】一、DANN与梯度反转层(GRL)详解中,我们主要讲解了DANN的网络架构与梯度反转层(GRL)的基本原理,接下来这篇文章中我们将主要复现DANN论文Unsupervised Domain Adaptation by Backpropagation中MNIST和MNIST-M数据集的迁移训练实验。
AI那点小事
2022/01/21
1.5K0
【深度域自适应】二、利用DANN实现MNIST和MNIST-M数据集迁移训练
TensorFlow 卷积神经网络实用指南:6~10
本章将介绍一种与到目前为止所看到的模型稍有不同的模型。 到目前为止提供的所有模型都属于一种称为判别模型的模型。 判别模型旨在找到不同类别之间的界限。 他们对找到P(Y|X)-给定某些输入X的输出Y的概率感兴趣。 这是用于分类的自然概率分布,因为您通常要在给定一些输入X的情况下找到标签Y。
ApacheCN_飞龙
2023/04/23
7120
基于TensorFlow卷积神经网络与MNIST数据集设计手写数字识别算法
TensorFlow是一个基于Python和基于数据流编程的机器学习框架,由谷歌基于DistBelief进行研发,并在图形分类、音频处理、推荐系统和自然语言处理等场景下有着丰富的应用。2015年11月9日,TensorFlow依据Apache 2.0 开源协议开放源代码。
润森
2022/09/22
7900
基于TensorFlow卷积神经网络与MNIST数据集设计手写数字识别算法
Tensorflow MNIST CNN 手写数字识别
Tesorflow实现基于MNIST数据集上简单CNN: https://github.com/Asurada2015/TF_Cookbook/blob/master/08_Convolutional_Neural_Networks/02_Intro_to_CNN_MNIST/02_introductory_cnn.py
演化计算与人工智能
2020/08/14
7250
Tensorflow MNIST CNN 手写数字识别
【TensorFlow】TensorFlow 的卷积神经网络 CNN - 无TensorBoard版
本文介绍了如何使用深度学习实现图像分类,并通过CIFAR-10数据集进行了实验。首先,作者介绍了如何使用卷积神经网络(CNN)进行图像分类,并给出了详细的理论推导。其次,作者介绍了在CIFAR-10数据集上如何进行数据扩充,并给出了具体的代码实现。最后,作者对实验结果进行了分析,并给出了在实验过程中需要注意的一些问题。
Alan Lee
2018/01/08
9040
【TensorFlow】TensorFlow 的卷积神经网络 CNN - 无TensorBoard版
TensorFlow 卷积神经网络实用指南:1~5
TensorFlow 是 Google 创建的开源软件库,可让您构建和执行数据流图以进行数值计算。 在这些图中,每个节点表示要执行的某些计算或功能,连接节点的图边表示它们之间流动的数据。 在 TensorFlow 中,数据是称为张量的多维数组。 张量围绕图流动,因此命名为 TensorFlow。
ApacheCN_飞龙
2023/04/23
1.1K0
TensorFlow 深度学习第二版:1~5
人工神经网络利用了 DL 的概念 。它们是人类神经系统的抽象表示,其中包含一组神经元,这些神经元通过称为轴突的连接相互通信。
ApacheCN_飞龙
2023/04/23
1.7K0
TensorFlow 深度学习第二版:1~5
tensorflow的基本用法——使用MNIST训练神经网络
本文主要是使用tensorflow和mnist数据集来训练神经网络。 #!/usr/bin/env python # _*_ coding: utf-8 _*_ import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # 下载mnist数据 mnist = input_data.read_data_sets('MNIST_data', one_hot=True) # 定义神经网络模型的评估
Tyan
2019/05/25
6600
如何使用TensorFlow实现卷积神经网络
编者按:本文节选自图书《TensorFlow实战》第五章,本书将重点从实用的层面,为读者讲解如何使用TensorFlow实现全连接神经网络、卷积神经网络、循环神经网络,乃至Deep Q-Network。同时结合TensorFlow原理,以及深度学习的部分知识,尽可能让读者通过学习本书做出实际项目和成果。 卷积神经网络简介 卷积神经网络(Convolutional Neural Network,CNN)最初是为解决图像识别等问题设计的,当然其现在的应用不仅限于图像和视频,也可用于时间序列信号,比如音频信号
用户1737318
2018/07/20
6660
入门 | Tensorflow实战讲解神经网络搭建详细过程
作者 | AI小昕 编辑 | 磐石 出品 | 磐创AI技术团队 【磐创AI导读】:本文详细介绍了神经网络在实战过程中的构建与调节方式。主欢迎大家点击上方蓝字关注我们的公众号:磐创AI。点击公众号下方文
磐创AI
2018/07/03
5300
Hello TensorFlow : MINST数据集识别
我们需要做的就是通过算法让电脑能够识别出图片中的数字,是不是像识别验证码一样。 本文会介绍两种方法:
Awesome_Tang
2018/12/27
1.3K0
Tensorflow入门1-CNN网络及MNIST例子讲解
人工智能自从阿尔法狗大败李世石后就异常火爆,最近工作中需要探索AI在移动端的应用,趁着这个计划入门下深度学习吧。
用户3578099
2019/08/16
1.3K0
基于tensorflow实现简单卷积神经网络Lenet5
参考博客:https://blog.csdn.net/u012871279/article/details/78037984 https://blog.csdn.net/u014380165/article/details/77284921 目前人工智能神经网络已经成为非常火的一门技术,今天就用tensorflow来实现神经网络的第一块敲门砖。 首先先分模块解释代码。 1.先导入模块,若没有tensorflow还需去网上下载,这里使用mnist训练集来训练,进行手写数字的识别。 from tensorflo
徐飞机
2018/05/15
1.1K0
TF图层指南:构建卷积神经网络
本文介绍了如何利用TensorFlow搭建一个简单的CNN模型来识别MNIST数据集中的手写数字。首先,介绍了CNN模型的基本原理和TensorFlow中的Keras API。然后,使用MNIST数据集训练了一个具有卷积层和全连接层的CNN模型。最后,通过在测试集上评估模型的性能,得到了97.3%的准确率。
片刻
2018/01/05
2.5K0
TF图层指南:构建卷积神经网络
TensorFlow - TF-Slim 使用总览
虽然这里是采用 TF-Slim 处理图像分类问题,还需要安装 TF-Slim 图像模型库 tensorflow/models/research/slim. 假设该库的安装路径为 TF_MODELS. 添加 TF_MODELS/research/slim 到 python path.
狼啸风云
2019/07/08
3K0
深度学习_1_神经网络_2_深度神经网络
​ sigmod f(x)=1/(1+e^(-x)) 计算量大,反向传播时 容易出现梯度爆炸
Dean0731
2020/05/08
5800
TensorFlow2.0代码实战专栏(五):神经网络示例
原项目 | https://github.com/aymericdamien/TensorFlow-Examples/
磐创AI
2019/12/11
2.2K0
推荐阅读
相关推荐
Tensorflow之 CNN卷积神经网络的MNIST手写数字识别
更多 >
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档
本文部分代码块支持一键运行,欢迎体验
本文部分代码块支持一键运行,欢迎体验