|
|
|
|
@ -81,8 +81,10 @@ class DFunctor : public std::enable_shared_from_this<DFunctor> {
|
|
|
|
|
void BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app, const AdjointPtr &node_adjoint);
|
|
|
|
|
AnfNodePtr AttachFvDoutToTape(const AnfNodePtr &grad_fv);
|
|
|
|
|
AnfNodePtr AttachIndirectFvDoutToTape(const AnfNodePtr &grad_fv);
|
|
|
|
|
// Map Anfnode object from D category to K category.
|
|
|
|
|
// Map AnfNode object from D category to K category.
|
|
|
|
|
AnfNodePtr MapToK(const AnfNodePtr &primal);
|
|
|
|
|
// Map CNode object from D category to K category.
|
|
|
|
|
AnfNodePtr MapToK(const CNodePtr &primal_user, size_t index);
|
|
|
|
|
// Map FuncGraph object from D category to K category.
|
|
|
|
|
AnfNodePtr MapToK(const FuncGraphPtr &primal);
|
|
|
|
|
// MapObject impls.
|
|
|
|
|
@ -129,7 +131,8 @@ class KPrim {
|
|
|
|
|
KPrim() = default;
|
|
|
|
|
~KPrim() = default;
|
|
|
|
|
|
|
|
|
|
FuncGraphPtr KPrimitive(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources);
|
|
|
|
|
FuncGraphPtr KPrimitive(const CNodePtr &primal_user, const ValueNodePtr &value_node,
|
|
|
|
|
const pipeline::ResourceBasePtr &resources);
|
|
|
|
|
MetaFuncGraphPtr KMetaFuncGraph(const PrimitivePtr &prim);
|
|
|
|
|
FuncGraphPtr KUserDefinedCellBprop(FuncGraphPtr bprop);
|
|
|
|
|
|
|
|
|
|
@ -145,7 +148,7 @@ class KPrim {
|
|
|
|
|
FuncGraphPtr BpropCut(const ValueNodePtr &value_node, const pipeline::ResourceBasePtr &resources);
|
|
|
|
|
// Given a bprop rule, do the K mapping.
|
|
|
|
|
template <typename T>
|
|
|
|
|
FuncGraphPtr BpropToK(const T &primal, const FuncGraphPtr &bprop_g);
|
|
|
|
|
FuncGraphPtr BpropToK(const T &primal, const FuncGraphPtr &bprop_g, const CNodePtr &cnode);
|
|
|
|
|
AnfNodePtr BuildOutput(const FuncGraphPtr &bprop_fg);
|
|
|
|
|
void TransformArgs(const FuncGraphManagerPtr &mng, const FuncGraphPtr &bprop_fg, const FuncGraphPtr &outer,
|
|
|
|
|
std::vector<AnfNodePtr> *const transf_args);
|
|
|
|
|
@ -156,7 +159,7 @@ class KPrim {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) {
|
|
|
|
|
FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg, const CNodePtr &cnode) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(primal);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(bprop_fg);
|
|
|
|
|
CheckBprop(bprop_fg, primal->ToString());
|
|
|
|
|
@ -197,8 +200,13 @@ FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) {
|
|
|
|
|
TransformArgs(mng, cloned_bprop_fg, outer, &transf_args);
|
|
|
|
|
|
|
|
|
|
(void)transf_args.insert(transf_args.begin(), NewValueNode(primal));
|
|
|
|
|
auto out_value = outer->NewCNode(transf_args);
|
|
|
|
|
|
|
|
|
|
CNodePtr out_value = nullptr;
|
|
|
|
|
if (cnode != nullptr) { // Set equiv debug info. for Primitive CNode out.
|
|
|
|
|
TraceGuard trace_guard(std::make_shared<TraceEquiv>(cnode->debug_info()));
|
|
|
|
|
out_value = outer->NewCNode(transf_args);
|
|
|
|
|
} else {
|
|
|
|
|
out_value = outer->NewCNode(transf_args);
|
|
|
|
|
}
|
|
|
|
|
(void)mng->Replace(out_param, out_value);
|
|
|
|
|
|
|
|
|
|
TraceGuard guard(std::make_shared<TraceGradSens>(out_param->debug_info()));
|
|
|
|
|
@ -207,7 +215,6 @@ FuncGraphPtr KPrim::BpropToK(const T &primal, const FuncGraphPtr &bprop_fg) {
|
|
|
|
|
// We remove all parameters except new_dout.
|
|
|
|
|
std::vector<AnfNodePtr> newBpropParams = {new_dout};
|
|
|
|
|
cloned_bprop_fg->set_parameters(newBpropParams);
|
|
|
|
|
|
|
|
|
|
outer->set_output(outer->NewCNode({NewValueNode(prim::kPrimMakeTuple), out_value, NewValueNode(cloned_bprop_fg)}));
|
|
|
|
|
return BasicClone(outer);
|
|
|
|
|
}
|
|
|
|
|
|