|
|
@ -234,18 +234,17 @@ static bool AllGradInSet(const std::vector<std::string>& names,
|
|
|
|
return true;
|
|
|
|
return true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<OpDescBind> CreatBackwardOps(
|
|
|
|
std::vector<std::unique_ptr<OpDescBind>> MakeGradOpDescs(
|
|
|
|
const std::unique_ptr<OpDescBind>& op_desc_ptr,
|
|
|
|
const std::unique_ptr<OpDescBind>& op_desc,
|
|
|
|
unordered_map<std::string>& no_grad_vars) {
|
|
|
|
unordered_set<std::string>& no_grad_vars) {
|
|
|
|
const OpDescBind& op_desc = *op_desc_ptr;
|
|
|
|
std::vector<std::unique_ptr<OpDescBind>> grad_op_descs;
|
|
|
|
std::vector<OpDescBind> grad_op_descs;
|
|
|
|
|
|
|
|
// All input gradients of forwarding operator do not need to calculat.
|
|
|
|
// All input gradients of forwarding operator do not need to calculat.
|
|
|
|
if (AllGradInSet(op_desc_.InputArgumentNames(), kGradVarSuffix,
|
|
|
|
if (AllGradInSet(op_desc->InputArgumentNames(), kGradVarSuffix,
|
|
|
|
no_grad_vars)) {
|
|
|
|
no_grad_vars)) {
|
|
|
|
return grad_op_descs; // empty vector
|
|
|
|
return grad_op_descs; // empty vector
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// All output gradients of forwarding operator do not need to calculate.
|
|
|
|
// All output gradients of forwarding operator do not need to calculate.
|
|
|
|
const std::vector<std::string>& outputs = op_desc_.OutputArugumentNames();
|
|
|
|
const std::vector<std::string>& outputs = op_desc->OutputArugumentNames();
|
|
|
|
if (AllGradInSet(outputs, kGradVarSuffix, no_grad_vars)) {
|
|
|
|
if (AllGradInSet(outputs, kGradVarSuffix, no_grad_vars)) {
|
|
|
|
for (const std::string& name : outputs) {
|
|
|
|
for (const std::string& name : outputs) {
|
|
|
|
no_grad_vars.insert(GradVarName(name));
|
|
|
|
no_grad_vars.insert(GradVarName(name));
|
|
|
@ -255,50 +254,54 @@ std::vector<OpDescBind> CreatBackwardOps(
|
|
|
|
|
|
|
|
|
|
|
|
grad_op_descs = OpRegistry::CreateGradOpDescs(op_desc);
|
|
|
|
grad_op_descs = OpRegistry::CreateGradOpDescs(op_desc);
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<OpDescBind> fill_zeros_ops;
|
|
|
|
std::list<std::unique_ptr<OpDescBind>> pending_fill_zeros_ops;
|
|
|
|
for (OpDescBind& desc : grad_op_descs) {
|
|
|
|
for (auto& desc : grad_op_descs) {
|
|
|
|
for (const std::string& in_name : desc.InputArgumentNames()) {
|
|
|
|
for (const std::string& in_name : desc->InputArgumentNames()) {
|
|
|
|
if (no_grad_vars.count(in_name)) {
|
|
|
|
if (no_grad_vars.count(in_name)) {
|
|
|
|
std::string prefix = in_name.substr(
|
|
|
|
std::string prefix = in_name.substr(
|
|
|
|
0, in_name.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1);
|
|
|
|
0, in_name.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1);
|
|
|
|
std::string new_name = prefix + kZeroVarSuffix;
|
|
|
|
std::string new_name = prefix + kZeroVarSuffix;
|
|
|
|
desc.Rename(in_name, new_name);
|
|
|
|
desc->Rename(in_name, new_name);
|
|
|
|
OpDescBind op_desc_bind(
|
|
|
|
OpDescBind* fill_zeros_op = new OpDescBind(
|
|
|
|
{"fill_zeros_like", {{"X", {prefix}}}, {{"Y", {new_name}}}, {}});
|
|
|
|
"fill_zeros_like", {{"X", {prefix}}}, {{"Y", {new_name}}}, {});
|
|
|
|
fill_zeros_ops.push_back(op_desc_bind);
|
|
|
|
pending_fill_zeros_ops.push_back({fill_zeros_op});
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
for (const std::string& out_name : desc.OutputName()) {
|
|
|
|
for (const std::string& out_name : desc->OutputArgumentName()) {
|
|
|
|
if (no_grad_vars.count(out_name)) {
|
|
|
|
if (no_grad_vars.count(out_name)) {
|
|
|
|
desc.Rename(out_name, kEmptyVarName);
|
|
|
|
desc->Rename(out_name, kEmptyVarName);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
grad_op_descs.insert(grad_op_descs.begin(), fill_zeros_ops.begin(),
|
|
|
|
grad_op_descs.insert(std::begin(grad_op_descs),
|
|
|
|
fill_zeros_ops.end());
|
|
|
|
std::begin(pending_fill_zeros_ops),
|
|
|
|
|
|
|
|
std::end(pending_fill_zeros_ops));
|
|
|
|
|
|
|
|
|
|
|
|
// TODO (fengjiayi): RNN op
|
|
|
|
// TODO (fengjiayi): RNN op
|
|
|
|
return grad_op_descs;
|
|
|
|
return grad_op_descs;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void AppendBackwardOps(BlockDescBind& block_desc,
|
|
|
|
void AppendBackwardOpDescs(
|
|
|
|
|
|
|
|
BlockDescBind& block_desc,
|
|
|
|
const std::unordered_set<std::string>& no_grad_vars) {
|
|
|
|
const std::unordered_set<std::string>& no_grad_vars) {
|
|
|
|
std::unordered_map<std::string, std::vector<size_t>> dup_out_ops;
|
|
|
|
std::unordered_map<std::string, std::vector<size_t>> dup_out_ops;
|
|
|
|
size_t grad_desc_idx = 0;
|
|
|
|
size_t grad_desc_idx = 0;
|
|
|
|
std::deque<std::unique_ptr<OpDescBind>> op_descs = block_desc.ops_;
|
|
|
|
std::deque<std::unique_ptr<OpDescBind>> block_op_descs = block_desc.ops_;
|
|
|
|
std::vector<std::unique_ptr<OpDescBind>> grad_op_descs;
|
|
|
|
std::vector<std::unique_ptr<OpDescBind>> backward_descs;
|
|
|
|
for (auto it = op_descs.rbegin(); it != op_descs.rend(); ++it) {
|
|
|
|
for (auto it = block_op_descs.rbegin(); it != block_op_descs.rend(); ++it) {
|
|
|
|
std::vector<OpDescBind> op_grads = CreatBackwardOps(*it, no_grad_vars);
|
|
|
|
std::vector<std::unique_ptr<OpDescBind>> op_grads =
|
|
|
|
for (const OpDescBind& desc : op_grads) {
|
|
|
|
MakeGradOpDescs(*it, no_grad_vars);
|
|
|
|
for (const std::string& out_name : desc.OutputArugumentNames()) {
|
|
|
|
for (const auto& desc : op_grads) {
|
|
|
|
|
|
|
|
for (const std::string& out_name : desc->OutputArugumentNames()) {
|
|
|
|
dup_out_ops[out_name].emplace_back(grad_desc_idx);
|
|
|
|
dup_out_ops[out_name].emplace_back(grad_desc_idx);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
++grad_desc_idx;
|
|
|
|
++grad_desc_idx;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
grad_op_descs.insert(grad_op_descs.end(), op_grads.begin(), op_grads.end());
|
|
|
|
backward_descs.insert(backward_descs.end(), op_grads.begin(),
|
|
|
|
|
|
|
|
op_grads.end());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// Check whether some variables are written more than once
|
|
|
|
// Check whether some variables are written more than once
|
|
|
|
std::list<std::pair<size_t, OpDescBind>> pending_sum_ops;
|
|
|
|
std::list<std::pair<size_t, std::unique_ptr<OpDescBind>>> pending_sum_ops;
|
|
|
|
for (const auto& dup : dup_out_ops) {
|
|
|
|
for (const auto& dup : dup_out_ops) {
|
|
|
|
const std::string& out_name = dup.first;
|
|
|
|
const std::string& out_name = dup.first;
|
|
|
|
const std::vector<size_t> dup_op = dup.second;
|
|
|
|
const std::vector<size_t> dup_op = dup.second;
|
|
|
@ -306,25 +309,27 @@ void AppendBackwardOps(BlockDescBind& block_desc,
|
|
|
|
std::vector<std::string> sum_op_inputs;
|
|
|
|
std::vector<std::string> sum_op_inputs;
|
|
|
|
for (size_t i = 0; i < dup_op.size(); ++i) {
|
|
|
|
for (size_t i = 0; i < dup_op.size(); ++i) {
|
|
|
|
std::string new_name = out_name + "@RENAME@" + std::to_string(i);
|
|
|
|
std::string new_name = out_name + "@RENAME@" + std::to_string(i);
|
|
|
|
grad_op_descs[dup_op[i]].Rename(out_name, new_name);
|
|
|
|
backward_descs[dup_op[i]]->Rename(out_name, new_name);
|
|
|
|
sum_op_inputs.emplace_back(new_name);
|
|
|
|
sum_op_inputs.emplace_back(new_name);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
pending_sum_ops.push_back(
|
|
|
|
OpDescBind* sum_op = new OpDescBind("sum", {{"X", sum_op_inputs}},
|
|
|
|
{dup_op.back(),
|
|
|
|
{{"Out", {out_name}}}, {});
|
|
|
|
OpDescBind(
|
|
|
|
pending_sum_ops.push_back({dup_op.back(), {sum_op}});
|
|
|
|
{"sum", {{"X", {sum_op_inputs}}}, {{"Out", {out_name}}}, {}})});
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
pending_sum_ops.sort(
|
|
|
|
pending_sum_ops.sort(
|
|
|
|
[](const std::pair<size_t, OpDescBind>& a,
|
|
|
|
[](const std::pair<size_t, std::unique_ptr<OpDescBind>>& a,
|
|
|
|
const std::pair<size_t, OpDescBind>& b) { return a.first > b.first; });
|
|
|
|
const std::pair<size_t, std::unique_ptr<OpDescBind>>& b) {
|
|
|
|
|
|
|
|
return a.first > b.first;
|
|
|
|
|
|
|
|
});
|
|
|
|
for (auto& p : pending_sum_ops) {
|
|
|
|
for (auto& p : pending_sum_ops) {
|
|
|
|
grad_op_descs.insert(grad_op_descs.begin() + p.first + 1,
|
|
|
|
backward_descs.insert(backward_descs.begin() + p.first + 1,
|
|
|
|
std::move(p.second));
|
|
|
|
std::move(p.second));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
// Append grad_op_descs to BlockDescBind::ops_
|
|
|
|
// Append backward_descs to BlockDescBind::ops_
|
|
|
|
for () {
|
|
|
|
block_op_descs.insert(std::end(block_op_descs), std::begin(backward_descs),
|
|
|
|
}
|
|
|
|
std::end(backward_descs));
|
|
|
|
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
} // namespace framework
|
|
|
|