diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index 2a2530a7ae..8eb6a6d05b 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -49,8 +49,6 @@ AbstractBasePtr InferImplFusedSparseAdam(const AnalysisEnginePtr &, const Primit const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplBatchNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplBatchNormGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplReluGrad(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplConv2D(const AnalysisEnginePtr &, const PrimitivePtr &primitive, diff --git a/mindspore/core/ops/avg_pool.cc b/mindspore/core/ops/avg_pool.cc index 01d581d6ac..11462a7273 100644 --- a/mindspore/core/ops/avg_pool.cc +++ b/mindspore/core/ops/avg_pool.cc @@ -36,8 +36,8 @@ PadMode AvgPool::get_pad_mode() const { return PadMode(GetValue(value_ptr)); } void AvgPool::set_kernel_size(const std::vector &kernel_size) { - this->AddAttr(kKernelSize, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name(), - false, true))); + this->AddAttr(kKernelSize, + MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name()))); } std::vector AvgPool::get_kernel_size() const { @@ -45,8 +45,7 @@ std::vector AvgPool::get_kernel_size() const { return GetValue>(value_ptr); } void AvgPool::set_strides(const std::vector &strides) { - this->AddAttr(kStrides, - MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name(), false, true))); + this->AddAttr(kStrides, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name()))); } std::vector AvgPool::get_strides() const { diff --git a/mindspore/core/ops/conv2d.cc b/mindspore/core/ops/conv2d.cc index 24127da3ed..4aff527387 100644 --- a/mindspore/core/ops/conv2d.cc +++ b/mindspore/core/ops/conv2d.cc @@ -93,8 +93,7 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve w_out = floor(w_out); } CheckAndConvertUtils::CheckInteger("pad_size", pad_list.size(), kEqual, 4, prim_name); - primitive->AddAttr(kPadList, - MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad_list, prim_name, true, true))); + primitive->AddAttr(kPadList, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad_list, prim_name))); std::vector out_shape = {x_shape[0], out_channel, h_out, w_out}; if (format == NHWC) { out_shape = {x_shape[0], h_out, w_out, out_channel}; @@ -144,11 +143,11 @@ void Conv2D::set_kernel_size(const std::vector &kernel_size) { } void Conv2D::set_stride(const std::vector &stride) { - AddAttr(kStride, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, name(), true, true))); + AddAttr(kStride, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, name()))); } void Conv2D::set_dilation(const std::vector &dilation) { - AddAttr(kDilation, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, name(), true, true))); + AddAttr(kDilation, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, name()))); } void Conv2D::set_pad_mode(const PadMode &pad_mode) { @@ -166,7 +165,7 @@ void Conv2D::set_pad_mode(const PadMode &pad_mode) { void Conv2D::set_pad(const std::vector &pad) { CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name()); - AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name(), true, true))); + AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name()))); } void Conv2D::set_mode(int64_t mode) { diff --git a/mindspore/core/ops/conv2d_transpose.cc b/mindspore/core/ops/conv2d_transpose.cc index ea41d32ba3..36c08c7d4b 100644 --- a/mindspore/core/ops/conv2d_transpose.cc +++ b/mindspore/core/ops/conv2d_transpose.cc @@ -111,7 +111,7 @@ void Conv2dTranspose::set_pad_mode(const PadMode &pad_mode) { void Conv2dTranspose::set_pad(const std::vector &pad) { CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name()); - AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name(), true, true))); + AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name()))); } void Conv2dTranspose::set_mode(int64_t mode) { diff --git a/mindspore/core/ops/depthwise_conv2d.cc b/mindspore/core/ops/depthwise_conv2d.cc index 93290dbc9d..bbc5599cf9 100644 --- a/mindspore/core/ops/depthwise_conv2d.cc +++ b/mindspore/core/ops/depthwise_conv2d.cc @@ -35,13 +35,13 @@ void DepthWiseConv2D::Init(const int64_t channel_multiplier, const std::vectorset_mode(CheckAndConvertUtils::CheckInteger("mode", mode, kEqual, 3, prim_name)); this->set_kernel_size(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, prim_name)); - auto strides = CheckAndConvertUtils::CheckPositiveVector(kStride, stride, this->name(), false, false); + auto strides = CheckAndConvertUtils::CheckPositiveVector(kStride, stride, this->name()); if (strides[0] != strides[1]) { MS_EXCEPTION(ValueError) << "The height and width of stride should be equal, but got height " << strides[0] << ", width " << strides[1]; } this->set_stride(strides); - auto dilations = CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, this->name(), false, false); + auto dilations = CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, this->name()); if (dilations[0] != dilations[1]) { MS_EXCEPTION(ValueError) << "The height and width of dilation should be equal, but got height " << dilations[0] << ", width " << dilations[1]; @@ -57,7 +57,7 @@ void DepthWiseConv2D::Init(const int64_t channel_multiplier, const std::vectorset_pad(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, this->name(), true, true)); + this->set_pad(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, this->name())); this->set_out_channel( CheckAndConvertUtils::CheckInteger("channel_multiplier", channel_multiplier, kGreaterThan, 0, prim_name)); diff --git a/mindspore/core/ops/fusion/depthwise_conv2d_fusion.cc b/mindspore/core/ops/fusion/depthwise_conv2d_fusion.cc index 52b3d56ff7..cf1831472e 100644 --- a/mindspore/core/ops/fusion/depthwise_conv2d_fusion.cc +++ b/mindspore/core/ops/fusion/depthwise_conv2d_fusion.cc @@ -30,13 +30,13 @@ void DepthWiseConv2DFusion::Init(const int64_t channel_multiplier, const std::ve this->set_mode(CheckAndConvertUtils::CheckInteger("mode", mode, kEqual, 3, prim_name)); this->set_kernel_size(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, prim_name)); - auto strides = CheckAndConvertUtils::CheckPositiveVector(kStride, stride, this->name(), false, false); + auto strides = CheckAndConvertUtils::CheckPositiveVector(kStride, stride, this->name()); if (strides[0] != strides[1]) { MS_EXCEPTION(ValueError) << "The height and width of stride should be equal, but got height " << strides[0] << ", width " << strides[1]; } this->set_stride(strides); - auto dilations = CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, this->name(), false, false); + auto dilations = CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, this->name()); if (dilations[0] != dilations[1]) { MS_EXCEPTION(ValueError) << "The height and width of dilation should be equal, but got height " << dilations[0] << ", width " << dilations[1]; @@ -52,7 +52,7 @@ void DepthWiseConv2DFusion::Init(const int64_t channel_multiplier, const std::ve } else { CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, prim_name); } - this->set_pad(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, this->name(), true, true)); + this->set_pad(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, this->name())); this->set_out_channel( CheckAndConvertUtils::CheckInteger("channel_multiplier", channel_multiplier, kGreaterThan, 0, prim_name)); diff --git a/mindspore/core/ops/grad/conv2d_backprop_input.cc b/mindspore/core/ops/grad/conv2d_backprop_input.cc index 0fa3cb0269..b801164cac 100644 --- a/mindspore/core/ops/grad/conv2d_backprop_input.cc +++ b/mindspore/core/ops/grad/conv2d_backprop_input.cc @@ -105,11 +105,11 @@ void Conv2DBackpropInput::set_kernel_size(const std::vector &kernel_siz } void Conv2DBackpropInput::set_stride(const std::vector &stride) { - AddAttr(kStride, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, name(), true, true))); + AddAttr(kStride, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, name()))); } void Conv2DBackpropInput::set_dilation(const std::vector &dilation) { - AddAttr(kDilation, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, name(), true, true))); + AddAttr(kDilation, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, name()))); } void Conv2DBackpropInput::set_pad_mode(const PadMode &pad_mode) { @@ -127,7 +127,7 @@ void Conv2DBackpropInput::set_pad_mode(const PadMode &pad_mode) { void Conv2DBackpropInput::set_pad(const std::vector &pad) { CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, name()); - AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name(), true, true))); + AddAttr(kPad, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, name()))); } void Conv2DBackpropInput::set_mode(int64_t mode) { diff --git a/mindspore/core/ops/max_pool.cc b/mindspore/core/ops/max_pool.cc index b4b324014e..22e0b464da 100644 --- a/mindspore/core/ops/max_pool.cc +++ b/mindspore/core/ops/max_pool.cc @@ -36,8 +36,8 @@ PadMode MaxPool::get_pad_mode() const { return PadMode(GetValue(value_ptr)); } void MaxPool::set_kernel_size(const std::vector &kernel_size) { - this->AddAttr(kKernelSize, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name(), - false, true))); + this->AddAttr(kKernelSize, + MakeValue(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, this->name()))); } std::vector MaxPool::get_kernel_size() const { @@ -45,8 +45,7 @@ std::vector MaxPool::get_kernel_size() const { return GetValue>(value_ptr); } void MaxPool::set_strides(const std::vector &strides) { - this->AddAttr(kStrides, - MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name(), false, true))); + this->AddAttr(kStrides, MakeValue(CheckAndConvertUtils::CheckPositiveVector(kStrides, strides, this->name()))); } std::vector MaxPool::get_strides() const { diff --git a/mindspore/core/utils/check_convert_utils.cc b/mindspore/core/utils/check_convert_utils.cc index af83e22b94..9cbe3cdfae 100644 --- a/mindspore/core/utils/check_convert_utils.cc +++ b/mindspore/core/utils/check_convert_utils.cc @@ -330,24 +330,10 @@ bool CheckAndConvertUtils::IsEqualVector(const std::vector &vec_1, cons std::vector CheckAndConvertUtils::CheckPositiveVector(const std::string &arg_name, const std::vector &arg_value, - const std::string &prim_name, bool allow_four, - bool ret_four) { - auto raise_message = [allow_four, prim_name, arg_value, arg_name]() -> void { - std::ostringstream buffer; - buffer << "For " << prim_name << " attr " << arg_name << " should be a positive vector of size two "; - if (allow_four) { - buffer << "or four "; - } - buffer << " positive int64_t numbers , but got ["; - for (auto item : arg_value) { - buffer << item << ","; - } - buffer << "]"; - MS_EXCEPTION(ValueError) << buffer.str(); - }; + const std::string &prim_name) { for (auto item : arg_value) { if (item < 0) { - raise_message(); + MS_EXCEPTION(ValueError) << "For " << prim_name << " attr " << arg_name << " should be a positive vector"; } } return arg_value; diff --git a/mindspore/core/utils/check_convert_utils.h b/mindspore/core/utils/check_convert_utils.h index c47c8b3779..bb667e5436 100644 --- a/mindspore/core/utils/check_convert_utils.h +++ b/mindspore/core/utils/check_convert_utils.h @@ -162,8 +162,7 @@ const std::map> kCompareRangeT class CheckAndConvertUtils { public: static std::vector CheckPositiveVector(const std::string &arg_name, const std::vector &arg_value, - const std::string &prim_name, bool allow_four = false, - bool ret_four = false); + const std::string &prim_name); static std::string CheckString(const std::string &arg_name, const std::string &arg_value, const std::set &check_list, const std::string &prim_name);