diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index 41d2c7e726..d2f25c0fac 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -149,7 +149,9 @@ add_subdirectory(${CMAKE_SOURCE_DIR}/mindspore/core/utils util) list(APPEND SUB_OBJECTS_SRC $) add_subdirectory(${CMAKE_SOURCE_DIR}/mindspore/core/ir ir) list(APPEND SUB_OBJECTS_SRC $) -add_dependencies(_mindspore_core_utils_obj _mindspore_base_obj _mindspore_ir_obj _mindspore_abstract_obj proto_input ) +add_subdirectory(${CMAKE_SOURCE_DIR}/mindspore/core/c_ops c_ops) +list(APPEND SUB_OBJECTS_SRC $) +add_dependencies(_mindspore_core_utils_obj _mindspore_base_obj _mindspore_ir_obj _mindspore_abstract_obj _mindspore_c_ops_obj proto_input) set_property(SOURCE ${SUB_OBJECTS_SRC} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_ME) add_library(mindspore STATIC ${SUB_OBJECTS_SRC}) diff --git a/mindspore/core/c_ops/CMakeLists.txt b/mindspore/core/c_ops/CMakeLists.txt new file mode 100644 index 0000000000..04709ed635 --- /dev/null +++ b/mindspore/core/c_ops/CMakeLists.txt @@ -0,0 +1,2 @@ +file(GLOB_RECURSE _C_OPS_ALL_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +add_library(_mindspore_c_ops_obj OBJECT ${_C_OPS_ALL_SRC_FILES}) diff --git a/mindspore/core/c_ops/conv2d.cc b/mindspore/core/c_ops/conv2d.cc new file mode 100644 index 0000000000..347ab79c1a --- /dev/null +++ b/mindspore/core/c_ops/conv2d.cc @@ -0,0 +1,139 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "c_ops/conv2d.h" +#include +#include +#include +#include +#include +#include "utils/check_convert_utils.h" + +namespace mindspore { +namespace { +using PrimConv2dPtr = std::shared_ptr; +abstract::ShapePtr InferShape(const PrimitivePtr &primitive, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto conv_prim = primitive->cast(); + MS_EXCEPTION_IF_NULL(conv_prim); + auto prim_name = conv_prim->name(); + CheckAndConvertUtils::CheckInRange("Conv2d Infer", input_args.size(), kIncludeLeft, {2, 3}, prim_name); + auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[0]->GetShapeTrack(), prim_name); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[1]->GetShapeTrack(), prim_name); + + CheckAndConvertUtils::CheckInteger("weight rank", w_shape.size(), kEqual, 4, prim_name); + CheckAndConvertUtils::CheckInteger("x rank", x_shape.size(), kEqual, 4, prim_name); + CheckAndConvertUtils::Check("x_shape[1] / group", x_shape[1] / conv_prim->GetGroup(), kEqual, "w_shape[1]", + w_shape[1], conv_prim->name()); + auto out_channel = conv_prim->GetOutputChannel(); + CheckAndConvertUtils::Check("out_channel", out_channel, kEqual, "w_shape[0]", w_shape[0], conv_prim->name()); + std::vector temp_w; + std::copy(w_shape.begin() + 2, w_shape.end(), std::back_inserter(temp_w)); + CheckAndConvertUtils::Check("kernel_size", conv_prim->GetKernelSize(), kEqual, "w_shape[2:4]", temp_w, + conv_prim->name()); + + auto kernel_size_h = w_shape[2]; + auto kernel_size_w = w_shape[3]; + auto stride = conv_prim->GetStride(); + auto dilation = conv_prim->GetDilation(); + auto stride_h = stride[2]; + auto stride_w = stride[3]; + auto dilation_h = dilation[2]; + auto dilation_w = dilation[3]; + int h_out = -1; + int w_out = -1; + std::vector pad_list(4, 0); + auto pad_mode = conv_prim->GetPadMode(); + if (pad_mode == "valid") { + h_out = ceil((x_shape[2] - dilation_h * (kernel_size_h - 1)) / stride_h); + w_out = ceil((x_shape[3] - dilation_w * (kernel_size_w - 1)) / stride_w); + } else if (pad_mode == "same") { + h_out = ceil(x_shape[2] / stride_h); + w_out = ceil(x_shape[3] / stride_w); + + auto pad_needed_h = std::max(0, (h_out - 1) * stride_h + dilation_h * (kernel_size_h - 1) + 1 - x_shape[2]); + pad_list.emplace_back(floor(pad_needed_h / 2)); + pad_list.emplace_back(pad_needed_h / 2); + auto pad_needed_w = std::max(0, (w_out - 1) * stride_w + dilation_w * (kernel_size_w - 1) + 1 - x_shape[3]); + auto pad_left = floor(pad_needed_w / 2); + pad_list.emplace_back(pad_left); + pad_list.emplace_back(pad_needed_h - pad_left); + } else if (pad_mode == "pad") { + std::copy(conv_prim->GetPad().begin(), conv_prim->GetPad().end(), std::back_inserter(pad_list)); + auto pad_top = conv_prim->GetPad()[0]; + auto pad_bottom = conv_prim->GetPad()[1]; + auto pad_right = conv_prim->GetPad()[2]; + auto pad_left = conv_prim->GetPad()[3]; + + h_out = 1 + (x_shape[2] + pad_top + pad_bottom - kernel_size_h - (kernel_size_h - 1) * (dilation_h - 1)) / stride_h; + w_out = 1 + (x_shape[3] + pad_left + pad_right - kernel_size_w - (kernel_size_w - 1) * (dilation_w - 1)) / stride_w; + h_out = floor(h_out); + w_out = floor(w_out); + } + conv_prim->SetPadList(pad_list); + std::vector out_shape = {x_shape[0], out_channel, h_out, w_out}; + return std::make_shared(out_shape); +} + +TypePtr InferType(const PrimitivePtr &prim, const std::vector &input_args) { + CheckAndConvertUtils::CheckInRange("", input_args.size(), kIncludeLeft, {2, 3}, prim->name()); + for (const auto &item : input_args) { + MS_EXCEPTION_IF_NULL(item); + } + auto x_type = CheckAndConvertUtils::ConvertTypePtrToTypeId("x_dtype", input_args[0]->GetTypeTrack(), prim->name()); + const std::set valid_types = {kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat16, kNumberTypeFloat32}; + std::map types; + types.emplace("x", input_args[0]->GetTypeTrack()); + types.emplace("w", input_args[1]->GetTypeTrack()); + CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name()); + if (x_type == kNumberTypeInt8) { + return std::make_shared(TypeIdToType(kNumberTypeInt32)); + } + return std::make_shared(TypeIdToType(x_type)); +} +} // namespace +void Conv2d::Init(int out_channel, const std::vector &kernel_size, int mode, const std::string &pad_mode, + const std::vector &pad, const std::vector &stride, const std::vector &dilation, + int group) { + auto prim_name = this->name(); + this->AddAttr("data_format", MakeValue("NCHW")); + this->AddAttr("offset_a", MakeValue(0)); + this->SetKernelSize(CheckAndConvertUtils::CheckPositiveVector(kKernelSize, kernel_size, prim_name)); + this->SetStride(CheckAndConvertUtils::CheckPositiveVector(kStride, stride, this->name(), true, true)); + this->SetDilation(CheckAndConvertUtils::CheckPositiveVector(kDilation, dilation, this->name(), true, true)); + this->SetPadMode(CheckAndConvertUtils::CheckString(kPadMode, pad_mode, {"valid", "same", "pad"}, prim_name)); + CheckAndConvertUtils::CheckInteger("pad size", pad.size(), kEqual, 4, prim_name); + if (pad_mode == "pad") { + for (auto item : pad) { + CheckAndConvertUtils::Check("pad item", item, kGreaterEqual, "zeros list", 0, prim_name); + } + } else { + CheckAndConvertUtils::Check(kPad, pad, kEqual, "zeros list", {0, 0, 0, 0}, prim_name); + } + this->SetPad(CheckAndConvertUtils::CheckPositiveVector(kPad, pad, this->name(), true, true)); + this->SetMode(CheckAndConvertUtils::CheckInteger("mode", mode, kEqual, 1, prim_name)); + this->SetOutChannel(CheckAndConvertUtils::CheckInteger("out_channel", out_channel, kGreaterThan, 0, prim_name)); + this->SetGroup(CheckAndConvertUtils::CheckInteger("group", group, kGreaterThan, 0, prim_name)); +} + +AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + return std::make_shared(InferType(primitive, input_args), + InferShape(primitive, input_args)->shape()); +} +} // namespace mindspore diff --git a/mindspore/core/c_ops/conv2d.h b/mindspore/core/c_ops/conv2d.h new file mode 100644 index 0000000000..cbba5fa068 --- /dev/null +++ b/mindspore/core/c_ops/conv2d.h @@ -0,0 +1,94 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_C_OPS_CONV2D_H +#define MINDSPORE_CORE_C_OPS_CONV2D_H +#include +#include +#include +#include "c_ops/primitive_c.h" +#include "abstract/abstract_value.h" +#include "utils/check_convert_utils.h" +namespace mindspore { +class Conv2d : public PrimitiveC { + public: + Conv2d() : PrimitiveC(kConv2DName) { InitIOName({"x", "w"}, {"output"}); } + void Init(int out_channel, const std::vector &kernel_size, int mode = 1, const std::string &pad_mode = "valid", + const std::vector &pad = {0, 0, 0, 0}, const std::vector &stride = {1, 1, 1, 1}, + const std::vector &dilation = {1, 1, 1, 1}, int group = 1); + std::vector GetKernelSize() const { + auto value_ptr = this->GetAttr(kKernelSize); + return GetValue>(value_ptr); + } + std::vector GetStride() const { + auto value_ptr = GetAttr(kStride); + return GetValue>(value_ptr); + } + std::vector GetDilation() const { + auto value_ptr = GetAttr(kDilation); + return GetValue>(value_ptr); + } + std::string GetPadMode() const { + auto value_ptr = this->GetAttr(kPadMode); + return GetValue(value_ptr); + } + std::vector GetPad() const { + auto value_ptr = this->GetAttr(kPad); + return GetValue>(value_ptr); + } + int GetMode() const { + auto value_ptr = this->GetAttr(kMode); + return GetValue(value_ptr); + } + + int GetGroup() const { + auto value_ptr = this->GetAttr(kGroup); + return GetValue(value_ptr); + } + int GetOutputChannel() const { + auto value_ptr = this->GetAttr(kOutputChannel); + return GetValue(value_ptr); + } + + void SetKernelSize(const std::vector &kernel_size) { this->AddAttr(kKernelSize, MakeValue(kernel_size)); } + void SetStride(const std::vector &stride) { this->AddAttr(kStride, MakeValue(stride)); } + void SetDilation(const std::vector &dilation) { this->AddAttr(kDilation, MakeValue(dilation)); } + void SetPadMode(const std::string &pad_mode) { this->AddAttr(kPadMode, MakeValue(pad_mode)); } + void SetPad(const std::vector &pad) { this->AddAttr(kPad, MakeValue(pad)); } + void SetMode(int mode) { this->AddAttr(kMode, MakeValue(mode)); } + void SetGroup(int group) { this->AddAttr(kGroup, MakeValue(group)); } + void SetOutChannel(int output_channel) { this->AddAttr(kOutputChannel, MakeValue(output_channel)); } + void SetPadList(const std::vector &pad_list) { this->AddAttr(kPadList, MakeValue(pad_list)); } + + private: + inline static const string kKernelSize = "kernel_size"; + inline static const string kStride = "stride"; + inline static const string kDilation = "dilation"; + inline static const string kPadMode = "pad_mode"; + inline static const string kPad = "pad"; + inline static const string kMode = "mode"; + inline static const string kGroup = "group"; + inline static const string kOutputChannel = "output channel"; + inline static const string kPadList = "pad_list"; + inline static const string kConv2DName = "Conv2D"; +}; +AbstractBasePtr Conv2dInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args); +} // namespace mindspore + +#endif // MINDSPORE_CORE_C_OPS_CONV2D_H diff --git a/mindspore/core/c_ops/primitive_c.h b/mindspore/core/c_ops/primitive_c.h new file mode 100644 index 0000000000..69c85b31ae --- /dev/null +++ b/mindspore/core/c_ops/primitive_c.h @@ -0,0 +1,37 @@ +/** + * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). + * + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H +#define MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H +#include +#include +#include "ir/primitive.h" +#include "ir/value.h" +namespace mindspore { +class PrimitiveC : public Primitive { + public: + explicit PrimitiveC(const std::string &name) : Primitive(name) { attrs_ = {}; } + + protected: + void InitIOName(const std::vector &inputs_name, const std::vector &outputs_name) { + this->AddAttr("input_names", MakeValue(inputs_name)); + this->AddAttr("output_names", MakeValue(outputs_name)); + } +}; +} // namespace mindspore +#endif // MINDSPORE_CORE_C_OPS_PRIMITIVE_C_H diff --git a/mindspore/core/ir/func_graph.cc b/mindspore/core/ir/func_graph.cc index e8992fda8a..6832a94e82 100644 --- a/mindspore/core/ir/func_graph.cc +++ b/mindspore/core/ir/func_graph.cc @@ -632,6 +632,19 @@ void FuncGraph::CheckOrder() { MS_LOG(DEBUG) << "Check order okay."; } } +CNodePtr FuncGraph::NewCNode(const PrimitivePtr &primitive, const std::vector &inputs) { + auto primitive_node = std::make_shared(primitive); + std::vector input_node_list = {primitive_node}; + std::copy(inputs.begin(), inputs.end(), std::back_inserter(input_node_list)); + return NewCNode(input_node_list); +} + +ParameterPtr FuncGraph::add_parameter(const tensor::MetaTensorPtr &meta_tensor) { + auto parameter = add_parameter(); + parameter->set_default_param(MakeValue(meta_tensor)); + parameter->set_abstract(meta_tensor->ToAbstract()); + return parameter; +} size_t NewFgSeenGeneration() { static size_t fg_seen_generation = 0; diff --git a/mindspore/core/ir/func_graph.h b/mindspore/core/ir/func_graph.h index 3ce74cfb5b..8bcbd3fdc1 100644 --- a/mindspore/core/ir/func_graph.h +++ b/mindspore/core/ir/func_graph.h @@ -170,7 +170,9 @@ class FuncGraph : public FuncGraphBase { // create a cnode with given inputs, bound to this graph, and set to specific scope CNodePtr NewCNodeWithScope(const std::vector &inputs, const ScopePtr &scope); + virtual CNodePtr NewCNode(const PrimitivePtr &primitive, const std::vector &prim_inputs); + virtual ParameterPtr add_parameter(const tensor::MetaTensorPtr &meta_tensor); // Functions for handling variable argument, keyword-only arguments and variable keyword argument AnfNodePtr GetDefaultValueByName(const std::string &name); void set_param_default_value(const std::string &name, const AnfNodePtr &node) { diff --git a/mindspore/core/utils/check_convert_utils.cc b/mindspore/core/utils/check_convert_utils.cc new file mode 100644 index 0000000000..140f3efc0d --- /dev/null +++ b/mindspore/core/utils/check_convert_utils.cc @@ -0,0 +1,270 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "utils/check_convert_utils.h" +#include +#include "abstract/abstract_value.h" + +namespace mindspore { +namespace { +const std::map> kCompareMap = { + {kEqual, [](int num1, int num2) -> bool { return num1 == num2; }}, + {kNotEqual, [](int num1, int num2) -> bool { return num1 != num2; }}, + {kLessThan, [](int num1, int num2) -> bool { return num1 < num2; }}, + {kLessEqual, [](int num1, int num2) -> bool { return num1 <= num2; }}, + {kGreaterThan, [](int num1, int num2) -> bool { return num1 > num2; }}, + {kGreaterEqual, [](int num1, int num2) -> bool { return num1 >= num2; }}}; + +const std::map)>> kCompareRangeMap = { + {kIncludeNeither, + [](int num1, std::pair range) -> bool { return num1 > range.first && num1 < range.second; }}, + {kIncludeLeft, + [](int num1, std::pair range) -> bool { return num1 >= range.first && num1 < range.second; }}, + {kIncludeRight, + [](int num1, std::pair range) -> bool { return num1 > range.first && num1 <= range.second; }}, + {kIncludeBoth, + [](int num1, std::pair range) -> bool { return num1 >= range.first && num1 <= range.second; }}}; + +const std::map kCompareToString = { + {kEqual, "equal"}, {kNotEqual, "not equal"}, {kLessThan, "less than"}, + {kLessEqual, "less eqaul"}, {kGreaterThan, "greater than"}, {kGreaterEqual, "greate equal"}}; + +const std::map> kCompareRangeToString = { + {kIncludeNeither, {"in (", ")"}}, + {kIncludeLeft, {" in [", ")"}}, + {kIncludeRight, {"in (", "]"}}, + {kIncludeBoth, {"in [", "]"}}}; +} // namespace +bool CheckAndConvertUtils::IsEqualVector(const std::vector &vec_1, const std::vector &vec_2) { + if (vec_1.size() != vec_2.size()) { + return false; + } + for (size_t index = 0; index < vec_1.size(); ++index) { + if (vec_1[index] != vec_2[index]) { + return false; + } + } + return true; +} + +std::vector CheckAndConvertUtils::CheckPositiveVector(const std::string &arg_name, + const std::vector &arg_value, + const std::string &prim_name, bool allow_four, + bool ret_four) { + if (arg_value.size() == 2) { + return ret_four ? std::vector{1, 1, arg_value[0], arg_value[1]} : arg_value; + } else if (arg_value.size() == 4 && allow_four) { + return ret_four ? arg_value : std::vector{arg_value[2], arg_value[3]}; + } + std::ostringstream buffer; + buffer << "For " << prim_name << " attr " << arg_name << " should be a positive vector of size two "; + if (allow_four) { + buffer << "or four "; + } + buffer << " positive int numbers , but got ["; + for (auto item : arg_value) { + buffer << item << ","; + } + buffer << "]"; + MS_EXCEPTION(ValueError) << buffer.str(); +} +std::string CheckAndConvertUtils::CheckString(const std::string &arg_name, const std::string &arg_value, + const std::set &check_list, const std::string &prim_name) { + if (check_list.find(arg_value) != check_list.end()) { + return arg_value; + } + std::ostringstream buffer; + buffer << "For " << prim_name << " the " << arg_name << " should be str and must be "; + if (check_list.size() == 1) { + buffer << (*check_list.begin()) << "but got " << arg_value; + MS_EXCEPTION(ValueError) << buffer.str(); + } + buffer << "one of {"; + for (const auto &item : check_list) { + buffer << item << " ,"; + } + buffer << " }" + << " but got " << arg_value; + MS_EXCEPTION(ValueError) << buffer.str(); +} + +int CheckAndConvertUtils::CheckInteger(const std::string &arg_name, int arg_value, CompareEnum compare_operator, + int match_value, const std::string &prim_name) { + auto iter = kCompareMap.find(compare_operator); + if (iter == kCompareMap.end()) { + MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator << " cannot find in the compare map"; + } + if (iter->second(arg_value, match_value)) { + return arg_value; + } + std::ostringstream buffer; + if (prim_name.empty()) { + buffer << "The "; + } else { + buffer << "For " << prim_name << " the "; + } + buffer << arg_name << " must "; + auto iter_to_string = kCompareToString.find(compare_operator); + if (iter_to_string == kCompareToString.end()) { + MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator << " cannot find in the compare string map"; + } + buffer << iter_to_string->second << match_value << " , but got " << arg_value; + MS_EXCEPTION(ValueError) << buffer.str(); +} + +void CheckAndConvertUtils::CheckInRange(const std::string &arg_name, int arg_value, CompareRange compare_operator, + const std::pair &range, const std::string &prim_name) { + auto iter = kCompareRangeMap.find(compare_operator); + if (iter == kCompareRangeMap.end()) { + MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator << " cannot find in the compare map"; + } + if (iter->second(arg_value, range)) { + return; + } + std::ostringstream buffer; + if (prim_name.empty()) { + buffer << "The "; + } else { + buffer << "For " << prim_name << " the "; + } + buffer << arg_name << " must "; + auto iter_to_string = kCompareRangeToString.find(compare_operator); + if (iter_to_string == kCompareRangeToString.end()) { + MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator << " cannot find in the compare string map"; + } + auto range_strng = iter_to_string->second; + buffer << range_strng.first << range.first << "," << range_strng.second << " , but got " << arg_value; + MS_EXCEPTION(ValueError) << buffer.str(); +} + +std::vector CheckAndConvertUtils::ConvertShapePtrToShape(const std::string &arg_name, const BaseShapePtr &shape, + const std::string &prim_name) { + MS_EXCEPTION_IF_NULL(shape); + if (!shape->isa()) { + MS_EXCEPTION(ValueError) << "The " << arg_name << "'s shape is " << shape->ToString() + << "should be a common shape!"; + } + auto shape_element = shape->cast(); + MS_EXCEPTION_IF_NULL(shape_element); + return shape_element->shape(); +} + +TypeId CheckAndConvertUtils::ConvertTypePtrToTypeId(const string &arg_name, const TypePtr &type_ptr, + const string &prim_name) { + MS_EXCEPTION_IF_NULL(type_ptr); + if (!type_ptr->isa() || !type_ptr->isa()) { + MS_EXCEPTION(ValueError) << "The " << arg_name << "'s shape is " << type_ptr->ToString() + << "should be a common type!(tensor_type && numbertype)"; + } + return type_ptr->type_id(); +} + +void CheckAndConvertUtils::Check(const string &arg_name, int arg_value, CompareEnum compare_type, + const string &value_name, int value, const string &prim_name, + ExceptionType exception_type) { + auto iter = kCompareMap.find(compare_type); + if (iter == kCompareMap.end()) { + MS_EXCEPTION(NotExistsError) << "the compare type :" << compare_type << " is not in the compare map"; + } + if (iter->second(arg_value, value)) { + return; + } + std::ostringstream buffer; + if (prim_name.empty()) { + buffer << "The "; + } else { + buffer << "For " << prim_name << " the "; + } + auto iter_to_string = kCompareToString.find(compare_type); + if (iter_to_string == kCompareToString.end()) { + MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_type << " cannot find in the compare string map"; + } + MS_EXCEPTION(exception_type) << buffer.str() << arg_name << " should be " << iter_to_string->second << value + << " but got " << arg_value; +} +void CheckAndConvertUtils::Check(const string &arg_name, const std::vector &arg_value, CompareEnum compare_type, + const string &value_name, const std::vector &value, const string &prim_name, + ExceptionType exception_type) { + if (compare_type != kEqual) { + auto iter = kCompareToString.find(compare_type); + if (iter != kCompareToString.end()) { + MS_EXCEPTION(NotSupportError) << "Only supported equal to compare two vectors but got " << iter->second; + } + MS_EXCEPTION(UnknownError) << "Cannot find the operator " << compare_type << "in the compare map!"; + } + if (arg_value == value) { + return; + } + std::ostringstream buffer; + if (prim_name.empty()) { + buffer << "The "; + } else { + buffer << "For " << prim_name << " the "; + } + auto iter_to_string = kCompareToString.find(compare_type); + if (iter_to_string == kCompareToString.end()) { + MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_type << " cannot find in the compare string map"; + } + buffer << arg_name << "should be " << iter_to_string->second << " ["; + for (auto item : value) { + buffer << item << ","; + } + buffer << "] " + << "but got ["; + for (auto item : arg_value) { + buffer << item << " ,"; + } + buffer << "]"; + MS_EXCEPTION(exception_type) << buffer.str(); +} + +void CheckAndConvertUtils::CheckTensorTypeSame(const std::map &types, + const std::set &check_list, const std::string &prim_name) { + if (types.empty()) { + MS_LOG(WARNING) << "Tryinh to use the function to check a empty types map!"; + return; + } + std::set types_id; + std::ostringstream buffer; + buffer << "For " << prim_name; + for (const auto &type : types) { + MS_EXCEPTION_IF_NULL(type.second); + if (!type.second->isa()) { + MS_EXCEPTION(TypeError) << "The " << prim_name << "'s" << type.first << " input must be tensor type but got " + << type.second->ToString(); + } + types_id.emplace(type.second->type_id()); + } + if (types_id.size() > 1) { + buffer << "'s input type is not same : "; + for (const auto &item : types) { + buffer << "[ name : " << item.first << " ,type : " << item.second->ToString() << "]"; + } + MS_EXCEPTION(TypeError) << buffer.str(); + } + if (check_list.find(*(types_id.begin())) != check_list.end()) { + buffer << " type of "; + for (const auto &elem : types) { + buffer << elem.first << " should be in ["; + for (auto type_elem : check_list) { + buffer << type_elem << " ,"; + } + buffer << "] , but got " << types.begin()->second->ToString(); + } + } + MS_EXCEPTION(TypeError) << buffer.str(); +} +} // namespace mindspore diff --git a/mindspore/core/utils/check_convert_utils.h b/mindspore/core/utils/check_convert_utils.h new file mode 100644 index 0000000000..f61f22d19b --- /dev/null +++ b/mindspore/core/utils/check_convert_utils.h @@ -0,0 +1,72 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H +#define MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H +#include +#include +#include +#include +#include +#include "base/base.h" +#include "ir/anf.h" +#include "ir/dtype/type_id.h" +#include "utils/log_adapter.h" +namespace mindspore { +enum CompareEnum : int { + kEqual = 1, // == + kNotEqual = 2, // != + kLessThan = 3, // < + kLessEqual = 4, // <= + kGreaterThan = 5, // > + kGreaterEqual = 6, // >= +}; + +enum CompareRange { + kIncludeNeither = 1, // (a,b) + kIncludeLeft = 2, // [a,b) + kIncludeRight = 3, // (a,b] + kIncludeBoth = 4, // [a,b] +}; + +class CheckAndConvertUtils { + public: + static std::vector CheckPositiveVector(const std::string &arg_name, const std::vector &arg_value, + const std::string &prim_name, bool allow_four = false, + bool ret_four = false); + static std::string CheckString(const std::string &arg_name, const std::string &arg_value, + const std::set &check_list, const std::string &prim_name); + static int CheckInteger(const std::string &arg_name, int arg_value, CompareEnum compare_operator, int match_value, + const std::string &prim_name); + static void CheckInRange(const std::string &arg_name, int arg_value, CompareRange compare_operator, + const std::pair &range, const std::string &prim_name); + static std::vector ConvertShapePtrToShape(const std::string &arg_name, const BaseShapePtr &shape, + const std::string &prim_name); + static TypeId ConvertTypePtrToTypeId(const std::string &arg_name, const TypePtr &type_ptr, + const std::string &prim_name); + static void Check(const std::string &arg_name, int arg_value, CompareEnum compare_type, const std::string &value_name, + int value, const std::string &prim_name = "", ExceptionType exception_type = ValueError); + static void Check(const std::string &arg_name, const std::vector &arg_value, CompareEnum compare_type, + const std::string &value_name, const std::vector &value, const std::string &prim_name = "", + ExceptionType exception_type = ValueError); + static void CheckTensorTypeSame(const std::map &types, const std::set &check_list, + const std::string &prim_name); + + private: + static bool IsEqualVector(const std::vector &vec_1, const std::vector &vec_2); +}; +} // namespace mindspore +#endif // MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H