|
|
|
@ -145,12 +145,14 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
} else if (IsDistTrainOp(*op, send_op)) {
|
|
|
|
|
CreateComputationalOps(&result, *op, 1);
|
|
|
|
|
} else if (IsScaleLossOp(*op)) {
|
|
|
|
|
CreateComputationalOps(&result, *op, places_.size());
|
|
|
|
|
// user can customize loss@grad if not use_default_grad_scale_
|
|
|
|
|
if (use_default_grad_scale_) {
|
|
|
|
|
CreateScaleLossGradOp(&result);
|
|
|
|
|
}
|
|
|
|
|
is_forwarding = false;
|
|
|
|
|
} else {
|
|
|
|
|
if (IsScaleLossGradOp(*op)) continue;
|
|
|
|
|
int op_dev_id = GetOpDeviceID(var_name_on_devices, *op);
|
|
|
|
|
if (op_dev_id == -1) { // var on all device
|
|
|
|
|
CreateComputationalOps(&result, *op, places_.size());
|
|
|
|
@ -399,6 +401,12 @@ void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const {
|
|
|
|
|
// FIXME(yy): Do not hard code like this
|
|
|
|
|
return op.OutputArgumentNames().size() == 1 &&
|
|
|
|
|
(op.OutputArgumentNames()[0]) == loss_var_name_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool MultiDevSSAGraphBuilder::IsScaleLossGradOp(const OpDesc &op) const {
|
|
|
|
|
// FIXME(yy): Do not hard code like this
|
|
|
|
|
return op.OutputArgumentNames().size() == 1 &&
|
|
|
|
|
op.OutputArgumentNames()[0] == GradVarName(loss_var_name_);
|
|
|
|
|