!4293 add infer function of primitive c

Merge pull request !4293 from lianliguang/add-infer-to-primitive-c
pull/4293/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit ab84b2f18a

@ -18,8 +18,6 @@
#include "abstract/utils.h"
#include "abstract/param_validator.h"
#include "utils/check_convert_utils.h"
#include "c_ops/conv2d.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace abstract {
@ -428,91 +426,5 @@ AbstractBasePtr InferImplDropoutGenMask(const AnalysisEnginePtr &, const Primiti
return std::make_shared<AbstractTensor>(std::make_shared<AbstractScalar>(kAnyValue, kUInt8),
std::make_shared<Shape>(std::vector<int64_t>{shape_y}));
}
abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto conv_prim = primitive->cast<PrimConv2dPtr>();
MS_EXCEPTION_IF_NULL(conv_prim);
auto prim_name = conv_prim->name();
CheckAndConvertUtils::CheckInRange("Conv2d Infer", input_args.size(), kIncludeLeft, {2, 3}, prim_name);
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[0]->GetShapeTrack(), prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[1]->GetShapeTrack(), prim_name);
CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name);
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);
CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / conv_prim->GetGroup(), kEqual, "w_shape[1]",
w_shape[1], conv_prim->name());
auto out_channel = conv_prim->GetOutputChannel();
CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], conv_prim->name());
std::vector<int> temp_w;
std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w));
CheckAndConvertUtils::Check("kernel_size", conv_prim->GetKernelSize(), kEqual, "w_shape[2:4]", temp_w,
conv_prim->name());
auto kernel_size_h = w_shape[2];
auto kernel_size_w = w_shape[3];
auto stride = conv_prim->GetStride();
auto dilation = conv_prim->GetDilation();
auto stride_h = stride[2];
auto stride_w = stride[3];
auto dilation_h = dilation[2];
auto dilation_w = dilation[3];
int h_out = -1;
int w_out = -1;
std::vector<int> pad_list(4, 0);
auto pad_mode = conv_prim->GetPadMode();
if (pad_mode == "valid") {
h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h);
w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w);
} else if (pad_mode == "same") {
h_out = ceil(x_shape[2] / stride_h);
w_out = ceil(x_shape[3] / stride_w);
auto pad_needed_h = std::max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2]);
pad_list.emplace_back(floor(pad_needed_h / 2));
pad_list.emplace_back(pad_needed_h / 2);
auto pad_needed_w = std::max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3]);
auto pad_left = floor(pad_needed_w / 2);
pad_list.emplace_back(pad_left);
pad_list.emplace_back(pad_needed_h - pad_left);
} else if (pad_mode == "pad") {
std::copy(conv_prim->GetPad().begin(), conv_prim->GetPad().end(), std::back_inserter(pad_list));
auto pad_top = conv_prim->GetPad()[0];
auto pad_bottom = conv_prim->GetPad()[1];
auto pad_right = conv_prim->GetPad()[2];
auto pad_left = conv_prim->GetPad()[3];
h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h;
w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w;
h_out = floor(h_out);
w_out = floor(w_out);
}
conv_prim->SetPadList(pad_list);
std::vector<int> out_shape = {x_shape[0], out_channel, h_out, w_out};
return std::make_shared<abstract::Shape>(out_shape);
}
TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
CheckAndConvertUtils::CheckInRange("", input_args.size(), kIncludeLeft, {2, 3}, prim->name());
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto x_type = CheckAndConvertUtils::ConvertTypePtrToTypeId("x_dtype", input_args[0]->GetTypeTrack(), prim->name());
const std::set<TypeId> valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, kNumberTypeFloat32};
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->GetTypeTrack());
types.emplace("w", input_args[1]->GetTypeTrack());
CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
if (x_type == kNumberTypeInt8) {
return std::make_shared<TensorType>(TypeIdToType(kNumberTypeInt32));
}
return std::make_shared<TensorType>(TypeIdToType(x_type));
}
AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(Conv2dInferType(primitive, input_args),
Conv2dInferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2D, prim::kPrimConv2D, Conv2dInfer);
} // namespace abstract
} // namespace mindspore

@ -47,7 +47,7 @@ class RegisterStandardPrimitiveEvalHelper {
};
#define REGISTER_PRIMITIVE_EVAL_IMPL(name, primitive, impl) \
static auto helper_##name = RegisterStandardPrimitiveEvalHelper(primitive, impl)
static auto helper_##name = abstract::RegisterStandardPrimitiveEvalHelper(primitive, impl)
} // namespace abstract
} // namespace mindspore
#endif // MINDSPORE_CORE_ABSTRACT_PRIMITIVE_INFER_MAP_H_

