!8831 modify export script for centerface and yolov4

From: @yuzhenhua666
Reviewed-by: @c_34,@yingjy
Signed-off-by: @c_34
pull/8831/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit ce503eb3bc

@ -12,51 +12,43 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Convert ckpt to air."""
import os
import argparse import argparse
import numpy as np import numpy as np
from mindspore import context import mindspore
from mindspore import Tensor from mindspore import context, Tensor
from mindspore.train.serialization import export, load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.centerface import CenterfaceMobilev2 from src.centerface import CenterfaceMobilev2
from src.config import ConfigCenterface
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False)
parser = argparse.ArgumentParser(description='centerface export')
def save_air(): parser.add_argument("--device_id", type=int, default=0, help="Device id")
"""Save air file""" parser.add_argument("--batch_size", type=int, default=1, help="batch size")
print('============= centerface start save air ==================') parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--file_name", type=str, default="centerface.air", help="output file name.")
parser = argparse.ArgumentParser(description='Convert ckpt to air') parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format')
parser.add_argument('--pretrained', type=str, default='', help='pretrained model to load') args = parser.parse_args()
parser.add_argument('--batch_size', type=int, default=8, help='batch size')
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args.device_id)
args = parser.parse_args()
network = CenterfaceMobilev2() if __name__ == '__main__':
config = ConfigCenterface()
if os.path.isfile(args.pretrained): net = CenterfaceMobilev2()
param_dict = load_checkpoint(args.pretrained)
param_dict_new = {} param_dict = load_checkpoint(args.ckpt_file)
for key, values in param_dict.items(): param_dict_new = {}
if key.startswith('moments.') or key.startswith('moment1.') or key.startswith('moment2.'): for key, values in param_dict.items():
continue if key.startswith('moments.') or key.startswith('moment1.') or key.startswith('moment2.'):
elif key.startswith('centerface_network.'): continue
param_dict_new[key[19:]] = values elif key.startswith('centerface_network.'):
else: param_dict_new[key[19:]] = values
param_dict_new[key] = values else:
load_param_into_net(network, param_dict_new) param_dict_new[key] = values
print('load model {} success'.format(args.pretrained))
load_param_into_net(net, param_dict_new)
input_data = np.random.uniform(low=0, high=1.0, size=(args.batch_size, 3, 832, 832)).astype(np.float32) net.set_train(False)
tensor_input_data = Tensor(input_data) input_data = Tensor(np.zeros([args.batch_size, 3, config.input_h, config.input_w]), mindspore.float32)
export(network, tensor_input_data, export(net, input_data, file_name=args.file_name, file_format=args.file_format)
file_name=args.pretrained.replace('.ckpt', '_' + str(args.batch_size) + 'b.air'), file_format='AIR')
print("export model success.")
if __name__ == "__main__":
save_air()

@ -26,7 +26,7 @@ parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--batch_size", type=int, default=1, help="batch size") parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument("--testing_shape", type=int, default=608, help="test shape") parser.add_argument("--testing_shape", type=int, default=608, help="test shape")
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.") parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
parser.add_argument("--file_name", type=str, default="ssd.air", help="output file name.") parser.add_argument("--file_name", type=str, default="yolov4.air", help="output file name.")
parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format') parser.add_argument('--file_format', type=str, choices=["AIR", "ONNX", "MINDIR"], default='AIR', help='file format')
args = parser.parse_args() args = parser.parse_args()

Loading…
Cancel
Save