前往小程序,Get更优阅读体验!
立即前往
首页
学习
活动
专区
工具
TVP
发布
社区首页 >专栏 >使用Tensor2Tensor和10行代码训练尖端语言翻译神经网络

使用Tensor2Tensor和10行代码训练尖端语言翻译神经网络

作者头像
AiTechYun
发布2018-12-11 15:34:54
2.7K0
发布2018-12-11 15:34:54
举报
文章被收录于专栏:ATYUN订阅号

编译:yxy

出品:ATYUN订阅号

有许多库可以帮助人们构建深度学习应用程序,但如果想使用最新架构的最先进模型和最少的代码,有这样一个API脱颖而出:Google的Tensor2Tensor。我通过这个库来使用高级的新神经网络架构(特别是Transformer)进行翻译,几乎不需要任何代码。

即使我不会说法语,我可以用T2T听懂法国团队和客户的话。

每周都有来自大学教授、谷歌和其他大型科技公司的研究人员、甚至是对深度学习有浓厚兴趣的开发人员发布新的神经网络架构和新的人工智能研究论文。

不幸的是,对于没有博士学位或者在反向传播,线性代数或计算数学方面了解不深的人,在没有高级API(如Keras)的情况下实现这些新的深度学习技术既困难又费时。

但好在,Google Brain团队认识到AI社区普遍存在的这些问题,随后创建了一个开源库来帮助解决它。

Tensor2Tensor,简称T2T,是一个深度学习模型和数据集库,旨在使深度学习更容易实现并加速ML研究。T2T由Google Brain团队和用户社区的研究人员和工程师积极使用和维护。

引自Github上的tensor2tensor介绍

深度学习和Tensor2Tensor

尽管深度学习并不总是人们在数据科学领域所期望的灵丹妙药,但它对于自然语言处理(NLP)任务来说非常有用。例如,使用词嵌入已经彻底改变了语言理解技术的有效性。

我想使用当前最先进的技术为我的团队和客户制作一个离线的法语到英语翻译器,也就是Transformer架构。T2T为快速、简单的训练和模型制作提供了一个框架,不需要从头开始编写和训练这个神经网络。

架构论文:https://arxiv.org/abs/1706.03762

Tensor2Tensor API概述

T2T库旨在与shell脚本一起使用,但你可以轻松地将其打包以供Python使用。API是多模块化的,这意味着任何内置模型都可以与各种类型的数据(文本,图像,音频等)一起使用。而API的作者为特定任务(如翻译,文本摘要,语音识别等)提供了推荐的数据集和模型。

GitHub:https://github.com/tensorflow/tensor2tensor#suggested-datasets-and-models

有时,你可能想要使用Tensor2Tensor的预编码模型之一,并将其应用于你自己的数据集和超参数组合。或者,你也可能想使用他们的简单框架来试验你自己的模型架构。通过定义一些新的子类可以很容易地做到这一点(我稍后会详细说明)。

T2T库有详细的说明文档,但为了深入了解,我们将逐步介绍其API的核心部分,并使用T2T开始你的第一个项目。

定义Tensor2Tensor问题

想要使用Tensor2Tensor(T2T),你要做的第一件事就是确定你要用它做什么,即问题是什么。这定义了你解决的任务,你使用的数据集,以及词汇表(如果可用)。这与模型架构和训练超参数无关。

你需要首先选择T2T中可找到的许多问题之一。你可以使用命令行查看API中已经内置的所有问题(使用命令t2t-datagen),也可以使用Python:

代码语言:javascript
复制
from tensor2tensorimport problems
代码语言:javascript
复制
代码语言:javascript
复制
# Print all T2T problems to console
代码语言:javascript
复制
problems.available()