@ -23,6 +23,7 @@
#include <set>
#include <vector>
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace {
@ -36,6 +37,84 @@ constexpr auto kGroup = "group";
constexpr auto kOutputChannel = "output channel";
constexpr auto kPadList = "pad_list";
constexpr auto kConv2DName = "Conv2D";
abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto conv_prim = std::dynamic_pointer_cast<Conv2d>(primitive);
MS_EXCEPTION_IF_NULL(conv_prim);
auto prim_name = conv_prim->name();
CheckAndConvertUtils::CheckInRange("Conv2d Infer", input_args.size(), kIncludeLeft, {2, 3}, prim_name);
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name);
auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->GetShapeTrack(), prim_name);
CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name);
CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name);
CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / conv_prim->GetGroup(), kEqual, "w_shape[1]",
w_shape[1], conv_prim->name());
auto out_channel = conv_prim->GetOutputChannel();
CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], conv_prim->name());
std::vector<int> temp_w;
std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w));
CheckAndConvertUtils::Check("kernel_size", conv_prim->GetKernelSize(), kEqual, "w_shape[2:4]", temp_w,
conv_prim->name());
auto kernel_size_h = w_shape[2];
auto kernel_size_w = w_shape[3];
auto stride = conv_prim->GetStride();
auto dilation = conv_prim->GetDilation();
auto stride_h = stride[2];
auto stride_w = stride[3];
auto dilation_h = dilation[2];
auto dilation_w = dilation[3];
int h_out = -1;
int w_out = -1;
std::vector<int> pad_list(4, 0);
auto pad_mode = conv_prim->GetPadMode();
if (pad_mode == "valid") {
h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h);
w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w);
} else if (pad_mode == "same") {
h_out = ceil(x_shape[2] / stride_h);
w_out = ceil(x_shape[3] / stride_w);
auto pad_needed_h = std::max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2]);
pad_list.emplace_back(floor(pad_needed_h / 2));
pad_list.emplace_back(pad_needed_h / 2);
auto pad_needed_w = std::max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3]);
auto pad_left = floor(pad_needed_w / 2);
pad_list.emplace_back(pad_left);
pad_list.emplace_back(pad_needed_h - pad_left);
} else if (pad_mode == "pad") {
std::copy(conv_prim->GetPad().begin(), conv_prim->GetPad().end(), std::back_inserter(pad_list));
auto pad_top = conv_prim->GetPad()[0];
auto pad_bottom = conv_prim->GetPad()[1];
auto pad_right = conv_prim->GetPad()[2];
auto pad_left = conv_prim->GetPad()[3];
h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h;
w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w;
h_out = floor(h_out);
w_out = floor(w_out);
}
conv_prim->SetPadList(pad_list);
std::vector<int> out_shape = {x_shape[0], out_channel, h_out, w_out};
return std::make_shared<abstract::Shape>(out_shape);
}
TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
CheckAndConvertUtils::CheckInRange("", input_args.size(), kIncludeLeft, {2, 3}, prim->name());
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypeId> valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, kNumberTypeFloat32};
std::map<std::string, TypePtr> types;
types.emplace("x", input_args[0]->BuildType());
types.emplace("w", input_args[1]->BuildType());
auto infer_type = CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
if (infer_type == kNumberTypeInt8) {
return std::make_shared<TensorType>(TypeIdToType(kNumberTypeInt32));
}
return TypeIdToType(infer_type);
}
} // namespace
Conv2d::Conv2d() : PrimitiveC(kConv2DName) { InitIOName({"x", "w"}, {"output"}); }
@ -105,4 +184,11 @@ void Conv2d::SetMode(int mode) { this->AddAttr(kMode, MakeValue(mode)); }
void Conv2d::SetGroup(int group) { this->AddAttr(kGroup, MakeValue(group)); }
void Conv2d::SetOutChannel(int output_channel) { this->AddAttr(kOutputChannel, MakeValue(output_channel)); }
void Conv2d::SetPadList(const std::vector<int> &pad_list) { this->AddAttr(kPadList, MakeValue(pad_list)); }
AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return std::make_shared<abstract::AbstractTensor>(Conv2dInferType(primitive, input_args),
Conv2dInferShape(primitive, input_args)->shape());
}
REGISTER_PRIMITIVE_EVAL_IMPL(Conv2D, prim::kPrimConv2D, Conv2dInfer);
} // namespace mindspore

@ -55,5 +55,4 @@ AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const Primitive
const std::vector<AbstractBasePtr> &input_args);
using PrimConv2dPtr = std::shared_ptr<Conv2d>;
} // namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_CONV2D_H_

@ -0,0 +1,36 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "c_ops/primitive_c.h"
#include <memory>
#include <string>
namespace mindspore {
void PrimitiveC::InitIOName(const std::vector<std::string> &inputs_name, const std::vector<std::string> &outputs_name) {
this->AddAttr("input_names", MakeValue(inputs_name));
this->AddAttr("output_names", MakeValue(outputs_name));
}
AbstractBasePtr PrimitiveC::Infer(const AbstractBasePtrList &abstract_list) {
auto infer_map = abstract::GetPrimitiveToEvalImplMap();
auto iter = infer_map.find(std::make_shared<Primitive>(this->name()));
if (iter == infer_map.end()) {
MS_EXCEPTION(NotExistsError) << "Cannot find the " << this->name() << "infer function in the infer map!";
}
auto infer_function = iter->second.impl_;
return infer_function(nullptr, shared_from_base<Primitive>(), abstract_list);
}
} // namespace mindspore

