Loading [MathJax]/jax/output/CommonHTML/config.js
首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
MCP广场
社区首页 >问答首页 >TensorFlow:图表已完成,无法修改?

TensorFlow:图表已完成,无法修改?

提问于 2020-06-25 15:06:18
回答 0关注 0查看 235

我在进行TensorFlow的分布式训练,想通过筛选梯度的方式,来实现较少通信时长的目的。这是我的代码:

import time

import tensorflow as tf

import numpy as np

from tensorflow.examples.tutorials.mnist import input_data # 数据的获取不是本章重点,这里直接导入

FLAGS = tf.app.flags.FLAGS

tf.app.flags.DEFINE_integer('thread_steps', 0, 'Steps run before sync gradients.')

tf.app.flags.DEFINE_string('data_dir', '/tmp/mnist-data', 'Directory for storing mnist data')

tf.app.flags.DEFINE_string("job_name", "worker", "ps or worker")

tf.app.flags.DEFINE_integer("task_id", 0, "Task ID of the worker/ps running the train")

tf.app.flags.DEFINE_string("ps_hosts", "localhost:2222", "ps机")

tf.app.flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224", "worker机,用逗号隔开")

全局变量

MODEL_DIR = "./distribute_model_ckpt/"

DATA_DIR = "./data/mnist/"

BATCH_SIZE = 32

THREAD_STEPS = FLAGS.thread_steps

main函数

def main(self):

代码语言:txt
AI代码解释
复制
# ==========  STEP1: 读取数据  ========== #
代码语言:txt
AI代码解释
复制
mnist = input\_data.read\_data\_sets(FLAGS.data\_dir, one\_hot=True)    # 读取数据

代码语言:txt
AI代码解释
复制
# ==========  STEP2: 声明集群  ========== #
代码语言:txt
AI代码解释
复制
# 构建集群ClusterSpec和服务声明
代码语言:txt
AI代码解释
复制
ps\_hosts = FLAGS.ps\_hosts.split(",")
代码语言:txt
AI代码解释
复制
worker\_hosts = FLAGS.worker\_hosts.split(",")
代码语言:txt
AI代码解释
复制
cluster = tf.train.ClusterSpec({"ps":ps\_hosts, "worker":worker\_hosts})    # 构建集群名单
代码语言:txt
AI代码解释
复制
server = tf.train.Server(cluster, job\_name=FLAGS.job\_name, task\_index=FLAGS.task\_id)    # 声明服务
代码语言:txt
AI代码解释
复制
n\_workers = len(worker\_hosts)    # worker机的数量

代码语言:txt
AI代码解释
复制
# ==========  STEP3: ps机内容  ========== #
代码语言:txt
AI代码解释
复制
# 分工,对于ps机器不需要执行训练过程,只需要管理变量。server.join()会一直停在这条语句上。
代码语言:txt
AI代码解释
复制
if FLAGS.job\_name == "ps":
代码语言:txt
AI代码解释
复制
    with tf.device("/cpu:0"):
代码语言:txt
AI代码解释
复制
        server.join()

代码语言:txt
AI代码解释
复制
# ==========  STEP4: worker机内容  ========== #
代码语言:txt
AI代码解释
复制
# 下面定义worker机需要进行的操作
代码语言:txt
AI代码解释
复制
is\_chief = (FLAGS.task\_id == 0)    # 选取task\_id=0的worker机作为chief

代码语言:txt
AI代码解释
复制
# 通过replica\_device\_setter函数来指定每一个运算的设备。
代码语言:txt
AI代码解释
复制
# replica\_device\_setter会自动将所有参数分配到参数服务器上,将计算分配到当前的worker机上。
代码语言:txt
AI代码解释
复制
device\_setter = tf.train.replica\_device\_setter(
代码语言:txt
AI代码解释
复制
    worker\_device="/job:worker/task:%d" % FLAGS.task\_id,
代码语言:txt
AI代码解释
复制
    cluster=cluster)

