commit
55372f3c87
@ -0,0 +1,53 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/add.h"
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "c_ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto add_prim = primitive->cast<PrimTensorAddPtr>();
|
||||
MS_EXCEPTION_IF_NULL(add_prim);
|
||||
auto op_name = add_prim->name();
|
||||
return BroadCastInferShape(op_name, input_args);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) {
|
||||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
types.emplace("y", input_args[1]->BuildType());
|
||||
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
return TypeIdToType(infer_type);
|
||||
}
|
||||
|
||||
AbstractBasePtr TensorAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(TensorAdd, prim::kPrimTensorAdd, TensorAddInfer);
|
||||
} // namespace mindspore
|
@ -0,0 +1,42 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_C_OPS_ADD_H_
|
||||
#define MINDSPORE_CORE_C_OPS_ADD_H_
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "c_ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
constexpr auto kNameTensorAdd = "TensorAdd";
|
||||
class TensorAdd : public PrimitiveC {
|
||||
public:
|
||||
TensorAdd() : PrimitiveC(kNameTensorAdd) { InitIOName({"x", "y"}, {"output"}); }
|
||||
~TensorAdd() = default;
|
||||
MS_DECLARE_PARENT(TensorAdd, PrimitiveC);
|
||||
void Init() {}
|
||||
};
|
||||
|
||||
AbstractBasePtr TensorAddInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimTensorAddPtr = std::shared_ptr<TensorAdd>;
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_C_OPS_ADD_H_
|
@ -0,0 +1,104 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "c_ops/avg_pool.h"
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "c_ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
void AvgPool::set_padding(const std::string &pad) { this->AddAttr("padding", MakeValue(pad)); }
|
||||
|
||||
void AvgPool::set_kernel_size(const std::vector<int> &kernel_size) { this->AddAttr("ksize", MakeValue(kernel_size)); }
|
||||
|
||||
void AvgPool::set_strides(const std::vector<int> &strides) { this->AddAttr("strides", MakeValue(strides)); }
|
||||
|
||||
std::vector<int> AvgPool::get_strides() const {
|
||||
auto value_ptr = GetAttr("strides");
|
||||
return GetValue<std::vector<int>>(value_ptr);
|
||||
}
|
||||
|
||||
std::vector<int> AvgPool::get_kernel_size() const {
|
||||
auto value_ptr = GetAttr("ksize");
|
||||
return GetValue<std::vector<int>>(value_ptr);
|
||||
}
|
||||
|
||||
std::string AvgPool::get_padding() const {
|
||||
auto value_ptr = GetAttr("padding");
|
||||
return GetValue<std::string>(value_ptr);
|
||||
}
|
||||
|
||||
void AvgPool::Init(const std::vector<int> &kernel_size, const std::vector<int> &stride, const std::string &padding) {
|
||||
auto prim_name = this->name();
|
||||
this->AddAttr("data_format", MakeValue("NCHW"));
|
||||
this->set_padding(CheckAndConvertUtils::CheckString("padding", padding, {"valid", "same"}, prim_name));
|
||||
this->set_kernel_size(CheckAndConvertUtils::CheckPositiveVector("ksize", kernel_size, prim_name, false, true));
|
||||
this->set_strides(CheckAndConvertUtils::CheckPositiveVector("strides", stride, this->name(), false, true));
|
||||
}
|
||||
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto pool_prim = primitive->cast<PrimAvgPoolPtr>();
|
||||
MS_EXCEPTION_IF_NULL(pool_prim);
|
||||
auto op_name = pool_prim->name();
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name);
|
||||
CheckAndConvertUtils::CheckInteger("x_rank", in_shape.size(), kEqual, 4, op_name);
|
||||
auto kernel_size = pool_prim->get_kernel_size();
|
||||
auto pad_mode = pool_prim->get_padding();
|
||||
auto batch = in_shape[0];
|
||||
auto channel = in_shape[1];
|
||||
auto in_h = in_shape[2];
|
||||
auto in_w = in_shape[3];
|
||||
|
||||
auto strides = pool_prim->get_strides();
|
||||
auto kernel_h = kernel_size[2];
|
||||
auto kernel_w = kernel_size[3];
|
||||
auto stride_h = strides[2];
|
||||
auto stride_w = strides[3];
|
||||
int out_h = -1;
|
||||
int out_w = -1;
|
||||
if (pad_mode == "valid") {
|
||||
out_h = ceil((in_h - (kernel_h - 1)) / stride_h);
|
||||
out_w = ceil((in_w - (kernel_w - 1)) / stride_w);
|
||||
} else if (pad_mode == "same") {
|
||||
out_h = ceil(in_h / stride_h);
|
||||
out_w = ceil(in_w / stride_w);
|
||||
}
|
||||
std::vector<int> out_shape = {batch, channel, out_h, out_w};
|
||||
if (std::any_of(out_shape.begin(), out_shape.end(), [](int a) { return a <= 0; })) {
|
||||
MS_LOG(EXCEPTION) << "Kernel size is not valid.";
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) {
|
||||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
return input_args[0]->BuildType();
|
||||
}
|
||||
|
||||
AbstractBasePtr AvgPoolInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(AvgPool, prim::kPrimAvgPool, AvgPoolInfer);
|
||||
} // namespace mindspore
|
@ -0,0 +1,50 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_C_OPS_AVG_POOL_H_
|
||||
#define MINDSPORE_CORE_C_OPS_AVG_POOL_H_
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "c_ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
constexpr auto kNameAvgPool = "AvgPool";
|
||||
class AvgPool : public PrimitiveC {
|
||||
public:
|
||||
AvgPool() : PrimitiveC(kNameAvgPool) { InitIOName({"x"}, {"output"}); }
|
||||
~AvgPool() = default;
|
||||
MS_DECLARE_PARENT(AvgPool, PrimitiveC);
|
||||
void Init(const std::vector<int> &kernel_size = {1}, const std::vector<int> &stride = {1},
|
||||
const std::string &padding = "valid");
|
||||
void set_padding(const std::string &pad);
|
||||
void set_kernel_size(const std::vector<int> &kernel_size);
|
||||
void set_strides(const std::vector<int> &strides);
|
||||
std::vector<int> get_kernel_size() const;
|
||||
std::vector<int> get_strides() const;
|
||||
std::string get_padding() const;
|
||||
};
|
||||
|
||||
AbstractBasePtr AvgPoolInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimAvgPoolPtr = std::shared_ptr<AvgPool>;
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_C_OPS_AVG_POOL_H_
|
@ -0,0 +1,199 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/depthwise_conv2d.h"
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "c_ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
void DepthWiseConv2D::Init(int channel_multiplier, const std::vector<int> &kernel_size, int mode,
|
||||
const std::string &pad_mode, const std::vector<int> &pad, const std::vector<int> &stride,
|
||||
const std::vector<int> &dilation, int group) {
|
||||
auto prim_name = this->name();
|
||||
this->AddAttr("data_format", MakeValue("NCHW"));
|
||||
this->AddAttr("offset_a", MakeValue(0));
|
||||
this->set_mode(CheckAndConvertUtils::CheckInteger("mode", mode, kEqual, 3, prim_name));
|
||||
|
||||
this->set_kernel_size(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, prim_name));
|
||||
auto strides = CheckAndConvertUtils::CheckPositiveVector(kStride, stride, this->name(), false, false);
|
||||
if (strides[0] != strides[1]) {
|
||||
MS_EXCEPTION(ValueError) << "The height and width of stride should be equal, but got height " << strides[0]
|
||||
<< ", width " << strides[1];
|
||||
}
|
||||
this->set_stride(strides);
|
||||
auto dilations = CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, this->name(), false, false);
|
||||
if (dilations[0] != dilations[1]) {
|
||||
MS_EXCEPTION(ValueError) << "The height and width of dilation should be equal, but got height " << dilations[0]
|
||||
<< ", width " << dilations[1];
|
||||
}
|
||||
this->set_dilation(dilations);
|
||||
this->set_pad_mode(CheckAndConvertUtils::CheckString(kPadMode, pad_mode, {"valid", "same", "pad"}, prim_name));
|
||||
|
||||
CheckAndConvertUtils::CheckInteger("pad_size", pad.size(), kEqual, 4, prim_name);
|
||||
if (pad_mode == "pad") {
|
||||
for (auto item : pad) {
|
||||
CheckAndConvertUtils::Check("pad_item", item, kGreaterEqual, "zeros_list", 0, prim_name);
|
||||
}
|
||||
} else {
|
||||
CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros_list", {0, 0, 0, 0}, prim_name);
|
||||
}
|
||||
this->set_pad(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, this->name(), true, true));
|
||||
|
||||
this->set_out_channel(
|
||||
CheckAndConvertUtils::CheckInteger("channel_multiplier", channel_multiplier, kGreaterThan, 0, prim_name));
|
||||
this->set_group(CheckAndConvertUtils::CheckInteger("group", group, kGreaterThan, 0, prim_name));
|
||||
}
|
||||
|
||||
std::vector<int> DepthWiseConv2D::get_kernel_size() const {
|
||||
auto value_ptr = GetAttr(kKernelSize);
|
||||
return GetValue<std::vector<int>>(value_ptr);
|
||||
}
|
||||
std::vector<int> DepthWiseConv2D::get_stride() const {
|
||||
auto value_ptr = GetAttr(kStride);
|
||||
return GetValue<std::vector<int>>(value_ptr);
|
||||
}
|
||||
std::vector<int> DepthWiseConv2D::get_dilation() const {
|
||||
auto value_ptr = GetAttr(kDilation);
|
||||
return GetValue<std::vector<int>>(value_ptr);
|
||||
}
|
||||
std::string DepthWiseConv2D::get_pad_mode() const {
|
||||
auto value_ptr = this->GetAttr(kPadMode);
|
||||
return GetValue<string>(value_ptr);
|
||||
}
|
||||
std::vector<int> DepthWiseConv2D::get_pad() const {
|
||||
auto value_ptr = this->GetAttr(kPad);
|
||||
return GetValue<std::vector<int>>(value_ptr);
|
||||
}
|
||||
int DepthWiseConv2D::get_mode() const {
|
||||
auto value_ptr = this->GetAttr(kMode);
|
||||
return GetValue<int>(value_ptr);
|
||||
}
|
||||
|
||||
int DepthWiseConv2D::get_group() const {
|
||||
auto value_ptr = this->GetAttr(kGroup);
|
||||
return GetValue<int>(value_ptr);
|
||||
}
|
||||
int DepthWiseConv2D::get_output_channel() const {
|
||||
auto value_ptr = this->GetAttr(kOutputChannel);
|
||||
return GetValue<int>(value_ptr);
|
||||
}
|
||||
|
||||
void DepthWiseConv2D::set_kernel_size(const std::vector<int> &kernel_size) {
|
||||
this->AddAttr(kKernelSize, MakeValue(kernel_size));
|
||||
}
|
||||
|
||||
void DepthWiseConv2D::set_stride(const std::vector<int> &stride) { this->AddAttr(kStride, MakeValue(stride)); }
|
||||
void DepthWiseConv2D::set_dilation(const std::vector<int> &dilation) { this->AddAttr(kDilation, MakeValue(dilation)); }
|
||||
void DepthWiseConv2D::set_pad_mode(const std::string &pad_mode) { this->AddAttr(kPadMode, MakeValue(pad_mode)); }
|
||||
void DepthWiseConv2D::set_pad(const std::vector<int> &pad) { this->AddAttr(kPad, MakeValue(pad)); }
|
||||
void DepthWiseConv2D::set_mode(int mode) { this->AddAttr(kMode, MakeValue(mode)); }
|
||||
void DepthWiseConv2D::set_group(int group) { this->AddAttr(kGroup, MakeValue(group)); }
|
||||
void DepthWiseConv2D::set_out_channel(int output_channel) { this->AddAttr(kOutputChannel, MakeValue(output_channel)); }
|
||||
void DepthWiseConv2D::set_pads(const std::vector<int> &pad_list) { this->AddAttr(kPads, MakeValue(pad_list)); }
|
||||
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto conv_prim = primitive->cast<PrimDepthWiseConv2DPtr>();
|
||||
MS_EXCEPTION_IF_NULL(conv_prim);
|
||||
auto prim_name = conv_prim->name();
|
||||
CheckAndConvertUtils::CheckInRange("conv2d_Infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name);
|
||||
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->GetShapeTrack(), prim_name);
|
||||
|
||||
CheckAndConvertUtils::CheckInteger("weight_rank", w_shape.size(), kEqual, 4, prim_name);
|
||||
CheckAndConvertUtils::CheckInteger("x_rank", x_shape.size(), kEqual, 4, prim_name);
|
||||
CheckAndConvertUtils::Check("x_shape[1]", x_shape[1], kEqual, "w_shape[1]", w_shape[1], conv_prim->name());
|
||||
auto out_channel = conv_prim->get_output_channel();
|
||||
|
||||
std::vector<int> temp_w;
|
||||
std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w));
|
||||
CheckAndConvertUtils::Check("kernel_size", conv_prim->get_kernel_size(), kEqual, "w_shape[2:4]", temp_w,
|
||||
conv_prim->name());
|
||||
|
||||
auto kernel_size_n = w_shape[0];
|
||||
if (kernel_size_n != 1) {
|
||||
MS_EXCEPTION(ValueError) << "The batch of input weeight should be 1, but got " << kernel_size_n;
|
||||
}
|
||||
auto kernel_size_h = w_shape[2];
|
||||
auto kernel_size_w = w_shape[3];
|
||||
auto stride = conv_prim->get_stride();
|
||||
auto dilation = conv_prim->get_dilation();
|
||||
auto stride_h = stride[2];
|
||||
auto stride_w = stride[3];
|
||||
auto dilation_h = dilation[2];
|
||||
auto dilation_w = dilation[3];
|
||||
int h_out = -1;
|
||||
int w_out = -1;
|
||||
std::vector<int> pad_list(4, 0);
|
||||
auto pad_mode = conv_prim->get_pad_mode();
|
||||
if (pad_mode == "valid") {
|
||||
h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h);
|
||||
w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w);
|
||||
} else if (pad_mode == "same") {
|
||||
h_out = ceil(x_shape[2] / stride_h);
|
||||
w_out = ceil(x_shape[3] / stride_w);
|
||||
|
||||
auto pad_needed_h = std::max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2]);
|
||||
pad_list.emplace_back(floor(pad_needed_h / 2));
|
||||
pad_list.emplace_back(pad_needed_h / 2);
|
||||
auto pad_needed_w = std::max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3]);
|
||||
auto pad_left = floor(pad_needed_w / 2);
|
||||
pad_list.emplace_back(pad_left);
|
||||
pad_list.emplace_back(pad_needed_h - pad_left);
|
||||
} else if (pad_mode == "pad") {
|
||||
std::copy(conv_prim->get_pad().begin(), conv_prim->get_pad().end(), std::back_inserter(pad_list));
|
||||
auto pad_top = conv_prim->get_pad()[0];
|
||||
auto pad_bottom = conv_prim->get_pad()[1];
|
||||
auto pad_right = conv_prim->get_pad()[2];
|
||||
auto pad_left = conv_prim->get_pad()[3];
|
||||
|
||||
h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h;
|
||||
w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w;
|
||||
h_out = floor(h_out);
|
||||
w_out = floor(w_out);
|
||||
}
|
||||
conv_prim->set_pads(pad_list);
|
||||
std::vector<int> out_shape = {x_shape[0], out_channel * x_shape[1], h_out, w_out};
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
CheckAndConvertUtils::CheckInRange("", input_args.size(), kIncludeBoth, {2, 3}, prim->name());
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
types.emplace("w", input_args[1]->BuildType());
|
||||
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, common_valid_types, prim->name());
|
||||
if (infer_type == kNumberTypeInt8) {
|
||||
return std::make_shared<TensorType>(TypeIdToType(kNumberTypeInt32));
|
||||
}
|
||||
return TypeIdToType(infer_type);
|
||||
}
|
||||
|
||||
AbstractBasePtr DepthWiseConv2DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
}
|
||||
} // namespace mindspore
|
@ -0,0 +1,60 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_C_OPS_DEPTHWISE_CONV2D_H
|
||||
#define MINDSPORE_CORE_C_OPS_DEPTHWISE_CONV2D_H
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "c_ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
constexpr auto kNameDepthWiseConv2D = "DepthwiseConv2dNative";
|
||||
class DepthWiseConv2D : public PrimitiveC {
|
||||
public:
|
||||
DepthWiseConv2D() : PrimitiveC(kNameDepthWiseConv2D) { InitIOName({"x", "w"}, {"output"}); }
|
||||
~DepthWiseConv2D() = default;
|
||||
MS_DECLARE_PARENT(DepthWiseConv2D, PrimitiveC);
|
||||
void Init(int out_channel, const std::vector<int> &kernel_size, int mode = 1, const std::string &pad_mode = "valid",
|
||||
const std::vector<int> &pad = {0, 0, 0, 0}, const std::vector<int> &stride = {1, 1, 1, 1},
|
||||
const std::vector<int> &dilation = {1, 1, 1, 1}, int group = 1);
|
||||
std::vector<int> get_kernel_size() const;
|
||||
std::vector<int> get_stride() const;
|
||||
std::vector<int> get_dilation() const;
|
||||
std::string get_pad_mode() const;
|
||||
std::vector<int> get_pad() const;
|
||||
int get_mode() const;
|
||||
int get_group() const;
|
||||
int get_output_channel() const;
|
||||
void set_kernel_size(const std::vector<int> &kernel_size);
|
||||
void set_stride(const std::vector<int> &stride);
|
||||
void set_dilation(const std::vector<int> &dilation);
|
||||
void set_pad_mode(const std::string &pad_mode);
|
||||
void set_pad(const std::vector<int> &pad);
|
||||
void set_mode(int mode);
|
||||
void set_group(int group);
|
||||
void set_out_channel(int output_channel);
|
||||
void set_pads(const std::vector<int> &pad_list);
|
||||
};
|
||||
AbstractBasePtr DepthWiseConv2DInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimDepthWiseConv2DPtr = std::shared_ptr<DepthWiseConv2D>;
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_C_OPS_DEPTHWISE_CONV2D_H
|
@ -0,0 +1,46 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_C_OPS_CONV_UTILS_H
|
||||
#define MINDSPORE_CORE_C_OPS_CONV_UTILS_H
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
constexpr auto kKernelSize = "kernel_size";
|
||||
constexpr auto kStride = "stride";
|
||||
constexpr auto kDilation = "dilation";
|
||||
constexpr auto kPadMode = "pad_mode";
|
||||
constexpr auto kPad = "pad";
|
||||
constexpr auto kPads = "pads";
|
||||
constexpr auto kMode = "mode";
|
||||
constexpr auto kGroup = "group";
|
||||
constexpr auto kOutputChannel = "output_channel";
|
||||
constexpr auto kPadList = "pad_list";
|
||||
constexpr auto kAxis = "axis";
|
||||
|
||||
const std::set<TypeId> common_valid_types = {
|
||||
kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt64, kNumberTypeUInt8, kNumberTypeUInt16,
|
||||
kNumberTypeUInt32, kNumberTypeUInt64, kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64};
|
||||
|
||||
abstract::ShapePtr BroadCastInferShape(const std::string &op_name, const std::vector<AbstractBasePtr> &input_args);
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_C_OPS_CONV_UTILS_H
|
@ -0,0 +1,57 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include "c_ops/op_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
abstract::ShapePtr BroadCastInferShape(const std::string &op_name, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_LOG(INFO) << "Do infer shape for op " << op_name;
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), op_name);
|
||||
auto y_shape = CheckAndConvertUtils::ConvertShapePtrToShape("y_shape", input_args[1]->GetShapeTrack(), op_name);
|
||||
if (x_shape == y_shape) {
|
||||
return std::make_shared<abstract::Shape>(x_shape);
|
||||
}
|
||||
|
||||
auto x_length = x_shape.size();
|
||||
auto y_length = y_shape.size();
|
||||
auto length = x_length < y_length ? x_length : y_length;
|
||||
std::vector<int> broadcast_shape;
|
||||
if (x_length == length) {
|
||||
std::copy(y_shape.begin(), y_shape.end() - length, std::back_inserter(broadcast_shape));
|
||||
} else {
|
||||
std::copy(x_shape.begin(), x_shape.end() - length, std::back_inserter(broadcast_shape));
|
||||
}
|
||||
for (int i = -length; i < 0; i++) {
|
||||
if (x_shape[x_length + i] == 1) {
|
||||
broadcast_shape.push_back(y_shape[y_length + i]);
|
||||
} else if (y_shape[y_length + i] == 1) {
|
||||
broadcast_shape.push_back(x_shape[x_length + i]);
|
||||
} else if (x_shape[x_length + i] == y_shape[y_length + i]) {
|
||||
broadcast_shape.push_back(x_shape[x_length + i]);
|
||||
} else {
|
||||
MS_EXCEPTION(ValueError) << "For op " << op_name << ", the two input can not broadcast";
|
||||
}
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(broadcast_shape);
|
||||
}
|
||||
} // namespace mindspore
|
@ -0,0 +1,52 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/relu6.h"
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
abstract::ShapePtr Relu6InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto x = input_args[0]->GetShapeTrack();
|
||||
auto shape_element = x->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape_element);
|
||||
return shape_element;
|
||||
}
|
||||
|
||||
TypePtr Relu6InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32};
|
||||
if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) {
|
||||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
return TypeIdToType(infer_type);
|
||||
}
|
||||
|
||||
AbstractBasePtr Relu6Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(Relu6InferType(primitive, input_args),
|
||||
Relu6InferShape(primitive, input_args)->shape());
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Relu6, prim::kPrimRelu6, Relu6Infer);
|
||||
} // namespace mindspore
|
@ -0,0 +1,40 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CORE_C_OPS_RELU6_H_
|
||||
#define MINDSPORE_CORE_C_OPS_RELU6_H_
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "c_ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
constexpr auto kNameRelu6 = "Relu6";
|
||||
class Relu6 : public PrimitiveC {
|
||||
public:
|
||||
Relu6() : PrimitiveC(kNameRelu6) { InitIOName({"x"}, {"output"}); }
|
||||
~Relu6() = default;
|
||||
MS_DECLARE_PARENT(Relu6, PrimitiveC);
|
||||
void Init() {}
|
||||
};
|
||||
|
||||
AbstractBasePtr Relu6Infer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimRelu6Ptr = std::shared_ptr<Relu6>;
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_C_OPS_RELU6_H_
|
@ -0,0 +1,44 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/reshape.h"
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
// to do
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
// to do
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
AbstractBasePtr ReshapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Reshape, prim::kPrimReshape, ReshapeInfer);
|
||||
} // namespace mindspore
|
@ -0,0 +1,42 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CORE_C_OPS_RESHAPE_H_
|
||||
#define MINDSPORE_CORE_C_OPS_RESHAPE_H_
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "c_ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
constexpr auto kNameReshape = "Reshape";
|
||||
class Reshape : public PrimitiveC {
|
||||
public:
|
||||
Reshape() : PrimitiveC(kNameReshape) { InitIOName({"tensor", "shape"}, {"output"}); }
|
||||
~Reshape() = default;
|
||||
MS_DECLARE_PARENT(Reshape, PrimitiveC);
|
||||
void Init() {}
|
||||
};
|
||||
|
||||
AbstractBasePtr ReshapeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimTensorAddPtr = std::shared_ptr<Reshape>;
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_C_OPS_RESHAPE_H_
|
@ -0,0 +1,68 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/softmax.h"
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "c_ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
void Softmax::set_axis(const std::vector<int> &axis) { this->set_attr(kAxis, MakeValue(axis)); }
|
||||
|
||||
void Softmax::Init(int axis) {
|
||||
auto op_name = this->name();
|
||||
std::vector<int> axis_vec = {axis};
|
||||
CheckAndConvertUtils::CheckInteger("axis_len", axis_vec.size(), kEqual, 1, op_name);
|
||||
auto rank = axis_vec.size();
|
||||
for (auto &item : axis_vec) {
|
||||
CheckAndConvertUtils::CheckInRange("axis", item, kIncludeLeft, {-rank, rank}, op_name);
|
||||
}
|
||||
this->set_axis(axis_vec);
|
||||
}
|
||||
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto softmax_prim = primitive->cast<PrimSoftmaxPtr>();
|
||||
MS_EXCEPTION_IF_NULL(softmax_prim);
|
||||
auto op_name = softmax_prim->name();
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->GetShapeTrack(), op_name);
|
||||
return std::make_shared<abstract::Shape>(in_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
if (std::any_of(input_args.begin(), input_args.end(), [](const AbstractBasePtr &a) { return a == nullptr; })) {
|
||||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
std::map<std::string, TypePtr> types;
|
||||
const std::set<TypeId> valid_types = {kNumberTypeFloat16, kNumberTypeFloat32, kNumberTypeFloat64};
|
||||
types.emplace("x", input_args[0]->BuildType());
|
||||
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
return TypeIdToType(infer_type);
|
||||
}
|
||||
|
||||
AbstractBasePtr SoftmaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Softmax, prim::kPrimSoftmax, SoftmaxInfer);
|
||||
} // namespace mindspore
|
@ -0,0 +1,44 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_C_OPS_SOFTMAX_H_
|
||||
#define MINDSPORE_CORE_C_OPS_SOFTMAX_H_
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "c_ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
constexpr auto kNameSoftmax = "Softmax";
|
||||
class Softmax : public PrimitiveC {
|
||||
public:
|
||||
Softmax() : PrimitiveC(kNameSoftmax) { InitIOName({"x"}, {"output"}); }
|
||||
~Softmax() = default;
|
||||
MS_DECLARE_PARENT(Softmax, PrimitiveC);
|
||||
void Init(int axis = 1);
|
||||
void set_axis(const std::vector<int> &axis);
|
||||
};
|
||||
|
||||
AbstractBasePtr SoftmaxInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimSoftmaxPtr = std::shared_ptr<Softmax>;
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_C_OPS_SOFTMAX_H_
|
@ -0,0 +1,79 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "c_ops/squeeze.h"
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "c_ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
void Squeeze::set_axis(const std::vector<int> &axis) { this->set_attr(kAxis, MakeValue(axis)); }
|
||||
void Squeeze::Init(const std::vector<int> &axis) { this->set_axis(axis); }
|
||||
std::vector<int> Squeeze::get_axis() const {
|
||||
auto value_ptr = this->GetAttr(kAxis);
|
||||
return GetValue<std::vector<int>>(value_ptr);
|
||||
}
|
||||
|
||||
abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto squeeze_prim = primitive->cast<PrimSqueezePtr>();
|
||||
MS_EXCEPTION_IF_NULL(squeeze_prim);
|
||||
auto op_name = squeeze_prim->name();
|
||||
auto axis = squeeze_prim->get_axis();
|
||||
std::vector<int> infer_shape;
|
||||
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShape("input_shape", input_args[0]->GetShapeTrack(), op_name);
|
||||
auto len = in_shape.size();
|
||||
if (axis.empty()) {
|
||||
std::copy_if(in_shape.begin(), in_shape.end(), std::back_inserter(infer_shape),
|
||||
[](int value) { return value != 1; });
|
||||
} else {
|
||||
for (auto &item : axis) {
|
||||
CheckAndConvertUtils::CheckInRange("axis_or_elememt", item, kIncludeBoth, {-len, len + 1}, op_name);
|
||||
auto idx = item >= 0 ? item : len + item;
|
||||
if (in_shape[idx] != 1) {
|
||||
MS_EXCEPTION(ValueError) << "Cannot select an axis to squeeze out which has size not equal to one.";
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < len; i++) {
|
||||
auto it = std::find(axis.begin(), axis.end(), i);
|
||||
auto it2 = std::find(axis.begin(), axis.end(), i - len);
|
||||
if (!(it != axis.end() || it2 != axis.end())) {
|
||||
infer_shape.push_back(in_shape[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(infer_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
if (std::any_of(input_args.begin(), input_args.end(), [](AbstractBasePtr a) { return a == nullptr; })) {
|
||||
MS_LOG(EXCEPTION) << "nullptr";
|
||||
}
|
||||
return input_args[0]->BuildType();
|
||||
}
|
||||
|
||||
AbstractBasePtr SqueezeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return std::make_shared<abstract::AbstractTensor>(InferType(primitive, input_args),
|
||||
InferShape(primitive, input_args)->shape());
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Squeeze, prim::kPrimSqueeze, SqueezeInfer);
|
||||
} // namespace mindspore
|
@ -0,0 +1,46 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_C_OPS_SQUEEZE_H_
|
||||
#define MINDSPORE_CORE_C_OPS_SQUEEZE_H_
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "c_ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
constexpr auto kNameSqueeze = "Squeeze";
|
||||
class Squeeze : public PrimitiveC {
|
||||
public:
|
||||
Squeeze() : PrimitiveC(kNameSqueeze) { InitIOName({"x"}, {"output"}); }
|
||||
~Squeeze() = default;
|
||||
MS_DECLARE_PARENT(Squeeze, PrimitiveC);
|
||||
void Init(const std::vector<int> &axis = {});
|
||||
void set_axis(const std::vector<int> &axis);
|
||||
std::vector<int> get_axis() const;
|
||||
};
|
||||
|
||||
AbstractBasePtr SqueezeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimSqueezePtr = std::shared_ptr<Squeeze>;
|
||||
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_C_OPS_SQUEEZE_H_
|
Loading…
Reference in new issue