Fix bprop describe

Signed-off-by: zjun <zhangjun0@huawei.com>
pull/10348/head
zjun 4 years ago
parent 32dc10a735
commit 5ce00d4de2

@ -52,6 +52,7 @@ FuncGraphPtr ConvertToBpropCut(const py::object &obj) {
outputs.push_back(NewValueNode(fake_bprop));
py::object code_obj = py::getattr(bprop_func, "__code__");
// Three parameters self, out and dout need to be excluded
size_t inputs_num = py::cast<int64_t>(py::getattr(code_obj, "co_argcount")) - 3;
for (size_t i = 0; i < inputs_num; ++i) {
auto param = bprop_graph->add_parameter();

@ -2488,7 +2488,7 @@ void PynativeExecutor::MakeNestedCnode(const std::string &cell_id, const py::arg
inputs.emplace_back(GetInput(args[i], false));
}
if (newfg->parameters().size() > inputs_size) {
SetNestedWeigthsParam(newfg, cell_id, &inputs);
SetNestedWeightsParam(newfg, cell_id, &inputs);
}
auto out_id = GetId(out);
auto cnode = graph_prev->NewCNode(inputs);
@ -2497,7 +2497,7 @@ void PynativeExecutor::MakeNestedCnode(const std::string &cell_id, const py::arg
MS_LOG(DEBUG) << "Nested make cnode is " << cnode->DebugString(4);
}
void PynativeExecutor::SetNestedWeigthsParam(const FuncGraphPtr &newfg, const std::string &cell_id,
void PynativeExecutor::SetNestedWeightsParam(const FuncGraphPtr &newfg, const std::string &cell_id,
std::vector<AnfNodePtr> *inputs) {
FuncGraphPtr forward_graph = nullptr;
auto ic = std::find_if(cell_graph_list_.begin(), cell_graph_list_.end(),

@ -240,7 +240,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void SetNestedTopGraph(const py::object &cell, const py::args &args, const std::string &cell_id);
void MakeNestedCnode(const std::string &cell_id, const py::args &args, const ResourcePtr &resource,
const py::object &out, bool has_sens);
void SetNestedWeigthsParam(const FuncGraphPtr &newfg, const std::string &cell_id, std::vector<AnfNodePtr> *inputs);
void SetNestedWeightsParam(const FuncGraphPtr &newfg, const std::string &cell_id, std::vector<AnfNodePtr> *inputs);
bool MakeBpropNestedCnode(const py::object &cell, const py::object &out, const std::string &cell_id);
// Hold graph(forward and grad) info

@ -49,7 +49,7 @@ class Cell(Cell_):
The bprop implementation will receive a Tensor `dout` containing the gradient of the loss w.r.t.
the output, and a Tensor `out` containing the forward result. The bprop needs to compute the
gradient of the loss w.r.t. the inputs, gradient of the loss w.r.t. Parameter variables are not supported
currently.
currently. The bprop method must contain the self parameter.
Args:
auto_prefix (bool): Recursively generate namespaces. Default: True.

Loading…
Cancel
Save