From 4c028167c1fb502149c020ed217b5abf0f1c89de Mon Sep 17 00:00:00 2001 From: bai-yangfan Date: Thu, 10 Dec 2020 14:18:55 +0800 Subject: [PATCH] export_gpu --- model_zoo/official/cv/mobilenetv2_quant/export.py | 11 +++-------- .../official/cv/mobilenetv2_quant/src/config.py | 13 +++++++++++++ 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/model_zoo/official/cv/mobilenetv2_quant/export.py b/model_zoo/official/cv/mobilenetv2_quant/export.py index c951d64f20..fce359ce2f 100644 --- a/model_zoo/official/cv/mobilenetv2_quant/export.py +++ b/model_zoo/official/cv/mobilenetv2_quant/export.py @@ -22,7 +22,7 @@ from mindspore import Tensor, context, load_checkpoint, load_param_into_net, exp from mindspore.compression.quant import QuantizationAwareTraining from src.mobilenetV2 import mobilenetV2 -from src.config import config_ascend_quant +from src.config import config_quant parser = argparse.ArgumentParser(description='Image classification') parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') @@ -30,13 +30,8 @@ parser.add_argument('--device_target', type=str, default=None, help='Run device args_opt = parser.parse_args() if __name__ == '__main__': - cfg = None - if args_opt.device_target == "Ascend": - cfg = config_ascend_quant - context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False) - else: - raise ValueError("Unsupported device target: {}.".format(args_opt.device_target)) - + cfg = config_quant(args_opt.device_target) + context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target, save_graphs=False) # define fusion network network = mobilenetV2(num_classes=cfg.num_classes) # convert fusion network to quantization aware network diff --git a/model_zoo/official/cv/mobilenetv2_quant/src/config.py b/model_zoo/official/cv/mobilenetv2_quant/src/config.py index 86ae10e738..ce2e613c18 100644 --- a/model_zoo/official/cv/mobilenetv2_quant/src/config.py +++ b/model_zoo/official/cv/mobilenetv2_quant/src/config.py @@ -38,6 +38,8 @@ config_ascend_quant = ed({ config_gpu_quant = ed({ "num_classes": 1000, + "image_height": 224, + "image_width": 224, "batch_size": 300, "epoch_size": 60, "start_epoch": 200, @@ -52,3 +54,14 @@ config_gpu_quant = ed({ "keep_checkpoint_max": 300, "save_checkpoint_path": "./checkpoint", }) + +def config_quant(device_target): + if device_target not in ["Ascend", "GPU"]: + raise ValueError("Unsupported device target: {}.".format(device_target)) + configs = ed({ + "Ascend": config_ascend_quant, + "GPU": config_gpu_quant + }) + config = configs.Ascend if device_target == "Ascend" else configs.GPU + config["device_target"] = device_target + return config