|
|
|
@ -731,6 +731,7 @@ bool DistSSAGraphBuilder::DealWithSpecialOp(ir::Graph *result,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
insert_op = true;
|
|
|
|
|
need_broadcast_var_ = true;
|
|
|
|
|
} else if (OpHaveRole(*node, OpRole::kDist)) {
|
|
|
|
|
int op_dev_id = CreateDistTrainOp(result, node);
|
|
|
|
|
if (node->Op()->Type() == "concat") {
|
|
|
|
@ -925,14 +926,17 @@ void DistSSAGraphBuilder::InsertCollectiveOp(ir::Graph *result,
|
|
|
|
|
|
|
|
|
|
void DistSSAGraphBuilder::InsertPostprocessOps(ir::Graph *result) const {
|
|
|
|
|
// only GPU reduce mode need to broadcast parameters to each device.
|
|
|
|
|
if (UseGPU() && strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
|
|
|
|
|
if (strategy_.fuse_broadcast_op_) {
|
|
|
|
|
CreateFusedBroadcastOp(result, bcast_var_name_set_);
|
|
|
|
|
} else {
|
|
|
|
|
for (size_t dev_id = 0; dev_id < bcast_var_name_set_.size(); ++dev_id) {
|
|
|
|
|
auto &to_bcast_set = bcast_var_name_set_[dev_id];
|
|
|
|
|
for (auto &bcast_name : to_bcast_set) {
|
|
|
|
|
CreateBroadcastOp(result, bcast_name, dev_id);
|
|
|
|
|
if (UseGPU()) {
|
|
|
|
|
if (need_broadcast_var_ ||
|
|
|
|
|
strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) {
|
|
|
|
|
if (strategy_.fuse_broadcast_op_) {
|
|
|
|
|
CreateFusedBroadcastOp(result, bcast_var_name_set_);
|
|
|
|
|
} else {
|
|
|
|
|
for (size_t dev_id = 0; dev_id < bcast_var_name_set_.size(); ++dev_id) {
|
|
|
|
|
auto &to_bcast_set = bcast_var_name_set_[dev_id];
|
|
|
|
|
for (auto &bcast_name : to_bcast_set) {
|
|
|
|
|
CreateBroadcastOp(result, bcast_name, dev_id);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|