modify tinybert for MindSpore BERT

pull/13999/head
wang_hua_2019 4 years ago
parent 9bac30d37f
commit e0d85aecea

@ -50,8 +50,15 @@ The backbone structure of BERT is transformer. For BERT_base, the transformer co
# [Dataset](#contents)
- Download the zhwiki or enwiki dataset for pre-training. Extract and refine texts in the dataset with [WikiExtractor](https://github.com/attardi/wikiextractor). Convert the dataset to TFRecord format. Please refer to create_pretraining_data.py file in [BERT](https://github.com/google-research/bert) repository.
- Download dataset for fine-tuning and evaluation such as CLUENER, TNEWS, SQuAD v1.1, etc. Convert dataset files from JSON format to TFRECORD format, please refer to run_classifier.py file in [BERT](https://github.com/google-research/bert) repository.
- Create pre-training dataset
- Download the [zhwiki](https://dumps.wikimedia.org/zhwiki/) or [enwiki](https://dumps.wikimedia.org/enwiki/) dataset for pre-training.
- Extract and refine texts in the dataset with [WikiExtractor](https://github.com/attardi/wikiextractor). The commands are as follows:
- pip install wikiextractor
- python -m wikiextractor.WikiExtractor -o <output file path> -b <output file size> <Wikipedia dump file>
- Convert the dataset to TFRecord format. Please refer to create_pretraining_data.py file in [BERT](https://github.com/google-research/bert) repository and download vocab.txt here, if AttributeError: module 'tokenization' has no attribute 'FullTokenizer' occur, please install bert-tensorflow.
- Create fine-tune dataset
- Download dataset for fine-tuning and evaluation such as [CLUENER](https://github.com/CLUEbenchmark/CLUENER2020), [TNEWS](https://github.com/CLUEbenchmark/CLUE), [SQuAD v1.1 train dataset](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json), [SQuAD v1.1 eval dataset](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json), etc.
- Convert dataset files from JSON format to TFRECORD format, please refer to run_classifier.py file in [BERT](https://github.com/google-research/bert) repository.
# [Environment Requirements](#contents)

@ -53,8 +53,15 @@ BERT的主干结构为Transformer。对于BERT_baseTransformer包含12个编
# 数据集
- 下载zhwiki或enwiki数据集进行预训练使用[WikiExtractor](https://github.com/attardi/wikiextractor)提取和整理数据集中的文本并将数据集转换为TFRecord格式。详见[BERT](https://github.com/google-research/bert)代码仓中的create_pretraining_data.py文件。
- 下载数据集进行微调和评估如CLUENER、TNEWS、SQuAD v1.1等。将数据集文件从JSON格式转换为TFRecord格式。详见[BERT](https://github.com/google-research/bert)代码仓中的run_classifier.py文件。
- 生成预训练数据集
- 下载[zhwiki](https://dumps.wikimedia.org/zhwiki/)或[enwiki](https://dumps.wikimedia.org/enwiki/)数据集进行预训练,
- 使用[WikiExtractor](https://github.com/attardi/wikiextractor)提取和整理数据集中的文本,使用步骤如下:
- pip install wikiextractor
- python -m wikiextractor.WikiExtractor -o <output file path> -b <output file size> <Wikipedia dump file>
- 将数据集转换为TFRecord格式。详见[BERT](https://github.com/google-research/bert)代码仓中的create_pretraining_data.py文件同时下载对应的vocab.txt文件, 如果出现AttributeError: module 'tokenization' has no attribute 'FullTokenizer请安装bert-tensorflow。
- 生成下游任务数据集
- 下载数据集进行微调和评估,如[CLUENER](https://github.com/CLUEbenchmark/CLUENER2020)、[TNEWS](https://github.com/CLUEbenchmark/CLUE)、[SQuAD v1.1训练集](https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json)、[SQuAD v1.1验证集](https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json)等。
- 将数据集文件从JSON格式转换为TFRecord格式。详见[BERT](https://github.com/google-research/bert)代码仓中的run_classifier.py文件。
# 环境要求

@ -45,8 +45,15 @@ The backbone structure of TinyBERT is transformer, the transformer contains four
# [Dataset](#contents)
- Download the zhwiki or enwiki dataset for general distillation. Extract and clean text in the dataset with [WikiExtractor](https://github.com/attardi/wikiextractor). Convert the dataset to TFRecord format, please refer to create_pretraining_data.py which in [BERT](https://github.com/google-research/bert) repository.
- Download glue dataset for task distillation. Convert dataset files from json format to tfrecord format, please refer to run_classifier.py which in [BERT](https://github.com/google-research/bert) repository.
- Create dataset for general distill phase
- Download the [zhwiki](https://dumps.wikimedia.org/zhwiki/) or [enwiki](https://dumps.wikimedia.org/enwiki/) dataset for pre-training.
- Extract and refine texts in the dataset with [WikiExtractor](https://github.com/attardi/wikiextractor). The commands are as follows:
- pip install wikiextractor
- python -m wikiextractor.WikiExtractor -o <output file path> -b <output file size> <Wikipedia dump file>
- Convert the dataset to TFRecord format. Please refer to create_pretraining_data.py file in [BERT](https://github.com/google-research/bert) repository and download vocab.txt here, if AttributeError: module 'tokenization' has no attribute 'FullTokenizer' occur, please install bert-tensorflow.
- Create dataset for task distill phase
- Download [GLUE](https://github.com/nyu-mll/GLUE-baselines) dataset for task distill phase
- Convert dataset files from JSON format to TFRECORD format, please refer to run_classifier.py file in [BERT](https://github.com/google-research/bert) repository.
# [Environment Requirements](#contents)

@ -50,8 +50,15 @@ TinyBERT模型的主干结构是转换器转换器包含四个编码器模块
# 数据集
- 下载zhwiki或enwiki数据集进行一般蒸馏。使用[WikiExtractor](https://github.com/attardi/wikiextractor)提取和整理数据集中的文本。如需将数据集转化为TFRecord格式。详见[BERT](https://github.com/google-research/bert)代码库中的create_pretraining_data.py文件。
- 下载GLUE数据集进行任务蒸馏。将数据集由JSON格式转化为TFRecord格式。详见[BERT](https://github.com/google-research/bert)代码库中的run_classifier.py文件。
- 生成通用蒸馏阶段数据集
- 下载[zhwiki](https://dumps.wikimedia.org/zhwiki/)或[enwiki](https://dumps.wikimedia.org/enwiki/)数据集进行预训练,
- 使用[WikiExtractor](https://github.com/attardi/wikiextractor)提取和整理数据集中的文本,使用步骤如下:
- pip install wikiextractor
- python -m wikiextractor.WikiExtractor -o <output file path> -b <output file size> <Wikipedia dump file>
- 将数据集转换为TFRecord格式。详见[BERT](https://github.com/google-research/bert)代码仓中的create_pretraining_data.py文件同时下载对应的vocab.txt文件, 如果出现AttributeError: module 'tokenization' has no attribute 'FullTokenizer请安装bert-tensorflow。
- 生成下游任务蒸馏阶段数据集
- 下载数据集进行微调和评估,如[GLUE](https://github.com/nyu-mll/GLUE-baselines)
- 将数据集文件从JSON格式转换为TFRecord格式。详见[BERT](https://github.com/google-research/bert)代码仓中的run_classifier.py文件。
# 环境要求

@ -52,7 +52,7 @@ def create_tinybert_dataset(task='td', batch_size=32, device_num=1, rank=0,
data_set = ds.MindDataset(data_files, columns_list=columns_list,
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank)
else:
data_set = ds.TFRecordDataset(data_files, schema_dir, columns_list=columns_list,
data_set = ds.TFRecordDataset(data_files, schema_dir if schema_dir != "" else None, columns_list=columns_list,
shuffle=shuffle, num_shards=device_num, shard_id=rank,
shard_equal_rows=shard_equal_rows)
if device_num == 1 and shuffle is True:

@ -86,55 +86,6 @@ class BertConfig:
self.dtype = dtype
self.compute_type = compute_type
class EmbeddingLookup(nn.Cell):
"""
A embeddings lookup table with a fixed dictionary and size.
Args:
vocab_size (int): Size of the dictionary of embeddings.
embedding_size (int): The size of each embedding vector.
embedding_shape (list): [batch_size, seq_length, embedding_size], the shape of
each embedding vector.
use_one_hot_embeddings (bool): Specifies whether to use one hot encoding form. Default: False.
initializer_range (float): Initialization value of TruncatedNormal. Default: 0.02.
"""
def __init__(self,
vocab_size,
embedding_size,
embedding_shape,
use_one_hot_embeddings=False,
initializer_range=0.02):
super(EmbeddingLookup, self).__init__()
self.vocab_size = vocab_size
self.use_one_hot_embeddings = use_one_hot_embeddings
self.embedding_table = Parameter(initializer
(TruncatedNormal(initializer_range),
[vocab_size, embedding_size]))
self.expand = P.ExpandDims()
self.shape_flat = (-1,)
self.gather = P.Gather()
self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
self.off_value = Tensor(0.0, mstype.float32)
self.array_mul = P.MatMul()
self.reshape = P.Reshape()
self.shape = tuple(embedding_shape)
def construct(self, input_ids):
"""embedding lookup"""
extended_ids = self.expand(input_ids, -1)
flat_ids = self.reshape(extended_ids, self.shape_flat)
if self.use_one_hot_embeddings:
one_hot_ids = self.one_hot(flat_ids, self.vocab_size, self.on_value, self.off_value)
output_for_reshape = self.array_mul(
one_hot_ids, self.embedding_table)
else:
output_for_reshape = self.gather(self.embedding_table, flat_ids, 0)
output = self.reshape(output_for_reshape, self.shape)
return output, self.embedding_table
class EmbeddingPostprocessor(nn.Cell):
"""
Postprocessors apply positional and token type embeddings to word embeddings.
@ -166,10 +117,10 @@ class EmbeddingPostprocessor(nn.Cell):
self.token_type_vocab_size = token_type_vocab_size
self.use_one_hot_embeddings = use_one_hot_embeddings
self.max_position_embeddings = max_position_embeddings
self.embedding_table = Parameter(initializer
(TruncatedNormal(initializer_range),
[token_type_vocab_size,
embedding_size]))
self.token_type_embedding = nn.Embedding(
vocab_size=token_type_vocab_size,
embedding_size=embedding_size,
use_one_hot=use_one_hot_embeddings)
self.shape_flat = (-1,)
self.one_hot = P.OneHot()
self.on_value = Tensor(1.0, mstype.float32)
@ -177,35 +128,28 @@ class EmbeddingPostprocessor(nn.Cell):
self.array_mul = P.MatMul()
self.reshape = P.Reshape()
self.shape = tuple(embedding_shape)
self.layernorm = nn.LayerNorm((embedding_size,))
self.dropout = nn.Dropout(1 - dropout_prob)
self.gather = P.Gather()
self.use_relative_positions = use_relative_positions
self.slice = P.StridedSlice()
self.full_position_embeddings = Parameter(initializer
(TruncatedNormal(initializer_range),
[max_position_embeddings,
embedding_size]))
_, seq, _ = self.shape
self.full_position_embedding = nn.Embedding(
vocab_size=max_position_embeddings,
embedding_size=embedding_size,
use_one_hot=False)
self.layernorm = nn.LayerNorm((embedding_size,))
self.position_ids = Tensor(np.arange(seq).reshape(-1, seq).astype(np.int32))
self.add = P.Add()
def construct(self, token_type_ids, word_embeddings):
"""embedding postprocessor"""
"""Postprocessors apply positional and token type embeddings to word embeddings."""
output = word_embeddings
if self.use_token_type:
flat_ids = self.reshape(token_type_ids, self.shape_flat)
if self.use_one_hot_embeddings:
one_hot_ids = self.one_hot(flat_ids,
self.token_type_vocab_size, self.on_value, self.off_value)
token_type_embeddings = self.array_mul(one_hot_ids,
self.embedding_table)
else:
token_type_embeddings = self.gather(self.embedding_table, flat_ids, 0)
token_type_embeddings = self.reshape(token_type_embeddings, self.shape)
output += token_type_embeddings
token_type_embeddings = self.token_type_embedding(token_type_ids)
output = self.add(output, token_type_embeddings)
if not self.use_relative_positions:
_, seq, width = self.shape
position_embeddings = self.slice(self.full_position_embeddings, (0, 0), (seq, width), (1, 1))
position_embeddings = self.reshape(position_embeddings, (1, seq, width))
output += position_embeddings
position_embeddings = self.full_position_embedding(self.position_ids)
output = self.add(output, position_embeddings)
output = self.layernorm(output)
output = self.dropout(output)
return output
@ -788,12 +732,10 @@ class BertModel(nn.Cell):
self.last_idx = self.num_hidden_layers - 1
output_embedding_shape = [-1, self.seq_length,
self.embedding_size]
self.bert_embedding_lookup = EmbeddingLookup(
self.bert_embedding_lookup = nn.Embedding(
vocab_size=config.vocab_size,
embedding_size=self.embedding_size,
embedding_shape=output_embedding_shape,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=config.initializer_range)
use_one_hot=use_one_hot_embeddings)
self.bert_embedding_postprocessor = EmbeddingPostprocessor(
use_relative_positions=config.use_relative_positions,
embedding_size=self.embedding_size,
@ -831,7 +773,8 @@ class BertModel(nn.Cell):
def construct(self, input_ids, token_type_ids, input_mask):
"""bert model"""
# embedding
word_embeddings, embedding_tables = self.bert_embedding_lookup(input_ids)
embedding_tables = self.bert_embedding_lookup.embedding_table
word_embeddings = self.bert_embedding_lookup(input_ids)
embedding_output = self.bert_embedding_postprocessor(token_type_ids, word_embeddings)
# attention mask [batch_size, seq_length, seq_length]
attention_mask = self._create_attention_mask_from_input_mask(input_mask)
@ -883,12 +826,10 @@ class TinyBertModel(nn.Cell):
self.last_idx = self.num_hidden_layers - 1
output_embedding_shape = [-1, self.seq_length,
self.embedding_size]
self.tinybert_embedding_lookup = EmbeddingLookup(
self.tinybert_embedding_lookup = nn.Embedding(
vocab_size=config.vocab_size,
embedding_size=self.embedding_size,
embedding_shape=output_embedding_shape,
use_one_hot_embeddings=use_one_hot_embeddings,
initializer_range=config.initializer_range)
use_one_hot=use_one_hot_embeddings)
self.tinybert_embedding_postprocessor = EmbeddingPostprocessor(
use_relative_positions=config.use_relative_positions,
embedding_size=self.embedding_size,
@ -926,7 +867,8 @@ class TinyBertModel(nn.Cell):
def construct(self, input_ids, token_type_ids, input_mask):
"""tiny bert model"""
# embedding
word_embeddings, embedding_tables = self.tinybert_embedding_lookup(input_ids)
embedding_tables = self.tinybert_embedding_lookup.embedding_table
word_embeddings = self.tinybert_embedding_lookup(input_ids)
embedding_output = self.tinybert_embedding_postprocessor(token_type_ids,
word_embeddings)
# attention mask [batch_size, seq_length, seq_length]
@ -969,12 +911,8 @@ class BertModelCLS(nn.Cell):
self.dtype = config.dtype
self.num_labels = num_labels
self.phase_type = phase_type
if self.phase_type == "teacher":
self.dense = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init,
has_bias=True).to_float(config.compute_type)
else:
self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init,
has_bias=True).to_float(config.compute_type)
self.dense_1 = nn.Dense(config.hidden_size, self.num_labels, weight_init=self.weight_init,
has_bias=True).to_float(config.compute_type)
self.dropout = nn.ReLU()
def construct(self, input_ids, token_type_id, input_mask):
@ -982,10 +920,7 @@ class BertModelCLS(nn.Cell):
_, pooled_output, _, seq_output, att_output = self.bert(input_ids, token_type_id, input_mask)
cls = self.cast(pooled_output, self.dtype)
cls = self.dropout(cls)
if self.phase_type == "teacher":
logits = self.dense(cls)
else:
logits = self.dense_1(cls)
logits = self.dense_1(cls)
logits = self.cast(logits, self.dtype)
log_probs = self.log_softmax(logits)
if self._phase == 'train' or self.phase_type == "teacher":

Loading…
Cancel
Save