前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >词向量可视化--[tensorflow , python]

词向量可视化--[tensorflow , python]

作者头像
Gxjun
发布2018-12-28 16:56:53
1.7K0
发布2018-12-28 16:56:53
举报
文章被收录于专栏:mlml
代码语言:javascript
复制
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
----------------------------------
Version    : ??
File Name :     visual_vec.py
Description :   
Author  :       xijun1
Email   :
Date    :       2018/12/25
-----------------------------------
Change Activiy  :   2018/12/25
-----------------------------------

"""
__author__ = 'xijun1'
from tqdm import tqdm
import numpy as np
import tensorflow as tf
from tensorflow.contrib.tensorboard.plugins import projector
import os
import codecs

words, embeddings = [], []
log_path = 'model'

with codecs.open('/Users/xxx/github/python_demo/vec.txt', 'r') as f:
    header = f.readline()
    vocab_size, vector_size = map(int, header.split())
    for line in tqdm(range(vocab_size)):
        word_list = f.readline().split(' ')
        word = word_list[0]
        vector = word_list[1:-1]
        if word == "":
            continue
        words.append(word)
        embeddings.append(np.array(vector))
assert len(words) == len(embeddings)
print(len(words))

with tf.Session() as sess:
    X = tf.Variable([0.0], name='embedding')
    place = tf.placeholder(tf.float32, shape=[len(words), vector_size])
    set_x = tf.assign(X, place, validate_shape=False)
    sess.run(tf.global_variables_initializer())
    sess.run(set_x, feed_dict={place: embeddings})
    with codecs.open(log_path + '/metadata.tsv', 'w') as f:
        for word in tqdm(words):
            f.write(word + '\n')

    # with summary
    summary_writer = tf.summary.FileWriter(log_path, sess.graph)
    config = projector.ProjectorConfig()
    embedding_conf = config.embeddings.add()
    embedding_conf.tensor_name = 'embedding:0'
    embedding_conf.metadata_path = os.path.join('metadata.tsv')
    projector.visualize_embeddings(summary_writer, config)

    # save
    saver = tf.train.Saver()
    saver.save(sess, os.path.join(log_path, "model.ckpt"))

结果:

本文参与 腾讯云自媒体同步曝光计划,分享自作者个人站点/博客。
原始发表:2018-12-25 ,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 作者个人站点/博客 前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 结果:
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档