代码语言:txt
AI代码解释
复制
# 这一台worker机器需要做的计算内容
代码语言:txt
AI代码解释
复制
with tf.device(device\_setter):
代码语言:txt
AI代码解释
复制
    # 输入数据
代码语言:txt
AI代码解释
复制
    x = tf.placeholder(name="x-input",shape=[None,28\*28],dtype=tf.float32)    # 输入样本像素为28\*28
代码语言:txt
AI代码解释
复制
    # x\_shape = x.get\_shape().as\_list()
代码语言:txt
AI代码解释
复制
    # length = x\_shape[1]
代码语言:txt
AI代码解释
复制
    # x\_reshaped = tf.reshape(x, [-1,length])
代码语言:txt
AI代码解释
复制
    y\_ = tf.placeholder(name="y-input", shape=[None,10],dtype=tf.float32)      # MNIST是十分类
代码语言:txt
AI代码解释
复制
    # 第一层(隐藏层)
代码语言:txt
AI代码解释
复制
    with tf.variable\_scope("layer1"):
代码语言:txt
AI代码解释
复制
        weight1 = tf.get\_variable(name="weight1", shape=[28\*28, 10], initializer=tf.glorot\_normal\_initializer())
代码语言:txt
AI代码解释
复制
        biases1 = tf.get\_variable(name="biases1", shape=[10], initializer=tf.glorot\_normal\_initializer())
代码语言:txt
AI代码解释
复制
        layer1 = tf.nn.relu(tf.matmul(x, weight1) + biases1, name="layer1")
代码语言:txt
AI代码解释
复制
    # 第二层(输出层)
代码语言:txt
AI代码解释
复制
    with tf.variable\_scope("layer2"):
代码语言:txt
AI代码解释
复制
        weight2 = tf.get\_variable(name="weight2", shape=[10, 10], initializer=tf.glorot\_normal\_initializer())
代码语言:txt
AI代码解释
复制
        biases2 = tf.get\_variable(name="biases2", shape=[10], initializer=tf.glorot\_normal\_initializer())
代码语言:txt
AI代码解释
复制
        y = tf.add(tf.matmul(layer1, weight2), biases2, name="y")
代码语言:txt
AI代码解释
复制
    pred = tf.argmax(y, axis=1, name="pred")
代码语言:txt
AI代码解释
复制
    global\_step = tf.contrib.framework.get\_or\_create\_global\_step()    # 必须手动声明global\_step否则会报错
代码语言:txt
AI代码解释
复制
    # 损失和优化
代码语言:txt
AI代码解释
复制
    cross\_entropy = tf.nn.sparse\_softmax\_cross\_entropy\_with\_logits(logits=y, labels=tf.argmax(y\_, axis=1))
代码语言:txt
AI代码解释
复制
    loss = tf.reduce\_mean(cross\_entropy)
代码语言:txt
AI代码解释
复制
    with tf.name\_scope('train'):
代码语言:txt
AI代码解释
复制
        optimizer = tf.train.GradientDescentOptimizer(0.01)
代码语言:txt
AI代码解释
复制
    with tf.name\_scope('gradient'):
代码语言:txt
AI代码解释
复制
        gradient\_all = optimizer.compute\_gradients(loss,weight2)
代码语言:txt
AI代码解释
复制
        gradients\_node=tf.gradients(loss,weight2)
代码语言:txt
AI代码解释
复制
        grads\_holder = [(tf.placeholder(tf.float32,shape=g.get\_shape()), v) 
代码语言:txt
AI代码解释
复制
                        for (g, v) in gradient\_all]
代码语言:txt
AI代码解释
复制
    # \*\*通过tf.train.SyncReplicasOptimizer函数实现函数同步更新\*\*
代码语言:txt
AI代码解释
复制
    opt = tf.train.SyncReplicasOptimizer(
代码语言:txt
AI代码解释
复制
        tf.train.GradientDescentOptimizer(0.01),
代码语言:txt
AI代码解释
复制
        replicas\_to\_aggregate=n\_workers,
代码语言:txt
AI代码解释
复制
        total\_num\_replicas=n\_workers
代码语言:txt
AI代码解释
复制
    )
