|
|
|
@ -96,34 +96,37 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
|
|
|
|
|
MS_LOG(EXCEPTION) << "Primitive node is not valid.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto prim = value_node->value()->cast<PrimitivePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim);
|
|
|
|
|
|
|
|
|
|
auto iter = bprop_registry_.find(prim);
|
|
|
|
|
if (iter != bprop_registry_.end()) {
|
|
|
|
|
return iter->second;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (prim->Hash() == prim::kPrimSwitchLayer->Hash() && prim->name() == "switch_layer") {
|
|
|
|
|
auto prim = GetValueNode<PrimitivePtr>(value_node);
|
|
|
|
|
if (prim->Hash() == prim::kPrimSwitchLayer->Hash() && prim->name() == prim::kPrimSwitchLayer->name()) {
|
|
|
|
|
auto fprop = GetFprop(prim);
|
|
|
|
|
fprop->transforms().emplace("primal", FuncGraphTransform(prim::kPrimSwitchLayer));
|
|
|
|
|
bprop_registry_[prim::kPrimSwitchLayer] = fprop;
|
|
|
|
|
return fprop;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
|
|
|
|
|
} else if (prim->Hash() == prim::kPrimMakeTuple->Hash() && prim->name() == prim::kPrimMakeTuple->name()) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool is_faked_bprop = false;
|
|
|
|
|
FuncGraphPtr bprop_fg = nullptr;
|
|
|
|
|
if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == "HookBackward") {
|
|
|
|
|
bprop_fg = BpropCut(value_node, resources);
|
|
|
|
|
} else {
|
|
|
|
|
bprop_fg = GetBprop(prim);
|
|
|
|
|
if (bprop_fg == nullptr) {
|
|
|
|
|
bprop_fg = FakeBprop(value_node, resources);
|
|
|
|
|
is_faked_bprop = true;
|
|
|
|
|
auto iter = bprop_registry_.find(prim);
|
|
|
|
|
if (iter != bprop_registry_.end()) {
|
|
|
|
|
bprop_fg = iter->second;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (bprop_fg == nullptr) {
|
|
|
|
|
bool is_faked_bprop = false;
|
|
|
|
|
if (prim->Hash() == prim::kPrimHookBackward->Hash() && prim->name() == prim::kPrimHookBackward->name()) {
|
|
|
|
|
bprop_fg = BpropCut(value_node, resources);
|
|
|
|
|
} else {
|
|
|
|
|
bprop_fg = GetBprop(prim);
|
|
|
|
|
if (bprop_fg == nullptr) {
|
|
|
|
|
bprop_fg = FakeBprop(value_node, resources);
|
|
|
|
|
is_faked_bprop = true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// To support primitives with variable params, do not cache faked bprop
|
|
|
|
|
if (!is_faked_bprop) {
|
|
|
|
|
// Set bprop_g graph cache
|
|
|
|
|
bprop_registry_[prim] = bprop_fg;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -134,11 +137,6 @@ FuncGraphPtr KPrim::KPrimitive(const ValueNodePtr &value_node, const pipeline::R
|
|
|
|
|
<< trace::GetDebugInfo(bprop_fg->debug_info());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// To support primitives with variable params, do not cache faked bprop
|
|
|
|
|
if (!is_faked_bprop) {
|
|
|
|
|
// Set bprop_g graph cache
|
|
|
|
|
bprop_registry_[prim] = expanded_fg;
|
|
|
|
|
}
|
|
|
|
|
return expanded_fg;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|