fix bert bug and add export file for yolov3 and add test file for maskrcnn

pull/9055/head
yuzhenhua 4 years ago
parent a95cbdb121
commit a1db640211

@ -0,0 +1,48 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import numpy as np
import mindspore as ms
from mindspore import context, Tensor
from mindspore.train.serialization import export, load_checkpoint, load_param_into_net
from src.yolo import YOLOV3DarkNet53
from src.config import ConfigYOLOV3DarkNet53
parser = argparse.ArgumentParser(description='yolov3_darknet53 export')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--file_name", type=str, default="yolov3_darknet53.air", help="output file name.")
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)
if __name__ == "__main__":
network = YOLOV3DarkNet53(is_training=False)
param_dict = load_checkpoint(args.ckpt_file)
load_param_into_net(network, param_dict)
config = ConfigYOLOV3DarkNet53()
network.set_train(False)
shape = [args.batch_size, 3] + config.test_img_shape
input_data = Tensor(np.zeros(shape), ms.float32)
input_shape = Tensor(tuple(config.test_img_shape), ms.float32)
export(network, input_data, input_shape, file_name=args.file_name, file_format=args.file_format)

@ -0,0 +1,53 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import numpy as np
import mindspore as ms
from mindspore import context, Tensor
from mindspore.train.serialization import export, load_checkpoint, load_param_into_net
from mindspore.compression.quant import QuantizationAwareTraining
from src.yolo import YOLOV3DarkNet53
from src.config import ConfigYOLOV3DarkNet53
parser = argparse.ArgumentParser(description='yolov3_darknet53_quant export')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--file_name", type=str, default="yolov3_darknet53_quant.mindir", help="output file name.")
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='MINDIR', help='file format')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)
if __name__ == "__main__":
network = YOLOV3DarkNet53(is_training=False)
config = ConfigYOLOV3DarkNet53()
if config.quantization_aware:
quantizer = QuantizationAwareTraining(bn_fold=True, per_channel=[True, False], symmetric=[True, False])
network = quantizer.quantize(network)
param_dict = load_checkpoint(args.ckpt_file)
load_param_into_net(network, param_dict)
network.set_train(False)
shape = [args.batch_size, 3] + config.test_img_shape
input_data = Tensor(np.zeros(shape), ms.float32)
input_shape = Tensor(tuple(config.test_img_shape), ms.float32)
export(network, input_data, input_shape, file_name=args.file_name, file_format=args.file_format)

@ -0,0 +1,47 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
import numpy as np
import mindspore as ms
from mindspore import context, Tensor
from mindspore.train.serialization import export, load_checkpoint, load_param_into_net
from src.yolov3 import yolov3_resnet18
from src.config import ConfigYOLOV3ResNet18
parser = argparse.ArgumentParser(description='yolov3_resnet18 export')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--file_name", type=str, default="yolov3_resnet18.air", help="output file name.")
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format')
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)
if __name__ == "__main__":
config = ConfigYOLOV3ResNet18()
network = yolov3_resnet18(config)
param_dict = load_checkpoint(args.ckpt_file)
load_param_into_net(network, param_dict)
network.set_train(False)
shape = [args.batch_size, 3] + config.img_shape
input_data = Tensor(np.zeros(shape), ms.float32)
export(network, input_data, file_name=args.file_name, file_format=args.file_format)

