|
|
|
@ -199,6 +199,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
BuildStrategy::GradientScaleStrategy::kCustomized) {
|
|
|
|
|
CreateScaleLossGradOp(&result);
|
|
|
|
|
}
|
|
|
|
|
// This assumes the backward generating code will ensure IsScaleLossOp
|
|
|
|
|
// is true only for the op that scale the final scalar loss.
|
|
|
|
|
// It also assumes backward op will always follow the forward op in
|
|
|
|
|
// the block.
|
|
|
|
|
is_forwarding = false;
|
|
|
|
|
} else {
|
|
|
|
|
int op_dev_id = GetOpDeviceID(var_name_on_devices, *op);
|
|
|
|
@ -243,6 +247,9 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
InsertAllReduceOp(&result, g_name);
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
LOG(FATAL) << "Unknown reduce strategy ";
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} catch (boost::bad_get e) {
|
|
|
|
@ -261,7 +268,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
}
|
|
|
|
|
/*
|
|
|
|
|
Dependency graph has been constructed. However, there are still data
|
|
|
|
|
harzaeds need to be handled.
|
|
|
|
|
hazards need to be handled.
|
|
|
|
|
*/
|
|
|
|
|
PolishGraphToSupportDataHazards(&result);
|
|
|
|
|
|
|
|
|
@ -449,6 +456,8 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
|
|
|
|
|
return var;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Find the first occurence of `prev_op_name` and make current `op` depend
|
|
|
|
|
// on it.
|
|
|
|
|
void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
|
|
|
|
|
const std::string &prev_op_name) const {
|
|
|
|
|
for (auto &prev_op : result->ops_) {
|
|
|
|
@ -469,6 +478,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Create RPC related op handles that connects its in ops and out ops.
|
|
|
|
|
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
|
|
|
|
|
const OpDesc &op) const {
|
|
|
|
|
result->ops_.emplace_back(
|
|
|
|
|