fix eval symmetric bug

pull/9480/head
xiaoyisd 4 years ago
parent f019a4a0af
commit 607a729705

@ -41,10 +41,12 @@ if __name__ == '__main__':
config_device_target = config_ascend_quant config_device_target = config_ascend_quant
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", context.set_context(mode=context.GRAPH_MODE, device_target="Ascend",
device_id=device_id, save_graphs=False) device_id=device_id, save_graphs=False)
symmetric_list = [True, False]
elif args_opt.device_target == "GPU": elif args_opt.device_target == "GPU":
config_device_target = config_gpu_quant config_device_target = config_gpu_quant
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", context.set_context(mode=context.GRAPH_MODE, device_target="GPU",
device_id=device_id, save_graphs=False) device_id=device_id, save_graphs=False)
symmetric_list = [False, False]
else: else:
raise ValueError("Unsupported device target: {}.".format(args_opt.device_target)) raise ValueError("Unsupported device target: {}.".format(args_opt.device_target))
@ -53,7 +55,7 @@ if __name__ == '__main__':
# convert fusion network to quantization aware network # convert fusion network to quantization aware network
quantizer = QuantizationAwareTraining(bn_fold=True, quantizer = QuantizationAwareTraining(bn_fold=True,
per_channel=[True, False], per_channel=[True, False],
symmetric=[True, False]) symmetric=symmetric_list)
network = quantizer.quantize(network) network = quantizer.quantize(network)
# define network loss # define network loss
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')

Loading…
Cancel
Save