|
|
|
@ -333,28 +333,28 @@ ArgsPairList HyperMap::Harmonize(const FuncGraphPtr &func_graph, const ArgsPairL
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
FuncGraphPtr HyperMap::GenerateFromTypes(const TypePtrList &args_spec_list) {
|
|
|
|
|
FuncGraphPtr ptrGraph = std::make_shared<FuncGraph>();
|
|
|
|
|
ptrGraph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
|
|
|
|
ptrGraph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
|
|
|
|
|
ptrGraph->debug_info()->set_name("hyper_map");
|
|
|
|
|
FuncGraphPtr ptr_graph = std::make_shared<FuncGraph>();
|
|
|
|
|
ptr_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
|
|
|
|
ptr_graph->set_flag(FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER, true);
|
|
|
|
|
ptr_graph->debug_info()->set_name("hyper_map");
|
|
|
|
|
|
|
|
|
|
AnfNodePtr ptrFnArg = nullptr;
|
|
|
|
|
std::size_t i = 0;
|
|
|
|
|
ArgsPairList argmap;
|
|
|
|
|
ArgsPairList argmap2;
|
|
|
|
|
if (fn_leaf_ == nullptr) {
|
|
|
|
|
ptrFnArg = ptrGraph->add_parameter();
|
|
|
|
|
ptrFnArg = ptr_graph->add_parameter();
|
|
|
|
|
i = 1;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::size_t size = args_spec_list.size();
|
|
|
|
|
for (; i < size; ++i) {
|
|
|
|
|
argmap.push_back(std::make_pair(ptrGraph->add_parameter(), args_spec_list[i]));
|
|
|
|
|
argmap.push_back(std::make_pair(ptr_graph->add_parameter(), args_spec_list[i]));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
argmap2 = Harmonize(ptrGraph, argmap);
|
|
|
|
|
ptrGraph->set_output(Make(ptrGraph, ptrFnArg, argmap2));
|
|
|
|
|
return ptrGraph;
|
|
|
|
|
argmap2 = Harmonize(ptr_graph, argmap);
|
|
|
|
|
ptr_graph->set_output(Make(ptr_graph, ptrFnArg, argmap2));
|
|
|
|
|
return ptr_graph;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList &args_spec_list) const {
|
|
|
|
@ -582,30 +582,30 @@ FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights,
|
|
|
|
|
inputs.push_back(opsTupleItem);
|
|
|
|
|
inputs.push_back(cnode);
|
|
|
|
|
inputs.push_back(NewValueNode(1));
|
|
|
|
|
AnfNodePtr ptrBprop = ret->NewCNode(inputs);
|
|
|
|
|
AnfNodePtr ptr_bprop = ret->NewCNode(inputs);
|
|
|
|
|
|
|
|
|
|
doGetGrad(ret, out, ptrBprop, weights_node, opsTupleItem);
|
|
|
|
|
doGetGrad(ret, out, ptr_bprop, weights_node, opsTupleItem);
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptrBprop, AnfNodePtr weights,
|
|
|
|
|
void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptr_bprop, AnfNodePtr weights,
|
|
|
|
|
ValueNodePtr opsTupleItem) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
|
|
|
|
|
|
AnfNodePtr ptrBPropArg = nullptr;
|
|
|
|
|
AnfNodePtr ptr_bprop_arg = nullptr;
|
|
|
|
|
if (sens_param_) {
|
|
|
|
|
ptrBPropArg = func_graph->add_parameter();
|
|
|
|
|
ptr_bprop_arg = func_graph->add_parameter();
|
|
|
|
|
} else {
|
|
|
|
|
auto ones_like = prim::GetPythonOps("ones_like");
|
|
|
|
|
ptrBPropArg = func_graph->NewCNode({NewValueNode(ones_like), out});
|
|
|
|
|
ptr_bprop_arg = func_graph->NewCNode({NewValueNode(ones_like), out});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr ptrBApp = func_graph->NewCNode({ptrBprop, ptrBPropArg});
|
|
|
|
|
AnfNodePtr ptr_bapp = func_graph->NewCNode({ptr_bprop, ptr_bprop_arg});
|
|
|
|
|
|
|
|
|
|
CNodePtr fv_bprop = nullptr;
|
|
|
|
|
if (get_by_list_) {
|
|
|
|
|
// python code: grads = hyper_map(F.partial(env_get, env), weights)
|
|
|
|
|
AnfNodePtr env = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptrBApp, NewValueNode(0)});
|
|
|
|
|
AnfNodePtr env = func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptr_bapp, NewValueNode(0)});
|
|
|
|
|
AnfNodePtr partial_env_get =
|
|
|
|
|
func_graph->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(prim::GetPythonOps("env_get")), env});
|
|
|
|
|
MetaFuncGraphPtr hyper_map = std::make_shared<HyperMap>();
|
|
|
|
@ -614,7 +614,7 @@ void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, An
|
|
|
|
|
|
|
|
|
|
CNodePtr inputs_bprop = nullptr;
|
|
|
|
|
if (get_all_) {
|
|
|
|
|
inputs_bprop = func_graph->NewCNode({NewValueNode(kTail), ptrBApp});
|
|
|
|
|
inputs_bprop = func_graph->NewCNode({NewValueNode(kTail), ptr_bapp});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Gradients wrt inputs and parameters
|
|
|
|
@ -636,8 +636,8 @@ void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, An
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Gradients wrt first input.
|
|
|
|
|
// ptrBApp returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input
|
|
|
|
|
func_graph->set_output(func_graph->NewCNode({opsTupleItem, ptrBApp, NewValueNode(1)}));
|
|
|
|
|
// ptr_bapp returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input
|
|
|
|
|
func_graph->set_output(func_graph->NewCNode({opsTupleItem, ptr_bapp, NewValueNode(1)}));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Generate the graph.
|
|
|
|
@ -657,35 +657,35 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
|
|
|
|
|
auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(real_fn);
|
|
|
|
|
|
|
|
|
|
FuncGraphPtr ptrGraph = real_fn->func_graph();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(ptrGraph);
|
|
|
|
|
TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptrGraph->debug_info()));
|
|
|
|
|
FuncGraphPtr dfBuilder = std::make_shared<FuncGraph>();
|
|
|
|
|
FuncGraphPtr ptr_graph = real_fn->func_graph();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(ptr_graph);
|
|
|
|
|
TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptr_graph->debug_info()));
|
|
|
|
|
FuncGraphPtr df_builder = std::make_shared<FuncGraph>();
|
|
|
|
|
TraceManager::EndTrace();
|
|
|
|
|
auto nparam = ptrGraph->parameters().size();
|
|
|
|
|
auto nparam = ptr_graph->parameters().size();
|
|
|
|
|
|
|
|
|
|
std::ostringstream ss;
|
|
|
|
|
ss << "grad{" << nparam << "}";
|
|
|
|
|
dfBuilder->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
|
|
|
|
dfBuilder->debug_info()->set_name(ss.str());
|
|
|
|
|
ParameterPtr param_graph = dfBuilder->add_parameter();
|
|
|
|
|
df_builder->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
|
|
|
|
df_builder->debug_info()->set_name(ss.str());
|
|
|
|
|
ParameterPtr param_graph = df_builder->add_parameter();
|
|
|
|
|
|
|
|
|
|
AnfNodePtr weights = nullptr;
|
|
|
|
|
if (get_by_list_) {
|
|
|
|
|
weights = dfBuilder->add_parameter();
|
|
|
|
|
weights = df_builder->add_parameter();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> inputs;
|
|
|
|
|
inputs.push_back(NewValueNode(prim::kPrimJ));
|
|
|
|
|
inputs.push_back(param_graph);
|
|
|
|
|
auto jf = dfBuilder->NewCNode(inputs);
|
|
|
|
|
auto jf = df_builder->NewCNode(inputs);
|
|
|
|
|
// df is checked in GetGrad
|
|
|
|
|
TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptrGraph->debug_info()));
|
|
|
|
|
auto df = GetGrad(jf, weights, ptrGraph->parameters());
|
|
|
|
|
TraceManager::DebugTrace(std::make_shared<TraceGradOperation>(ptr_graph->debug_info()));
|
|
|
|
|
auto df = GetGrad(jf, weights, ptr_graph->parameters());
|
|
|
|
|
TraceManager::EndTrace();
|
|
|
|
|
dfBuilder->set_output(NewValueNode(df));
|
|
|
|
|
df_builder->set_output(NewValueNode(df));
|
|
|
|
|
|
|
|
|
|
return dfBuilder;
|
|
|
|
|
return df_builder;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) {
|
|
|
|
|