BERT中文实战:文本相似度计算与文本分类
关注微信公众号 datanlp 然后回复 bert 即可获取下载链接。
下载预训练模型
谷歌提供了以下几个版本的BERT模型,每个模型的参数都做了简单的说明,中文的预训练模型在11月3日的时候提供了,这里我们只需要用到中文的版本
https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip
下载下来的文件包括以下内容
编写代码
模型准备好后就可以编写代码了,我们先把BERT的github代码clone下来,之后我们的代码编写会基于run_classifier.py文件,我们看下代码的结构
可以看到有好几个xxxProcessor
的类,这些类都有同一个父类DataProcessor
,其中DataProcessor
提供了4个抽象方法,如图
顾名思义,Processor就是用来获取对应的训练集、验证集、测试集的数据与label的数据,并把这些数据喂给BERT的,而我们要做的就是自定义新的Processor并重写这4个方法,也就是说我们只需要提供我们自己场景对应的数据。这里我自定义了一个名叫SimProcessor的类,我们简单看一下
读取的数据需要封装成一个InputExample的对象并添加到list中,注意这里有一个guid的参数,这个参数是必填的,是用来区分每一条数据的。是否进行训练集、验证集、测试集的计算,在执行代码时会有参数控制,我们下文会讲,所以这里的抽象方法也并不是需要全部都重写,但是为了体验一个完整的流程, 建议大家还是简单写一下。
get_labels方法返回的是一个数组,因为相似度问题可以理解为分类问题,所以返回的标签只有0和1,注意,这里我返回的是参数是字符串,所以在重写获取数据的方法时InputExample中的label也要传字符串的数据,可以看到上图中我对label做了一个str()的处理。
接下来还需要给Processor加一个名字,让我们的在运行时告诉代码我们要执行哪一个Processor,如图我自定义的叫做sim
ok,到这里我们已经把Processor
编写好了,接下来就是运行代码了,我们来看下run_classifier.py
的执行过程。
可以看到,在执行run_classifier.py
时需要先输入这5个必填参数,这里我们对参数做一个简单的说明
当然还有一些其他的参数,这里给出官方提供的运行参数
这里再补充下以下三个可选参数说明
执行以上的代码即可训练我们自己的模型了,如果需要使用模型来进行预测,可执行以下命令
当然,我们需要在data_dir下有测试数据,测试完成后会在output_dir路径下生成一个test_results.tsv文件,该文件包含了测试用例和相似度probabilities
总结
除了相似度计算,以上的代码完全能够用来做文本二分类,你也可以根据自己的需求来修改Processor,更多的细节大家可以参阅github源码。