|
|
|
@ -230,10 +230,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
|
|
|
|
}
|
|
|
|
|
CreateDistTrainOp(&result, *op, rpc_op_device_id);
|
|
|
|
|
}
|
|
|
|
|
if (op->Type() == "oncat") {
|
|
|
|
|
if (op->Type() == "concat") {
|
|
|
|
|
auto got = remote_vars_devices_.find(op->InputArgumentNames()[0]);
|
|
|
|
|
PADDLE_ENFORCE_NE(got != remote_vars_devices_.end(),
|
|
|
|
|
"can not find right place to concat received var.");
|
|
|
|
|
PADDLE_ENFORCE(got != remote_vars_devices_.end(),
|
|
|
|
|
"can not find right place to concat received var.");
|
|
|
|
|
CreateDistTrainOp(&result, *op, got->second);
|
|
|
|
|
} else {
|
|
|
|
|
CreateDistTrainOp(&result, *op, 0);
|
|
|
|
|