代码语言:txt
AI代码解释
复制
    sync\_replicas\_hook = opt.make\_session\_run\_hook(is\_chief)
代码语言:txt
AI代码解释
复制
    train\_op = opt.apply\_gradients(grads\_holder, global\_step=global\_step)
代码语言:txt
AI代码解释
复制
    if is\_chief:
代码语言:txt
AI代码解释
复制
        train\_op = tf.no\_op()
代码语言:txt
AI代码解释
复制
    hooks = [sync\_replicas\_hook, tf.train.StopAtStepHook(last\_step=10000)]    # 把同步更新的hook加进来
代码语言:txt
AI代码解释
复制
    config = tf.ConfigProto(
代码语言:txt
AI代码解释
复制
        allow\_soft\_placement=True,    # 设置成True,那么当运行设备不满足要求时,会自动分配GPU或者CPU。
代码语言:txt
AI代码解释
复制
        log\_device\_placement=False,   # 设置为True时,会打印出TensorFlow使用了哪种操作
代码语言:txt
AI代码解释
复制
    )

代码语言:txt
AI代码解释
复制
    # ==========  STEP5: 打开会话  ========== #
代码语言:txt
AI代码解释
复制
    # 对于分布式训练,打开会话时不采用tf.Session(),而采用tf.train.MonitoredTrainingSession()
代码语言:txt
AI代码解释
复制
    # 详情参考:https://www.cnblogs.com/estragon/p/10034511.html
代码语言:txt
AI代码解释
复制
    with tf.train.MonitoredTrainingSession(
代码语言:txt
AI代码解释
复制
            master=server.target,
代码语言:txt
AI代码解释
复制
            is\_chief=is\_chief,
代码语言:txt
AI代码解释
复制
            # checkpoint\_dir=MODEL\_DIR,
代码语言:txt
AI代码解释
复制
            hooks=hooks,
代码语言:txt
AI代码解释
复制
            # save\_checkpoint\_secs=10,
代码语言:txt
AI代码解释
复制
            config=config) as sess:
代码语言:txt
AI代码解释
复制
        print("session started!")
代码语言:txt
AI代码解释
复制
        start\_time = time.time()
代码语言:txt
AI代码解释
复制
        step = 0
代码语言:txt
AI代码解释
复制
        while not sess.should\_stop(): 
代码语言:txt
AI代码解释
复制
            # for step in range(THREAD\_STEPS):   
代码语言:txt
AI代码解释
复制
            xs,ys= mnist.train.next\_batch(BATCH\_SIZE)    # batch\_size=32
代码语言:txt
AI代码解释
复制
            #求每个梯度
代码语言:txt
AI代码解释
复制
            grads = sess.run(gradients\_node, feed\_dict={x:xs, y\_: ys})
代码语言:txt
AI代码解释
复制
            grads=np.array(grads)
代码语言:txt
AI代码解释
复制
            grads=grads.reshape((10,10))
代码语言:txt
AI代码解释
复制
            print(grads)
代码语言:txt
AI代码解释
复制
            grad\_abs=np.abs(grads)
代码语言:txt
AI代码解释
复制
            variance = np.var(grad\_abs, axis=1)
代码语言:txt
AI代码解释
复制
            print('variance:',variance)
代码语言:txt
AI代码解释
复制
            #取方差最大的几组值
代码语言:txt
AI代码解释
复制
            topk\_var=tf.constant(variance)
代码语言:txt
AI代码解释
复制
            k=1
代码语言:txt
AI代码解释
复制
            output1 = tf.nn.top\_k(topk\_var, k)
代码语言:txt
AI代码解释
复制
            with tf.Session() as sess1:
代码语言:txt
AI代码解释
复制
                print(sess1.run(output1))
代码语言:txt
AI代码解释
复制
                a=output1.indices[-1]
代码语言:txt
AI代码解释
复制
                # print(sess1.run(a))  #a是所在TOPK个方差最大的索引值
代码语言:txt
AI代码解释
复制
                x=a.eval()
代码语言:txt
AI代码解释
复制
                a=int(x)
