From 0681c2597bef41bf976b3abc3df979019699be99 Mon Sep 17 00:00:00 2001 From: "wangnan39@huawei.com" Date: Mon, 1 Mar 2021 21:42:27 +0800 Subject: [PATCH] develop mindir load and run --- mindspore/ccsrc/pipeline/jit/action.cc | 76 +++++++++++++++- mindspore/ccsrc/pipeline/jit/init.cc | 1 + .../pipeline/jit/parse/data_converter.cc | 86 ++++++++++++------- mindspore/ccsrc/pipeline/jit/parse/resolve.cc | 44 ++++++++++ mindspore/ccsrc/pipeline/jit/pipeline.cc | 3 + mindspore/ccsrc/pipeline/jit/pipeline.h | 1 + mindspore/ccsrc/pipeline/jit/validator.cc | 3 + mindspore/core/ir/anf.h | 4 + mindspore/core/ir/func_graph.cc | 12 +++ mindspore/core/ir/func_graph.h | 2 + mindspore/core/ir/func_graph_cloner.cc | 1 + .../core/load_mindir/anf_model_parser.cc | 41 +++++---- mindspore/nn/__init__.py | 4 +- mindspore/nn/cell.py | 38 +++++++- mindspore/train/__init__.py | 4 +- mindspore/train/model.py | 17 ++++ mindspore/train/serialization.py | 44 ++++++++++ .../test_train_mindir.py | 25 +++++- .../{export => export_and_load}/text_air.py | 0 .../text_lite_mindir.py | 0 20 files changed, 351 insertions(+), 55 deletions(-) rename tests/st/{export => export_and_load}/test_train_mindir.py (81%) rename tests/st/{export => export_and_load}/text_air.py (100%) rename tests/st/{export => export_and_load}/text_lite_mindir.py (100%) diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index e3c4659909..1295874497 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -66,6 +66,10 @@ abstract::AnalysisResult AbstractAnalyze(const ResourcePtr &res, const FuncGraph for (auto &node : manager->all_nodes()) { MS_EXCEPTION_IF_NULL(node); const AbstractBasePtr &prev_inferred = node->abstract(); + // Keep previous inferred value for CNode if is loaded from MindIR. + if (node->isa() && node->cast()->get_load_flag()) { + continue; + } // Keep previous inferred value for ValueNode if the inferred value is not AbstractFunction. if (!node->isa() || (prev_inferred != nullptr && prev_inferred->isa())) { node->set_abstract(nullptr); @@ -113,6 +117,69 @@ FuncGraphPtr Renormalize(const ResourcePtr &res, const FuncGraphPtr &func_graph, return ret; } +const FuncGraphPtr GetLoadedGraph(const ResourcePtr &res) { + MS_EXCEPTION_IF_NULL(res); + auto manager = res->manager(); + MS_EXCEPTION_IF_NULL(manager); + FuncGraphPtr loaded_graph = nullptr; + size_t loaded_graph_num = 0; + auto all_graphs = manager->func_graphs(); + for (auto &graph : all_graphs) { + MS_EXCEPTION_IF_NULL(graph); + if (graph->has_attr("is_load")) { + loaded_graph = graph; + loaded_graph_num += 1; + } + } + if (loaded_graph_num == 0) { + return nullptr; + } + if (loaded_graph_num == 1) { + return loaded_graph; + } + MS_LOG(EXCEPTION) << "The loaded sub graph currently should less than 2, but got " << loaded_graph_num; +} + +void CheckRootInputShapeAndType(const ResourcePtr &res, const FuncGraphPtr &loaded_graph) { + MS_EXCEPTION_IF_NULL(res); + auto manager = res->manager(); + MS_EXCEPTION_IF_NULL(manager); + FuncGraphPtr root_graph = *(manager->roots().begin()); + auto root_inputs = root_graph->get_inputs(); + auto loaded_inputs = loaded_graph->get_inputs(); + + size_t root_inputs_num = root_inputs.size(); + size_t loaded_inputs_num = loaded_inputs.size(); + if (root_inputs_num != loaded_inputs_num) { + MS_LOG(EXCEPTION) << "The inputs number " << root_inputs_num << " not equal to the inputs number of loaded graph " + << loaded_inputs_num; + } + for (size_t index = 0; index < root_inputs_num; index++) { + auto root_input = root_inputs[index]; + auto loaded_input = loaded_inputs[index]; + + auto root_shape = root_input->Shape() == nullptr ? nullptr : dyn_cast(root_input->Shape()); + auto loaded_shape = loaded_input->Shape() == nullptr ? nullptr : dyn_cast(loaded_input->Shape()); + auto root_type = root_input->Type() == nullptr ? nullptr : dyn_cast(root_input->Type()); + auto loaded_type = loaded_input->Type() == nullptr ? nullptr : dyn_cast(loaded_input->Type()); + MS_EXCEPTION_IF_NULL(root_shape); + MS_EXCEPTION_IF_NULL(loaded_shape); + MS_EXCEPTION_IF_NULL(root_type); + MS_EXCEPTION_IF_NULL(loaded_type); + + if (root_shape->shape() != loaded_shape->shape()) { + MS_EXCEPTION(ValueError) << "The " << index + << " th input shape differ from loaded graph. Input shape: " << root_shape->ToString() + << ", input shape of loaded graph: " << loaded_shape->ToString(); + } + if (root_type->type_id() != loaded_type->type_id()) { + MS_EXCEPTION(TypeError) << "The " << std::to_string(index) + << " th input type differ from loaded graph. Input type: " << root_type->ToString() + << ", input type of loaded graph: " << loaded_type->ToString(); + } + } +} + bool ParseAction(const ResourcePtr &res) { if (!res->input()) { MS_LOG(EXCEPTION) << "Parse error"; @@ -255,12 +322,14 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { if (res->func_graph() == nullptr) { MS_LOG(EXCEPTION) << "AbstractSpecialize error"; } - FuncGraphPtr func_graph = res->func_graph(); abstract::AbstractBasePtrList args_spec = res->args_spec(); auto context = parallel::ParallelContext::GetInstance(); MS_EXCEPTION_IF_NULL(parallel::ParallelContext::GetInstance()); context->ParallelParameterContextInitShape(func_graph); + + // get original loaded graph to check inputs later + auto loaded_graph_ptr = GetLoadedGraph(res); // suppose that there is not KeywordArgument for the top graph // get the hyper parameter for (const auto ¶m : func_graph->parameters()) { @@ -294,7 +363,10 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { } } } - + // check input after abstract when there is a loaded graph + if (loaded_graph_ptr != nullptr) { + CheckRootInputShapeAndType(res, loaded_graph_ptr); + } MS_LOG(DEBUG) << "End graph: " << new_fg->ToString() << ", return: " << new_fg->get_return()->DebugString(true); return true; } diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index 57349f6502..952196292f 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -111,6 +111,7 @@ PYBIND11_MODULE(_c_expression, m) { (void)m.def("init_pipeline", &mindspore::pipeline::InitPipeline, "Init Pipeline."); (void)m.def("export_graph", &mindspore::pipeline::ExportGraph, "Export Graph."); + (py::object) m.def("load_mindir", &mindspore::pipeline::LoadMindIR, py::arg("file_name"), "Load model as Graph."); (void)py::class_>(m, "MpiConfig") .def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.") diff --git a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc index 14b219d3d0..fcad535cc9 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/data_converter.cc @@ -203,6 +203,19 @@ bool ConvertMetaFuncGraph(const py::object &obj, ValuePtr *const data, bool use_ return true; } +bool ConvertFuncGraph(const py::object &obj, ValuePtr *const data) { + MS_LOG(DEBUG) << "Converting FuncGraph object"; + auto func_graph = obj.cast(); + if (func_graph == nullptr) { + MS_LOG(ERROR) << "Resolve FuncGraph error, get ptr is null"; + return false; + } + auto new_fg = BasicClone(func_graph); + new_fg->set_attr("is_load", MakeValue(true)); + *data = new_fg; + return true; +} + bool ConvertSlice(const py::object &obj, ValuePtr *const data) { MS_LOG(DEBUG) << "Converting slice object"; @@ -368,47 +381,21 @@ bool ConvertFloatWithType(const float &obj, ValuePtr *const data, TypePtr dtype } } // namespace -bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, TypePtr dtype) { - // check parameter valid - if (data == nullptr) { - MS_LOG(ERROR) << "Data is null pointer"; - return false; - } - - bool ret = true; +bool ConvertSingleData(const py::object &obj, ValuePtr *const data) { + MS_EXCEPTION_IF_NULL(data); ValuePtr converted = nullptr; if (py::isinstance(obj)) { converted = kNone; } else if (py::isinstance(obj)) { converted = std::make_shared(py::cast(obj)); - } else if (py::isinstance(obj)) { - ret = ConvertIntegerWithType(py::cast(obj), &converted, dtype); - } else if (py::isinstance(obj)) { - ret = ConvertFloatWithType(py::cast(obj), &converted, dtype); } else if (py::isinstance(obj)) { converted = std::make_shared(py::cast(obj)); - } else if (py::isinstance(obj)) { - ret = ConvertDict(obj, &converted, use_signature); - } else if (py::isinstance(obj)) { - ret = ConvertSlice(obj, &converted); } else if (py::isinstance(obj)) { converted = kEllipsis; - } else if (py::isinstance(obj)) { - ret = ConvertTuple(obj, &converted, use_signature); - } else if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) { - ret = ConvertCellList(obj, &converted, use_signature); - } else if (py::isinstance(obj)) { - return ConvertCellObjToFuncGraph(obj.cast(), data); - } else if (py::isinstance(obj)) { - ret = ConvertList(obj, &converted, use_signature); } else if (py::isinstance(obj)) { ConvertNameSpace(obj, &converted); } else if (py::hasattr(obj, PYTHON_DATACLASS_FIELDS)) { ConvertDataClass(obj, &converted); - } else if (py::hasattr(obj, PYTHON_PRIMITIVE_FLAG)) { - ret = ConvertPrimitive(obj, &converted, use_signature); - } else if (py::isinstance(obj)) { - ret = ConvertMetaFuncGraph(obj, &converted, use_signature); } else if (py::isinstance(obj)) { converted = obj.cast(); } else if (py::isinstance(obj)) { @@ -425,9 +412,50 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature } else if (py::hasattr(obj, PYTHON_CLASS_MEMBER_NAMESPACE)) { converted = std::make_shared(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, obj); } else { - ret = ConvertOtherObj(obj, &converted); + return false; + } + *data = converted; + return true; +} + +bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature, TypePtr dtype) { + // check parameter valid + if (data == nullptr) { + MS_LOG(ERROR) << "Data is null pointer"; + return false; } + ValuePtr converted = nullptr; + bool ret = ConvertSingleData(obj, &converted); + if (ret) { + *data = converted; + return true; + } + if (py::isinstance(obj)) { + ret = ConvertIntegerWithType(py::cast(obj), &converted, dtype); + } else if (py::isinstance(obj)) { + ret = ConvertFloatWithType(py::cast(obj), &converted, dtype); + } else if (py::isinstance(obj)) { + ret = ConvertDict(obj, &converted, use_signature); + } else if (py::isinstance(obj)) { + ret = ConvertSlice(obj, &converted); + } else if (py::isinstance(obj)) { + ret = ConvertTuple(obj, &converted, use_signature); + } else if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) { + ret = ConvertCellList(obj, &converted, use_signature); + } else if (py::isinstance(obj)) { + return ConvertCellObjToFuncGraph(obj.cast(), data); + } else if (py::isinstance(obj)) { + ret = ConvertList(obj, &converted, use_signature); + } else if (py::hasattr(obj, PYTHON_PRIMITIVE_FLAG)) { + ret = ConvertPrimitive(obj, &converted, use_signature); + } else if (py::isinstance(obj)) { + ret = ConvertMetaFuncGraph(obj, &converted, use_signature); + } else if (py::isinstance(obj)) { + ret = ConvertFuncGraph(obj, &converted); + } else { + ret = ConvertOtherObj(obj, &converted); + } *data = converted; return ret; } diff --git a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc index a6d2f1cc3d..b4387a8e7c 100644 --- a/mindspore/ccsrc/pipeline/jit/parse/resolve.cc +++ b/mindspore/ccsrc/pipeline/jit/parse/resolve.cc @@ -113,6 +113,49 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object return para_node; } +void BroadenCNodeAbstract(const FuncGraphPtr &func_graph) { + std::vector nodes = TopoSort(func_graph->get_return(), SuccIncoming, AlwaysInclude); + for (const AnfNodePtr &node : nodes) { + if (!node->isa()) { + continue; + } + auto abstract = node->abstract(); + if (abstract != nullptr) { + node->set_abstract(abstract->Broaden()); + } + } +} + +void ConvertLoadedGraph(const FuncGraphPtr &func_graph, const ValuePtr &value) { + if (!value->isa()) { + return; + } + auto resolved_graph = value->cast(); + MS_EXCEPTION_IF_NULL(resolved_graph); + if (!resolved_graph->has_attr("is_load")) { + return; + } + auto top_graph = Parser::GetTopFuncGraph(); + std::vector input_params; + for (auto const ¶m : resolved_graph->parameters()) { + auto param_ptr = dyn_cast(param); + MS_EXCEPTION_IF_NULL(param_ptr); + if (param_ptr->has_default()) { + param_ptr->set_func_graph(top_graph); + func_graph->add_used_global_parameters(param_ptr); + + // update top_graph + top_graph->add_parameter(param_ptr); + size_t hyper_param_count = top_graph->hyper_param_count(); + top_graph->set_hyper_param_count(hyper_param_count + 1); + } else { + input_params.push_back(param_ptr); + } + } + resolved_graph->set_parameters(input_params); + BroadenCNodeAbstract(resolved_graph); +} + bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, AnfNodePtr *const node) { AnfNodePtr output = nullptr; if (py::hasattr(obj, "__parameter__") && py::isinstance(obj)) { @@ -146,6 +189,7 @@ bool ResolveObjectToNode(const FuncGraphPtr &func_graph, const py::object &obj, return false; } MS_EXCEPTION_IF_NULL(convert_result); + ConvertLoadedGraph(func_graph, convert_result); output = NewValueNode(convert_result); if (convert_result->isa()) { output = GetMixedPrecisionCastHelp(func_graph, output); diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index 874314feaa..3866530b16 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -48,6 +48,7 @@ #include "pybind_api/pybind_patch.h" #include "utils/shape_utils.h" #include "utils/info.h" +#include "load_mindir/load_model.h" #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #include "ps/constants.h" #include "ps/util.h" @@ -1096,6 +1097,8 @@ void ExportGraph(const std::string &file_name, const std::string &, const std::s #endif } +FuncGraphPtr LoadMindIR(const std::string &file_name) { return mindspore::LoadMindIR(file_name); } + void ReleaseGeTsd() { auto context_ptr = MsContext::GetInstance(); if (context_ptr != nullptr) { diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.h b/mindspore/ccsrc/pipeline/jit/pipeline.h index 2d0caba826..6ead471423 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.h +++ b/mindspore/ccsrc/pipeline/jit/pipeline.h @@ -140,6 +140,7 @@ void ClearResAtexit(); void ReleaseGeTsd(); void ExportGraph(const std::string &file_name, const std::string &, const std::string &phase); +FuncGraphPtr LoadMindIR(const std::string &file_name); // init and exec dataset sub graph bool InitExecDataset(const std::string &queue_name, int64_t iter_num, int64_t batch_size, diff --git a/mindspore/ccsrc/pipeline/jit/validator.cc b/mindspore/ccsrc/pipeline/jit/validator.cc index 17426ad4e0..944dc9d1c0 100644 --- a/mindspore/ccsrc/pipeline/jit/validator.cc +++ b/mindspore/ccsrc/pipeline/jit/validator.cc @@ -51,6 +51,9 @@ void ValidateOperation(const AnfNodePtr &node) { if (abstract::IsInWhiteList(prim)) { return; } + if (prim->HasAttr("is_load")) { + return; + } if (prim->HasPyEvaluator()) { MS_LOG(DEBUG) << "Primitive " << prim->name() << " has python evaluator."; return; diff --git a/mindspore/core/ir/anf.h b/mindspore/core/ir/anf.h index 63989555a1..966dc834ca 100644 --- a/mindspore/core/ir/anf.h +++ b/mindspore/core/ir/anf.h @@ -273,6 +273,9 @@ class CNode : public AnfNode, public EffectInfoHolder { void set_in_forward_flag(bool flag) { in_forward_flag_ = flag; } bool in_forward_flag() const { return in_forward_flag_; } + void set_load_flag(bool is_load) { is_load_ = is_load; } + bool get_load_flag() { return is_load_; } + VarPtr func_graph_as_var() const { return func_graph_as_var_; } const std::unordered_map &attrs() const { return attrs_; } @@ -304,6 +307,7 @@ class CNode : public AnfNode, public EffectInfoHolder { bool stop_gradient_; bool in_forward_flag_ = false; bool effect_handled_ = false; + bool is_load_ = false; // inputs_value_ store cnode input value and id in pynative mode // output_value_ store cnode value and id in pynative mode std::vector> inputs_value_; diff --git a/mindspore/core/ir/func_graph.cc b/mindspore/core/ir/func_graph.cc index 6d25f77d14..aae51466d1 100644 --- a/mindspore/core/ir/func_graph.cc +++ b/mindspore/core/ir/func_graph.cc @@ -68,6 +68,18 @@ AnfNodePtr FuncGraph::output() const { } } +const std::vector FuncGraph::get_inputs() const { + std::vector input_params; + for (auto const &node : parameters_) { + MS_EXCEPTION_IF_NULL(node); + auto parameter = dyn_cast(node); + if (!parameter->has_default()) { + input_params.push_back(parameter); + } + } + return input_params; +} + ParameterPtr FuncGraph::add_parameter() { FuncGraphPtr this_func_graph = shared_from_base(); ParameterPtr p = std::make_shared(this_func_graph); diff --git a/mindspore/core/ir/func_graph.h b/mindspore/core/ir/func_graph.h index 11dff388a8..f82618dd29 100644 --- a/mindspore/core/ir/func_graph.h +++ b/mindspore/core/ir/func_graph.h @@ -160,6 +160,8 @@ class FuncGraph : public FuncGraphBase, public EffectInfoHolder { abstract::AbstractFunctionPtr abstract(); abstract::AbstractBasePtr ToAbstract() override; + // get function graph inputs, but parameters + const std::vector get_inputs() const; // Return the graph's output, or nullptr if not yet deduced. AnfNodePtr output() const; void set_output(const AnfNodePtr &value, bool force_new_ret = false); diff --git a/mindspore/core/ir/func_graph_cloner.cc b/mindspore/core/ir/func_graph_cloner.cc index a9dfb4c8d6..c9cc36ec04 100644 --- a/mindspore/core/ir/func_graph_cloner.cc +++ b/mindspore/core/ir/func_graph_cloner.cc @@ -91,6 +91,7 @@ void Cloner::CloneCNode(const AnfNodePtr &node, const FuncGraphPtr &target) { new_node->set_forward(old_node->forward().first, old_node->forward().second); new_node->set_inputs_value(old_node->inputs_value()); new_node->set_attrs(old_node->attrs()); + new_node->set_load_flag(old_node->get_load_flag()); ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope(); new_node->set_scope(scope); new_node->CloneUserData(old_node); diff --git a/mindspore/core/load_mindir/anf_model_parser.cc b/mindspore/core/load_mindir/anf_model_parser.cc index 3164c1a85b..4d06d53a48 100644 --- a/mindspore/core/load_mindir/anf_model_parser.cc +++ b/mindspore/core/load_mindir/anf_model_parser.cc @@ -228,17 +228,14 @@ tensor::TensorPtr MSANFModelParser::BuildTensorInfoForFuncGraph(const mind_ir::T } if (!tensor_proto.has_data_type()) { - MS_LOG(ERROR) << "mind_ir TensorProto has no data_type or name!"; - return nullptr; + MS_LOG(EXCEPTION) << "mind_ir TensorProto has no data_type or name!"; } if (kDefaultValueSwitchMap.find(tensor_proto.data_type()) == kDefaultValueSwitchMap.end()) { - MS_LOG(ERROR) << "mind_ir TensorProto data_type is not support yet!"; - return nullptr; + MS_LOG(EXCEPTION) << "mind_ir TensorProto data_type is not support yet!"; } tensor::TensorPtr tensor_info = std::make_shared(kDefaultValueSwitchMap[tensor_proto.data_type()], shape); - MS_EXCEPTION_IF_NULL(tensor_info); return tensor_info; } @@ -253,9 +250,14 @@ bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node, string debug_info_name = ParseParameterName(parameter_proto.name()); auto debug_info_ptr = std::make_shared(debug_info_name); node->set_debug_info(debug_info_ptr); - node->set_name(parameter_proto.name()); + node->set_name(debug_info_name); tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(parameter_proto); + MS_EXCEPTION_IF_NULL(tensor_info); + ParamInfoPtr param_info = std::make_shared(); + param_info->set_name(debug_info_name); + tensor_info->set_param_info(param_info); + auto tensor_abstract = tensor_info->ToAbstract(); MS_EXCEPTION_IF_NULL(tensor_abstract); node->set_abstract(tensor_abstract); @@ -284,13 +286,13 @@ bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mi string debug_info_name = ParseParameterName(value_proto.name()); auto debug_info_ptr = std::make_shared(debug_info_name); node->set_debug_info(debug_info_ptr); - node->set_name(value_proto.name()); + node->set_name(debug_info_name); const mind_ir::TensorProto &tensor_proto = value_proto.tensor(0); tensor::TensorPtr tensor_info = BuildTensorInfoForFuncGraph(tensor_proto); + MS_EXCEPTION_IF_NULL(tensor_info); auto tensor_abstract = tensor_info->ToAbstract(); - MS_EXCEPTION_IF_NULL(tensor_abstract); node->set_abstract(tensor_abstract); anfnode_build_map_[value_proto.name()] = node; @@ -300,15 +302,6 @@ bool MSANFModelParser::BuildInputForFuncGraph(const ParameterPtr &node, const mi bool MSANFModelParser::ImportParametersForGraph(const FuncGraphPtr &outputFuncGraph, const mind_ir::GraphProto &importProto) { MS_EXCEPTION_IF_NULL(outputFuncGraph); - MS_LOG(INFO) << "All Parameters size is: " << importProto.parameter_size(); - for (int i = 0; i < importProto.parameter_size(); ++i) { - const mind_ir::TensorProto ¶meter_proto = importProto.parameter(i); - if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), parameter_proto)) { - MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i; - return false; - } - } - MS_LOG(INFO) << "All inputs size is: " << importProto.input_size(); for (int i = 0; i < importProto.input_size(); ++i) { const mind_ir::ValueInfoProto &input_proto = importProto.input(i); @@ -317,6 +310,15 @@ bool MSANFModelParser::ImportParametersForGraph(const FuncGraphPtr &outputFuncGr return false; } } + + MS_LOG(INFO) << "All Parameters size is: " << importProto.parameter_size(); + for (int i = 0; i < importProto.parameter_size(); ++i) { + const mind_ir::TensorProto ¶meter_proto = importProto.parameter(i); + if (!BuildParameterForFuncGraph(outputFuncGraph->add_parameter(), parameter_proto)) { + MS_LOG(ERROR) << "Build parameter for funcgraph fail at index: " << i; + return false; + } + } return true; } @@ -745,7 +747,7 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc inputs.push_back(anfnode_build_map_[input_name]); } - + prim->set_attr("is_load", MakeValue(true)); auto cnode_ptr = outputFuncGraph->NewCNode(prim, inputs); MS_EXCEPTION_IF_NULL(cnode_ptr); @@ -777,6 +779,7 @@ CNodePtr MSANFModelParser::BuildCNodeForFuncGraph(const FuncGraphPtr &outputFunc auto debug_info_ptr = std::make_shared(debug_info_name); cnode_ptr->set_debug_info(debug_info_ptr); cnode_ptr->set_fullname_with_scope(fullname_with_scope); + cnode_ptr->set_load_flag(true); anfnode_build_map_[node_name] = cnode_ptr; return cnode_ptr; @@ -804,6 +807,7 @@ bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGra inputs.push_back(maketuple_ptr); auto return_node = outputFuncGraph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(return_node); + return_node->set_load_flag(true); outputFuncGraph->set_return(return_node); MS_LOG(INFO) << "Construct funcgraph finined, all success."; } else { @@ -812,6 +816,7 @@ bool MSANFModelParser::BuildReturnForFuncGraph(const FuncGraphPtr &outputFuncGra inputs.push_back(cnode_ptr); auto return_node = outputFuncGraph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(return_node); + return_node->set_load_flag(true); outputFuncGraph->set_return(return_node); MS_LOG(INFO) << "Construct funcgraph finined, all success!"; } diff --git a/mindspore/nn/__init__.py b/mindspore/nn/__init__.py index 9a486023e7..0dcad2f6fd 100644 --- a/mindspore/nn/__init__.py +++ b/mindspore/nn/__init__.py @@ -20,7 +20,7 @@ Pre-defined building blocks or computing units to construct neural networks. from . import layer, loss, optim, metrics, wrap, probability, sparse, dynamic_lr from .learning_rate_schedule import * from .dynamic_lr import * -from .cell import Cell, GraphKernel +from .cell import Cell, GraphKernel, GraphCell from .layer import * from .loss import * from .optim import * @@ -29,7 +29,7 @@ from .wrap import * from .sparse import * -__all__ = ["Cell", "GraphKernel"] +__all__ = ["Cell", "GraphKernel", "GraphCell"] __all__.extend(layer.__all__) __all__.extend(loss.__all__) __all__.extend(optim.__all__) diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 817b0091d3..a27a2bad5d 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -25,7 +25,7 @@ from mindspore import log as logger from mindspore.common.parameter import PARAMETER_NAME_DEFAULT from mindspore.context import ParallelMode from .. import context -from .._c_expression import init_pipeline, Cell_ +from .._c_expression import init_pipeline, Cell_, FuncGraph from .._checkparam import Validator from ..common import dtype as mstype from ..common.api import _executor, _pynative_exec @@ -1191,3 +1191,39 @@ class GraphKernel(Cell): def construct(self): raise NotImplementedError + + +class GraphCell(Cell): + """ + Base class for running the graph loaded from a MindIR. + + This feature is still under development. Currently `GraphCell` do not support modifying the structure of the + diagram, and can only use data that shape and type are the same as the input when exporting the MindIR. + + Args: + graph (object): A compiled graph loaded from MindIR. + + Examples: + >>> import numpy as np + >>> import mindspore.nn as nn + >>> from mindspore import Tensor + >>> from mindspore.train import export, load + >>> + >>> net = nn.Conv2d(1, 1, kernel_size=3) + >>> input = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)) + >>> export(net, input, file_name="net", file_format="MINDIR") + >>> graph = load("net.mindir") + >>> net = nn.GraphCell(graph) + >>> output = net(input) + """ + def __init__(self, graph): + super(GraphCell, self).__init__(auto_prefix=True) + if not isinstance(graph, FuncGraph): + raise TypeError(f"graph must be a FuncGraph loaded from MindIR, but got {type(graph)}.") + self.graph = graph + + def construct(self, *inputs): + return self.graph(*inputs) + + def __call__(self, *inputs): + return self.compile_and_run(*inputs) diff --git a/mindspore/train/__init__.py b/mindspore/train/__init__.py index b6aff9b773..5a05c655b0 100644 --- a/mindspore/train/__init__.py +++ b/mindspore/train/__init__.py @@ -22,10 +22,10 @@ from .dataset_helper import DatasetHelper, connect_network_with_dataset from . import amp from .amp import build_train_network from .loss_scale_manager import LossScaleManager, FixedLossScaleManager, DynamicLossScaleManager -from .serialization import save_checkpoint, load_checkpoint, load_param_into_net, export, parse_print,\ +from .serialization import save_checkpoint, load_checkpoint, load_param_into_net, export, load, parse_print,\ build_searched_strategy, merge_sliced_parameter, load_distributed_checkpoint __all__ = ["Model", "DatasetHelper", "amp", "connect_network_with_dataset", "build_train_network", "LossScaleManager", "FixedLossScaleManager", "DynamicLossScaleManager", "save_checkpoint", "load_checkpoint", - "load_param_into_net", "export", "parse_print", "build_searched_strategy", "merge_sliced_parameter", + "load_param_into_net", "export", "load", "parse_print", "build_searched_strategy", "merge_sliced_parameter", "load_distributed_checkpoint"] diff --git a/mindspore/train/model.py b/mindspore/train/model.py index 5c2f8ed3d5..269d4817b9 100755 --- a/mindspore/train/model.py +++ b/mindspore/train/model.py @@ -139,10 +139,22 @@ class Model: self._global_rank = _get_global_rank() self._parameter_broadcast = _get_parameter_broadcast() + self._check_for_graph_cell(kwargs) self._train_network = self._build_train_network() self._build_eval_network(metrics, eval_network, eval_indexes) self._build_predict_network() + def _check_for_graph_cell(self, kwargs): + if not isinstance(self._network, nn.GraphCell): + return + if self._amp_level != "O0": + logger.warning("amp_level will not work when network is a GraphCell.") + + if self._loss_fn is not None or self._optimizer is not None: + raise ValueError("Currently loss_fn and optimizer should be None when network is a GraphCell. ") + if kwargs: + raise ValueError("Currently kwargs should be empty when network is a GraphCell. ") + def _process_amp_args(self, kwargs): if self._amp_level in ["O0", "O3"]: self._keep_bn_fp32 = False @@ -586,6 +598,8 @@ class Model: >>> model.train(2, dataset) """ dataset_sink_mode = Validator.check_bool(dataset_sink_mode) + if isinstance(self._train_network, nn.GraphCell) and dataset_sink_mode is True: + raise ValueError("Sink mode is currently not supported when training with a GraphCell.") Validator.check_is_int(sink_size) dataset_size = train_dataset.get_dataset_size() if dataset_size == 0: @@ -704,9 +718,12 @@ class Model: >>> acc = model.eval(dataset, dataset_sink_mode=False) """ dataset_sink_mode = Validator.check_bool(dataset_sink_mode) + _device_number_check(self._parallel_mode, self._device_number) if not self._metric_fns: raise ValueError("metric fn can not be None or empty.") + if isinstance(self._eval_network, nn.GraphCell) and dataset_sink_mode is True: + raise ValueError("Sink mode is currently not supported when evaluating with a GraphCell.") cb_params = _InternalCallbackParam() cb_params.eval_network = self._eval_network diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 498e54de3f..8b50de18d0 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -38,6 +38,7 @@ from mindspore._checkparam import check_input_data, Validator from mindspore.compression.export import quant_export from mindspore.parallel._tensor import _load_tensor from mindspore.parallel._utils import _infer_rank_list, _remove_repeated_slices +from .._c_expression import load_mindir tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16, @@ -228,6 +229,49 @@ def _check_param_prefix(filter_prefix, param_name): return False +def load(file_name): + """ + Load MindIR. + + The returned object can be executed by a `GraphCell`. However, there are some limitations to the current use + of `GraphCell`, see class :class:`mindspore.nn.GraphCell` for more details. + + Args: + file_name (str): MindIR file name. + + Returns: + Object, a compiled graph that can executed by `GraphCell`. + + Raises: + ValueError: MindIR file is incorrect. + + Examples: + >>> import numpy as np + >>> import mindspore.nn as nn + >>> from mindspore import Tensor + >>> from mindspore.train import export, load + >>> + >>> net = nn.Conv2d(1, 1, kernel_size=3) + >>> input = Tensor(np.ones([1, 1, 3, 3]).astype(np.float32)) + >>> export(net, input, file_name="net", file_format="MINDIR") + >>> graph = load("net.mindir") + >>> net = nn.GraphCell(graph) + >>> output = net(input) + """ + if not isinstance(file_name, str): + raise ValueError("The file name must be string.") + if not os.path.exists(file_name): + raise ValueError("The file is not exist.") + if not file_name.endswith(".mindir"): + raise ValueError("The MindIR should end with mindir, please input the correct file name.") + + logger.info("Execute the process of loading mindir.") + graph = load_mindir(file_name) + if graph is None: + raise RuntimeError("Load MindIR failed.") + return graph + + def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=None): """ Loads checkpoint info from a specified file. diff --git a/tests/st/export/test_train_mindir.py b/tests/st/export_and_load/test_train_mindir.py similarity index 81% rename from tests/st/export/test_train_mindir.py rename to tests/st/export_and_load/test_train_mindir.py index 70579d5ed7..12197b124a 100644 --- a/tests/st/export/test_train_mindir.py +++ b/tests/st/export_and_load/test_train_mindir.py @@ -22,7 +22,7 @@ from mindspore.common.initializer import TruncatedNormal from mindspore.common.parameter import ParameterTuple from mindspore.ops import operations as P from mindspore.ops import composite as C -from mindspore.train.serialization import export +from mindspore.train.serialization import export, load def weight_variable(): @@ -112,3 +112,26 @@ def test_export_lenet_grad_mindir(): export(net, predict, label, file_name="lenet_grad", file_format='MINDIR') verify_name = "lenet_grad.mindir" assert os.path.exists(verify_name) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.env_onecard +def test_load_mindir_and_run(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + network = LeNet5() + network.set_train() + + inputs0 = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01) + outputs0 = network(inputs0) + + inputs = Tensor(np.zeros([32, 1, 32, 32]).astype(np.float32)) + export(network, inputs, file_name="test_lenet_load", file_format='MINDIR') + mindir_name = "test_lenet_load.mindir" + assert os.path.exists(mindir_name) + + graph = load(mindir_name) + loaded_net = nn.GraphCell(graph) + outputs_after_load = loaded_net(inputs0) + assert np.allclose(outputs0.asnumpy(), outputs_after_load.asnumpy()) diff --git a/tests/st/export/text_air.py b/tests/st/export_and_load/text_air.py similarity index 100% rename from tests/st/export/text_air.py rename to tests/st/export_and_load/text_air.py diff --git a/tests/st/export/text_lite_mindir.py b/tests/st/export_and_load/text_lite_mindir.py similarity index 100% rename from tests/st/export/text_lite_mindir.py rename to tests/st/export_and_load/text_lite_mindir.py