|
|
|
@ -42,7 +42,8 @@
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace abstract {
|
|
|
|
|
PrimitiveEvalImplMap PrimitiveToInferImplMap = {
|
|
|
|
|
PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|
|
|
|
static PrimitiveEvalImplMap prim_eval_implement_map = {
|
|
|
|
|
// Statements
|
|
|
|
|
{prim::kPrimReturn, {InferImplReturn, true}},
|
|
|
|
|
{prim::kPrimTypeOf, {InferImplTypeof, false}},
|
|
|
|
@ -127,7 +128,9 @@ PrimitiveEvalImplMap PrimitiveToInferImplMap = {
|
|
|
|
|
{prim::kPrimScalarSummary, {InferImplScalarSummary, true}},
|
|
|
|
|
{prim::kPrimImageSummary, {InferImplTensorSummary, true}},
|
|
|
|
|
{prim::kPrimTensorSummary, {InferImplTensorSummary, true}},
|
|
|
|
|
};
|
|
|
|
|
};
|
|
|
|
|
return prim_eval_implement_map;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
using mindspore::parse::PyObjectWrapper;
|
|
|
|
|
|
|
|
|
@ -961,10 +964,7 @@ class PartialEvaluator : public Evaluator {
|
|
|
|
|
new_nodes_inputs[1] = NewValueNode(new_signature_value);
|
|
|
|
|
FuncGraphPtr func_graph = cnode->func_graph();
|
|
|
|
|
|
|
|
|
|
ScopePtr scope = kDefaultScope;
|
|
|
|
|
if (out_conf != nullptr) {
|
|
|
|
|
scope = out_conf->node()->scope();
|
|
|
|
|
}
|
|
|
|
|
ScopePtr scope = out_conf->node()->scope();
|
|
|
|
|
ScopeGuard scope_guard(scope);
|
|
|
|
|
|
|
|
|
|
CNodePtr new_cnode = func_graph->NewCNode(new_nodes_inputs);
|
|
|
|
@ -981,8 +981,8 @@ struct PrimitiveImplInferValue {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
using PrimitiveToImplMap = std::unordered_map<PrimitivePtr, PrimitiveImplInferValue, PrimitiveHasher, PrimitiveEqual>;
|
|
|
|
|
|
|
|
|
|
PrimitiveToImplMap UniformPrimitiveToImplMapValue = {
|
|
|
|
|
PrimitiveToImplMap &GetUniformPrimitiveToImplMap() {
|
|
|
|
|
static PrimitiveToImplMap uniform_prim_implement_map = {
|
|
|
|
|
{prim::kPrimScalarAdd, {prim::ScalarAdd, true, nullptr, true}},
|
|
|
|
|
{prim::kPrimScalarSub, {prim::ScalarSub, true, nullptr, true}},
|
|
|
|
|
{prim::kPrimScalarMul, {prim::ScalarMul, true, nullptr, true}},
|
|
|
|
@ -1001,19 +1001,21 @@ PrimitiveToImplMap UniformPrimitiveToImplMapValue = {
|
|
|
|
|
{prim::kPrimBoolAnd, {prim::BoolAnd, true, std::make_shared<Bool>(), true}},
|
|
|
|
|
{prim::kPrimBoolEq, {prim::BoolEq, true, std::make_shared<Bool>(), true}},
|
|
|
|
|
{prim::kPrimBoolOr, {prim::BoolOr, true, std::make_shared<Bool>(), true}},
|
|
|
|
|
};
|
|
|
|
|
};
|
|
|
|
|
return uniform_prim_implement_map;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PrimEvaluatorMap PrimEvaluatorConstructors = PrimEvaluatorMap();
|
|
|
|
|
std::mutex PrimEvaluatorConstructorMutex;
|
|
|
|
|
|
|
|
|
|
void InitPrimEvaluatorConstructors(const PrimitiveEvalImplMap &prim_eval_impl_map) {
|
|
|
|
|
void InitPrimEvaluatorConstructors() {
|
|
|
|
|
PrimEvaluatorMap &constructor = PrimEvaluatorConstructors;
|
|
|
|
|
|
|
|
|
|
for (const auto &iter : prim_eval_impl_map) {
|
|
|
|
|
for (const auto &iter : GetPrimitiveToEvalImplMap()) {
|
|
|
|
|
constructor[iter.first] = InitStandardPrimEvaluator(iter.first, iter.second.impl_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (const auto &iter : UniformPrimitiveToImplMapValue) {
|
|
|
|
|
for (const auto &iter : GetUniformPrimitiveToImplMap()) {
|
|
|
|
|
constructor[iter.first] =
|
|
|
|
|
InitUniformPrimEvaluator(iter.first, iter.second.impl_, iter.second.eval_value_, iter.second.specify_out_type_);
|
|
|
|
|
}
|
|
|
|
@ -1028,20 +1030,20 @@ void InitPrimEvaluatorConstructors(const PrimitiveEvalImplMap &prim_eval_impl_ma
|
|
|
|
|
|
|
|
|
|
void ClearPrimEvaluatorMap() {
|
|
|
|
|
PrimEvaluatorConstructors.clear();
|
|
|
|
|
PrimitiveToInferImplMap.clear();
|
|
|
|
|
UniformPrimitiveToImplMapValue.clear();
|
|
|
|
|
GetPrimitiveToEvalImplMap().clear();
|
|
|
|
|
GetUniformPrimitiveToImplMap().clear();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool IsInWhiteList(const PrimitivePtr primitive) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(primitive);
|
|
|
|
|
|
|
|
|
|
auto iter = PrimitiveToInferImplMap.find(primitive);
|
|
|
|
|
if (iter != PrimitiveToInferImplMap.end()) {
|
|
|
|
|
auto iter = GetPrimitiveToEvalImplMap().find(primitive);
|
|
|
|
|
if (iter != GetPrimitiveToEvalImplMap().end()) {
|
|
|
|
|
return iter->second.in_white_list_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto uni_iter = UniformPrimitiveToImplMapValue.find(primitive);
|
|
|
|
|
if (uni_iter != UniformPrimitiveToImplMapValue.end()) {
|
|
|
|
|
auto uni_iter = GetUniformPrimitiveToImplMap().find(primitive);
|
|
|
|
|
if (uni_iter != GetUniformPrimitiveToImplMap().end()) {
|
|
|
|
|
return uni_iter->second.in_white_list_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -1050,8 +1052,8 @@ bool IsInWhiteList(const PrimitivePtr primitive) {
|
|
|
|
|
|
|
|
|
|
StandardPrimitiveEvalImpl GetPrimitiveInferImpl(const PrimitivePtr &primitive) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(primitive);
|
|
|
|
|
auto iter = PrimitiveToInferImplMap.find(primitive);
|
|
|
|
|
if (iter == PrimitiveToInferImplMap.end()) {
|
|
|
|
|
auto iter = GetPrimitiveToEvalImplMap().find(primitive);
|
|
|
|
|
if (iter == GetPrimitiveToEvalImplMap().end()) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
return iter->second.impl_;
|
|
|
|
@ -1064,7 +1066,7 @@ PrimEvaluatorMap &GetPrimEvaluatorConstructors() {
|
|
|
|
|
}
|
|
|
|
|
std::lock_guard<std::mutex> initLock(PrimEvaluatorConstructorMutex);
|
|
|
|
|
if (constructor.empty()) {
|
|
|
|
|
InitPrimEvaluatorConstructors(PrimitiveToInferImplMap);
|
|
|
|
|
InitPrimEvaluatorConstructors();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return constructor;
|
|
|
|
|