T2T的命名方案遵循[task-family] _ [task] _ [specifics]的形式。因此,如果你想使用WMT翻译数据集(http://data.statmt.org/)制作一个英语到法语翻译器,词汇量为32k ,你可以选择问题名称:translate_enfr_wmt32k。对于某些预先存在的T2T模型,你可以通过在问题名末尾添加字符串_rev来反转输入和输出。为了训练法语到英语翻译的模型,问题名称将是translate_enfr_wmt32k_rev。Tensor2Tensor内置的其他问题包括:

  • summarize_cnn_dailymail32:使用具有32k词汇量的CNN Daily Mail数据集的文本摘要神经网络
  • img2img_celeba:超分辨率的图像到图像转换(8×8到32×32)
  • sentiment_imdb:使用IMBD数据集的情绪分析模型

生成训练数据

选择并命名要解决的问题后,你需要为其选择正确的数据。如果你使用预置的问题,Tensor2Tensor会自动下载和准备用于训练的数据。

你首先需要选择一个目录来存储T2T将为你下载的未处理数据。目录名为tmp_dir。很多相同的问题都下载相同的数据,因此可以在T2T中重复使用此目录来解决多个问题,尤其是如果这些问题位于同一个任务或问题系列中。

在生成最终训练数据之前,你还需要确定存储预处理数据的目录。Tensor2Tensor中名为data_dir。同样,你可以在适当时重用目录。

可以认为tmp_dir是internet上的zip文件存储的位置,而data_dir是在从tmp_dir中读取数据之后,针对特定的T2T问题进行适当的预处理的位置。例如,如果进行NLP,在预处理期间,T2T将使用数字对每个单词进行编码,分割训练和测试集,创建词汇表等。

如果你想使用自己的数据集并使用T2T的预编码神经网络训练模型,则需要创建一个新的问题子类。

初始化这些目录后,你可以使用命令行生成数据,如下:

代码语言:javascript
复制
t2t-datagen \
代码语言:javascript
复制
  --data_dir=$DATA_DIR \
代码语言:javascript
复制
  --tmp_dir=$TMP_DIR \
代码语言:javascript
复制
  --problem=$PROBLEM

也可以用Python:

代码语言:javascript
复制
from tensor2tensorimport problems
代码语言:javascript
复制
代码语言:javascript
复制
PROBLEM= '{Tensor2Tensor_problem_name}
代码语言:javascript
复制
TMP_DIR= '{/Tmp_Dir_Path}' # Where data files from internet stored
代码语言:javascript
复制
DATA_DIR= '{/Data_Dir_Path}' # Where pre-prcessed data is stored
代码语言:javascript
复制
代码语言:javascript
复制
# Init problem T2T object the generated training data
代码语言:javascript
复制
t2t_problem= problems.problem(PROBLEM)
代码语言:javascript
复制
t2t_problem.generate_data(DATA_DIR, TMP_DIR)

模型选择和超参数

你可以通过t2t-trainer在命令行中调用,也可以使用Python调用来查看所有可用的模型

代码语言:javascript
复制
from tensor2tensor.utilsimport registry
代码语言:javascript
复制
from tensor2tensorimport models
代码语言:javascript
复制
代码语言:javascript
复制
# Print all models in T2T to console
代码语言:javascript
复制
registry.list_models()

例如,Transformer 模型最适合翻译。

当然,你还可以在模型中自定义多个超参数集。例如,在Transformer python文件的底部,你可以看到所有可以进行训练的超参数(见下图)。但通常最好先从基础参数集开始,然后根据需要进行调整。

值得注意的是,用于Tensor2Tensor的hparams和模型参数一起定义了训练参数。这意味着在测试新模型时,你可以非常轻松地调整网络的大小、批尺寸,学习率,优化器类型等。

训练你最先进的神经网络

现在,你已准备好用几行代码训练你的神经网络。

使用命令行,你需要做的就是通过设置相应的变量来执行以下脚本:

代码语言:javascript
复制
t2t-trainer \
代码语言:javascript
复制
  --data_dir=$DATA_DIR \
代码语言:javascript
复制
  --problem=$PROBLEM \
代码语言:javascript
复制
  --model=$MODEL \
代码语言:javascript
复制
  --hparams_set=$HPARAMS \
代码语言:javascript
复制
  --output_dir=$TRAIN_DIR

output_dir参数是为此模型运行存储模型文件检查点的位置,这样你可以通过预加载该目录中的模型文件来获取之前的训练。

你可以通过在上面的shell脚本中添加额外的标志来更改任何超参数。

要在Python中设置训练,需要花费更多精力,但同样可行。

使用逆向工程Notebook构建翻译器

首先,你必须设置所需的T2T变量,目录,预处理数据的位置以及模型文件存储位置。

代码语言:javascript
复制
PROBLEM =  ' translate_enfr_wmt32k_rev '
代码语言:javascript
复制
MODEL =  ' TRANSFORMER '
代码语言:javascript
复制
HPARAMS =  ' transformer_base '
代码语言:javascript
复制
代码语言:javascript
复制
TRAIN_DIR =  '〜/ translator / model_files '
代码语言:javascript
复制
DATA_DIR =  '〜/ translator / fr_en_data '

接下来,你需要初始化hparam对象并重置一些变量。如果你的VRAM(显存)有限,你需要减少批尺寸(例如,从4096减小到1024),以便在训练时可以适应内存。随后,你将需要调整学习率和学习率准备步骤,以针对修改的批尺寸优化模型的收敛。接下来,你可以使用隐藏层来确定这是否有助于提高特定情况下的模型性能。

代码语言:javascript
复制
from tensor2tensor.utils.trainer_libimport create_hparams
代码语言:javascript
复制
代码语言:javascript
复制
# Init Hparams object from T2T Problem
代码语言:javascript
复制
hparams= create_hparams(HPARAMS)
代码语言:javascript
复制
代码语言:javascript
复制
# Make Chngaes to Hparams
代码语言:javascript
复制
hparams.batch_size= 1024
代码语言:javascript
复制
hparams.learning_rate_warmup_steps= 45000
代码语言:javascript
复制
hparams.learning_rate= .4
代码语言:javascript
复制
代码语言:javascript
复制
# Can see all Hparams with code below
代码语言:javascript
复制
print(json.loads(hparams.to_json())

要开始训练模型,你需要初始化Tensorflow的run_config和实验对象。最后,打电话tensorflow_exp_fn.train_and_evaluate()实施训练。

代码语言:javascript
复制
from tensor2tensor.utils.trainer_libimport create_run_config, create_experiment
代码语言:javascript
复制
代码语言:javascript
复制
# Initi Run COnfig for Model Training
代码语言:javascript
复制
RUN_CONFIG= create_run_config(
代码语言:javascript
复制
      model_dir=TRAIN_DIR# Location of where model file is store
代码语言:javascript
复制
      # More Params here in this fucntion for controling how noften to tave checkpoints and more.
代码语言:javascript
复制
)
代码语言:javascript
复制
代码语言:javascript
复制
# Create Tensorflow Experiment Object
代码语言:javascript
复制
tensorflow_exp_fn= create_experiment(
代码语言:javascript
复制
        run_config=RUN_CONFIG,
代码语言:javascript
复制
        hparams=hparams,
代码语言:javascript
复制
        model_name=MODEL,
代码语言:javascript
复制
        problem_name=PROBLEM,
代码语言:javascript
复制
        data_dir=DATA_DIR,
代码语言:javascript
复制
        train_steps=400000,# Total number of train steps for all Epochs
代码语言:javascript
复制
        eval_steps=100 # Number of steps to perform for each evaluation
代码语言:javascript
复制
    )
代码语言:javascript
复制
代码语言:javascript
复制
# Kick off Training
代码语言:javascript
复制
tensorflow_exp_fn.train_and_evaluate()

跟踪模型训练和表现

现在你的模型训练已经开始,你可以看到损失和准确性指标的变化:

初始化Tensorflow实验对象时设置了train_steps参数。这是训练停止前的训练次数。你可以使用save_checkpoints_steps(默认为1000)控制执行评估的频率。初始化run_config对象时,将其设置为可选的hypermeter。

虽然Tensor2Tensor与CPU完全兼容,但GPU和分布式训练也有很多选择。比如使用哪一个,在每个GPU中限制多少内存,等等。如果你有兴趣学习如何集成GPU以训练T2T模型,访问下方链接。

链接:https://github.com/tensorflow/tensor2tensor/blob/master/docs/distributed_training.md

Tensorboard

要激活Tensorboard,首先需要转到命令行并输入tensorboard — logdir /{path_to_train_dir}。要直接访问Tensorboard二进制文件,你可能首先要在bash shell中激活包含Tensorflow的python环境。

详细操作:https://stackoverflow.com/questions/14604699/how-to-activate-virtualenv

激活后,你将能够在http://{host_ip}:6006 /实时跟踪模型性能。Tensorboard可用于比较周期的训练和评估指标,请参阅Tensorflow模型图。

最常用于此任务准确性的指标是BLEU分数。对于法语到英语的翻译,这个模型的BLEU得分大约为28,这是最先进水平。

使用Tensor2Tensor模型进行评分

要使用新训练的模型进行评分,你可以使用t2t-decoder二进制文件:

代码语言:javascript
复制
t2t-decoder \
代码语言:javascript
复制
  --data_dir=$DATA_DIR \
代码语言:javascript
复制
  --problem=$PROBLEM \
代码语言:javascript
复制
  --model=$MODEL \
代码语言:javascript
复制
  --hparams_set=$HPARAMS \
代码语言:javascript
复制
  --output_dir=$TRAIN_DIR \
代码语言:javascript
复制
  --decode_hparams="beam_size=$BEAM_SIZE,alpha=$ALPHA" \
代码语言:javascript
复制
  --decode_from_file=$DECODE_FILE \
代码语言:javascript
复制
  --decode_to_file=translation.en

但是,这只能读取文本文件,并将结果输出到文本文件,这并不总是你需要的。所以我已经移植了它,以便在Python中更轻松地访问模型。你可以查看下面的代码,看看它是如何实现的。

代码语言:javascript
复制
from tensor2tensor.utils.trainer_libimport create_hparams, registry
代码语言:javascript
复制
from tensor2tensorimport problems
代码语言:javascript
复制
代码语言:javascript
复制
INPUT_TEXT_TO_TRANSLATE= 'Translate this sentence into French'
代码语言:javascript
复制
代码语言:javascript
复制
# Set Tensor2Tensor Arguments
代码语言:javascript
复制
MODEL_DIR_PATH= ~/En_to_Fr_translator'
代码语言:javascript
复制
MODEL= 'transformer'
代码语言:javascript
复制
HPARAMS= 'transformer_big_single_gpu'
代码语言:javascript
复制
T2T_PROBLEM= 'translate_enfr_wmt32k'
代码语言:javascript
复制
代码语言:javascript
复制
hparams= create_hparams(HPARAMS, data_dir=model_dir, problem_name=T2T_PROBLEM)
代码语言:javascript
复制
代码语言:javascript
复制
# Make any changes to default Hparams for model architechture used during training
代码语言:javascript
复制
hparams.batch_size= 1024
代码语言:javascript
复制
hparams.hidden_size= 7*80
代码语言:javascript
复制
hparams.filter_size= 7*80*4
代码语言:javascript
复制
hparams.num_heads= 8
代码语言:javascript
复制
代码语言:javascript
复制
# Load model into Memory
代码语言:javascript
复制
T2T_MODEL= registry.model(MODEL)(hparams, tf.estimator.ModeKeys.PREDICT)
代码语言:javascript
复制
代码语言:javascript
复制
# Init T2T Token Encoder/ Decoders
代码语言:javascript
复制
DATA_ENCODERS= problems.problem(T2T_PROBLEM).feature_encoders(model_dir)
代码语言:javascript
复制
代码语言:javascript
复制
### START USING MODELS
代码语言:javascript
复制
encoded_inputs= encode(INPUT_TEXT_TO_TRANSLATE, DATA_ENCODERS)
代码语言:javascript
复制
model_output= T2T_MODEL.infer(encoded_inputs, beam_size=2)["outputs"]
代码语言:javascript
复制
translated_text_in_french=  decode(model_output, DATA_ENCODERS)
代码语言:javascript
复制
代码语言:javascript
复制
print(translated_text_in_french)

你可能已经看到上面的代码中有两个函数,名为encode()和decode()。它们用于获取常规文本数据并将其编码为适合模型的格式。类似地,在相应的输出格式中对模型输出进行解码。

这意味着他们可以在批尺寸(1024)上输入许多输入序列,并且可以更快地翻译长段落,而不必对模型进行1024次翻译调用以翻译1024个句子。

代码语言:javascript
复制
def encode(input_txt, encoders):
代码语言:javascript
复制
    """List of Strings to features dict, ready for inference"""
代码语言:javascript
复制
    encoded_inputs= [encoders["inputs"].encode(x)+ [1]for xin input_txt]
代码语言:javascript
复制
代码语言:javascript
复制
    # pad each input so is they are the same length
代码语言:javascript
复制
    biggest_seq= len(max(encoded_inputs, key=len))
代码语言:javascript
复制
    for i, text_inputin enumerate(encoded_inputs):
代码语言:javascript
复制
        encoded_inputs[i]= text_input+ [0 for xin range(biggest_seq- len(text_input))]
代码语言:javascript
复制
代码语言:javascript
复制
    # Format Input Data For Model
代码语言:javascript
复制
    batched_inputs= tf.reshape(encoded_inputs, [len(encoded_inputs),-1,1])
代码语言:javascript
复制
    return {"inputs": batched_inputs}
代码语言:javascript
复制
代码语言:javascript
复制
代码语言:javascript
复制
def decode(integers, encoders):
代码语言:javascript
复制
    """Decode list of ints to list of strings"""
代码语言:javascript
复制
代码语言:javascript
复制
    # Turn to list to remove EOF mark
代码语言:javascript
复制
    to_decode= list(np.squeeze(integers))
代码语言:javascript
复制
    if isinstance(to_decode[0], np.ndarray):
代码语言:javascript
复制
        to_decode= map(lambda x:list(np.squeeze(x)), to_decode)
代码语言:javascript
复制
    else:
代码语言:javascript
复制
        to_decode= [to_decode]
代码语言:javascript
复制
代码语言:javascript
复制
    # remove <EOF> Tag before decoding
代码语言:javascript
复制
    to_decode= map(lambda x: x[:x.index(1)],filter(lambda x:1 in x, to_decode))
代码语言:javascript
复制
代码语言:javascript
复制
    # Decode and return Translated text
代码语言:javascript
复制
    return [encoders["inputs"].decode(np.squeeze(x))for xin to_decode]

让我的Tensor2Tensor模型投入生产

我做法语翻译器的主要原因之一是因为我在一家法国公司工作。很多人在团队聊天中讲法语。不幸的是,我根本不知道他们在说什么。

我最终使用Dataiku创建REST API端点,以使用我制作的Tensorflow模型执行翻译。我使用名为Errbot的聊天机器人API将REST端点连接到公司的Hipchat上。

Dataiku:https://www.dataiku.com/learn/guide/tutorials/deploy-scoring.html

errbot:http://errbot.io/en/latest/

现在,无论同事说些什么,我都可以轻松看懂。

完整项目:https://github.com/alexwolf22/tensor2tensor_translator

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2018-11-03,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 ATYUN订阅号 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 深度学习和Tensor2Tensor
  • Tensor2Tensor API概述
  • 定义Tensor2Tensor问题
  • 生成训练数据
  • 模型选择和超参数
  • 训练你最先进的神经网络
  • 使用逆向工程Notebook构建翻译器
  • 跟踪模型训练和表现
  • Tensorboard
  • 使用Tensor2Tensor模型进行评分
  • 让我的Tensor2Tensor模型投入生产
相关产品与服务
文件存储
文件存储(Cloud File Storage,CFS)为您提供安全可靠、可扩展的共享文件存储服务。文件存储可与腾讯云服务器、容器服务、批量计算等服务搭配使用,为多个计算节点提供容量和性能可弹性扩展的高性能共享存储。腾讯云文件存储的管理界面简单、易使用,可实现对现有应用的无缝集成;按实际用量付费,为您节约成本,简化 IT 运维工作。
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档