|
|
|
|
@ -525,105 +525,88 @@ GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
FuncGraphPtr GradOperation::GetGrad(AnfNodePtr node, const AnfNodePtr &weights,
|
|
|
|
|
const std::vector<AnfNodePtr> ¶ms_list, const std::vector<AnfNodePtr> &args,
|
|
|
|
|
bool applyJ) {
|
|
|
|
|
FuncGraphPtr ret = std::make_shared<FuncGraph>();
|
|
|
|
|
ret->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
|
|
|
|
FuncGraphPtr GradOperation::GetGrad(const AnfNodePtr &k, const AnfNodePtr &weights,
|
|
|
|
|
const std::vector<AnfNodePtr> &forward_graph_params,
|
|
|
|
|
const std::vector<AnfNodePtr> &weight_args) {
|
|
|
|
|
FuncGraphPtr k_child = std::make_shared<FuncGraph>();
|
|
|
|
|
k_child->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
|
|
|
|
|
|
|
|
|
auto weights_node = weights;
|
|
|
|
|
if (weights == nullptr && !args.empty()) {
|
|
|
|
|
weights_node = ret->NewCNode(args);
|
|
|
|
|
AnfNodePtr weights_node = nullptr;
|
|
|
|
|
if (weights != nullptr) {
|
|
|
|
|
weights_node = weights;
|
|
|
|
|
} else if (!weight_args.empty()) {
|
|
|
|
|
weights_node = k_child->NewCNode(weight_args);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ValueNodePtr opsJ = NewValueNode(prim::kPrimJ);
|
|
|
|
|
ValueNodePtr opsTupleItem = NewValueNode(prim::kPrimTupleGetItem);
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> inputs;
|
|
|
|
|
if (applyJ) {
|
|
|
|
|
inputs.push_back(opsJ);
|
|
|
|
|
inputs.push_back(node);
|
|
|
|
|
node = ret->NewCNode(inputs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> params;
|
|
|
|
|
for (size_t i = 0; i < params_list.size(); ++i) {
|
|
|
|
|
params.push_back(ret->add_parameter());
|
|
|
|
|
inputs.push_back(k);
|
|
|
|
|
for (size_t i = 0; i < forward_graph_params.size(); ++i) {
|
|
|
|
|
inputs.push_back(k_child->add_parameter());
|
|
|
|
|
}
|
|
|
|
|
auto k_app = k_child->NewCNode(inputs);
|
|
|
|
|
|
|
|
|
|
inputs.clear();
|
|
|
|
|
inputs.push_back(node);
|
|
|
|
|
(void)std::copy(params.begin(), params.end(), std::back_inserter(inputs));
|
|
|
|
|
AnfNodePtr cnode = ret->NewCNode(inputs);
|
|
|
|
|
|
|
|
|
|
inputs.clear();
|
|
|
|
|
inputs.push_back(opsTupleItem);
|
|
|
|
|
inputs.push_back(cnode);
|
|
|
|
|
inputs.push_back(NewValueNode(static_cast<int64_t>(0)));
|
|
|
|
|
auto out = ret->NewCNode(inputs);
|
|
|
|
|
|
|
|
|
|
inputs.clear();
|
|
|
|
|
inputs.push_back(opsTupleItem);
|
|
|
|
|
inputs.push_back(cnode);
|
|
|
|
|
inputs.push_back(NewValueNode(static_cast<int64_t>(1)));
|
|
|
|
|
AnfNodePtr ptr_bprop = ret->NewCNode(inputs);
|
|
|
|
|
auto tuple_get_item = NewValueNode(prim::kPrimTupleGetItem);
|
|
|
|
|
auto f_app = k_child->NewCNode({tuple_get_item, k_app, NewValueNode(static_cast<int64_t>(0))});
|
|
|
|
|
auto bprop = k_child->NewCNode({tuple_get_item, k_app, NewValueNode(static_cast<int64_t>(1))});
|
|
|
|
|
|
|
|
|
|
doGetGrad(ret, out, ptr_bprop, weights_node, opsTupleItem);
|
|
|
|
|
return ret;
|
|
|
|
|
GradByParameter(k_child, f_app, bprop, weights_node);
|
|
|
|
|
return k_child;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void GradOperation::doGetGrad(const FuncGraphPtr &func_graph, AnfNodePtr out, AnfNodePtr ptr_bprop, AnfNodePtr weights,
|
|
|
|
|
ValueNodePtr opsTupleItem) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(func_graph);
|
|
|
|
|
// Do grad by the parameter of GradOperation.
|
|
|
|
|
void GradOperation::GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop,
|
|
|
|
|
const AnfNodePtr &weights) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(k_child);
|
|
|
|
|
|
|
|
|
|
AnfNodePtr ptr_bprop_arg = nullptr;
|
|
|
|
|
AnfNodePtr bprop_arg = nullptr;
|
|
|
|
|
if (sens_param_) {
|
|
|
|
|
ptr_bprop_arg = func_graph->add_parameter();
|
|
|
|
|
bprop_arg = k_child->add_parameter();
|
|
|
|
|
} else {
|
|
|
|
|
auto ones_like = prim::GetPythonOps("ones_like");
|
|
|
|
|
ptr_bprop_arg = func_graph->NewCNode({NewValueNode(ones_like), out});
|
|
|
|
|
bprop_arg = k_child->NewCNode({NewValueNode(ones_like), f_app});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr ptr_bapp = func_graph->NewCNode({ptr_bprop, ptr_bprop_arg});
|
|
|
|
|
AnfNodePtr b_app = k_child->NewCNode({bprop, bprop_arg});
|
|
|
|
|
|
|
|
|
|
CNodePtr fv_bprop = nullptr;
|
|
|
|
|
if (get_by_list_) {
|
|
|
|
|
// python code: grads = hyper_map(F.partial(env_get, env), weights)
|
|
|
|
|
AnfNodePtr env =
|
|
|
|
|
func_graph->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ptr_bapp, NewValueNode(static_cast<int64_t>(0))});
|
|
|
|
|
k_child->NewCNode({NewValueNode(prim::kPrimTupleGetItem), b_app, NewValueNode(static_cast<int64_t>(0))});
|
|
|
|
|
AnfNodePtr partial_env_get =
|
|
|
|
|
func_graph->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(prim::GetPythonOps("env_get")), env});
|
|
|
|
|
k_child->NewCNode({NewValueNode(prim::kPrimPartial), NewValueNode(prim::GetPythonOps("env_get")), env});
|
|
|
|
|
MetaFuncGraphPtr hyper_map = std::make_shared<HyperMap>();
|
|
|
|
|
fv_bprop = func_graph->NewCNode({NewValueNode(hyper_map), partial_env_get, weights});
|
|
|
|
|
fv_bprop = k_child->NewCNode({NewValueNode(hyper_map), partial_env_get, weights});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CNodePtr inputs_bprop = nullptr;
|
|
|
|
|
if (get_all_) {
|
|
|
|
|
TailPtr tail = std::make_shared<Tail>("tail", true);
|
|
|
|
|
inputs_bprop = func_graph->NewCNode({NewValueNode(tail), ptr_bapp});
|
|
|
|
|
inputs_bprop = k_child->NewCNode({NewValueNode(tail), b_app});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Gradients wrt inputs and parameters
|
|
|
|
|
if (fv_bprop != nullptr && inputs_bprop != nullptr) {
|
|
|
|
|
func_graph->set_output(func_graph->NewCNode({NewValueNode(kPrimMakeTuple), inputs_bprop, fv_bprop}));
|
|
|
|
|
k_child->set_output(k_child->NewCNode({NewValueNode(kPrimMakeTuple), inputs_bprop, fv_bprop}));
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Gradients wrt parameters
|
|
|
|
|
if (fv_bprop != nullptr) {
|
|
|
|
|
func_graph->set_output(fv_bprop);
|
|
|
|
|
k_child->set_output(fv_bprop);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Gradients wrt inputs
|
|
|
|
|
if (inputs_bprop != nullptr) {
|
|
|
|
|
func_graph->set_output(inputs_bprop);
|
|
|
|
|
k_child->set_output(inputs_bprop);
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Gradients wrt first input.
|
|
|
|
|
// ptr_bapp returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input
|
|
|
|
|
func_graph->set_output(func_graph->NewCNode({opsTupleItem, ptr_bapp, NewValueNode(static_cast<int64_t>(1))}));
|
|
|
|
|
// b_app returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...), so 1 is for first input
|
|
|
|
|
k_child->set_output(
|
|
|
|
|
k_child->NewCNode({NewValueNode(prim::kPrimTupleGetItem), b_app, NewValueNode(static_cast<int64_t>(1))}));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Generate the graph.
|
|
|
|
|
@ -643,39 +626,39 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
|
|
|
|
|
auto real_fn = dyn_cast<FuncGraphAbstractClosure>(fn);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(real_fn);
|
|
|
|
|
|
|
|
|
|
FuncGraphPtr ptr_graph = real_fn->func_graph();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(ptr_graph);
|
|
|
|
|
FuncGraphPtr df_builder = nullptr;
|
|
|
|
|
FuncGraphPtr forward_graph = real_fn->func_graph();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(forward_graph);
|
|
|
|
|
FuncGraphPtr grad_fg = nullptr;
|
|
|
|
|
{
|
|
|
|
|
TraceGuard g(std::make_shared<TraceGradOperation>(ptr_graph->debug_info()));
|
|
|
|
|
df_builder = std::make_shared<FuncGraph>();
|
|
|
|
|
TraceGuard g(std::make_shared<TraceGradOperation>(forward_graph->debug_info()));
|
|
|
|
|
grad_fg = std::make_shared<FuncGraph>();
|
|
|
|
|
}
|
|
|
|
|
auto nparam = ptr_graph->parameters().size();
|
|
|
|
|
auto nparam = forward_graph->parameters().size();
|
|
|
|
|
|
|
|
|
|
std::ostringstream ss;
|
|
|
|
|
ss << "grad{" << nparam << "}";
|
|
|
|
|
df_builder->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
|
|
|
|
df_builder->debug_info()->set_name(ss.str());
|
|
|
|
|
ParameterPtr param_graph = df_builder->add_parameter();
|
|
|
|
|
grad_fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
|
|
|
|
grad_fg->debug_info()->set_name(ss.str());
|
|
|
|
|
ParameterPtr param_graph = grad_fg->add_parameter();
|
|
|
|
|
|
|
|
|
|
AnfNodePtr weights = nullptr;
|
|
|
|
|
if (get_by_list_) {
|
|
|
|
|
weights = df_builder->add_parameter();
|
|
|
|
|
weights = grad_fg->add_parameter();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> inputs;
|
|
|
|
|
inputs.push_back(NewValueNode(prim::kPrimJ));
|
|
|
|
|
inputs.push_back(param_graph);
|
|
|
|
|
auto jf = df_builder->NewCNode(inputs);
|
|
|
|
|
auto j = grad_fg->NewCNode(inputs);
|
|
|
|
|
// df is checked in GetGrad
|
|
|
|
|
FuncGraphPtr df = nullptr;
|
|
|
|
|
FuncGraphPtr k_child = nullptr;
|
|
|
|
|
{
|
|
|
|
|
TraceGuard guard(std::make_shared<TraceGradOperation>(ptr_graph->debug_info()));
|
|
|
|
|
df = GetGrad(jf, weights, ptr_graph->parameters());
|
|
|
|
|
TraceGuard guard(std::make_shared<TraceGradOperation>(forward_graph->debug_info()));
|
|
|
|
|
k_child = GetGrad(j, weights, forward_graph->parameters());
|
|
|
|
|
}
|
|
|
|
|
df_builder->set_output(NewValueNode(df));
|
|
|
|
|
grad_fg->set_output(NewValueNode(k_child));
|
|
|
|
|
|
|
|
|
|
return df_builder;
|
|
|
|
|
return grad_fg;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
REGISTER_PYBIND_DEFINE(GradOperation_, ([](const py::module *m) {
|
|
|
|
|
|