From c4eafb98200bec3f48d9e7560f2a7ec50f6a9c2c Mon Sep 17 00:00:00 2001 From: yuzhenhua Date: Thu, 29 Oct 2020 15:21:03 +0800 Subject: [PATCH] fix tinybert bug and add export script file for mass --- model_zoo/official/nlp/mass/export.py | 82 +++++++++++++++++++++++ model_zoo/official/nlp/tinybert/export.py | 2 + 2 files changed, 84 insertions(+) create mode 100644 model_zoo/official/nlp/mass/export.py diff --git a/model_zoo/official/nlp/mass/export.py b/model_zoo/official/nlp/mass/export.py new file mode 100644 index 0000000000..2b6da6d27d --- /dev/null +++ b/model_zoo/official/nlp/mass/export.py @@ -0,0 +1,82 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""export checkpoint file into air models""" + +import argparse +import numpy as np + +from mindspore import Tensor, context +from mindspore.common import dtype as mstype +from mindspore.train.serialization import export + +from src.utils import Dictionary +from src.utils.load_weights import load_infer_weights +from src.transformer.transformer_for_infer import TransformerInferModel +from config import TransformerConfig + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +parser = argparse.ArgumentParser(description='mass') +parser.add_argument('--gigaword_infer_config', type=str, required=True, help='gigaword config file') +parser.add_argument('--vocab_file', type=str, required=True, help='vocabulary file') +args = parser.parse_args() + +def get_config(config_file): + tfm_config = TransformerConfig.from_json_file(config_file) + tfm_config.compute_type = mstype.float16 + tfm_config.dtype = mstype.float32 + + return tfm_config + + +if __name__ == '__main__': + vocab = Dictionary.load_from_persisted_dict(args.vocab_file) + config = get_config(args.gigaword_infer_config) + dec_len = config.max_decode_length + output_file_name = 'giga_' + str(dec_len) + '.air' + + tfm_model = TransformerInferModel(config=config, use_one_hot_embeddings=False) + tfm_model.init_parameters_data() + + params = tfm_model.trainable_params() + weights = load_infer_weights(config) + + for param in params: + value = param.data + name = param.name + + if name not in weights: + raise ValueError(f'{name} is not found in weights.') + + with open('weight_after_deal.txt', 'a+') as f: + weights_name = name + f.write(weights_name + '\n') + + if isinstance(value, Tensor): + if weights_name in weights: + assert weights_name in weights + param.set_data(Tensor(weights[weights_name], mstype.float32)) + else: + raise ValueError(f'{weights_name} is not found in checkpoint') + else: + raise TypeError(f'Type of {weights_name} is not Tensor') + + print(' | Load weights successfully.') + tfm_model.set_train(False) + + source_ids = Tensor(np.ones((1, config.seq_length)).astype(np.int32)) + source_mask = Tensor(np.ones((1, config.seq_length)).astype(np.int32)) + + export(tfm_model, source_ids, source_mask, file_name=output_file_name, file_format="AIR") diff --git a/model_zoo/official/nlp/tinybert/export.py b/model_zoo/official/nlp/tinybert/export.py index 6adc8ac7bc..3502ea485a 100644 --- a/model_zoo/official/nlp/tinybert/export.py +++ b/model_zoo/official/nlp/tinybert/export.py @@ -32,6 +32,7 @@ args = parser.parse_args() DEFAULT_NUM_LABELS = 2 DEFAULT_SEQ_LENGTH = 128 +DEFAULT_BS = 32 task_params = {"SST-2": {"num_labels": 2, "seq_length": 64}, "QNLI": {"num_labels": 2, "seq_length": 128}, "MNLI": {"num_labels": 3, "seq_length": 128}} @@ -60,6 +61,7 @@ class Task: if __name__ == '__main__': task = Task(args.task_name) td_student_net_cfg.seq_length = task.seq_length + td_student_net_cfg.batch_size = DEFAULT_BS eval_model = BertModelCLS(td_student_net_cfg, False, task.num_labels, 0.0, phase_type="student") param_dict = load_checkpoint(args.ckpt_file)