From 3209687b9a1c569177027ce4db9384d5249dc5f7 Mon Sep 17 00:00:00 2001 From: zhaozhenlong Date: Fri, 24 Apr 2020 17:07:29 +0800 Subject: [PATCH] conv2d bp filter input primitive stride 2 and adapt to 4 --- mindspore/ccsrc/transform/op_adapter.h | 2 +- mindspore/ccsrc/transform/op_adapter_util.cc | 14 ++++++++++---- mindspore/ccsrc/transform/op_declare.cc | 4 ++-- mindspore/ops/operations/_grad_ops.py | 6 +++--- mindspore/ops/operations/nn_ops.py | 2 +- 5 files changed, 17 insertions(+), 11 deletions(-) diff --git a/mindspore/ccsrc/transform/op_adapter.h b/mindspore/ccsrc/transform/op_adapter.h index 2039dfa7d6..ae678606a4 100644 --- a/mindspore/ccsrc/transform/op_adapter.h +++ b/mindspore/ccsrc/transform/op_adapter.h @@ -736,7 +736,7 @@ class OpAdapter : public BaseOpAdapter { return static_cast(GetValue(value)); } - // specialization for int to Vector + // specialization for int or tuple broadcast to Vector static std::vector ConvertAny(const ValuePtr &value, const std::string &name, const AnyTraits> anyTraitsInt) { return ConvertAnyUtil(value, name, anyTraitsInt); diff --git a/mindspore/ccsrc/transform/op_adapter_util.cc b/mindspore/ccsrc/transform/op_adapter_util.cc index 0163b80f08..0d9e56e510 100644 --- a/mindspore/ccsrc/transform/op_adapter_util.cc +++ b/mindspore/ccsrc/transform/op_adapter_util.cc @@ -35,14 +35,20 @@ GeTensor ConvertAnyUtil(const ValuePtr &value, const AnyTraits ConvertAnyUtil(const ValuePtr &value, const std::string &name, const AnyTraits>) { + MS_EXCEPTION_IF_NULL(value); int64_t data = GetValue(value); std::vector list; int size = 2; // 2 int in list if (name == "pad") { - size = 4; // 4 int in list - list = TransformUtil::ConvertIntToList(data, size); - list[0] = 1; - list[1] = 1; + if (!value->isa()) { + MS_LOG(EXCEPTION) << "Value should be ValueTuple, but got" << value->type_name(); + } + auto vec = value->cast(); + list.push_back(1); + list.push_back(1); + for (auto &it : vec->value()) { + list.push_back(static_cast(GetValue(it))); + } } else { list = TransformUtil::ConvertIntToList(data, size); } diff --git a/mindspore/ccsrc/transform/op_declare.cc b/mindspore/ccsrc/transform/op_declare.cc index 377403cc89..5ec54b2037 100644 --- a/mindspore/ccsrc/transform/op_declare.cc +++ b/mindspore/ccsrc/transform/op_declare.cc @@ -733,7 +733,7 @@ INPUT_ATTR_MAP(Conv2DBackpropInputD) = { {3, ATTR_DESC(input_size, AnyTraits>(), AnyTraits>())}}; ATTR_MAP(Conv2DBackpropInputD) = { {"pad_list", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, - {"stride", ATTR_DESC(strides, AnyTraits>(), AnyTraits>())}, + {"stride", ATTR_DESC(strides, "pad", AnyTraits>())}, {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, {"data_format", ATTR_DESC(data_format, AnyTraits())}, {"group", ATTR_DESC(groups, AnyTraits())}, @@ -746,7 +746,7 @@ INPUT_ATTR_MAP(Conv2DBackpropFilterD) = { {3, ATTR_DESC(filter_size, AnyTraits>(), AnyTraits>())}}; ATTR_MAP(Conv2DBackpropFilterD) = { {"pad_list", ATTR_DESC(pads, AnyTraits>(), AnyTraits>())}, - {"stride", ATTR_DESC(strides, AnyTraits>(), AnyTraits>())}, + {"stride", ATTR_DESC(strides, "pad", AnyTraits>())}, {"dilation", ATTR_DESC(dilations, AnyTraits>(), AnyTraits>())}, {"data_format", ATTR_DESC(data_format, AnyTraits())}, {"group", ATTR_DESC(groups, AnyTraits())}, diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index c821063da8..782784ca00 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -174,9 +174,9 @@ class Conv2DBackpropFilter(PrimitiveWithInfer): pad_mode = pad_mode.upper() self.add_prim_attr('pad_mode', pad_mode) self.pad = pad - if isinstance(stride, tuple) and len(stride) == 2: - self.stride = stride - self.add_prim_attr('stride', (1, 1, self.stride[0], self.stride[1])) + if isinstance(stride, tuple) and len(stride) == 4: + self.stride = (stride[2], stride[3]) + self.add_prim_attr('stride', self.stride) self.dilation = dilation self.group = group self.add_prim_attr('data_format', "NCHW") diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 78a512fab7..9750549dc5 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1084,7 +1084,7 @@ class Conv2DBackpropInput(PrimitiveWithInfer): self.init_prim_io_names(inputs=['out_backprop', 'filter', 'input_sizes'], outputs=['output']) self.out_channel = validator.check_integer('out_channel', out_channel, 0, Rel.GT, self.name) self.kernel_size = _check_positive_int_or_tuple('kernel_size', kernel_size, self.name) - self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=True) + self.stride = _check_positive_int_or_tuple('stride', stride, self.name, allow_four=True, ret_four=False) self.add_prim_attr('stride', self.stride) self.dilation = _check_positive_int_or_tuple('dilation', dilation, self.name, allow_four=True, ret_four=True) self.add_prim_attr('dilation', self.dilation)