|
|
@ -42,7 +42,7 @@ static std::unordered_set<std::string>& CtrlFlowOps() {
|
|
|
|
static inline std::unique_ptr<OperatorBase> CreateGradOp(
|
|
|
|
static inline std::unique_ptr<OperatorBase> CreateGradOp(
|
|
|
|
const OperatorBase& op, const std::unordered_set<std::string>& no_grad_set,
|
|
|
|
const OperatorBase& op, const std::unordered_set<std::string>& no_grad_set,
|
|
|
|
std::unordered_map<std::string, std::string>* grad_to_var) {
|
|
|
|
std::unordered_map<std::string, std::string>* grad_to_var) {
|
|
|
|
OpDescBind op_desc;
|
|
|
|
OpDesc op_desc;
|
|
|
|
op_desc.SetInputMap(op.Inputs());
|
|
|
|
op_desc.SetInputMap(op.Inputs());
|
|
|
|
op_desc.SetOutputMap(op.Outputs());
|
|
|
|
op_desc.SetOutputMap(op.Outputs());
|
|
|
|
op_desc.SetType(op.Type());
|
|
|
|
op_desc.SetType(op.Type());
|
|
|
@ -53,7 +53,7 @@ static inline std::unique_ptr<OperatorBase> CreateGradOp(
|
|
|
|
grad_ops.reserve(grad_descs.size());
|
|
|
|
grad_ops.reserve(grad_descs.size());
|
|
|
|
std::transform(grad_descs.begin(), grad_descs.end(),
|
|
|
|
std::transform(grad_descs.begin(), grad_descs.end(),
|
|
|
|
std::back_inserter(grad_ops),
|
|
|
|
std::back_inserter(grad_ops),
|
|
|
|
[](const std::unique_ptr<OpDescBind>& grad_desc) {
|
|
|
|
[](const std::unique_ptr<OpDesc>& grad_desc) {
|
|
|
|
return OpRegistry::CreateOp(*grad_desc);
|
|
|
|
return OpRegistry::CreateOp(*grad_desc);
|
|
|
|
});
|
|
|
|
});
|
|
|
|
PADDLE_ENFORCE(!grad_ops.empty());
|
|
|
|
PADDLE_ENFORCE(!grad_ops.empty());
|
|
|
@ -296,7 +296,7 @@ static std::string FwdName(const std::string& grad_name) {
|
|
|
|
static void CreateGradVarInBlock(
|
|
|
|
static void CreateGradVarInBlock(
|
|
|
|
size_t grad_op_start_index,
|
|
|
|
size_t grad_op_start_index,
|
|
|
|
const std::unordered_map<std::string, std::string>& param_name_map,
|
|
|
|
const std::unordered_map<std::string, std::string>& param_name_map,
|
|
|
|
BlockDescBind* block_desc,
|
|
|
|
BlockDesc* block_desc,
|
|
|
|
std::unordered_map<std::string, GradVarInfo>* grad_var_record) {
|
|
|
|
std::unordered_map<std::string, GradVarInfo>* grad_var_record) {
|
|
|
|
auto ops = block_desc->AllOps();
|
|
|
|
auto ops = block_desc->AllOps();
|
|
|
|
for (size_t op_index = grad_op_start_index; op_index < ops.size();
|
|
|
|
for (size_t op_index = grad_op_start_index; op_index < ops.size();
|
|
|
@ -350,12 +350,11 @@ static void CreateGradVarInBlock(
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
|
|
|
|
std::vector<std::unique_ptr<OpDesc>> MakeOpGrad(
|
|
|
|
const OpDescBind* op_desc, std::unordered_set<std::string>* no_grad_vars,
|
|
|
|
const OpDesc* op_desc, std::unordered_set<std::string>* no_grad_vars,
|
|
|
|
std::unordered_map<std::string, std::string>* grad_to_var,
|
|
|
|
std::unordered_map<std::string, std::string>* grad_to_var,
|
|
|
|
const std::vector<BlockDescBind*>& grad_block =
|
|
|
|
const std::vector<BlockDesc*>& grad_block = std::vector<BlockDesc*>()) {
|
|
|
|
std::vector<BlockDescBind*>()) {
|
|
|
|
std::vector<std::unique_ptr<OpDesc>> grad_op_descs;
|
|
|
|
std::vector<std::unique_ptr<OpDescBind>> grad_op_descs;
|
|
|
|
|
|
|
|
// All input gradients of forwarding operator do not need to calculate.
|
|
|
|
// All input gradients of forwarding operator do not need to calculate.
|
|
|
|
const std::vector<std::string>& inputs = op_desc->InputArgumentNames();
|
|
|
|
const std::vector<std::string>& inputs = op_desc->InputArgumentNames();
|
|
|
|
if (AllGradInSet(inputs, *no_grad_vars)) {
|
|
|
|
if (AllGradInSet(inputs, *no_grad_vars)) {
|
|
|
@ -386,7 +385,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
|
|
|
|
.Get(op_desc->Type())
|
|
|
|
.Get(op_desc->Type())
|
|
|
|
.GradOpMaker()(*op_desc, *no_grad_vars, grad_to_var, grad_block);
|
|
|
|
.GradOpMaker()(*op_desc, *no_grad_vars, grad_to_var, grad_block);
|
|
|
|
|
|
|
|
|
|
|
|
std::list<std::unique_ptr<OpDescBind>> pending_fill_zeros_ops;
|
|
|
|
std::list<std::unique_ptr<OpDesc>> pending_fill_zeros_ops;
|
|
|
|
for (auto& desc : grad_op_descs) {
|
|
|
|
for (auto& desc : grad_op_descs) {
|
|
|
|
for (const std::string& in_name : desc->InputArgumentNames()) {
|
|
|
|
for (const std::string& in_name : desc->InputArgumentNames()) {
|
|
|
|
if (no_grad_vars->count(in_name)) {
|
|
|
|
if (no_grad_vars->count(in_name)) {
|
|
|
@ -394,8 +393,8 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
|
|
|
|
0, in_name.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1);
|
|
|
|
0, in_name.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1);
|
|
|
|
std::string new_name = prefix + kZeroVarSuffix;
|
|
|
|
std::string new_name = prefix + kZeroVarSuffix;
|
|
|
|
desc->Rename(in_name, new_name);
|
|
|
|
desc->Rename(in_name, new_name);
|
|
|
|
std::unique_ptr<OpDescBind> fill_zeros_op(
|
|
|
|
std::unique_ptr<OpDesc> fill_zeros_op(
|
|
|
|
new OpDescBind("fill_zeros_like", {{"X", {prefix}}},
|
|
|
|
new OpDesc("fill_zeros_like", {{"X", {prefix}}},
|
|
|
|
{{"Y", {new_name}}}, AttributeMap{}));
|
|
|
|
{{"Y", {new_name}}}, AttributeMap{}));
|
|
|
|
pending_fill_zeros_ops.push_back(std::move(fill_zeros_op));
|
|
|
|
pending_fill_zeros_ops.push_back(std::move(fill_zeros_op));
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -408,34 +407,33 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
|
|
|
|
return grad_op_descs;
|
|
|
|
return grad_op_descs;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
static BlockDescBind* CreateStepBlock(
|
|
|
|
static BlockDesc* CreateStepBlock(
|
|
|
|
ProgramDescBind& program_desc,
|
|
|
|
ProgramDesc& program_desc, std::unordered_set<std::string>* no_grad_vars,
|
|
|
|
std::unordered_set<std::string>* no_grad_vars,
|
|
|
|
|
|
|
|
std::unordered_map<std::string, std::string>* grad_to_var,
|
|
|
|
std::unordered_map<std::string, std::string>* grad_to_var,
|
|
|
|
int step_block_idx);
|
|
|
|
int step_block_idx);
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
|
|
|
|
std::vector<std::unique_ptr<OpDesc>> MakeBlockBackward(
|
|
|
|
ProgramDescBind& program_desc, int block_idx,
|
|
|
|
ProgramDesc& program_desc, int block_idx,
|
|
|
|
std::unordered_set<std::string>* no_grad_vars,
|
|
|
|
std::unordered_set<std::string>* no_grad_vars,
|
|
|
|
std::unordered_map<std::string, std::string>* grad_to_var) {
|
|
|
|
std::unordered_map<std::string, std::string>* grad_to_var) {
|
|
|
|
VLOG(5) << "MakeBlockBackward";
|
|
|
|
VLOG(5) << "MakeBlockBackward";
|
|
|
|
BlockDescBind* cur_block = program_desc.MutableBlock(block_idx);
|
|
|
|
BlockDesc* cur_block = program_desc.MutableBlock(block_idx);
|
|
|
|
std::vector<OpDescBind*> op_descs = cur_block->AllOps();
|
|
|
|
std::vector<OpDesc*> op_descs = cur_block->AllOps();
|
|
|
|
std::unordered_map<std::string, std::vector<size_t>> dup_out_ops;
|
|
|
|
std::unordered_map<std::string, std::vector<size_t>> dup_out_ops;
|
|
|
|
size_t grad_desc_idx = 0;
|
|
|
|
size_t grad_desc_idx = 0;
|
|
|
|
std::vector<std::unique_ptr<OpDescBind>> backward_descs;
|
|
|
|
std::vector<std::unique_ptr<OpDesc>> backward_descs;
|
|
|
|
|
|
|
|
|
|
|
|
for (auto it = op_descs.rbegin(); it != op_descs.rend(); ++it) {
|
|
|
|
for (auto it = op_descs.rbegin(); it != op_descs.rend(); ++it) {
|
|
|
|
VLOG(5) << "Making backward " << (*it)->Type() << " op";
|
|
|
|
VLOG(5) << "Making backward " << (*it)->Type() << " op";
|
|
|
|
std::vector<std::unique_ptr<OpDescBind>> op_grads;
|
|
|
|
std::vector<std::unique_ptr<OpDesc>> op_grads;
|
|
|
|
|
|
|
|
|
|
|
|
if ((*it)->Type() == "recurrent" || (*it)->Type() == "while") {
|
|
|
|
if ((*it)->Type() == "recurrent" || (*it)->Type() == "while") {
|
|
|
|
int step_block_idx = (*it)->GetBlockAttr("sub_block");
|
|
|
|
int step_block_idx = (*it)->GetBlockAttr("sub_block");
|
|
|
|
BlockDescBind* backward_block = CreateStepBlock(
|
|
|
|
BlockDesc* backward_block = CreateStepBlock(program_desc, no_grad_vars,
|
|
|
|
program_desc, no_grad_vars, grad_to_var, step_block_idx);
|
|
|
|
grad_to_var, step_block_idx);
|
|
|
|
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block});
|
|
|
|
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block});
|
|
|
|
} else if ((*it)->Type() == "conditional_block") {
|
|
|
|
} else if ((*it)->Type() == "conditional_block") {
|
|
|
|
BlockDescBind* backward_block =
|
|
|
|
BlockDesc* backward_block =
|
|
|
|
CreateStepBlock(program_desc, no_grad_vars, grad_to_var,
|
|
|
|
CreateStepBlock(program_desc, no_grad_vars, grad_to_var,
|
|
|
|
(*it)->GetBlockAttr("sub_block"));
|
|
|
|
(*it)->GetBlockAttr("sub_block"));
|
|
|
|
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block});
|
|
|
|
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block});
|
|
|
@ -463,14 +461,14 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
|
|
|
|
}
|
|
|
|
}
|
|
|
|
++grad_desc_idx;
|
|
|
|
++grad_desc_idx;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
std::transform(
|
|
|
|
std::transform(op_grads.begin(), op_grads.end(),
|
|
|
|
op_grads.begin(), op_grads.end(), std::back_inserter(backward_descs),
|
|
|
|
std::back_inserter(backward_descs),
|
|
|
|
[](std::unique_ptr<OpDescBind>& ptr) { return std::move(ptr); });
|
|
|
|
[](std::unique_ptr<OpDesc>& ptr) { return std::move(ptr); });
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
VLOG(5) << "Appending Sums";
|
|
|
|
VLOG(5) << "Appending Sums";
|
|
|
|
// Check whether some variables are written more than once
|
|
|
|
// Check whether some variables are written more than once
|
|
|
|
std::list<std::pair<size_t, std::unique_ptr<OpDescBind>>> pending_sum_ops;
|
|
|
|
std::list<std::pair<size_t, std::unique_ptr<OpDesc>>> pending_sum_ops;
|
|
|
|
for (const auto& dup : dup_out_ops) {
|
|
|
|
for (const auto& dup : dup_out_ops) {
|
|
|
|
const std::string& out_name = dup.first;
|
|
|
|
const std::string& out_name = dup.first;
|
|
|
|
const std::vector<size_t> dup_op = dup.second;
|
|
|
|
const std::vector<size_t> dup_op = dup.second;
|
|
|
@ -486,16 +484,15 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
|
|
|
|
sum_op_inputs.emplace_back(new_name);
|
|
|
|
sum_op_inputs.emplace_back(new_name);
|
|
|
|
next_g_name = sum_op_inputs.back();
|
|
|
|
next_g_name = sum_op_inputs.back();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
std::unique_ptr<OpDescBind> sum_op(
|
|
|
|
std::unique_ptr<OpDesc> sum_op(new OpDesc("sum", {{"X", sum_op_inputs}},
|
|
|
|
new OpDescBind("sum", {{"X", sum_op_inputs}}, {{"Out", {out_name}}},
|
|
|
|
{{"Out", {out_name}}},
|
|
|
|
AttributeMap{}));
|
|
|
|
AttributeMap{}));
|
|
|
|
pending_sum_ops.push_back({dup_op.back(), std::move(sum_op)});
|
|
|
|
pending_sum_ops.push_back({dup_op.back(), std::move(sum_op)});
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
pending_sum_ops.sort(
|
|
|
|
pending_sum_ops.sort([](const std::pair<size_t, std::unique_ptr<OpDesc>>& a,
|
|
|
|
[](const std::pair<size_t, std::unique_ptr<OpDescBind>>& a,
|
|
|
|
const std::pair<size_t, std::unique_ptr<OpDesc>>& b) {
|
|
|
|
const std::pair<size_t, std::unique_ptr<OpDescBind>>& b) {
|
|
|
|
|
|
|
|
return a.first > b.first;
|
|
|
|
return a.first > b.first;
|
|
|
|
});
|
|
|
|
});
|
|
|
|
for (auto& p : pending_sum_ops) {
|
|
|
|
for (auto& p : pending_sum_ops) {
|
|
|
@ -508,14 +505,13 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
|
|
|
|
return backward_descs;
|
|
|
|
return backward_descs;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
static BlockDescBind* CreateStepBlock(
|
|
|
|
static BlockDesc* CreateStepBlock(
|
|
|
|
ProgramDescBind& program_desc,
|
|
|
|
ProgramDesc& program_desc, std::unordered_set<std::string>* no_grad_vars,
|
|
|
|
std::unordered_set<std::string>* no_grad_vars,
|
|
|
|
|
|
|
|
std::unordered_map<std::string, std::string>* grad_to_var,
|
|
|
|
std::unordered_map<std::string, std::string>* grad_to_var,
|
|
|
|
int step_block_idx) {
|
|
|
|
int step_block_idx) {
|
|
|
|
auto backward_block_op_descs = MakeBlockBackward(program_desc, step_block_idx,
|
|
|
|
auto backward_block_op_descs = MakeBlockBackward(program_desc, step_block_idx,
|
|
|
|
no_grad_vars, grad_to_var);
|
|
|
|
no_grad_vars, grad_to_var);
|
|
|
|
BlockDescBind* backward_block =
|
|
|
|
BlockDesc* backward_block =
|
|
|
|
program_desc.AppendBlock(*program_desc.MutableBlock(step_block_idx));
|
|
|
|
program_desc.AppendBlock(*program_desc.MutableBlock(step_block_idx));
|
|
|
|
for (auto& ptr : backward_block_op_descs) {
|
|
|
|
for (auto& ptr : backward_block_op_descs) {
|
|
|
|
backward_block->AppendAllocatedOp(move(ptr));
|
|
|
|
backward_block->AppendAllocatedOp(move(ptr));
|
|
|
@ -524,7 +520,7 @@ static BlockDescBind* CreateStepBlock(
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
ParamGradInfoMap AppendBackward(
|
|
|
|
ParamGradInfoMap AppendBackward(
|
|
|
|
ProgramDescBind& program_desc, const VarDescBind& target,
|
|
|
|
ProgramDesc& program_desc, const VarDesc& target,
|
|
|
|
const std::unordered_set<std::string>& no_grad_vars) {
|
|
|
|
const std::unordered_set<std::string>& no_grad_vars) {
|
|
|
|
std::unordered_set<std::string> no_grad_var_names;
|
|
|
|
std::unordered_set<std::string> no_grad_var_names;
|
|
|
|
no_grad_var_names.reserve(no_grad_vars.size() + 1);
|
|
|
|
no_grad_var_names.reserve(no_grad_vars.size() + 1);
|
|
|
@ -541,8 +537,8 @@ ParamGradInfoMap AppendBackward(
|
|
|
|
PADDLE_ENFORCE(is_scalar, "target should be scalar");
|
|
|
|
PADDLE_ENFORCE(is_scalar, "target should be scalar");
|
|
|
|
VLOG(3) << "backward from loss=" << target.Name()
|
|
|
|
VLOG(3) << "backward from loss=" << target.Name()
|
|
|
|
<< " data_type=" << target.GetDataType();
|
|
|
|
<< " data_type=" << target.GetDataType();
|
|
|
|
std::unique_ptr<OpDescBind> fill_one_op(
|
|
|
|
std::unique_ptr<OpDesc> fill_one_op(
|
|
|
|
new OpDescBind("fill_constant", {}, {{"Out", {fill_one_op_out}}},
|
|
|
|
new OpDesc("fill_constant", {}, {{"Out", {fill_one_op_out}}},
|
|
|
|
{{"shape", std::vector<int>{1}},
|
|
|
|
{{"shape", std::vector<int>{1}},
|
|
|
|
{"value", static_cast<float>(1.0)},
|
|
|
|
{"value", static_cast<float>(1.0)},
|
|
|
|
{"dtype", target.GetDataType()}}));
|
|
|
|
{"dtype", target.GetDataType()}}));
|
|
|
|