From d11759d3975b755782cf2381970f80fade65cf32 Mon Sep 17 00:00:00 2001 From: linqingke Date: Mon, 15 Mar 2021 15:37:17 +0800 Subject: [PATCH] fix optimizer and conv2d infer. --- mindspore/ops/operations/inner_ops.py | 2 ++ mindspore/ops/operations/nn_ops.py | 21 +++++++++++++++++---- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/mindspore/ops/operations/inner_ops.py b/mindspore/ops/operations/inner_ops.py index efd5ef3ff1..cbcee7cef5 100644 --- a/mindspore/ops/operations/inner_ops.py +++ b/mindspore/ops/operations/inner_ops.py @@ -235,6 +235,7 @@ class LambApplyOptimizerAssign(PrimitiveWithInfer): @prim_attr_register def __init__(self): """Initialize LambApplyOptimizerAssign""" + self.add_prim_attr('side_effect_mem', True) def infer_shape(self, grad_shape, v_shape, m_shape, var_shape, beta1_shape, sub1_shape, beta2_shape, sub2_shape, eps_shape, steps_shape, use_weight_shape, weight_decay_shape): @@ -296,6 +297,7 @@ class LambApplyWeightAssign(PrimitiveWithInfer): @prim_attr_register def __init__(self): """Initialize LambApplyWeightAssign""" + self.add_prim_attr('side_effect_mem', True) def infer_shape(self, w_norm_shape, g_norm_shape, lr_shape, update_shape, var_shape): validator.check("var_shape", var_shape, "update_shape", update_shape, Rel.EQ, self.name) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index f3ec6f3c1c..7b87df0190 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -78,6 +78,17 @@ def _check_shape(arg_name, arg_value, prim_name): return arg_value +def _update_attr_by_format(arg_value, arg_format): + """ + If the format is NHWC, should modify the strides or dilation shape. + """ + ret = arg_value + if len(arg_value) == 4 and arg_format == "NHWC": + ret = arg_value[1:] + (1,) + + return ret + + class Flatten(PrimitiveWithInfer): r""" Flattens a tensor without changing its batch size on the 0-th axis. @@ -2080,9 +2091,15 @@ class Conv2DBackpropInput(PrimitiveWithInfer): self.init_prim_io_names(inputs=['out_backprop', 'filter', 'input_sizes'], outputs=['output']) self.out_channel = validator.check_positive_int(out_channel, 'out_channel', self.name) self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name) + self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) + if context.get_context("device_target") != "GPU" and self.format == "NHWC": + raise ValueError("NHWC format only support in GPU target.") + self.add_prim_attr('data_format', self.format) self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=True) + self.stride = _update_attr_by_format(self.stride, self.format) self.add_prim_attr('stride', self.stride) self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True) + self.dilation = _update_attr_by_format(self.dilation, self.format) self.add_prim_attr('dilation', self.dilation) validator.check_value_type('pad', pad, (int, tuple), self.name) if isinstance(pad, int): @@ -2103,10 +2120,6 @@ class Conv2DBackpropInput(PrimitiveWithInfer): self.mode = validator.check_equal_int(mode, 1, 'mode', self.name) self.group = validator.check_positive_int(group, 'group', self.name) self.add_prim_attr('groups', self.group) - self.format = validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.name) - if context.get_context("device_target") != "GPU" and self.format == "NHWC": - raise ValueError("NHWC format only support in GPU target.") - self.add_prim_attr('data_format', self.format) if pad_list: for x in pad_list: validator.check_non_negative_int(x, 'element of pad_list', self.name)