BERT的使用可以分为两个步骤:pre-training和fine-tuning。pre-training的话可以很好地适用于自己特定的任务,但是训练成本很高(four days on 4 to 16 Cloud TPUs),对于大对数从业者而言不太好实现从零开始(from scratch)。不过Google已经发布了各种预训练好的模型可供选择,只需要进行对特定任务的Fine-tuning即可。
今天我们就继续按照原始论文的框架,来一起读读BERT预训练的源码。BERT预训练过程分为两个具体子任务:Masked LM 和 Next Sentence Prediction
√tokenization.py
√create_pretraining_data.py
xrun_pretraining
tokenization.py是对原始文本语料的处理,分为BasicTokenizer和WordpieceTokenizer两类。
根据空格,标点进行普通的分词,最后返回的是关于词的列表,对于中文而言是关于字的列表。
1class BasicTokenizer(object):
2 def __init__(self, do_lower_case=True):
3 self.do_lower_case = do_lower_case
4
5 def tokenize(self, text):
6 text = convert_to_unicode(text)
7 text = self._clean_text(text)
8 # 增加中文支持
9 text = self._tokenize_chinese_chars(text)
10
11 orig_tokens = whitespace_tokenize(text)
12 split_tokens = []
13 for token in orig_tokens:
14 if self.do_lower_case:
15 token = token.lower()
16 token = self._run_strip_accents(token)
17 split_tokens.extend(self._run_split_on_punc(token))
18
19 output_tokens = whitespace_tokenize(" ".join(split_tokens))
20 return output_tokens
21
22 def _run_strip_accents(self, text):
23 # 对text进行归一化
24 text = unicodedata.normalize("NFD", text)
25 output = []
26 for char in text:
27 cat = unicodedata.category(char)
28 # 把category为Mn的去掉
29 # refer: https://www.fileformat.info/info/unicode/category/Mn/list.htm
30 if cat == "Mn":
31 continue
32 output.append(char)
33 return "".join(output)
34
35 def _run_split_on_punc(self, text):
36 # 用标点切分,返回list
37 chars = list(text)
38 i = 0
39 start_new_word = True
40 output = []
41 while i < len(chars):
42 char = chars[i]
43 if _is_punctuation(char):
44 output.append([char])
45 start_new_word = True
46 else:
47 if start_new_word:
48 output.append([])
49 start_new_word = False
50 output[-1].append(char)
51 i += 1
52
53 return ["".join(x) for x in output]
54
55 def _tokenize_chinese_chars(self, text):
56 # 按字切分中文,实现就是在字两侧添加空格
57 output = []
58 for char in text:
59 cp = ord(char)
60 if self._is_chinese_char(cp):
61 output.append(" ")
62 output.append(char)
63 output.append(" ")
64 else:
65 output.append(char)
66 return "".join(output)
67
68 def _is_chinese_char(self, cp):
69 # 判断是否是汉字
70 # refer:https://www.cnblogs.com/straybirds/p/6392306.html
71 if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
72 (cp >= 0x3400 and cp <= 0x4DBF) or #
73 (cp >= 0x20000 and cp <= 0x2A6DF) or #
74 (cp >= 0x2A700 and cp <= 0x2B73F) or #
75 (cp >= 0x2B740 and cp <= 0x2B81F) or #
76 (cp >= 0x2B820 and cp <= 0x2CEAF) or
77 (cp >= 0xF900 and cp <= 0xFAFF) or #
78 (cp >= 0x2F800 and cp <= 0x2FA1F)): #
79 return True
80
81 return False
82
83 def _clean_text(self, text):
84 # 去除无意义字符以及空格
85 output = []
86 for char in text:
87 cp = ord(char)
88 if cp == 0 or cp == 0xfffd or _is_control(char):
89 continue
90 if _is_whitespace(char):
91 output.append(" ")
92 else:
93 output.append(char)
94 return "".join(output)
WordpieceTokenizer是将BasicTokenizer的结果进一步做更细粒度的切分。做这一步的目的主要是为了去除未登录词对模型效果的影响。这一过程对中文没有影响,因为在前面BasicTokenizer里面已经切分成以字为单位的了。
1class WordpieceTokenizer(object):
2 def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=200):
3 self.vocab = vocab
4 self.unk_token = unk_token
5 self.max_input_chars_per_word = max_input_chars_per_word
6
7 def tokenize(self, text):
8 """使用贪心的最大正向匹配算法
9 例如:
10 input = "unaffable"
11 output = ["un", "##aff", "##able"]
12 """
13 text = convert_to_unicode(text)
14
15 output_tokens = []
16 for token in whitespace_tokenize(text):
17 chars = list(token)
18 if len(chars) > self.max_input_chars_per_word:
19 output_tokens.append(self.unk_token)
20 continue
21
22 is_bad = False
23 start = 0
24 sub_tokens = []
25 while start < len(chars):
26 end = len(chars)
27 cur_substr = None
28 while start < end:
29 substr = "".join(chars[start:end])
30 if start > 0:
31 substr = "##" + substr
32 if substr in self.vocab:
33 cur_substr = substr
34 break
35 end -= 1
36 if cur_substr is None:
37 is_bad = True
38 break
39 sub_tokens.append(cur_substr)
40 start = end
41
42 if is_bad:
43 output_tokens.append(self.unk_token)
44 else:
45 output_tokens.extend(sub_tokens)
46 return output_tokens
我们用一个例子来看代码的执行过程。比如假设输入是”unaffable”。我们跳到while循环部分,这是start=0,end=len(chars)=9,也就是先看看unaffable在不在词典里,如果在,那么直接作为一个WordPiece,如果不再,那么end-=1,也就是看unaffabl在不在词典里,最终发现”un”在词典里,把un加到结果里。
接着start=2,看affable在不在,不在再看affabl,…,最后发现 ##aff 在词典里。注意:##表示这个词是接着前面的,这样使得WordPiece切分是可逆的——我们可以恢复出“真正”的词。
BERT分词的主要接口,包含了上述两种实现。
1class FullTokenizer(object):
2 def __init__(self, vocab_file, do_lower_case=True):
3 # 加载词表文件为字典形式
4 self.vocab = load_vocab(vocab_file)
5 self.inv_vocab = {v: k for k, v in self.vocab.items()}
6 self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
7 self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
8
9 def tokenize(self, text):
10 split_tokens = []
11 # 调用BasicTokenizer粗粒度分词
12 for token in self.basic_tokenizer.tokenize(text):
13 # 调用WordpieceTokenizer细粒度分词
14 for sub_token in self.wordpiece_tokenizer.tokenize(token):
15 split_tokens.append(sub_token)
16
17 return split_tokens
18
19 def convert_tokens_to_ids(self, tokens):
20 return convert_by_vocab(self.vocab, tokens)
21
22 def convert_ids_to_tokens(self, ids):
23 return convert_by_vocab(self.inv_vocab, ids)
这个文件的这作用就是将原始输入语料转换成模型预训练所需要的数据格式TFRecoed。
1flags.DEFINE_string("input_file", None,
2 "Input raw text file (or comma-separated list of files).")
3
4flags.DEFINE_string("output_file", None,
5 "Output TF example file (or comma-separated list of files).")
6
7flags.DEFINE_string("vocab_file", None,
8 "The vocabulary file that the BERT model was trained on.")
9
10flags.DEFINE_bool( "do_lower_case", True,
11 "Whether to lower case the input text. Should be True for uncased "
12 "models and False for cased models.")
13
14flags.DEFINE_integer("max_seq_length", 128, "Maximum sequence length.")
15
16flags.DEFINE_integer("max_predictions_per_seq", 20,
17 "Maximum number of masked LM predictions per sequence.")
18
19flags.DEFINE_integer("random_seed", 12345, "Random seed for data generation.")
20
21flags.DEFINE_integer( "dupe_factor", 10,
22 "Number of times to duplicate the input data (with different masks).")
23
24flags.DEFINE_float("masked_lm_prob", 0.15, "Masked LM probability.")
25
26flags.DEFINE_float("short_seq_prob", 0.1,
27 "Probability of creating sequences which are shorter than the maximum length.")
这里就说几个参数
Hello world, this is bert.
,为了充分利用数据,第一次可以mask成Hello [MASK], this is bert.
,第二次可以变成Hello world, this is [MASK[.
首先来看构造数据的整体流程,
1def main(_):
2 tf.logging.set_verbosity(tf.logging.INFO)
3
4 tokenizer = tokenization.FullTokenizer(
5 vocab_file=FLAGS.vocab_file, do_lower_case=FLAGS.do_lower_case)
6
7 input_files = []
8 for input_pattern in FLAGS.input_file.split(","):
9 input_files.extend(tf.gfile.Glob(input_pattern))
10
11 tf.logging.info("*** Reading from input files ***")
12 for input_file in input_files:
13 tf.logging.info(" %s", input_file)
14
15 rng = random.Random(FLAGS.random_seed)
16 instances = create_training_instances(
17 input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
18 FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
19 rng)
20
21 output_files = FLAGS.output_file.split(",")
22 tf.logging.info("*** Writing to output files ***")
23 for output_file in output_files:
24 tf.logging.info(" %s", output_file)
25
26 write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
27 FLAGS.max_predictions_per_seq, output_files)
create_training_instances
函数构造训练instancewrite_instance_to_example_files
函数以TFRecord格式保存数据
下面我们一一解析这些函数。首先定义了一个训练样本的类
1class TrainingInstance(object):
2
3 def __init__(self, tokens, segment_ids, masked_lm_positions, masked_lm_labels,
4 is_random_next):
5 self.tokens = tokens
6 self.segment_ids = segment_ids
7 self.is_random_next = is_random_next
8 self.masked_lm_positions = masked_lm_positions
9 self.masked_lm_labels = masked_lm_labels
10
11 def __str__(self):
12 s = ""
13 s += "tokens: %s\n" % (" ".join(
14 [tokenization.printable_text(x) for x in self.tokens]))
15 s += "segment_ids: %s\n" % (" ".join([str(x) for x in self.segment_ids]))
16 s += "is_random_next: %s\n" % self.is_random_next
17 s += "masked_lm_positions: %s\n" % (" ".join(
18 [str(x) for x in self.masked_lm_positions]))
19 s += "masked_lm_labels: %s\n" % (" ".join(
20 [tokenization.printable_text(x) for x in self.masked_lm_labels]))
21 s += "\n"
22 return s
23
24 def __repr__(self):
25 return self.__str__()
构造训练样本的代码如下。在源码包中Google提供了一个实例训练样本输入(sample_text.txt),输入文件格式为:
1def create_training_instances(input_files, tokenizer, max_seq_length,
2 dupe_factor, short_seq_prob, masked_lm_prob,
3 max_predictions_per_seq, rng):
4 all_documents = [[]]
5 # all_documents是list的list,第一层list表示document,
6 # 第二层list表示document里的多个句子。
7 for input_file in input_files:
8 with tf.gfile.GFile(input_file, "r") as reader:
9 while True:
10 line = tokenization.convert_to_unicode(reader.readline())
11 if not line:
12 break
13 line = line.strip()
14
15 # 空行表示文档分割
16 if not line:
17 all_documents.append([])
18 tokens = tokenizer.tokenize(line)
19 if tokens:
20 all_documents[-1].append(tokens)
21
22 # 删除空文档
23 all_documents = [x for x in all_documents if x]
24 rng.shuffle(all_documents)
25
26 vocab_words = list(tokenizer.vocab.keys())
27 instances = []
28 # 重复dupe_factor次
29 for _ in range(dupe_factor):
30 for document_index in range(len(all_documents)):
31 instances.extend(
32 create_instances_from_document(
33 all_documents, document_index, max_seq_length, short_seq_prob,
34 masked_lm_prob, max_predictions_per_seq, vocab_words, rng))
35
36 rng.shuffle(instances)
37 return instances
上面的函数会调用create_instances_from_document
来实现从一个文档中抽取多个训练样本。
1def create_instances_from_document(
2 all_documents, document_index, max_seq_length, short_seq_prob,
3 masked_lm_prob, max_predictions_per_seq, vocab_words, rng):
4
5 document = all_documents[document_index]
6
7 # 为[CLS], [SEP], [SEP]预留三个空位
8 max_num_tokens = max_seq_length - 3
9
10 target_seq_length = max_num_tokens
11 # 以short_seq_prob的概率随机生成(2~max_num_tokens)的长度
12 if rng.random() < short_seq_prob:
13 target_seq_length = rng.randint(2, max_num_tokens)
14
15 #
16 instances = []
17 current_chunk = []
18 current_length = 0
19 i = 0
20 while i < len(document):
21 segment = document[i]
22 current_chunk.append(segment)
23 current_length += len(segment)
24 # 将句子依次加入current_chunk中,直到加完或者达到限制的最大长度
25 if i == len(document) - 1 or current_length >= target_seq_length:
26 if current_chunk:
27 # `a_end`是第一个句子A结束的下标
28 a_end = 1
29 # 随机选取切分边界
30 if len(current_chunk) >= 2:
31 a_end = rng.randint(1, len(current_chunk) - 1)
32
33 tokens_a = []
34 for j in range(a_end):
35 tokens_a.extend(current_chunk[j])
36
37 tokens_b = []
38 # 是否随机next
39 is_random_next = False
40 # 构建随机的下一句
41 if len(current_chunk) == 1 or rng.random() < 0.5:
42 is_random_next = True
43 target_b_length = target_seq_length - len(tokens_a)
44
45 # 随机的挑选另外一篇文档的随机开始的句子
46 # 但是理论上有可能随机到的文档就是当前文档,因此需要一个while循环
47 # 这里只while循环10次,理论上还是有重复的可能性,但是我们忽略
48 for _ in range(10):
49 random_document_index = rng.randint(0, len(all_documents) - 1)
50 if random_document_index != document_index:
51 break
52
53 random_document = all_documents[random_document_index]
54 random_start = rng.randint(0, len(random_document) - 1)
55 for j in range(random_start, len(random_document)):
56 tokens_b.extend(random_document[j])
57 if len(tokens_b) >= target_b_length:
58 break
59 # 对于上述构建的随机下一句,我们并没有真正地使用它们
60 # 所以为了避免数据浪费,我们将其“放回”
61 num_unused_segments = len(current_chunk) - a_end
62 i -= num_unused_segments
63 # 构建真实的下一句
64 else:
65 is_random_next = False
66 for j in range(a_end, len(current_chunk)):
67 tokens_b.extend(current_chunk[j])
68 # 如果太多了,随机去掉一些
69 truncate_seq_pair(tokens_a, tokens_b, max_num_tokens, rng)
70
71 assert len(tokens_a) >= 1
72 assert len(tokens_b) >= 1
73
74 tokens = []
75 segment_ids = []
76 # 处理句子A
77 tokens.append("[CLS]")
78 segment_ids.append(0)
79 for token in tokens_a:
80 tokens.append(token)
81 segment_ids.append(0)
82 # 句子A结束,加上【SEP】
83 tokens.append("[SEP]")
84 segment_ids.append(0)
85 # 处理句子B
86 for token in tokens_b:
87 tokens.append(token)
88 segment_ids.append(1)
89 # 句子B结束,加上【SEP】
90 tokens.append("[SEP]")
91 segment_ids.append(1)
92
93 # 调用 create_masked_lm_predictions来随机对某些Token进行mask
94 (tokens, masked_lm_positions,
95 masked_lm_labels) = create_masked_lm_predictions(
96 tokens, masked_lm_prob, max_predictions_per_seq, vocab_words, rng)
97 instance = TrainingInstance(
98 tokens=tokens,
99 segment_ids=segment_ids,
100 is_random_next=is_random_next,
101 masked_lm_positions=masked_lm_positions,
102 masked_lm_labels=masked_lm_labels)
103 instances.append(instance)
104 current_chunk = []
105 current_length = 0
106 i += 1
107
108 return instances
上面代码有点长,在关键的地方我都注释上了。下面我们结合一个具体的例子来看代码的实现过程。以提供的sample_text.txt中语料为例,只截取了一部分,下图包含了两个文档,第一个文档中有6个句子,第二个有4个句子:
create_instances_from_document
分析的是一个文档,我们就以上述第一个为例。
predict next
判断:
(1) 如果是正样本,前两个句子当成是句子A,后一个句子当成是句子B;
(2) 如果是负样本,前两个句子当成是句子A,无关的句子从其他文档中随机抽取对Tokens进行随机mask是BERT的一大创新点。使用mask的原因是为了防止模型在双向循环训练的过程中“预见自身”。于是,文章中选取的策略是对输入序列中15%的词使用[MASK]标记掩盖掉,然后通过上下文去预测这些被mask的token。但是为了防止模型过拟合地学习到【MASK】这个标记,对15%mask掉的词进一步优化:
1def create_masked_lm_predictions(tokens, masked_lm_prob,
2 max_predictions_per_seq, vocab_words, rng):
3
4 cand_indexes = []
5 # [CLS]和[SEP]不能用于MASK
6 for (i, token) in enumerate(tokens):
7 if token == "[CLS]" or token == "[SEP]":
8 continue
9 cand_indexes.append(i)
10
11 rng.shuffle(cand_indexes)
12
13 output_tokens = list(tokens)
14
15 num_to_predict = min(max_predictions_per_seq,
16 max(1, int(round(len(tokens) * masked_lm_prob))))
17
18 masked_lms = []
19 covered_indexes = set()
20 for index in cand_indexes:
21 if len(masked_lms) >= num_to_predict:
22 break
23 if index in covered_indexes:
24 continue
25 covered_indexes.add(index)
26
27 masked_token = None
28 # 80% of the time, replace with [MASK]
29 if rng.random() < 0.8:
30 masked_token = "[MASK]"
31 else:
32 # 10% of the time, keep original
33 if rng.random() < 0.5:
34 masked_token = tokens[index]
35 # 10% of the time, replace with random word
36 else:
37 masked_token = vocab_words[rng.randint(0, len(vocab_words) - 1)]
38
39 output_tokens[index] = masked_token
40
41 masked_lms.append(MaskedLmInstance(index=index, label=tokens[index]))
42
43 # 按照下标重排,保证是原来句子中出现的顺序
44 masked_lms = sorted(masked_lms, key=lambda x: x.index)
45
46 masked_lm_positions = []
47 masked_lm_labels = []
48 for p in masked_lms:
49 masked_lm_positions.append(p.index)
50 masked_lm_labels.append(p.label)
51
52 return (output_tokens, masked_lm_positions, masked_lm_labels)
最后是将上述步骤处理好的数据保存为tfrecord文件。整体逻辑比较简单,代码如下
1def write_instance_to_example_files(instances, tokenizer, max_seq_length,
2 max_predictions_per_seq, output_files):
3
4 writers = []
5 for output_file in output_files:
6 writers.append(tf.python_io.TFRecordWriter(output_file))
7
8 writer_index = 0
9
10 total_written = 0
11 for (inst_index, instance) in enumerate(instances):
12 # 将输入转成word-ids
13 input_ids = tokenizer.convert_tokens_to_ids(instance.tokens)
14 # 记录实际句子长度
15 input_mask = [1] * len(input_ids)
16 segment_ids = list(instance.segment_ids)
17 assert len(input_ids) <= max_seq_length
18
19 # padding
20 while len(input_ids) < max_seq_length:
21 input_ids.append(0)
22 input_mask.append(0)
23 segment_ids.append(0)
24
25 assert len(input_ids) == max_seq_length
26 assert len(input_mask) == max_seq_length
27 assert len(segment_ids) == max_seq_length
28
29 masked_lm_positions = list(instance.masked_lm_positions)
30 masked_lm_ids = tokenizer.convert_tokens_to_ids(instance.masked_lm_labels)
31 masked_lm_weights = [1.0] * len(masked_lm_ids)
32
33 while len(masked_lm_positions) < max_predictions_per_seq:
34 masked_lm_positions.append(0)
35 masked_lm_ids.append(0)
36 masked_lm_weights.append(0.0)
37
38 next_sentence_label = 1 if instance.is_random_next else 0
39
40 features = collections.OrderedDict()
41 features["input_ids"] = create_int_feature(input_ids)
42 features["input_mask"] = create_int_feature(input_mask)
43 features["segment_ids"] = create_int_feature(segment_ids)
44 features["masked_lm_positions"] = create_int_feature(masked_lm_positions)
45 features["masked_lm_ids"] = create_int_feature(masked_lm_ids)
46 features["masked_lm_weights"] = create_float_feature(masked_lm_weights)
47 features["next_sentence_labels"] = create_int_feature([next_sentence_label])
48
49 # 生成训练样本
50 tf_example = tf.train.Example(features=tf.train.Features(feature=features))
51
52 # 输出到文件
53 writers[writer_index].write(tf_example.SerializeToString())
54 writer_index = (writer_index + 1) % len(writers)
55
56 total_written += 1
57
58 # 打印前20个样本
59 if inst_index < 20:
60 tf.logging.info("*** Example ***")
61 tf.logging.info("tokens: %s" % " ".join(
62 [tokenization.printable_text(x) for x in instance.tokens]))
63
64 for feature_name in features.keys():
65 feature = features[feature_name]
66 values = []
67 if feature.int64_list.value:
68 values = feature.int64_list.value
69 elif feature.float_list.value:
70 values = feature.float_list.value
71 tf.logging.info(
72 "%s: %s" % (feature_name, " ".join([str(x) for x in values])))
73
74 for writer in writers:
75 writer.close()
76
77 tf.logging.info("Wrote %d total instances", total_written)
1python create_pretraining_data.py \
2 --input_file=./sample_text_zh.txt \
3 --output_file=/tmp/tf_examples.tfrecord \
4 --vocab_file=$BERT_BASE_DIR/vocab.txt \
5 --do_lower_case=True \
6 --max_seq_length=128 \
7 --max_predictions_per_seq=20 \
8 --masked_lm_prob=0.15 \
9 --random_seed=12345 \
10 --dupe_factor=5
因为我之前下载的词表是中文的,所以就网上随便找了几篇新闻进行测试。结果如下
这是其中的一个样例:
主要介绍BERT的自带分词组件以及pretraining数据生成过程,属于整个项目的准备部分。 没想到代码这么多,pretraining训练的部分就不放在这一篇里了,请见下篇~