From f2278d5607d2927fca740967a6a1e52580204103 Mon Sep 17 00:00:00 2001 From: liubuyu Date: Wed, 4 Nov 2020 16:43:53 +0800 Subject: [PATCH] update c_ops op style --- mindspore/core/c_ops/avg_pool.cc | 22 +++++++++---------- mindspore/core/c_ops/conv2d.cc | 4 ++++ mindspore/core/c_ops/conv2d.h | 1 + mindspore/core/c_ops/depthwise_conv2d.cc | 6 +++++ mindspore/core/c_ops/depthwise_conv2d.h | 1 + .../core/c_ops/{op_utils.h.cc => op_utils.cc} | 0 mindspore/core/c_ops/softmax.cc | 10 +++++++++ mindspore/core/c_ops/softmax.h | 3 ++- 8 files changed, 34 insertions(+), 13 deletions(-) rename mindspore/core/c_ops/{op_utils.h.cc => op_utils.cc} (100%) diff --git a/mindspore/core/c_ops/avg_pool.cc b/mindspore/core/c_ops/avg_pool.cc index 183b2619c4..414776b2e6 100644 --- a/mindspore/core/c_ops/avg_pool.cc +++ b/mindspore/core/c_ops/avg_pool.cc @@ -26,30 +26,28 @@ namespace mindspore { void AvgPool::set_padding(const std::string &pad) { this->AddAttr("padding", MakeValue(pad)); } -void AvgPool::set_kernel_size(const std::vector &kernel_size) { this->AddAttr("ksize", MakeValue(kernel_size)); } - -void AvgPool::set_strides(const std::vector &strides) { this->AddAttr("strides", MakeValue(strides)); } - -std::vector AvgPool::get_strides() const { - auto value_ptr = GetAttr("strides"); - return GetValue>(value_ptr); +std::string AvgPool::get_padding() const { + auto value_ptr = GetAttr("padding"); + return GetValue(value_ptr); } +void AvgPool::set_kernel_size(const std::vector &kernel_size) { this->AddAttr("k_size", MakeValue(kernel_size)); } std::vector AvgPool::get_kernel_size() const { - auto value_ptr = GetAttr("ksize"); + auto value_ptr = GetAttr("k_size"); return GetValue>(value_ptr); } +void AvgPool::set_strides(const std::vector &strides) { this->AddAttr("strides", MakeValue(strides)); } -std::string AvgPool::get_padding() const { - auto value_ptr = GetAttr("padding"); - return GetValue(value_ptr); +std::vector AvgPool::get_strides() const { + auto value_ptr = GetAttr("strides"); + return GetValue>(value_ptr); } void AvgPool::Init(const std::vector &kernel_size, const std::vector &stride, const std::string &padding) { auto prim_name = this->name(); this->AddAttr("data_format", MakeValue("NCHW")); this->set_padding(CheckAndConvertUtils::CheckString("padding", padding, {"valid", "same"}, prim_name)); - this->set_kernel_size(CheckAndConvertUtils::CheckPositiveVector("ksize", kernel_size, prim_name, false, true)); + this->set_kernel_size(CheckAndConvertUtils::CheckPositiveVector("k_size", kernel_size, prim_name, false, true)); this->set_strides(CheckAndConvertUtils::CheckPositiveVector("strides", stride, this->name(), false, true)); } diff --git a/mindspore/core/c_ops/conv2d.cc b/mindspore/core/c_ops/conv2d.cc index a25241401a..a38285f929 100644 --- a/mindspore/core/c_ops/conv2d.cc +++ b/mindspore/core/c_ops/conv2d.cc @@ -150,6 +150,10 @@ std::vector Conv2D::get_pad() const { auto value_ptr = this->GetAttr(kPad); return GetValue>(value_ptr); } +std::vector Conv2D::get_pad_list() const { + auto value_ptr = this->GetAttr(kPadList); + return GetValue>(value_ptr); +} int Conv2D::get_mode() const { auto value_ptr = this->GetAttr(kMode); return GetValue(value_ptr); diff --git a/mindspore/core/c_ops/conv2d.h b/mindspore/core/c_ops/conv2d.h index 670106a1ab..e899697e64 100644 --- a/mindspore/core/c_ops/conv2d.h +++ b/mindspore/core/c_ops/conv2d.h @@ -40,6 +40,7 @@ class Conv2D : public PrimitiveC { std::vector get_dilation() const; std::string get_pad_mode() const; std::vector get_pad() const; + std::vector get_pad_list() const; int get_mode() const; int get_group() const; int get_output_channel() const; diff --git a/mindspore/core/c_ops/depthwise_conv2d.cc b/mindspore/core/c_ops/depthwise_conv2d.cc index dd57a1cc0f..4e1e057240 100644 --- a/mindspore/core/c_ops/depthwise_conv2d.cc +++ b/mindspore/core/c_ops/depthwise_conv2d.cc @@ -82,6 +82,12 @@ std::vector DepthWiseConv2D::get_pad() const { auto value_ptr = this->GetAttr(kPad); return GetValue>(value_ptr); } + +std::vector DepthWiseConv2D::get_pads() const { + auto value_ptr = this->GetAttr(kPads); + return GetValue>(value_ptr); +} + int DepthWiseConv2D::get_mode() const { auto value_ptr = this->GetAttr(kMode); return GetValue(value_ptr); diff --git a/mindspore/core/c_ops/depthwise_conv2d.h b/mindspore/core/c_ops/depthwise_conv2d.h index f7773e8e88..9c19e81999 100644 --- a/mindspore/core/c_ops/depthwise_conv2d.h +++ b/mindspore/core/c_ops/depthwise_conv2d.h @@ -39,6 +39,7 @@ class DepthWiseConv2D : public PrimitiveC { std::vector get_dilation() const; std::string get_pad_mode() const; std::vector get_pad() const; + std::vector get_pads() const; int get_mode() const; int get_group() const; int get_output_channel() const; diff --git a/mindspore/core/c_ops/op_utils.h.cc b/mindspore/core/c_ops/op_utils.cc similarity index 100% rename from mindspore/core/c_ops/op_utils.h.cc rename to mindspore/core/c_ops/op_utils.cc diff --git a/mindspore/core/c_ops/softmax.cc b/mindspore/core/c_ops/softmax.cc index 7ee49646f9..5c6e865067 100644 --- a/mindspore/core/c_ops/softmax.cc +++ b/mindspore/core/c_ops/softmax.cc @@ -27,6 +27,11 @@ namespace mindspore { void Softmax::set_axis(const std::vector &axis) { this->set_attr(kAxis, MakeValue(axis)); } +std::vector Softmax::get_axis() const { + auto value_ptr = GetAttr(kAxis); + return GetValue>(value_ptr); +} + void Softmax::Init(int axis) { auto op_name = this->name(); std::vector axis_vec = {axis}; @@ -43,7 +48,12 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vectorcast(); MS_EXCEPTION_IF_NULL(softmax_prim); auto op_name = softmax_prim->name(); + auto axis = softmax_prim->get_axis(); auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->GetShapeTrack(), op_name); + auto rank = in_shape.size(); + for (auto &item : axis) { + CheckAndConvertUtils::CheckInRange("axis", item, kIncludeLeft, {-rank, rank}, op_name); + } return std::make_shared(in_shape); } diff --git a/mindspore/core/c_ops/softmax.h b/mindspore/core/c_ops/softmax.h index 529750bd81..c9005132d8 100644 --- a/mindspore/core/c_ops/softmax.h +++ b/mindspore/core/c_ops/softmax.h @@ -32,8 +32,9 @@ class Softmax : public PrimitiveC { Softmax() : PrimitiveC(kNameSoftmax) { InitIOName({"x"}, {"output"}); } ~Softmax() = default; MS_DECLARE_PARENT(Softmax, PrimitiveC); - void Init(int axis = 1); + void Init(int axis = -1); void set_axis(const std::vector &axis); + std::vector get_axis() const; }; AbstractBasePtr SoftmaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,