update c_ops op style

pull/8219/head
liubuyu 4 years ago
parent 28297a549b
commit f2278d5607

@ -26,30 +26,28 @@
namespace mindspore {
void AvgPool::set_padding(const std::string &pad) { this->AddAttr("padding", MakeValue(pad)); }
void AvgPool::set_kernel_size(const std::vector<int> &kernel_size) { this->AddAttr("ksize", MakeValue(kernel_size)); }
void AvgPool::set_strides(const std::vector<int> &strides) { this->AddAttr("strides", MakeValue(strides)); }
std::vector<int> AvgPool::get_strides() const {
auto value_ptr = GetAttr("strides");
return GetValue<std::vector<int>>(value_ptr);
std::string AvgPool::get_padding() const {
auto value_ptr = GetAttr("padding");
return GetValue<std::string>(value_ptr);
}
void AvgPool::set_kernel_size(const std::vector<int> &kernel_size) { this->AddAttr("k_size", MakeValue(kernel_size)); }
std::vector<int> AvgPool::get_kernel_size() const {
auto value_ptr = GetAttr("ksize");
auto value_ptr = GetAttr("k_size");
return GetValue<std::vector<int>>(value_ptr);
}
void AvgPool::set_strides(const std::vector<int> &strides) { this->AddAttr("strides", MakeValue(strides)); }
std::string AvgPool::get_padding() const {
auto value_ptr = GetAttr("padding");
return GetValue<std::string>(value_ptr);
std::vector<int> AvgPool::get_strides() const {
auto value_ptr = GetAttr("strides");
return GetValue<std::vector<int>>(value_ptr);
}
void AvgPool::Init(const std::vector<int> &kernel_size, const std::vector<int> &stride, const std::string &padding) {
auto prim_name = this->name();
this->AddAttr("data_format", MakeValue("NCHW"));
this->set_padding(CheckAndConvertUtils::CheckString("padding", padding, {"valid", "same"}, prim_name));
this->set_kernel_size(CheckAndConvertUtils::CheckPositiveVector("ksize", kernel_size, prim_name, false, true));
this->set_kernel_size(CheckAndConvertUtils::CheckPositiveVector("k_size", kernel_size, prim_name, false, true));
this->set_strides(CheckAndConvertUtils::CheckPositiveVector("strides", stride, this->name(), false, true));
}

@ -150,6 +150,10 @@ std::vector<int> Conv2D::get_pad() const {
auto value_ptr = this->GetAttr(kPad);
return GetValue<std::vector<int>>(value_ptr);
}
std::vector<int> Conv2D::get_pad_list() const {
auto value_ptr = this->GetAttr(kPadList);
return GetValue<std::vector<int>>(value_ptr);
}
int Conv2D::get_mode() const {
auto value_ptr = this->GetAttr(kMode);
return GetValue<int>(value_ptr);

@ -40,6 +40,7 @@ class Conv2D : public PrimitiveC {
std::vector<int> get_dilation() const;
std::string get_pad_mode() const;
std::vector<int> get_pad() const;
std::vector<int> get_pad_list() const;
int get_mode() const;
int get_group() const;
int get_output_channel() const;

@ -82,6 +82,12 @@ std::vector<int> DepthWiseConv2D::get_pad() const {
auto value_ptr = this->GetAttr(kPad);
return GetValue<std::vector<int>>(value_ptr);
}
std::vector<int> DepthWiseConv2D::get_pads() const {
auto value_ptr = this->GetAttr(kPads);
return GetValue<std::vector<int>>(value_ptr);
}
int DepthWiseConv2D::get_mode() const {
auto value_ptr = this->GetAttr(kMode);
return GetValue<int>(value_ptr);

@ -39,6 +39,7 @@ class DepthWiseConv2D : public PrimitiveC {
std::vector<int> get_dilation() const;
std::string get_pad_mode() const;
std::vector<int> get_pad() const;
std::vector<int> get_pads() const;
int get_mode() const;
int get_group() const;
int get_output_channel() const;

@ -27,6 +27,11 @@
namespace mindspore {
void Softmax::set_axis(const std::vector<int> &axis) { this->set_attr(kAxis, MakeValue(axis)); }
std::vector<int> Softmax::get_axis() const {
auto value_ptr = GetAttr(kAxis);
return GetValue<std::vector<int>>(value_ptr);
}
void Softmax::Init(int axis) {
auto op_name = this->name();
std::vector<int> axis_vec = {axis};
@ -43,7 +48,12 @@ abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<A
auto softmax_prim = primitive->cast<PrimSoftmaxPtr>();
MS_EXCEPTION_IF_NULL(softmax_prim);
auto op_name = softmax_prim->name();
auto axis = softmax_prim->get_axis();
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->GetShapeTrack(), op_name);
auto rank = in_shape.size();
for (auto &item : axis) {
CheckAndConvertUtils::CheckInRange("axis", item, kIncludeLeft, {-rank, rank}, op_name);
}
return std::make_shared<abstract::Shape>(in_shape);
}

@ -32,8 +32,9 @@ class Softmax : public PrimitiveC {
Softmax() : PrimitiveC(kNameSoftmax) { InitIOName({"x"}, {"output"}); }
~Softmax() = default;
MS_DECLARE_PARENT(Softmax, PrimitiveC);
void Init(int axis = 1);
void Init(int axis = -1);
void set_axis(const std::vector<int> &axis);
std::vector<int> get_axis() const;
};
AbstractBasePtr SoftmaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,

Loading…
Cancel
Save