diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index f69a67f7c3..f95b7f40d6 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -444,14 +444,9 @@ KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_nod if (!anf_node->isa()) { MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode."; } - auto cnode = anf_node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - if (input_idx + 1 >= cnode->inputs().size()) { - MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode); - } - auto node = cnode->input(input_idx + 1); - MS_EXCEPTION_IF_NULL(node); - return VisitKernelWithReturnType(node, 0); + auto input_node = AnfAlgo::GetInputNode(anf_node->cast(), input_idx); + MS_EXCEPTION_IF_NULL(input_node); + return VisitKernelWithReturnType(input_node, 0); } std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) { @@ -975,7 +970,7 @@ bool AnfRuntimeAlgorithm::IsTupleOutput(const AnfNodePtr &anf) { AnfNodePtr AnfRuntimeAlgorithm::GetInputNode(const CNodePtr &node, size_t index) { MS_EXCEPTION_IF_NULL(node); auto get_input_index = index + 1; - if (index + 1 > node->inputs().size()) { + if (index + 1 >= node->inputs().size()) { MS_LOG(EXCEPTION) << "Input index size " << get_input_index << "but the node input size just" << node->inputs().size(); } diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 29a0795afb..7f320d9564 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -1061,5 +1061,10 @@ void AscendSession::UpdateRefOutputMap(NotNull graph, } } } + +GraphId AscendSession::CompileGraph(NotNull func_graph, const vector &inputs) { + RunInfer(func_graph, inputs); + return CompileGraph(func_graph); +} } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/ascend_session.h b/mindspore/ccsrc/backend/session/ascend_session.h index 91a86e6d10..5ddf77354f 100755 --- a/mindspore/ccsrc/backend/session/ascend_session.h +++ b/mindspore/ccsrc/backend/session/ascend_session.h @@ -52,6 +52,7 @@ class AscendSession : public SessionBasic { } GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; GraphId CompileGraph(NotNull func_graph) override; + GraphId CompileGraph(NotNull func_graph, const std::vector &inputs) override; void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) override; void BuildGraph(GraphId) override; void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 865f7adc5d..945f037ec8 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -17,6 +17,7 @@ #include #include #include +#include "c_ops/primitive_c.h" #include "pipeline/jit/parse/data_converter.h" #include "ir/manager.h" #include "ir/param_info.h" @@ -1039,6 +1040,45 @@ void SessionBasic::RegisterSummaryCallBackFunc(const CallBackFunc &callback) { void SessionBasic::Reorder(std::vector *node_list) { AnfAlgo::ReorderExecList(NOT_NULL(node_list)); } +void SessionBasic::RunInfer(NotNull func_graph, const std::vector &inputs) { + auto node_list = TopoSort(func_graph->get_return()); + size_t tensor_index = 0; + for (const auto &node : node_list) { + MS_EXCEPTION_IF_NULL(node); + if (node->isa()) { + AbstractBasePtrList input_abstracts; + for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(node); ++index) { + auto input_node = AnfAlgo::GetInputNode(node->cast(), index); + MS_EXCEPTION_IF_NULL(input_node); + auto abstract = input_node->abstract(); + MS_EXCEPTION_IF_NULL(abstract); + input_abstracts.emplace_back(abstract); + } + auto prim = AnfAlgo::GetCNodePrimitive(node); + if (prim->isa()) { + auto prim_c = prim->cast>(); + MS_EXCEPTION_IF_NULL(prim_c); + auto abstract = prim_c->Infer(input_abstracts); + node->set_abstract(abstract); + } else { + node->set_abstract( + std::make_shared(kNumberTypeFloat32, std::vector{32, 64, 218, 218})->ToAbstract()); + } + } else if (node->isa()) { + if (tensor_index > inputs.size()) { + MS_EXCEPTION(IndexError) << "Index " << tensor_index << "is out of " << inputs.size() << "tensor's size"; + } + node->set_abstract(inputs[tensor_index++]->ToAbstract()); + } else { + auto value_node = node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto value = value_node->value(); + MS_EXCEPTION_IF_NULL(value); + value_node->set_abstract(value->ToAbstract()); + } + } +} + void SessionBasic::SetSummaryNodes(KernelGraph *graph) { MS_LOG(DEBUG) << "Update summary Start"; MS_EXCEPTION_IF_NULL(graph); diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index 863a222034..c0e0ed8d22 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -70,6 +70,9 @@ class SessionBasic : public std::enable_shared_from_this { virtual GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) = 0; virtual GraphId CompileGraph(NotNull func_graph) { return kInvalidGraphId; } + virtual GraphId CompileGraph(NotNull func_graph, const std::vector &inputs) { + MS_EXCEPTION(NotExistsError) << "Call an empty function"; + } // build graph, used to handle multiple child graphs virtual void BuildGraph(GraphId) {} @@ -129,6 +132,7 @@ class SessionBasic : public std::enable_shared_from_this { void CreateCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector *cnode_inputs); protected: + void RunInfer(NotNull func_graph, const std::vector &inputs); // Get graph by graph id ,if not exist return null ptr KernelGraphPtr GetGraph(GraphId graph_id) const; diff --git a/mindspore/core/c_ops/conv2d.cc b/mindspore/core/c_ops/conv2d.cc index c29bb1b1a9..92974731db 100644 --- a/mindspore/core/c_ops/conv2d.cc +++ b/mindspore/core/c_ops/conv2d.cc @@ -37,10 +37,10 @@ constexpr auto kPadList = "pad_list"; constexpr auto kConv2DName = "Conv2D"; abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); - auto conv_prim = std::dynamic_pointer_cast(primitive); + auto conv_prim = primitive->cast(); MS_EXCEPTION_IF_NULL(conv_prim); auto prim_name = conv_prim->name(); - CheckAndConvertUtils::CheckInRange("Conv2d Infer", input_args.size(), kIncludeLeft, {2, 3}, prim_name); + CheckAndConvertUtils::CheckInRange("Conv2d Infer", input_args.size(), kIncludeBoth, {2, 3}, prim_name); auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShape("x_shape", input_args[0]->GetShapeTrack(), prim_name); auto w_shape = CheckAndConvertUtils::ConvertShapePtrToShape("w_shape", input_args[1]->GetShapeTrack(), prim_name); @@ -99,7 +99,7 @@ abstract::ShapePtr Conv2dInferShape(const PrimitivePtr &primitive, const std::ve } TypePtr Conv2dInferType(const PrimitivePtr &prim, const std::vector &input_args) { - CheckAndConvertUtils::CheckInRange("", input_args.size(), kIncludeLeft, {2, 3}, prim->name()); + CheckAndConvertUtils::CheckInRange("", input_args.size(), kIncludeBoth, {2, 3}, prim->name()); for (const auto &item : input_args) { MS_EXCEPTION_IF_NULL(item); } diff --git a/mindspore/core/c_ops/conv2d.h b/mindspore/core/c_ops/conv2d.h index ed637c5829..7e8412e2f5 100644 --- a/mindspore/core/c_ops/conv2d.h +++ b/mindspore/core/c_ops/conv2d.h @@ -29,6 +29,7 @@ class Conv2d : public PrimitiveC { public: Conv2d(); ~Conv2d() = default; + MS_DECLARE_PARENT(Conv2d, PrimitiveC); void Init(int out_channel, const std::vector &kernel_size, int mode = 1, const std::string &pad_mode = "valid", const std::vector &pad = {0, 0, 0, 0}, const std::vector &stride = {1, 1, 1, 1}, const std::vector &dilation = {1, 1, 1, 1}, int group = 1); diff --git a/mindspore/core/c_ops/primitive_c.h b/mindspore/core/c_ops/primitive_c.h index e0ff2df6b6..f85d4559f4 100644 --- a/mindspore/core/c_ops/primitive_c.h +++ b/mindspore/core/c_ops/primitive_c.h @@ -25,6 +25,7 @@ namespace mindspore { class PrimitiveC : public Primitive { public: explicit PrimitiveC(const std::string &name) : Primitive(name) {} + MS_DECLARE_PARENT(PrimitiveC, Primitive); ~PrimitiveC() = default; AbstractBasePtr Infer(const AbstractBasePtrList &abstract_list); diff --git a/mindspore/core/ir/func_graph.cc b/mindspore/core/ir/func_graph.cc index 9ec8a6e1c5..3a6808ca71 100644 --- a/mindspore/core/ir/func_graph.cc +++ b/mindspore/core/ir/func_graph.cc @@ -640,7 +640,7 @@ CNodePtr FuncGraph::NewCNode(const PrimitivePtr &primitive, const std::vectorset_default_param(MakeValue(meta_tensor)); parameter->set_abstract(meta_tensor->ToAbstract()); diff --git a/mindspore/core/ir/func_graph.h b/mindspore/core/ir/func_graph.h index fe4ca7afa9..d813a958b6 100644 --- a/mindspore/core/ir/func_graph.h +++ b/mindspore/core/ir/func_graph.h @@ -173,7 +173,7 @@ class FuncGraph : public FuncGraphBase { CNodePtr NewCNodeWithScope(const std::vector &inputs, const ScopePtr &scope); virtual CNodePtr NewCNode(const PrimitivePtr &primitive, const std::vector &prim_inputs); - virtual ParameterPtr add_parameter(const tensor::MetaTensorPtr &meta_tensor); + virtual ParameterPtr add_weight(const tensor::MetaTensorPtr &meta_tensor); // Functions for handling variable argument, keyword-only arguments and variable keyword argument AnfNodePtr GetDefaultValueByName(const std::string &name); void set_param_default_value(const std::string &name, const AnfNodePtr &node) { diff --git a/mindspore/core/utils/check_convert_utils.cc b/mindspore/core/utils/check_convert_utils.cc index 1c7fe9abf9..5b7dd6faf1 100644 --- a/mindspore/core/utils/check_convert_utils.cc +++ b/mindspore/core/utils/check_convert_utils.cc @@ -64,23 +64,36 @@ std::vector CheckAndConvertUtils::CheckPositiveVector(const std::string &ar const std::vector &arg_value, const std::string &prim_name, bool allow_four, bool ret_four) { + auto raise_message = [allow_four, prim_name, arg_value, arg_name]() -> void { + std::ostringstream buffer; + buffer << "For " << prim_name << " attr " << arg_name << " should be a positive vector of size two "; + if (allow_four) { + buffer << "or four "; + } + buffer << " positive int numbers , but got ["; + for (auto item : arg_value) { + buffer << item << ","; + } + buffer << "]"; + MS_EXCEPTION(ValueError) << buffer.str(); + }; + for (auto item : arg_value) { + if (item < 0) { + raise_message(); + } + } + if (arg_value.size() == 1) { + return ret_four ? std::vector{1, 1, arg_value[0], arg_value[0]} : std::vector{arg_value[0], arg_value[0]}; + } if (arg_value.size() == 2) { return ret_four ? std::vector{1, 1, arg_value[0], arg_value[1]} : arg_value; } else if (arg_value.size() == 4 && allow_four) { return ret_four ? arg_value : std::vector{arg_value[2], arg_value[3]}; } - std::ostringstream buffer; - buffer << "For " << prim_name << " attr " << arg_name << " should be a positive vector of size two "; - if (allow_four) { - buffer << "or four "; - } - buffer << " positive int numbers , but got ["; - for (auto item : arg_value) { - buffer << item << ","; - } - buffer << "]"; - MS_EXCEPTION(ValueError) << buffer.str(); + raise_message(); + return arg_value; } + std::string CheckAndConvertUtils::CheckString(const std::string &arg_name, const std::string &arg_value, const std::set &check_list, const std::string &prim_name) { if (check_list.find(arg_value) != check_list.end()) { @@ -131,6 +144,10 @@ void CheckAndConvertUtils::CheckInRange(const std::string &arg_name, int arg_val if (iter == kCompareRangeMap.end()) { MS_EXCEPTION(NotExistsError) << "compare_operator " << compare_operator << " cannot find in the compare map"; } + if (range.first >= range.second) { + MS_EXCEPTION(ArgumentError) << "the check range left must be larger than right number bug got [ " << range.first + << "," << range.second; + } if (iter->second(arg_value, range)) { return; } diff --git a/mindspore/core/utils/check_convert_utils.h b/mindspore/core/utils/check_convert_utils.h index 77b78d2189..1085c4f135 100644 --- a/mindspore/core/utils/check_convert_utils.h +++ b/mindspore/core/utils/check_convert_utils.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H -#define MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H +#ifndef MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H_ +#define MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H_ #include #include #include @@ -67,4 +67,4 @@ class CheckAndConvertUtils { static bool IsEqualVector(const std::vector &vec_1, const std::vector &vec_2); }; } // namespace mindspore -#endif // MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H +#endif // MINDSPORE_CORE_UTILS_CHECK_CONVERT_UTILS_H_ diff --git a/mindspore/core/utils/tensor_construct_utils.cc b/mindspore/core/utils/tensor_construct_utils.cc new file mode 100644 index 0000000000..66764bb463 --- /dev/null +++ b/mindspore/core/utils/tensor_construct_utils.cc @@ -0,0 +1,59 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "utils/tensor_construct_utils.h" +#include +#include +namespace mindspore { +namespace { +template +void SetTensorData(void *data, float num, size_t data_length) { + MS_EXCEPTION_IF_NULL(data); + auto tensor_data = reinterpret_cast(data); + MS_EXCEPTION_IF_NULL(tensor_data); + for (size_t index = 0; index < data_length; ++index) { + *tensor_data = num; + ++tensor_data; + } +} +} // namespace +tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(TypeId type, const std::vector &shape) { + tensor::TensorPtr tensor = std::make_shared(type, shape); + + size_t mem_size = GetTypeByte(tensor->type()) * IntToSize(tensor->ElementsNum()); + auto tensor_data = tensor->data_c(); + char *data = reinterpret_cast(tensor_data); + MS_EXCEPTION_IF_NULL(data); + (void)memset_s(data, mem_size, 0, mem_size); + + return tensor; +} + +tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(TypeId type, const std::vector &shape) { + tensor::TensorPtr tensor = std::make_shared(type, shape); + auto mem_size = IntToSize(tensor->ElementsNum()); + if (tensor->data_type() == kNumberTypeFloat32) { + SetTensorData(tensor->data_c(), 1.0, mem_size); + } else if (tensor->data_type() == kNumberTypeInt) { + SetTensorData(tensor->data_c(), 1, mem_size); + } + return tensor; +} + +tensor::TensorPtr TensorConstructUtils::CreateTensor(TypeId type, const std::vector &shape, void *data) { + tensor::TensorPtr tensor = std::make_shared(type, shape, data, type); + return tensor; +} +} // namespace mindspore diff --git a/mindspore/core/utils/tensor_construct_utils.h b/mindspore/core/utils/tensor_construct_utils.h new file mode 100644 index 0000000000..b4e48a2ff9 --- /dev/null +++ b/mindspore/core/utils/tensor_construct_utils.h @@ -0,0 +1,28 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CORE_UTILS_TENSOR_CONSTRUCT_UTILS_H_ +#define MINDSPORE_CORE_UTILS_TENSOR_CONSTRUCT_UTILS_H_ +#include +#include "ir/tensor.h" +namespace mindspore { +class TensorConstructUtils { + public: + static tensor::TensorPtr CreateZerosTensor(TypeId type, const std::vector &shape); + static tensor::TensorPtr CreateOnesTensor(TypeId type, const std::vector &shape); + static tensor::TensorPtr CreateTensor(TypeId type, const std::vector &shape, void *data); +}; +} // namespace mindspore +#endif // MINDSPORE_CORE_UTILS_TENSOR_CONSTRUCT_UTILS_H_