|
|
|
@ -222,5 +222,61 @@ std::unique_ptr<OperatorBase> Backward(
|
|
|
|
|
return BackwardRecursive(forwardOp, no_grad_names, uid);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ==================================== //
|
|
|
|
|
|
|
|
|
|
static bool AllGradInSet(const std::vector<std::string>& names,
|
|
|
|
|
const std::unordered_set<std::string>& set) {
|
|
|
|
|
for (const std::string& name : names) {
|
|
|
|
|
if (!set.count(GradVarName(name))) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<OpDescBind> CreatBackwardOps(
|
|
|
|
|
const OpDescBind& op_desc, unordered_map<std::string>& no_grad_vars) {
|
|
|
|
|
std::vector<OpDescBind> grad_op_descs;
|
|
|
|
|
// All input gradients of forwarding operator do not need to calculat.
|
|
|
|
|
if (AllGradInSet(op_desc_.InputNames(), kGradVarSuffix, no_grad_vars)) {
|
|
|
|
|
return grad_op_descs; // empty vector
|
|
|
|
|
}
|
|
|
|
|
// All output gradients of forwarding operator do not need to calculate.
|
|
|
|
|
const std::vector<std::string>& outputs = op_desc_.OutputNames();
|
|
|
|
|
if (AllGradInSet(outputs, kGradVarSuffix, no_grad_vars)) {
|
|
|
|
|
for (const std::string& name : outputs) {
|
|
|
|
|
no_grad_vars.insert(GradVarName(name));
|
|
|
|
|
}
|
|
|
|
|
return grad_op_descs; // empty vector
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
grad_op_descs = OpRegistry::CreateGradOpDescs(op_desc);
|
|
|
|
|
|
|
|
|
|
std::vector<OpDescBind> fill_zeros_ops;
|
|
|
|
|
for (OpDescBind& desc : grad_op_descs) {
|
|
|
|
|
for (const std::string& in_name : desc.InputNames()) {
|
|
|
|
|
if (no_grad_vars.count(in_name)) {
|
|
|
|
|
std::string prefix = in_name.substr(
|
|
|
|
|
0, in_name.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1);
|
|
|
|
|
std::string new_name = prefix + kZeroVarSuffix;
|
|
|
|
|
desc.Rename(in_name, new_name);
|
|
|
|
|
OpDescBind op_desc_bind(
|
|
|
|
|
{"fill_zeros_like", {{"X", {prefix}}}, {{"Y", {new_name}}}, {}});
|
|
|
|
|
fill_zeros_ops.push_back(op_desc_bind);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
for (const std::string& out_name : desc.OutputName()) {
|
|
|
|
|
if (no_grad_vars.count(out_name)) {
|
|
|
|
|
desc.Rename(out_name, kEmptyVarName);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
grad_op_descs.insert(grad_op_descs.begin(), fill_zeros_ops.begin(),
|
|
|
|
|
fill_zeros_ops.end());
|
|
|
|
|
|
|
|
|
|
// TODO (fengjiayi): RNN op
|
|
|
|
|
return grad_op_descs;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
|
} // namespace paddle
|
|
|
|
|