From 047ac018da4b9f6c194b7d0d9b7fc2c3ee6891e9 Mon Sep 17 00:00:00 2001 From: zhousiyi Date: Sun, 26 Apr 2020 03:42:55 +0000 Subject: [PATCH] specialize hypermap paramater --- mindspore/ccsrc/ir/func_graph.h | 1 + .../ccsrc/operator/composite/composite.cc | 2 + mindspore/ccsrc/optimizer/clean.cc | 7 +- mindspore/ccsrc/optimizer/clean.h | 2 +- .../ccsrc/optimizer/irpass/cast_eliminate.cc | 6 +- mindspore/ccsrc/pipeline/pass.cc | 8 +- .../static_analysis/abstract_function.cc | 4 +- .../static_analysis/abstract_function.h | 13 ++- .../ccsrc/pipeline/static_analysis/prim.cc | 32 ++++++-- .../static_analysis/program_specialize.cc | 74 ++++++++++++++--- .../static_analysis/program_specialize.h | 3 + tests/ut/python/ops/test_math_ops_check.py | 80 ------------------- .../infer/test_hypermap_specialize.py | 54 +++++++++++++ 13 files changed, 175 insertions(+), 111 deletions(-) create mode 100644 tests/ut/python/pipeline/infer/test_hypermap_specialize.py diff --git a/mindspore/ccsrc/ir/func_graph.h b/mindspore/ccsrc/ir/func_graph.h index bca5759807..91fea89eb3 100644 --- a/mindspore/ccsrc/ir/func_graph.h +++ b/mindspore/ccsrc/ir/func_graph.h @@ -42,6 +42,7 @@ using CNodeIndexCounterMap = OrderedMap #include "ir/anf.h" +#include "ir/func_graph.h" #include "pipeline/static_analysis/abstract_value.h" #include "pipeline/static_analysis/abstract_function.h" #include "pipeline/static_analysis/dshape.h" @@ -334,6 +335,7 @@ ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairL FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) { FuncGraphPtr ptrGraph = std::make_shared(); ptrGraph->set_flags(FUNC_GRAPH_FLAG_CORE, true); + ptrGraph->set_flags(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true); ptrGraph->debug_info()->set_name("hyper_map"); AnfNodePtr ptrFnArg = nullptr; diff --git a/mindspore/ccsrc/optimizer/clean.cc b/mindspore/ccsrc/optimizer/clean.cc index fe11191546..97ac72e3fb 100644 --- a/mindspore/ccsrc/optimizer/clean.cc +++ b/mindspore/ccsrc/optimizer/clean.cc @@ -278,10 +278,12 @@ AnfNodePtr ConvertValueListNodeToValueTupleNode(const ValueNodePtr &node) { // Convert class to Tuple // Convert getattr to getitem // Convert make_record to make_tuple -void SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) { +bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager) { MS_EXCEPTION_IF_NULL(manager); manager->AddFuncGraph(root); + bool changed = false; + // Since `manager->Replace(...);` will modify member `all_nodes_`, so `all_node` can't be a ref var AnfNodeSet all_node = manager->all_nodes(); for (auto &node : all_node) { @@ -316,7 +318,9 @@ void SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr if (new_node != nullptr) { new_node->set_abstract(node->abstract()); + MS_LOG(DEBUG) << "Replace node: " << node->DebugString() << " with new_node: " << new_node->DebugString(); (void)manager->Replace(node, new_node); + changed = true; } } @@ -324,6 +328,7 @@ void SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr auto ret = Reabs(node->abstract()); node->set_abstract(ret); } + return changed; } // expand tuples in graph parameters diff --git a/mindspore/ccsrc/optimizer/clean.h b/mindspore/ccsrc/optimizer/clean.h index 01db7d363d..0130ecfb32 100644 --- a/mindspore/ccsrc/optimizer/clean.h +++ b/mindspore/ccsrc/optimizer/clean.h @@ -31,7 +31,7 @@ namespace mindspore { namespace opt { // Remove the class type from graphs -void SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager); +bool SimplifyDataStructures(const FuncGraphPtr &root, const FuncGraphManagerPtr &manager); // Remove most uses of tuples from the graph // tuples that are returned will be kept diff --git a/mindspore/ccsrc/optimizer/irpass/cast_eliminate.cc b/mindspore/ccsrc/optimizer/irpass/cast_eliminate.cc index c5a6b6672c..a497f3d5bd 100644 --- a/mindspore/ccsrc/optimizer/irpass/cast_eliminate.cc +++ b/mindspore/ccsrc/optimizer/irpass/cast_eliminate.cc @@ -38,13 +38,11 @@ AnfNodePtr CastSameTypeEliminater::operator()(const OptimizerPtr &, const AnfNod // src type check auto src_type = src_->Type(); - if (src_type == nullptr) { + if (src_type == nullptr || !src_type->isa()) { return nullptr; } - if (src_type->isa()) { - src_type = src_type->cast()->element(); - } + src_type = src_type->cast()->element(); // tgt type check auto tgt_type = GetValueNode(tgt_); diff --git a/mindspore/ccsrc/pipeline/pass.cc b/mindspore/ccsrc/pipeline/pass.cc index d9f805fdc9..4614f19442 100644 --- a/mindspore/ccsrc/pipeline/pass.cc +++ b/mindspore/ccsrc/pipeline/pass.cc @@ -52,14 +52,16 @@ bool SimplifyDataStructuresPass(const ResourcePtr &res) { MS_EXCEPTION_IF_NULL(res->func_graph()); FuncGraphPtr func_graph = res->func_graph(); - opt::SimplifyDataStructures(func_graph, res->manager()); + bool changed = opt::SimplifyDataStructures(func_graph, res->manager()); abstract::AbstractBasePtrList args_spec; auto parameters = func_graph->parameters(); (void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec), [](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); }); - FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec); - res->set_func_graph(new_fg); + if (changed) { + FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec); + res->set_func_graph(new_fg); + } res->set_args_spec(args_spec); return true; } diff --git a/mindspore/ccsrc/pipeline/static_analysis/abstract_function.cc b/mindspore/ccsrc/pipeline/static_analysis/abstract_function.cc index 98d9b49a79..ced4a518cb 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/abstract_function.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/abstract_function.cc @@ -177,8 +177,8 @@ std::size_t FuncGraphAbstractClosure::hash() const { std::string FuncGraphAbstractClosure::ToString() const { std::stringstream ss; - ss << "FuncGraphAbstractClosure: " << this << "FuncGraph: " << func_graph_.get() << ", " << func_graph_->ToString() - << "; Context: " << context_.get() << context_->ToString(); + ss << "FuncGraphAbstractClosure: " + << "FuncGraph: " << func_graph_->ToString() << "; Context: " << context_->ToString(); return ss.str(); } diff --git a/mindspore/ccsrc/pipeline/static_analysis/abstract_function.h b/mindspore/ccsrc/pipeline/static_analysis/abstract_function.h index 513b290a9d..9e1cf9ba83 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/abstract_function.h +++ b/mindspore/ccsrc/pipeline/static_analysis/abstract_function.h @@ -166,8 +166,9 @@ class PartialAbstractClosure : public AbstractFuncAtom { public: // Represents a partial application. // args_spec_list: The first few arguments of that function - PartialAbstractClosure(const AbstractFuncAtomPtr &fn, const AbstractBasePtrList &args_spec_list) - : fn_(fn), args_spec_list_(args_spec_list) {} + PartialAbstractClosure(const AbstractFuncAtomPtr &fn, const AbstractBasePtrList &args_spec_list, + const AnfNodePtr &node = nullptr) + : fn_(fn), args_spec_list_(args_spec_list), node_(AnfNodePtr(node)) {} ~PartialAbstractClosure() override = default; MS_DECLARE_PARENT(PartialAbstractClosure, AbstractFuncAtom) @@ -175,7 +176,11 @@ class PartialAbstractClosure : public AbstractFuncAtom { AbstractFunctionPtr fn() { return fn_; } AbstractBasePtrList args() { return args_spec_list_; } - AbstractFunctionPtr Copy() const override { return std::make_shared(fn_, args_spec_list_); } + AnfNodePtr node() { return node_.lock(); } + void set_node(const AnfNodePtr &node) { node_ = AnfNodeWeakPtr(node); } + AbstractFunctionPtr Copy() const override { + return std::make_shared(fn_, args_spec_list_, node_.lock()); + } bool operator==(const AbstractFunction &other) const override; std::size_t hash() const override; @@ -184,6 +189,8 @@ class PartialAbstractClosure : public AbstractFuncAtom { private: AbstractFuncAtomPtr fn_; AbstractBasePtrList args_spec_list_; + // The CNode which this PartialAbstractClosure evaluated from. + AnfNodeWeakPtr node_; }; class JTransformedAbstractClosure : public AbstractFuncAtom { diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/static_analysis/prim.cc index f80a0cdcc2..f8cf9d83bf 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/prim.cc @@ -951,8 +951,19 @@ class PartialEvaluator : public Evaluator { if (args_conf_list.size() == 0) { MS_LOG(EXCEPTION) << "Args size should be greater than 0"; } + MS_EXCEPTION_IF_NULL(out_conf); + MS_EXCEPTION_IF_NULL(out_conf->node()); + auto arg0_value = args_conf_list[0]->GetEvaluatedValue(); AbstractBasePtrList args_spec_list{arg0_value}; + // Func in hypermap(partial(Func, arg0), arg1, arg2) may become Poly Node. + if (arg0_value->isa()) { + auto ret = std::make_shared(arg0_value->GetValueTrack()->cast(), out_conf->node()); + MS_LOG(DEBUG) << "AbstractError for node: " << out_conf->node()->DebugString() + << " as func is: " << arg0_value->ToString(); + (*cache_)[args_spec_list] = ret; + return ret; + } auto func = CheckArg("partial", args_spec_list, 0); // Sometimes, node[0] in out_conf becomes phi0; if (func->isa()) { @@ -962,19 +973,26 @@ class PartialEvaluator : public Evaluator { return HandleDoSignature(engine, do_signature_prim->function(), out_conf); } } - (void)std::transform(args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list), - [](const ConfigPtr &ref) -> AbstractBasePtr { return ref->GetEvaluatedValue(); }); + (void)std::transform(args_conf_list.begin() + 1, args_conf_list.end(), std::back_inserter(args_spec_list), + [](const ConfigPtr &config) -> AbstractBasePtr { return config->GetEvaluatedValue(); }); AbstractBasePtrList args(args_spec_list.begin() + 1, args_spec_list.end()); - AbstractFuncAtomPtrList partialPtrList; - auto build_partial = [args, &partialPtrList](const AbstractFuncAtomPtr &atom_func) { - auto new_func = std::make_shared(atom_func, args); - partialPtrList.push_back(new_func); + auto cnode = out_conf->node()->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (cnode->size() != (args_conf_list.size() + 1)) { + MS_LOG(EXCEPTION) << "Out_conf node: " << cnode->DebugString() + << ", args_conf_list: " << mindspore::ToString(args_conf_list); + } + + AbstractFuncAtomPtrList partial_funcs_list; + auto build_partial = [args, cnode, &partial_funcs_list](const AbstractFuncAtomPtr &atom_func) { + auto new_func = std::make_shared(atom_func, args, cnode); + partial_funcs_list.push_back(new_func); }; func->Visit(build_partial); - auto ret = AbstractFunction::MakeAbstractFunction(partialPtrList); + auto ret = AbstractFunction::MakeAbstractFunction(partial_funcs_list); (*cache_)[args_spec_list] = ret; return ret; } diff --git a/mindspore/ccsrc/pipeline/static_analysis/program_specialize.cc b/mindspore/ccsrc/pipeline/static_analysis/program_specialize.cc index 987c5d1db0..49b1bb3dea 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/program_specialize.cc +++ b/mindspore/ccsrc/pipeline/static_analysis/program_specialize.cc @@ -23,7 +23,9 @@ #include "./common.h" #include "operator/ops.h" #include "operator/composite/do_signature.h" +#include "pipeline/static_analysis/abstract_function.h" #include "utils/graph_utils.h" +#include "utils/log_adapter.h" #include "utils/profile.h" #include "debug/trace.h" @@ -232,6 +234,13 @@ void FuncGraphSpecializer::ProcessNode(const AnfNodePtr &node) { return; } new_node->set_abstract(GetEvaluatedValueWrap(conf)); + if (new_node->isa() && new_node->abstract()->isa()) { + auto partial_abstract = dyn_cast(new_node->abstract()); + if (partial_abstract->node() == node) { + partial_abstract->set_node(new_node); + } + } + MS_LOG(DEBUG) << "Set new_node: " << new_node->ToString() << ", abstract as: " << new_node->abstract()->ToString(); if (node->isa()) { @@ -383,6 +392,56 @@ AnfNodePtr FuncGraphSpecializer::BuildSpecializedNodeInner(const AbstractBasePtr return BuildValueNode(v, abs); } +AnfNodePtr FuncGraphSpecializer::BuildSpecializedParameterNode(const CNodePtr &new_node) { + auto new_inputs = new_node->inputs(); + AnfNodePtr func = new_inputs[0]; + AbstractBasePtr fnval = new_inputs[0]->abstract(); + + AbstractBasePtrList args; + auto backed_fnval = fnval; + if (fnval->isa()) { + auto partial_closure = dyn_cast(fnval); + backed_fnval = partial_closure->fn(); + args = partial_closure->args(); + } + std::transform(new_inputs.cbegin() + 1, new_inputs.cend(), std::back_inserter(args), + [](const AnfNodePtr &inp) { return inp->abstract(); }); + + ScopeGuard scope_guard(new_node->scope()); + + auto specialized_node = BuildSpecializedNode(func, backed_fnval, args); + auto wrapped_node = specialized_node; + if (fnval->isa()) { + auto partial_closure = dyn_cast(fnval); + AnfNodePtrList partial_node_list = {BuildValueNode(prim::kPrimPartial, FromValueInside(prim::kPrimPartial)), + specialized_node}; + auto anf_node = partial_closure->node(); + if (!anf_node->isa()) { + MS_LOG(EXCEPTION) << "Must be cnode, but " << anf_node->DebugString(); + } + auto cnode = anf_node->cast(); + if (cnode->size() != partial_closure->args().size() + 2) { + MS_LOG(EXCEPTION) << "Size of cnode: " << cnode->DebugString() + << " is not equal to 2 added to size of args: " << mindspore::ToString(partial_closure->args()); + } + for (size_t i = 0; i < partial_closure->args().size(); i++) { + auto old_node = cnode->input(i + 2); + auto possibile_value_node = BuildPossibleValueNode(old_node, partial_closure->args()[i]); + if (possibile_value_node != nullptr) { + partial_node_list.push_back(possibile_value_node); + } else { + if (!(old_node->isa() || old_node->isa())) { + MS_LOG(EXCEPTION) << "Old node should be CNode or Parameter, but " << old_node->ToString(); + } + partial_node_list.push_back(old_node); + } + } + wrapped_node = new_node->func_graph()->NewCNode(partial_node_list); + wrapped_node->set_abstract(partial_closure); + } + return wrapped_node; +} + const EvaluatorCacheMapPtr &FuncGraphSpecializer::GetEvalCache(const EvaluatorPtr &eval) { auto cache_iter = evalcaches_.find(eval); if (cache_iter == evalcaches_.end()) { @@ -465,6 +524,11 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) { << new_inputs[i]->DebugString() << ", abstract: " << new_inputs[i]->abstract()->ToString(); } + if (func->isa() && func->func_graph()->has_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER)) { + auto wrapped_node = BuildSpecializedParameterNode(new_node); + new_inputs[0] = wrapped_node; + } + if (CanSpecializeNode(func)) { new_inputs[0] = BuildSpecializedNode(func, fnval, argvals); } @@ -474,16 +538,6 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) { if (CanSpecializeNode(args[i])) { new_inputs[next] = BuildSpecializedNode(args[i], argvals[i], std::vector{}); } - // support for partial(Multitype) which Multitype should not be inferred to POLY. - // after one or more times clone, Multitype metafuncgraph evaluator will specialized to one type only, - // so even with partial parameter, it will specialize to that graph. - // Maybe a better idea should inline graph with partial node first, then it will have full - // parameter list to infer and specialize. - MS_EXCEPTION_IF_NULL(new_inputs[next]); - if (new_inputs[next]->isa() && (GetValueNode(new_inputs[next]) == kPolyNode) && - IsPrimitive(func, prim::kPrimPartial)) { - new_inputs[next] = args[i]; - } i = next; } diff --git a/mindspore/ccsrc/pipeline/static_analysis/program_specialize.h b/mindspore/ccsrc/pipeline/static_analysis/program_specialize.h index ea3a3007d4..e3c2027d41 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/program_specialize.h +++ b/mindspore/ccsrc/pipeline/static_analysis/program_specialize.h @@ -106,6 +106,9 @@ class FuncGraphSpecializer : public std::enable_shared_from_thisnode; it may be a replicated forwared CNode in static analysis or just a diff --git a/tests/ut/python/ops/test_math_ops_check.py b/tests/ut/python/ops/test_math_ops_check.py index be6e5691ea..86e2480631 100755 --- a/tests/ut/python/ops/test_math_ops_check.py +++ b/tests/ut/python/ops/test_math_ops_check.py @@ -87,11 +87,6 @@ class CumSumNet(nn.Cell): raise_set = [ - # one input is scalar, and another is Tensor(float32) - ('TensorAdd0', { - 'block': (P.TensorAdd(), {'exception': TypeError, 'error_keywords': ['TensorAdd']}), - 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], - 'skip': ['backward']}), # input two tensors, but element types are not same ('TensorAdd1', { 'block': (P.TensorAdd(), {'exception': TypeError, 'error_keywords': ['TensorAdd']}), @@ -271,11 +266,6 @@ raise_set = [ 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))], 'skip': ['backward']}), - # one input is scalar, and another is Tensor(float32) - ('Sub0', { - 'block': (P.Sub(), {'exception': TypeError, 'error_keywords': ['Sub']}), - 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], - 'skip': ['backward']}), # input two tensors, but element types are not same ('Sub1', { 'block': (P.Sub(), {'exception': TypeError, 'error_keywords': ['Sub']}), @@ -287,11 +277,6 @@ raise_set = [ 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], 'skip': ['backward']}), - # one input is scalar, and another is Tensor(float32) - ('Mul0', { - 'block': (P.Mul(), {'exception': TypeError, 'error_keywords': ['Mul']}), - 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], - 'skip': ['backward']}), # input two tensors, but element types are not same ('Mul1', { 'block': (P.Mul(), {'exception': TypeError, 'error_keywords': ['Mul']}), @@ -352,11 +337,6 @@ raise_set = [ 'desc_inputs': [5.0], 'skip': ['backward']}), - # one input is scalar, and another is Tensor(float32) - ('Minimum0', { - 'block': (P.Minimum(), {'exception': TypeError, 'error_keywords': ['Minimum']}), - 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], - 'skip': ['backward']}), # input two tensors, but element types are not same ('Minimum1', { 'block': (P.Minimum(), {'exception': TypeError, 'error_keywords': ['Minimum']}), @@ -368,11 +348,6 @@ raise_set = [ 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], 'skip': ['backward']}), - # one input is scalar, and another is Tensor(float32) - ('Maximum0', { - 'block': (P.Maximum(), {'exception': TypeError, 'error_keywords': ['Maximum']}), - 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], - 'skip': ['backward']}), # input two tensors, but element types are not same ('Maximum1', { 'block': (P.Maximum(), {'exception': TypeError, 'error_keywords': ['Maximum']}), @@ -384,11 +359,6 @@ raise_set = [ 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], 'skip': ['backward']}), - # one input is scalar, and another is Tensor(float32) - ('RealDiv0', { - 'block': (P.RealDiv(), {'exception': TypeError, 'error_keywords': ['RealDiv']}), - 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], - 'skip': ['backward']}), # input two tensors, but element types are not same ('RealDiv1', { 'block': (P.RealDiv(), {'exception': TypeError, 'error_keywords': ['RealDiv']}), @@ -400,11 +370,6 @@ raise_set = [ 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], 'skip': ['backward']}), - # one input is scalar, and another is Tensor(float32) - ('Div0', { - 'block': (P.Div(), {'exception': TypeError, 'error_keywords': ['Div']}), - 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], - 'skip': ['backward']}), # input two tensors, but element types are not same ('Div1', { 'block': (P.Div(), {'exception': TypeError, 'error_keywords': ['Div']}), @@ -416,11 +381,6 @@ raise_set = [ 'desc_inputs': [Tensor(np.ones([3, 5]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.float32))], 'skip': ['backward']}), - # one input is scalar, and another is Tensor(float32) - ('FloorDiv0', { - 'block': (P.FloorDiv(), {'exception': TypeError, 'error_keywords': ['FloorDiv']}), - 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], - 'skip': ['backward']}), # input two tensors, but element types are not same ('FloorDiv1', { 'block': (P.FloorDiv(), {'exception': TypeError, 'error_keywords': ['FloorDiv']}), @@ -439,11 +399,6 @@ raise_set = [ 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.int32))], 'skip': ['backward']}), - # one input is scalar, and another is Tensor(float32) - ('FloorMod0', { - 'block': (P.FloorMod(), {'exception': TypeError, 'error_keywords': ['FloorMod']}), - 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], - 'skip': ['backward']}), # input two tensors, but element types are not same ('FloorMod1', { 'block': (P.FloorMod(), {'exception': TypeError, 'error_keywords': ['FloorMod']}), @@ -462,11 +417,6 @@ raise_set = [ 'desc_inputs': [Tensor(np.ones([2, 3]).astype(np.bool_))], 'skip': ['backward']}), - # input is not tensor - ('Equal0', { - 'block': (P.Equal(), {'exception': TypeError, 'error_keywords': ['Equal']}), - 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], - 'skip': ['backward']}), # type of x and y not match ('Equal1', { 'block': (P.Equal(), {'exception': TypeError, 'error_keywords': ['Equal']}), @@ -490,11 +440,6 @@ raise_set = [ 'skip': ['backward']}), # shape of x and y not match - # input is not tensor - ('NotEqual0', { - 'block': (P.NotEqual(), {'exception': TypeError, 'error_keywords': ['NotEqual']}), - 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], - 'skip': ['backward']}), # type of x and y not match ('NotEqual1', { 'block': (P.NotEqual(), {'exception': TypeError, 'error_keywords': ['NotEqual']}), @@ -506,11 +451,6 @@ raise_set = [ 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))], 'skip': ['backward']}), - # input is not tensor - ('Greater0', { - 'block': (P.Greater(), {'exception': TypeError, 'error_keywords': ['Greater']}), - 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], - 'skip': ['backward']}), # type of x and y not match ('Greater1', { 'block': (P.Greater(), {'exception': TypeError, 'error_keywords': ['Greater']}), @@ -522,11 +462,6 @@ raise_set = [ 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))], 'skip': ['backward']}), - # input is not tensor - ('GreaterEqual0', { - 'block': (P.GreaterEqual(), {'exception': TypeError, 'error_keywords': ['GreaterEqual']}), - 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], - 'skip': ['backward']}), # type of x and y not match ('GreaterEqual1', { 'block': (P.GreaterEqual(), {'exception': TypeError, 'error_keywords': ['GreaterEqual']}), @@ -538,11 +473,6 @@ raise_set = [ 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))], 'skip': ['backward']}), - # input is not tensor - ('Less0', { - 'block': (P.Less(), {'exception': TypeError, 'error_keywords': ['Less']}), - 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], - 'skip': ['backward']}), # type of x and y not match ('Less1', { 'block': (P.Less(), {'exception': TypeError, 'error_keywords': ['Less']}), @@ -554,11 +484,6 @@ raise_set = [ 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 2]).astype(np.float32))], 'skip': ['backward']}), - # input is not tensor - ('LessEqual0', { - 'block': (P.LessEqual(), {'exception': TypeError, 'error_keywords': ['LessEqual']}), - 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], - 'skip': ['backward']}), # type of x and y not match ('LessEqual1', { 'block': (P.LessEqual(), {'exception': TypeError, 'error_keywords': ['LessEqual']}), @@ -728,11 +653,6 @@ raise_set = [ 'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.bool_))], 'skip': ['backward']}), - # one input is scalar, and another is Tensor(float32) - ('Atan20', { - 'block': (P.Atan2(), {'exception': TypeError, 'error_keywords': ['Atan2']}), - 'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))], - 'skip': ['backward']}), # input two tensors, but element types are not same ('Atan21', { 'block': (P.Atan2(), {'exception': TypeError, 'error_keywords': ['Atan2']}), diff --git a/tests/ut/python/pipeline/infer/test_hypermap_specialize.py b/tests/ut/python/pipeline/infer/test_hypermap_specialize.py new file mode 100644 index 0000000000..633e696dbe --- /dev/null +++ b/tests/ut/python/pipeline/infer/test_hypermap_specialize.py @@ -0,0 +1,54 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test_hypermap_partial """ +import numpy as np + +import mindspore.nn as nn +from mindspore import Tensor, context +import mindspore.common.dtype as mstype +from mindspore.ops import composite as C +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.common.api import ms_function + +context.set_context(mode=context.GRAPH_MODE) + +def test_hypermap_specialize_param(): + class Net(nn.Cell): + """ Net definition """ + def __init__(self): + super(Net, self).__init__() + self.mul = P.Mul() + + def construct(self, x, y): + ret = self.mul(x, y) + return ret + + factor1 = Tensor(5, dtype=mstype.int32) + x = Tensor(np.ones([1]).astype(np.int32)) + y = Tensor(np.ones([2]).astype(np.int32)) + net = Net() + hypermap = C.HyperMap() + + @ms_function + def hypermap_specialize_param(): + ret1 = hypermap(F.partial(net, factor1), (x, y)) + # List will be converted to Tuple in SimlifyDataStructurePass. + ret2 = hypermap(F.partial(net, factor1), [x, y]) + return ret1, ret2 + + expected_ret = (Tensor(np.full(1, 5).astype(np.int32)), Tensor(np.full(2, 5).astype(np.int32))) + ret = hypermap_specialize_param() + assert(ret == (expected_ret, expected_ret))