代码语言:txt
AI代码解释
复制
            g\_a=grads[a,:]
代码语言:txt
AI代码解释
复制
                # print('g\_a=',g\_a) 
代码语言:txt
AI代码解释
复制
                # print('\n')
代码语言:txt
AI代码解释
复制
            #取方差最大的一组值中的前几个大的梯度值,设置梯度阈值
代码语言:txt
AI代码解释
复制
            g\_a\_abs=np.abs(g\_a)        
代码语言:txt
AI代码解释
复制
            k=3
代码语言:txt
AI代码解释
复制
            output2 = tf.nn.top\_k(g\_a\_abs, k)
代码语言:txt
AI代码解释
复制
            with tf.Session() as sess2:
代码语言:txt
AI代码解释
复制
                print(sess2.run(output2))
代码语言:txt
AI代码解释
复制
                b=output2.indices[-1]
代码语言:txt
AI代码解释
复制
                # print(sess2.run(b))  #a是所在TOPK个方差最大的索引值
代码语言:txt
AI代码解释
复制
                x=b.eval()
代码语言:txt
AI代码解释
复制
                b=int(x) 
代码语言:txt
AI代码解释
复制
                threshold=g\_a\_abs[b]
代码语言:txt
AI代码解释
复制
            grad\_end=np.where(grad\_abs<threshold,0,grads)
代码语言:txt
AI代码解释
复制
            grad\_end=[grad\_end]
代码语言:txt
AI代码解释
复制
            grad\_var={}
代码语言:txt
AI代码解释
复制
            for i in range(len(grads\_holder)):
代码语言:txt
AI代码解释
复制
                    k = grads\_holder[i][0]
代码语言:txt
AI代码解释
复制
                    if k is not None:
代码语言:txt
AI代码解释
复制
                          # grad\_var[k] =np.var([tf.reshape(g, [-1]) for g in grads])
代码语言:txt
AI代码解释
复制
                          grad\_var[k] =[g[i][0] for g in grad\_end]

代码语言:txt
AI代码解释
复制
            \_, loss\_value, global\_step\_value = sess.run([train\_op, loss, global\_step], feed\_dict=grad\_var)
代码语言:txt
AI代码解释
复制
            if step > 0 and step % 100 == 0:
代码语言:txt
AI代码解释
复制
                duration = time.time() - start\_time
代码语言:txt
AI代码解释
复制
                sec\_per\_batch = duration / global\_step\_value
代码语言:txt
AI代码解释
复制
                print("After %d training steps(%d global steps), loss on training batch is %g (%.3f sec/batch)" % (step, global\_step\_value, loss\_value, sec\_per\_batch))
代码语言:txt
AI代码解释
复制
                print('Training elapsed time:%f s' % duration)
代码语言:txt
AI代码解释
复制
            step += 1

if __name__ == "__main__":

代码语言:txt
AI代码解释
复制
tf.app.run()

这是我的报错信息:

回答