@ -22,6 +22,7 @@ from mindspore.train.serialization import load_checkpoint, export
from src.finetune_eval_model import BertCLSModel, BertSquadModel, BertNERModel from src.finetune_eval_model import BertCLSModel, BertSquadModel, BertNERModel
from src.finetune_eval_config import optimizer_cfg, bert_net_cfg from src.finetune_eval_config import optimizer_cfg, bert_net_cfg
from src.bert_for_finetune import BertNER
from src.utils import convert_labels_to_index from src.utils import convert_labels_to_index
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
@ -30,7 +31,7 @@ parser = argparse.ArgumentParser(description='Bert export')
parser.add_argument('--use_crf', type=str, default="false", help='Use cfg, default is false.') parser.add_argument('--use_crf', type=str, default="false", help='Use cfg, default is false.')
parser.add_argument('--downstream_task', type=str, choices=["NER", "CLS", "SQUAD"], default="NER", parser.add_argument('--downstream_task', type=str, choices=["NER", "CLS", "SQUAD"], default="NER",
help='at presentsupport NER only') help='at presentsupport NER only')
parser.add_argument('--num_class', type=int, default=2, help='The number of class, default is 2.') parser.add_argument('--num_class', type=int, default=41, help='The number of class, default is 41.')
parser.add_argument('--label_file_path', type=str, default="", help='label file path, used in clue benchmark.') parser.add_argument('--label_file_path', type=str, default="", help='label file path, used in clue benchmark.')
parser.add_argument('--ckpt_file', type=str, required=True, help='Bert ckpt file.') parser.add_argument('--ckpt_file', type=str, required=True, help='Bert ckpt file.')
parser.add_argument('--output_file', type=str, default='Bert.air', help='bert output air name.') parser.add_argument('--output_file', type=str, default='Bert.air', help='bert output air name.')
@ -55,6 +56,10 @@ else:
if __name__ == '__main__': if __name__ == '__main__':
if args.downstream_task == "NER": if args.downstream_task == "NER":
if args.use_crf.lower() == "true":
net = BertNER(bert_net_cfg, optimizer_cfg.batch_size, False, num_labels=number_labels,
use_crf=True, tag_to_index=tag_to_index)
else:
net = BertNERModel(bert_net_cfg, False, number_labels, use_crf=(args.use_crf.lower() == "true")) net = BertNERModel(bert_net_cfg, False, number_labels, use_crf=(args.use_crf.lower() == "true"))
elif args.downstream_task == "CLS": elif args.downstream_task == "CLS":
net = BertCLSModel(bert_net_cfg, False, num_labels=number_labels) net = BertCLSModel(bert_net_cfg, False, num_labels=number_labels)
@ -69,6 +74,10 @@ if __name__ == '__main__':
input_ids = Tensor(np.zeros([optimizer_cfg.batch_size, bert_net_cfg.seq_length]), mstype.int32) input_ids = Tensor(np.zeros([optimizer_cfg.batch_size, bert_net_cfg.seq_length]), mstype.int32)
input_mask = Tensor(np.zeros([optimizer_cfg.batch_size, bert_net_cfg.seq_length]), mstype.int32) input_mask = Tensor(np.zeros([optimizer_cfg.batch_size, bert_net_cfg.seq_length]), mstype.int32)
token_type_id = Tensor(np.zeros([optimizer_cfg.batch_size, bert_net_cfg.seq_length]), mstype.int32) token_type_id = Tensor(np.zeros([optimizer_cfg.batch_size, bert_net_cfg.seq_length]), mstype.int32)
label_ids = Tensor(np.zeros([optimizer_cfg.batch_size, bert_net_cfg.seq_length]), mstype.int32)
if args.downstream_task == "NER" and args.use_crf.lower() == "true":
input_data = [input_ids, input_mask, token_type_id, label_ids]
else:
input_data = [input_ids, input_mask, token_type_id] input_data = [input_ids, input_mask, token_type_id]
export(net, *input_data, file_name=args.output_file, file_format=args.file_format) export(net, *input_data, file_name=args.output_file, file_format=args.file_format)

@ -155,7 +155,7 @@ def parse_args():
help="Use crf, default is false") help="Use crf, default is false")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.") parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--epoch_num", type=int, default="1", help="Epoch number, default is 1.") parser.add_argument("--epoch_num", type=int, default="1", help="Epoch number, default is 1.")
parser.add_argument("--num_class", type=int, default="2", help="The number of class, default is 2.") parser.add_argument("--num_class", type=int, default="41", help="The number of class, default is 41.")
parser.add_argument("--train_data_shuffle", type=str, default="true", choices=["true", "false"], parser.add_argument("--train_data_shuffle", type=str, default="true", choices=["true", "false"],
help="Enable train data shuffle, default is true") help="Enable train data shuffle, default is true")
parser.add_argument("--eval_data_shuffle", type=str, default="false", choices=["true", "false"], parser.add_argument("--eval_data_shuffle", type=str, default="false", choices=["true", "false"],

@ -0,0 +1,56 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""maskrcnn testing script."""
import os
import pytest
import numpy as np
from model_zoo.official.cv.maskrcnn.src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50
from model_zoo.official.cv.maskrcnn.src.config import config
from mindspore import Tensor, context, export
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_maskrcnn_export():
"""
export maskrcnn air.
"""
net = Mask_Rcnn_Resnet50(config=config)
net.set_train(False)
bs = config.test_batch_size
img = Tensor(np.zeros([bs, 3, 768, 1280], np.float16))
img_metas = Tensor(np.zeros([bs, 4], np.float16))
gt_bboxes = Tensor(np.zeros([bs, 128, 4], np.float16))
gt_labels = Tensor(np.zeros([bs, 128], np.int32))
gt_num = Tensor(np.zeros([bs, 128], np.bool))
gt_mask = Tensor(np.zeros([bs, 128], np.bool))
input_data = [img, img_metas, gt_bboxes, gt_labels, gt_num, gt_mask]
file_name = "maskrcnn.air"
export(net, *input_data, file_name=file_name, file_format="AIR")
assert os.path.exists(file_name)
os.remove(file_name)
if __name__ == '__main__':
test_maskrcnn_export()
Loading…
Cancel
Save