|
|
|
@ -80,7 +80,7 @@ std::unique_ptr<ir::Graph> LockFreeOptimizePass::ApplyImpl(
|
|
|
|
|
if (IsVarNameEndsWith(merged_grad_var, kGradVarSuffix) &&
|
|
|
|
|
merged_grad_var->outputs.size() == 1u) {
|
|
|
|
|
ir::Node* opt_node = merged_grad_var->outputs[0];
|
|
|
|
|
LOG(ERROR) << "Found opt node " << opt_node->Name();
|
|
|
|
|
VLOG(3) << "Found opt node " << opt_node->Name();
|
|
|
|
|
|
|
|
|
|
// find the backward op connected with sum op
|
|
|
|
|
for (ir::Node* unmerged_grad_var : node->inputs) {
|
|
|
|
@ -88,13 +88,13 @@ std::unique_ptr<ir::Graph> LockFreeOptimizePass::ApplyImpl(
|
|
|
|
|
unmerged_grad_var->inputs.size() == 1u) {
|
|
|
|
|
ir::Node* backward_op = unmerged_grad_var->inputs[0];
|
|
|
|
|
|
|
|
|
|
LOG(ERROR) << "Found backward_op " << backward_op->Name();
|
|
|
|
|
VLOG(3) << "Found backward_op " << backward_op->Name();
|
|
|
|
|
|
|
|
|
|
// find the forward op related to the backward op
|
|
|
|
|
ir::Node* forward_op =
|
|
|
|
|
FindForwardOpViaBackwardOp(graph.get(), backward_op);
|
|
|
|
|
|
|
|
|
|
LOG(ERROR) << "Found forward_op " << forward_op->Name();
|
|
|
|
|
VLOG(3) << "Found forward_op " << forward_op->Name();
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(forward_op);
|
|
|
|
|
|
|
|
|
@ -114,29 +114,28 @@ std::unique_ptr<ir::Graph> LockFreeOptimizePass::ApplyImpl(
|
|
|
|
|
for (Node* optimize_op : sum_op_output->outputs) {
|
|
|
|
|
if (optimize_op->NodeType() == Node::Type::kOperation &&
|
|
|
|
|
optimize_op->Name() == kOptimizerType) {
|
|
|
|
|
LOG(ERROR) << "remove optimize_op: " << optimize_op->Name() << "_"
|
|
|
|
|
<< optimize_op->id();
|
|
|
|
|
VLOG(3) << "remove optimize_op: " << optimize_op->Name() << "_"
|
|
|
|
|
<< optimize_op->id();
|
|
|
|
|
graph->RemoveNode(optimize_op);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
LOG(ERROR) << "remove sum_op_output: " << sum_op_output->Name() << "_"
|
|
|
|
|
<< sum_op_output->id();
|
|
|
|
|
VLOG(3) << "remove sum_op_output: " << sum_op_output->Name() << "_"
|
|
|
|
|
<< sum_op_output->id();
|
|
|
|
|
graph->RemoveNode(sum_op_output);
|
|
|
|
|
}
|
|
|
|
|
LOG(ERROR) << "remove sum_op: " << sum_op->Name() << "_" << sum_op->id();
|
|
|
|
|
VLOG(3) << "remove sum_op: " << sum_op->Name() << "_" << sum_op->id();
|
|
|
|
|
graph->RemoveNode(sum_op);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto* node : graph->Nodes()) {
|
|
|
|
|
for (Node* output_node : node->outputs) {
|
|
|
|
|
if (output_node->Name() == "sgd") {
|
|
|
|
|
LOG(ERROR) << "Node link to SGD: " << node->Name() << "_" << node->id()
|
|
|
|
|
<< " --> " << output_node->Name() << "_"
|
|
|
|
|
<< output_node->id();
|
|
|
|
|
VLOG(3) << "Node link to SGD: " << node->Name() << "_" << node->id()
|
|
|
|
|
<< " --> " << output_node->Name() << "_" << output_node->id();
|
|
|
|
|
for (Node* input_node : node->inputs) {
|
|
|
|
|
LOG(ERROR) << "SGD Input link: " << input_node->Name() << "_"
|
|
|
|
|
<< input_node->id() << " --> " << node->Name() << "_"
|
|
|
|
|
<< node->id();
|
|
|
|
|
VLOG(3) << "SGD Input link: " << input_node->Name() << "_"
|
|
|
|
|
<< input_node->id() << " --> " << node->Name() << "_"
|
|
|
|
|
<< node->id();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -226,8 +225,7 @@ ir::Node* LockFreeOptimizePass::CreateNewSGDNode(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
LOG(ERROR) << "Create new opt node" << sgd_node->Name() << "_"
|
|
|
|
|
<< sgd_node->id();
|
|
|
|
|
VLOG(3) << "Create new opt node" << sgd_node->Name() << "_" << sgd_node->id();
|
|
|
|
|
|
|
|
|
|
return sgd_node;
|
|
|
|
|
}
|
|
|
|
|