From 4128fb1155d08a8d744ff98354feb100538dcb86 Mon Sep 17 00:00:00 2001 From: yuzhenhua Date: Sat, 19 Dec 2020 14:45:25 +0800 Subject: [PATCH] modify export to support mindir --- model_zoo/official/nlp/bert/export.py | 31 +++-- model_zoo/official/nlp/mass/export.py | 15 ++- model_zoo/official/nlp/transformer/export.py | 8 +- model_zoo/official/recommend/ncf/export.py | 68 +++++++++++ .../official/recommend/ncf/src/export.py | 112 ------------------ 5 files changed, 96 insertions(+), 138 deletions(-) create mode 100644 model_zoo/official/recommend/ncf/export.py delete mode 100644 model_zoo/official/recommend/ncf/src/export.py diff --git a/model_zoo/official/nlp/bert/export.py b/model_zoo/official/nlp/bert/export.py index 84a71344ea..5a34edd703 100644 --- a/model_zoo/official/nlp/bert/export.py +++ b/model_zoo/official/nlp/bert/export.py @@ -16,30 +16,30 @@ import argparse import numpy as np -from mindspore import Tensor, context import mindspore.common.dtype as mstype -from mindspore.train.serialization import load_checkpoint, export +from mindspore import Tensor, context, load_checkpoint, export from src.finetune_eval_model import BertCLSModel, BertSquadModel, BertNERModel from src.finetune_eval_config import bert_net_cfg from src.bert_for_finetune import BertNER from src.utils import convert_labels_to_index - -parser = argparse.ArgumentParser(description='Bert export') +parser = argparse.ArgumentParser(description="Bert export") parser.add_argument("--device_id", type=int, default=0, help="Device id") -parser.add_argument('--use_crf', type=str, default="false", help='Use cfg, default is false.') -parser.add_argument('--downstream_task', type=str, choices=["NER", "CLS", "SQUAD"], default="NER", - help='at present,support NER only') -parser.add_argument('--num_class', type=int, default=41, help='The number of class, default is 41.') +parser.add_argument("--use_crf", type=str, default="false", help="Use cfg, default is false.") +parser.add_argument("--downstream_task", type=str, choices=["NER", "CLS", "SQUAD"], default="NER", + help="at present,support NER only") +parser.add_argument("--num_class", type=int, default=41, help="The number of class, default is 41.") parser.add_argument("--batch_size", type=int, default=16, help="batch size") -parser.add_argument('--label_file_path', type=str, default="", help='label file path, used in clue benchmark.') -parser.add_argument('--ckpt_file', type=str, required=True, help='Bert ckpt file.') -parser.add_argument('--output_file', type=str, default='Bert', help='bert output air name.') -parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format') +parser.add_argument("--label_file_path", type=str, default="", help="label file path, used in clue benchmark.") +parser.add_argument("--ckpt_file", type=str, required=True, help="Bert ckpt file.") +parser.add_argument("--file_name", type=str, default="Bert", help="bert output air name.") +parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format") +parser.add_argument("--device_target", type=str, default="Ascend", + choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)") args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) label_list = [] with open(args.label_file_path) as f: @@ -56,8 +56,7 @@ if args.use_crf.lower() == "true": else: number_labels = args.num_class - -if __name__ == '__main__': +if __name__ == "__main__": if args.downstream_task == "NER": if args.use_crf.lower() == "true": net = BertNER(bert_net_cfg, args.batch_size, False, num_labels=number_labels, @@ -83,4 +82,4 @@ if __name__ == '__main__': input_data = [input_ids, input_mask, token_type_id, label_ids] else: input_data = [input_ids, input_mask, token_type_id] - export(net, *input_data, file_name=args.output_file, file_format=args.file_format) + export(net, *input_data, file_name=args.file_name, file_format=args.file_format) diff --git a/model_zoo/official/nlp/mass/export.py b/model_zoo/official/nlp/mass/export.py index 2b6da6d27d..fb92fc594e 100644 --- a/model_zoo/official/nlp/mass/export.py +++ b/model_zoo/official/nlp/mass/export.py @@ -26,13 +26,18 @@ from src.utils.load_weights import load_infer_weights from src.transformer.transformer_for_infer import TransformerInferModel from config import TransformerConfig -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - -parser = argparse.ArgumentParser(description='mass') +parser = argparse.ArgumentParser(description="mass export") +parser.add_argument("--device_id", type=int, default=0, help="Device id") +parser.add_argument("--file_name", type=str, default="mass", help="output file name.") +parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format") +parser.add_argument("--device_target", type=str, default="Ascend", + choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)") parser.add_argument('--gigaword_infer_config', type=str, required=True, help='gigaword config file') parser.add_argument('--vocab_file', type=str, required=True, help='vocabulary file') args = parser.parse_args() +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) + def get_config(config_file): tfm_config = TransformerConfig.from_json_file(config_file) tfm_config.compute_type = mstype.float16 @@ -40,12 +45,10 @@ def get_config(config_file): return tfm_config - if __name__ == '__main__': vocab = Dictionary.load_from_persisted_dict(args.vocab_file) config = get_config(args.gigaword_infer_config) dec_len = config.max_decode_length - output_file_name = 'giga_' + str(dec_len) + '.air' tfm_model = TransformerInferModel(config=config, use_one_hot_embeddings=False) tfm_model.init_parameters_data() @@ -79,4 +82,4 @@ if __name__ == '__main__': source_ids = Tensor(np.ones((1, config.seq_length)).astype(np.int32)) source_mask = Tensor(np.ones((1, config.seq_length)).astype(np.int32)) - export(tfm_model, source_ids, source_mask, file_name=output_file_name, file_format="AIR") + export(tfm_model, source_ids, source_mask, file_name=args.file_name, file_format=args.file_format) diff --git a/model_zoo/official/nlp/transformer/export.py b/model_zoo/official/nlp/transformer/export.py index 3342ae14e9..6cc998f7d6 100644 --- a/model_zoo/official/nlp/transformer/export.py +++ b/model_zoo/official/nlp/transformer/export.py @@ -29,9 +29,11 @@ parser.add_argument("--device_id", type=int, default=0, help="Device id") parser.add_argument("--batch_size", type=int, default=1, help="batch size") parser.add_argument("--file_name", type=str, default="transformer", help="output file name.") parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format') +parser.add_argument("--device_target", type=str, default="Ascend", + choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)") args = parser.parse_args() -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id) +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) if __name__ == '__main__': tfm_model = TransformerModel(config=transformer_net_cfg, is_training=False, use_one_hot_embeddings=False) @@ -42,6 +44,4 @@ if __name__ == '__main__': source_ids = Tensor(np.ones((args.batch_size, transformer_net_cfg.seq_length)).astype(np.int32)) source_mask = Tensor(np.ones((args.batch_size, transformer_net_cfg.seq_length)).astype(np.int32)) - dec_len = transformer_net_cfg.max_decode_length - - export(tfm_model, source_ids, source_mask, file_name=args.file_name + str(dec_len), file_format=args.file_format) + export(tfm_model, source_ids, source_mask, file_name=args.file_name, file_format=args.file_format) diff --git a/model_zoo/official/recommend/ncf/export.py b/model_zoo/official/recommend/ncf/export.py new file mode 100644 index 0000000000..10f8503969 --- /dev/null +++ b/model_zoo/official/recommend/ncf/export.py @@ -0,0 +1,68 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""ncf export file""" +import argparse +import numpy as np + +from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export + +import src.constants as rconst +from src.config import cfg +from ncf import NCFModel, PredictWithSigmoid + +parser = argparse.ArgumentParser(description='ncf export') +parser.add_argument("--device_id", type=int, default=0, help="Device id") +parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.") +parser.add_argument("--dataset", type=str, default="ml-1m", choices=["ml-1m", "ml-20m"], help="Dataset.") +parser.add_argument("--file_name", type=str, default="ncf", help="output file name.") +parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format') +parser.add_argument("--device_target", type=str, default="Ascend", + choices=["Ascend", "GPU", "CPU"], help="device target (default: Ascend)") +args = parser.parse_args() + +context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id) + +if __name__ == "__main__": + topk = rconst.TOP_K + num_eval_neg = rconst.NUM_EVAL_NEGATIVES + + if args.dataset == "ml-1m": + num_eval_users = 6040 + num_eval_items = 3706 + elif args.dataset == "ml-20m": + num_eval_users = 138493 + num_eval_items = 26744 + else: + raise ValueError("not supported dataset") + + ncf_net = NCFModel(num_users=num_eval_users, + num_items=num_eval_items, + num_factors=cfg.num_factors, + model_layers=cfg.layers, + mf_regularization=0, + mlp_reg_layers=[0.0, 0.0, 0.0, 0.0], + mf_dim=16) + + param_dict = load_checkpoint(args.ckpt_file) + load_param_into_net(ncf_net, param_dict) + + network = PredictWithSigmoid(ncf_net, topk, num_eval_neg) + + users = Tensor(np.zeros([cfg.eval_batch_size, 1]).astype(np.int32)) + items = Tensor(np.zeros([cfg.eval_batch_size, 1]).astype(np.int32)) + masks = Tensor(np.zeros([cfg.eval_batch_size, 1]).astype(np.float32)) + + input_data = [users, items, masks] + export(network, *input_data, file_name=args.file_name, file_format=args.file_format) diff --git a/model_zoo/official/recommend/ncf/src/export.py b/model_zoo/official/recommend/ncf/src/export.py deleted file mode 100644 index bd7f3d936f..0000000000 --- a/model_zoo/official/recommend/ncf/src/export.py +++ /dev/null @@ -1,112 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Export NCF air file.""" -import os -import argparse -from absl import logging - -from mindspore.train.serialization import load_checkpoint, load_param_into_net, export -from mindspore import Tensor, context, Model - -import constants as rconst -from dataset import create_dataset -from metrics import NCFMetric -from ncf import NCFModel, NetWithLossClass, TrainStepWrap, PredictWithSigmoid - - -logging.set_verbosity(logging.INFO) - - -def argparse_init(): - """Argparse init method""" - parser = argparse.ArgumentParser(description='NCF') - - parser.add_argument("--data_path", type=str, default="./dataset/") # The location of the input data. - parser.add_argument("--dataset", type=str, default="ml-1m", choices=["ml-1m", "ml-20m"]) # Dataset to be trained and evaluated. ["ml-1m", "ml-20m"] - parser.add_argument("--eval_batch_size", type=int, default=160000) # The batch size used for evaluation. - parser.add_argument("--layers", type=int, default=[64, 32, 16]) # The sizes of hidden layers for MLP - parser.add_argument("--num_factors", type=int, default=16) # The Embedding size of MF model. - parser.add_argument("--output_path", type=str, default="./output/") # The location of the output file. - parser.add_argument("--eval_file_name", type=str, default="eval.log") # Eval output file. - parser.add_argument("--checkpoint_file_path", type=str, default="./checkpoint/NCF.ckpt") # The location of the checkpoint file. - return parser - - -def export_air_file(): - """"Export file for eval""" - parser = argparse_init() - args, _ = parser.parse_known_args() - - if not os.path.exists(args.output_path): - os.makedirs(args.output_path) - - layers = args.layers - num_factors = args.num_factors - topk = rconst.TOP_K - num_eval_neg = rconst.NUM_EVAL_NEGATIVES - - ds_eval, num_eval_users, num_eval_items = create_dataset(test_train=False, data_dir=args.data_path, - dataset=args.dataset, train_epochs=0, - eval_batch_size=args.eval_batch_size) - print("ds_eval.size: {}".format(ds_eval.get_dataset_size())) - - ncf_net = NCFModel(num_users=num_eval_users, - num_items=num_eval_items, - num_factors=num_factors, - model_layers=layers, - mf_regularization=0, - mlp_reg_layers=[0.0, 0.0, 0.0, 0.0], - mf_dim=16) - param_dict = load_checkpoint(args.checkpoint_file_path) - load_param_into_net(ncf_net, param_dict) - - loss_net = NetWithLossClass(ncf_net) - train_net = TrainStepWrap(loss_net) - train_net.set_train() - eval_net = PredictWithSigmoid(ncf_net, topk, num_eval_neg) - - ncf_metric = NCFMetric() - model = Model(train_net, eval_network=eval_net, metrics={"ncf": ncf_metric}) - - ncf_metric.clear() - out = model.eval(ds_eval) - - eval_file_path = os.path.join(args.output_path, args.eval_file_name) - eval_file = open(eval_file_path, "a+") - eval_file.write("EvalCallBack: HR = {}, NDCG = {}\n".format(out['ncf'][0], out['ncf'][1])) - eval_file.close() - print("EvalCallBack: HR = {}, NDCG = {}".format(out['ncf'][0], out['ncf'][1])) - - param_dict = load_checkpoint(args.checkpoint_file_path) - # load the parameter into net - load_param_into_net(eval_net, param_dict) - - input_tensor_list = [] - for data in ds_eval: - for j in data: - input_tensor_list.append(Tensor(j)) - print(len(a)) - break - print(input_tensor_list) - export(eval_net, *input_tensor_list, file_name='NCF.air', file_format='AIR') - -if __name__ == '__main__': - devid = int(os.getenv('DEVICE_ID')) - context.set_context(mode=context.GRAPH_MODE, - device_target="Davinci", - save_graphs=True, - device_id=devid) - - export_air_file()