@ -21,17 +21,16 @@
#include <string>
#include <vector>
#include "ir/primitive.h"
#include "abstract/primitive_infer_map.h"
#include "ir/value.h"
namespace mindspore {
class PrimitiveC : public Primitive {
public:
explicit PrimitiveC(const std::string &name) : Primitive(name) {}
AbstractBasePtr Infer(const AbstractBasePtrList &abstract_list);
protected:
void InitIOName(const std::vector<std::string> &inputs_name, const std::vector<std::string> &outputs_name) {
this->AddAttr("input_names", MakeValue(inputs_name));
this->AddAttr("output_names", MakeValue(outputs_name));
}
void InitIOName(const std::vector<std::string> &inputs_name, const std::vector<std::string> &outputs_name);
};
} // namespace mindspore
#endif // MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H_

@ -162,16 +162,6 @@ std::vector<int> CheckAndConvertUtils::ConvertShapePtrToShape(const std::string
return shape_element->shape();
}
TypeId CheckAndConvertUtils::ConvertTypePtrToTypeId(const string &arg_name, const TypePtr &type_ptr,
const string &prim_name) {
MS_EXCEPTION_IF_NULL(type_ptr);
if (!type_ptr->isa<TensorType>() || !type_ptr->isa<Number>()) {
MS_EXCEPTION(ValueError) << "The " << arg_name << "'s shape is " << type_ptr->ToString()
<< "should be a common type!(tensor_type && numbertype)";
}
return type_ptr->type_id();
}
void CheckAndConvertUtils::Check(const string &arg_name, int arg_value, CompareEnum compare_type,
const string &value_name, int value, const string &prim_name,
ExceptionType exception_type) {
@ -231,11 +221,10 @@ void CheckAndConvertUtils::Check(const string &arg_name, const std::vector<int>
MS_EXCEPTION(exception_type) << buffer.str();
}
void CheckAndConvertUtils::CheckTensorTypeSame(const std::map<std::string, TypePtr> &types,
TypeId CheckAndConvertUtils::CheckTensorTypeSame(const std::map<std::string, TypePtr> &types,
const std::set<TypeId> &check_list, const std::string &prim_name) {
if (types.empty()) {
MS_LOG(WARNING) << "Tryinh to use the function to check a empty types map!";
return;
MS_EXCEPTION(ArgumentError) << "Trying to use the function to check a empty types map!";
}
std::set<TypeId> types_id;
std::ostringstream buffer;
@ -246,7 +235,11 @@ void CheckAndConvertUtils::CheckTensorTypeSame(const std::map<std::string, TypeP
MS_EXCEPTION(TypeError) << "The " << prim_name << "'s" << type.first << " input must be tensor type but got "
<< type.second->ToString();
}
types_id.emplace(type.second->type_id());
auto tensor_type = type.second->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
auto element = tensor_type->element();
MS_EXCEPTION_IF_NULL(element);
types_id.emplace(element->type_id());
}
if (types_id.size() > 1) {
buffer << "'s input type is not same : ";
@ -255,16 +248,17 @@ void CheckAndConvertUtils::CheckTensorTypeSame(const std::map<std::string, TypeP
}
MS_EXCEPTION(TypeError) << buffer.str();
}
if (check_list.find(*(types_id.begin())) != check_list.end()) {
if (check_list.find(*types_id.begin()) == check_list.end()) {
buffer << " type of ";
for (const auto &elem : types) {
buffer << elem.first << " should be in [";
for (auto type_elem : check_list) {
buffer << type_elem << " ,";
buffer << TypeIdToType(type_elem)->ToString() << " ,";
}
buffer << "] , but got " << types.begin()->second->ToString();
}
}
MS_EXCEPTION(TypeError) << buffer.str();
}
return *types_id.begin();
}
} // namespace mindspore

@ -55,14 +55,12 @@ class CheckAndConvertUtils {
const std::pair<int, int> &range, const std::string &prim_name);
static std::vector<int> ConvertShapePtrToShape(const std::string &arg_name, const BaseShapePtr &shape,
const std::string &prim_name);
static TypeId ConvertTypePtrToTypeId(const std::string &arg_name, const TypePtr &type_ptr,
const std::string &prim_name);
static void Check(const std::string &arg_name, int arg_value, CompareEnum compare_type, const std::string &value_name,
int value, const std::string &prim_name = "", ExceptionType exception_type = ValueError);
static void Check(const std::string &arg_name, const std::vector<int> &arg_value, CompareEnum compare_type,
const std::string &value_name, const std::vector<int> &value, const std::string &prim_name = "",
ExceptionType exception_type = ValueError);
static void CheckTensorTypeSame(const std::map<std::string, TypePtr> &types, const std::set<TypeId> &check_list,
static TypeId CheckTensorTypeSame(const std::map<std::string, TypePtr> &types, const std::set<TypeId> &check_list,
const std::string &prim_name);
private:

Loading…
Cancel
Save