成为首答用户。去 写回答
相关文章
JDBC完成修改
使用流程不变: 导入jar包 加载驱动 创建连接对象 创建sql命名对象 创建sql命令 执行sql命令 关闭资源
葆宁
2019/04/19
4300
TensorFlow已死,TensorFlow万岁!
如果你是一名人工智能爱好者,却没有关注到一条重大新闻,就好比你在一场罕见的地震中打了个盹。等你醒来,会发现一切都将改变!
abs_zero
2020/11/11
5490
TensorFlow已死,TensorFlow万岁!
tensorflow: 畅玩tensorboard图表(SCALARS)
这篇博客建立在你已经会使用tensorboard的基础上。如果你还不会记录数据并使用tensorboard,请移步我之前的另一篇博客:tensorflow: tensorboard 探究
JNingWei
2018/09/27
9510
tensorflow: 畅玩tensorboard图表(SCALARS)
C# 如何避免异常”集合已修改;可能无法执行枚举操作。“
private static List<string> lstShare = new List<string>();
全栈程序员站长
2022/07/05
7460
修改excel图表中的“系列一”
方法与步骤 设置好图表之后,右键点击图表→“选择数据(源)”,在系列一、系列二的地方点击并编辑: 弹出的窗口中,在系列名称处输入名称即可添加或修改:
演化计算与人工智能
2020/08/14
1.6K0
修改excel图表中的“系列一”
如何用Tensorflow完成手写数字识别?
深度学习最经典的任务问题就是分类。通过分类,我们可以将照片中的数字,人脸,动植物等等分到它属于的那一类当中,完成识别。接下来,我就带着大家一起完成一个简单的程序,来实现经典问题手写数字识别。
HuangWeiAI
2019/08/01
7270
如何用Tensorflow完成手写数字识别?
OSError: [WinError 1455] 页面文件太小,无法完成操作
解决方式目前查到三种: 1、重启pycharm 长时间运行pycharm可能会在后台占用大量内存。重启清除内存也许会解决问题。 但对我来说并无作用。
zstar
2022/06/14
4.8K0
Knock Knock!你的模型已训练完成……
项目地址:https://github.com/huggingface/knockknock
机器之心
2019/04/29
7100
[1058]centos修改/etc/fstab后无法启动
今天做实验,增加了一个磁盘sdb1,而且也增加了自动挂载的功能/etc/fstab里增加了记录。
周小董
2021/10/28
1.5K0
解决无法修改Hosts文件
作者:matrix 被围观: 1,514 次 发布时间:2013-05-07 分类:兼容并蓄 | 无评论 »
HHTjim 部落格
2022/09/26
3.7K0
解决无法修改Hosts文件
无法完成要求,暂存盘已满_无法使用因为暂存盘已满
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。
全栈程序员站长
2022/11/03
9480
git 撤销修改:未push 、已push
场景:不小心把一次错误的代码push到远程服务器上的分支上,需要立即删除/撤销这次代码提交。
不吃西红柿
2022/07/29
1.8K0
git 撤销修改:未push 、已push
Excel催化剂图表系列之一键完成IBCS国际商业标准图表
这两年自助式BI异常火爆,在整个BI的体系中,在PowerBI和Tableau一轮轮的可视化追逐中,看得眼花嘹亮,虽然有现成的各种图表可视化控件,但仍然深深地陷入图表选择综合征。
Excel催化剂
2021/08/19
6350
cp: 无法创建普通文件 : 文件已存在
看了下 Makefile,这句非常简单,就是 cp ./xxx ../xxx 而已,本身没什么问题。
zqb_all
2020/05/27
6.5K1
全面爆发,EasyShu网页图表精讲,已完成十五集,只剩下10个图表
全面爆发,离终点越来越近,来到每一个独立网页图表的精讲篇,每个图表一节视频,讲透每个细节,让EasyShu成为你的数据可视化终极武器。
Excel催化剂
2021/08/18
3170
matplot代码配置化,修改Excel就能调整图表!
这依然是我在准备可视化专栏的过程笔记,主题仍然是模仿各种非常规图表,大部分使用 matplotlib 包完成。
咋咋
2021/09/01
6710
matplot代码配置化,修改Excel就能调整图表!
hosts文件无法修改怎么办?
 hosts文件修改完不是直接保存而是弹出另存为窗口 解决: 1、右击hosts文件——属性——把“只读”前面勾去掉。 未经允许不得转载:肥猫博客 » hosts文件无法修改怎么办?
超级小可爱
2023/02/20
1.9K0
点击加载更多

相似问题

蓝鲸最后一步,host修改完成后无法访问?

1716

如何修改已关联的外部仓库?

1177

无法修改dns?

31.4K

无法完成配置?

1446

已备案已解析成功已绑定无法访问?

1470
相关问答用户
高级数据分析师擅长5个领域
擅长4个领域
萃橙科技 | 合伙人擅长4个领域
添加站长 进交流群

领取专属 10元无门槛券

AI混元助手 在线答疑

扫码加入开发者社群
关注 腾讯云开发者公众号

洞察 腾讯核心技术

剖析业界实践案例

扫码关注腾讯云开发者公众号
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档