From a1db640211fe95cbbbeb6c74511d00ca455850b8 Mon Sep 17 00:00:00 2001 From: yuzhenhua Date: Wed, 25 Nov 2020 20:14:58 +0800 Subject: [PATCH] fix bert bug and add export file for yolov3 and add test file for maskrcnn --- .../official/cv/yolov3_darknet53/export.py | 48 ++++++++++++++++ .../cv/yolov3_darknet53_quant/export.py | 53 ++++++++++++++++++ .../official/cv/yolov3_resnet18/export.py | 47 ++++++++++++++++ model_zoo/official/nlp/bert/export.py | 15 ++++- model_zoo/official/nlp/bert/run_ner.py | 2 +- .../model_zoo_tests/maskrcnn/test_maskrcnn.py | 56 +++++++++++++++++++ 6 files changed, 217 insertions(+), 4 deletions(-) create mode 100644 model_zoo/official/cv/yolov3_darknet53/export.py create mode 100644 model_zoo/official/cv/yolov3_darknet53_quant/export.py create mode 100644 model_zoo/official/cv/yolov3_resnet18/export.py create mode 100644 tests/st/model_zoo_tests/maskrcnn/test_maskrcnn.py diff --git a/model_zoo/official/cv/yolov3_darknet53/export.py b/model_zoo/official/cv/yolov3_darknet53/export.py new file mode 100644 index 0000000000..bab2192c53 --- /dev/null +++ b/model_zoo/official/cv/yolov3_darknet53/export.py @@ -0,0 +1,48 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import argparse +import numpy as np + +import mindspore as ms +from mindspore import context, Tensor +from mindspore.train.serialization import export, load_checkpoint, load_param_into_net + +from src.yolo import YOLOV3DarkNet53 +from src.config import ConfigYOLOV3DarkNet53 + +parser = argparse.ArgumentParser(description='yolov3_darknet53 export') +parser.add_argument("--device_id", type=int, default=0, help="Device id") +parser.add_argument("--batch_size", type=int, default=1, help="batch size") +parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.") +parser.add_argument("--file_name", type=str, default="yolov3_darknet53.air", help="output file name.") +parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format') +args = parser.parse_args() + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id) + +if __name__ == "__main__": + network = YOLOV3DarkNet53(is_training=False) + + param_dict = load_checkpoint(args.ckpt_file) + load_param_into_net(network, param_dict) + + config = ConfigYOLOV3DarkNet53() + network.set_train(False) + + shape = [args.batch_size, 3] + config.test_img_shape + input_data = Tensor(np.zeros(shape), ms.float32) + input_shape = Tensor(tuple(config.test_img_shape), ms.float32) + + export(network, input_data, input_shape, file_name=args.file_name, file_format=args.file_format) diff --git a/model_zoo/official/cv/yolov3_darknet53_quant/export.py b/model_zoo/official/cv/yolov3_darknet53_quant/export.py new file mode 100644 index 0000000000..9a82efc2ac --- /dev/null +++ b/model_zoo/official/cv/yolov3_darknet53_quant/export.py @@ -0,0 +1,53 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import argparse +import numpy as np + +import mindspore as ms +from mindspore import context, Tensor +from mindspore.train.serialization import export, load_checkpoint, load_param_into_net +from mindspore.compression.quant import QuantizationAwareTraining + +from src.yolo import YOLOV3DarkNet53 +from src.config import ConfigYOLOV3DarkNet53 + +parser = argparse.ArgumentParser(description='yolov3_darknet53_quant export') +parser.add_argument("--device_id", type=int, default=0, help="Device id") +parser.add_argument("--batch_size", type=int, default=1, help="batch size") +parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.") +parser.add_argument("--file_name", type=str, default="yolov3_darknet53_quant.mindir", help="output file name.") +parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='MINDIR', help='file format') +args = parser.parse_args() + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id) + +if __name__ == "__main__": + network = YOLOV3DarkNet53(is_training=False) + config = ConfigYOLOV3DarkNet53() + + if config.quantization_aware: + quantizer = QuantizationAwareTraining(bn_fold=True, per_channel=[True, False], symmetric=[True, False]) + network = quantizer.quantize(network) + + param_dict = load_checkpoint(args.ckpt_file) + load_param_into_net(network, param_dict) + + network.set_train(False) + + shape = [args.batch_size, 3] + config.test_img_shape + input_data = Tensor(np.zeros(shape), ms.float32) + input_shape = Tensor(tuple(config.test_img_shape), ms.float32) + + export(network, input_data, input_shape, file_name=args.file_name, file_format=args.file_format) diff --git a/model_zoo/official/cv/yolov3_resnet18/export.py b/model_zoo/official/cv/yolov3_resnet18/export.py new file mode 100644 index 0000000000..06c4baad2e --- /dev/null +++ b/model_zoo/official/cv/yolov3_resnet18/export.py @@ -0,0 +1,47 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import argparse +import numpy as np + +import mindspore as ms +from mindspore import context, Tensor +from mindspore.train.serialization import export, load_checkpoint, load_param_into_net + +from src.yolov3 import yolov3_resnet18 +from src.config import ConfigYOLOV3ResNet18 + +parser = argparse.ArgumentParser(description='yolov3_resnet18 export') +parser.add_argument("--device_id", type=int, default=0, help="Device id") +parser.add_argument("--batch_size", type=int, default=1, help="batch size") +parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.") +parser.add_argument("--file_name", type=str, default="yolov3_resnet18.air", help="output file name.") +parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format') +args = parser.parse_args() + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id) + +if __name__ == "__main__": + config = ConfigYOLOV3ResNet18() + network = yolov3_resnet18(config) + + param_dict = load_checkpoint(args.ckpt_file) + load_param_into_net(network, param_dict) + + network.set_train(False) + + shape = [args.batch_size, 3] + config.img_shape + input_data = Tensor(np.zeros(shape), ms.float32) + + export(network, input_data, file_name=args.file_name, file_format=args.file_format) diff --git a/model_zoo/official/nlp/bert/export.py b/model_zoo/official/nlp/bert/export.py index b77fa08fa0..30296e7ae8 100644 --- a/model_zoo/official/nlp/bert/export.py +++ b/model_zoo/official/nlp/bert/export.py @@ -22,6 +22,7 @@ from mindspore.train.serialization import load_checkpoint, export from src.finetune_eval_model import BertCLSModel, BertSquadModel, BertNERModel from src.finetune_eval_config import optimizer_cfg, bert_net_cfg +from src.bert_for_finetune import BertNER from src.utils import convert_labels_to_index context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") @@ -30,7 +31,7 @@ parser = argparse.ArgumentParser(description='Bert export') parser.add_argument('--use_crf', type=str, default="false", help='Use cfg, default is false.') parser.add_argument('--downstream_task', type=str, choices=["NER", "CLS", "SQUAD"], default="NER", help='at present,support NER only') -parser.add_argument('--num_class', type=int, default=2, help='The number of class, default is 2.') +parser.add_argument('--num_class', type=int, default=41, help='The number of class, default is 41.') parser.add_argument('--label_file_path', type=str, default="", help='label file path, used in clue benchmark.') parser.add_argument('--ckpt_file', type=str, required=True, help='Bert ckpt file.') parser.add_argument('--output_file', type=str, default='Bert.air', help='bert output air name.') @@ -55,7 +56,11 @@ else: if __name__ == '__main__': if args.downstream_task == "NER": - net = BertNERModel(bert_net_cfg, False, number_labels, use_crf=(args.use_crf.lower() == "true")) + if args.use_crf.lower() == "true": + net = BertNER(bert_net_cfg, optimizer_cfg.batch_size, False, num_labels=number_labels, + use_crf=True, tag_to_index=tag_to_index) + else: + net = BertNERModel(bert_net_cfg, False, number_labels, use_crf=(args.use_crf.lower() == "true")) elif args.downstream_task == "CLS": net = BertCLSModel(bert_net_cfg, False, num_labels=number_labels) elif args.downstream_task == "SQUAD": @@ -69,6 +74,10 @@ if __name__ == '__main__': input_ids = Tensor(np.zeros([optimizer_cfg.batch_size, bert_net_cfg.seq_length]), mstype.int32) input_mask = Tensor(np.zeros([optimizer_cfg.batch_size, bert_net_cfg.seq_length]), mstype.int32) token_type_id = Tensor(np.zeros([optimizer_cfg.batch_size, bert_net_cfg.seq_length]), mstype.int32) + label_ids = Tensor(np.zeros([optimizer_cfg.batch_size, bert_net_cfg.seq_length]), mstype.int32) - input_data = [input_ids, input_mask, token_type_id] + if args.downstream_task == "NER" and args.use_crf.lower() == "true": + input_data = [input_ids, input_mask, token_type_id, label_ids] + else: + input_data = [input_ids, input_mask, token_type_id] export(net, *input_data, file_name=args.output_file, file_format=args.file_format) diff --git a/model_zoo/official/nlp/bert/run_ner.py b/model_zoo/official/nlp/bert/run_ner.py index cf38bf1a4b..cd95657082 100644 --- a/model_zoo/official/nlp/bert/run_ner.py +++ b/model_zoo/official/nlp/bert/run_ner.py @@ -155,7 +155,7 @@ def parse_args(): help="Use crf, default is false") parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") parser.add_argument("--epoch_num", type=int, default="1", help="Epoch number, default is 1.") - parser.add_argument("--num_class", type=int, default="2", help="The number of class, default is 2.") + parser.add_argument("--num_class", type=int, default="41", help="The number of class, default is 41.") parser.add_argument("--train_data_shuffle", type=str, default="true", choices=["true", "false"], help="Enable train data shuffle, default is true") parser.add_argument("--eval_data_shuffle", type=str, default="false", choices=["true", "false"], diff --git a/tests/st/model_zoo_tests/maskrcnn/test_maskrcnn.py b/tests/st/model_zoo_tests/maskrcnn/test_maskrcnn.py new file mode 100644 index 0000000000..6b4e1662fb --- /dev/null +++ b/tests/st/model_zoo_tests/maskrcnn/test_maskrcnn.py @@ -0,0 +1,56 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""maskrcnn testing script.""" + +import os +import pytest +import numpy as np +from model_zoo.official.cv.maskrcnn.src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50 +from model_zoo.official.cv.maskrcnn.src.config import config + +from mindspore import Tensor, context, export + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_maskrcnn_export(): + """ + export maskrcnn air. + """ + net = Mask_Rcnn_Resnet50(config=config) + net.set_train(False) + + bs = config.test_batch_size + + img = Tensor(np.zeros([bs, 3, 768, 1280], np.float16)) + img_metas = Tensor(np.zeros([bs, 4], np.float16)) + gt_bboxes = Tensor(np.zeros([bs, 128, 4], np.float16)) + gt_labels = Tensor(np.zeros([bs, 128], np.int32)) + gt_num = Tensor(np.zeros([bs, 128], np.bool)) + gt_mask = Tensor(np.zeros([bs, 128], np.bool)) + + input_data = [img, img_metas, gt_bboxes, gt_labels, gt_num, gt_mask] + file_name = "maskrcnn.air" + + export(net, *input_data, file_name=file_name, file_format="AIR") + + assert os.path.exists(file_name) + os.remove(file_name) + +if __name__ == '__main__': + test_maskrcnn_export()