|
|
@ -11,6 +11,7 @@
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy
|
|
|
|
import paddle
|
|
|
|
import paddle
|
|
|
|
from paddle.fluid.framework import core
|
|
|
|
from paddle.fluid.framework import core
|
|
|
|
from paddle.fluid import compiler
|
|
|
|
from paddle.fluid import compiler
|
|
|
@ -51,13 +52,21 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
|
|
|
|
# should fix the variable
|
|
|
|
# should fix the variable
|
|
|
|
def _setup_nccl_op(self, startup_program, main_program, build_strategy):
|
|
|
|
def _setup_nccl_op(self, startup_program, main_program, build_strategy):
|
|
|
|
trainer_endpoints = self.role_maker._get_trainer_endpoints()
|
|
|
|
trainer_endpoints = self.role_maker._get_trainer_endpoints()
|
|
|
|
trainers = trainer_endpoints
|
|
|
|
other_trainers = copy.copy(trainer_endpoints)
|
|
|
|
|
|
|
|
|
|
|
|
trainer_id = self.role_maker._worker_index()
|
|
|
|
trainer_id = self.role_maker._worker_index()
|
|
|
|
current_endpoint = self.role_maker._get_trainer_endpoints()[trainer_id]
|
|
|
|
current_endpoint = self.role_maker._get_trainer_endpoints()[trainer_id]
|
|
|
|
|
|
|
|
other_trainers.remove(current_endpoint)
|
|
|
|
|
|
|
|
|
|
|
|
trainer_endpoints_env = ",".join(trainer_endpoints)
|
|
|
|
trainer_endpoints_env = ",".join(trainer_endpoints)
|
|
|
|
trainers_num = self.role_maker._worker_num()
|
|
|
|
trainers_num = self.role_maker._worker_num()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if trainer_id == 0:
|
|
|
|
|
|
|
|
wait_server_ready(other_trainers)
|
|
|
|
|
|
|
|
|
|
|
|
nccl_id_var = startup_program.global_block().create_var(
|
|
|
|
nccl_id_var = startup_program.global_block().create_var(
|
|
|
|
name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW)
|
|
|
|
name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW)
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(1, build_strategy.nccl_comm_num):
|
|
|
|
for i in range(1, build_strategy.nccl_comm_num):
|
|
|
|
startup_program.global_block().create_var(
|
|
|
|
startup_program.global_block().create_var(
|
|
|
|
name="NCCLID_{}".format(i),
|
|
|
|
name="NCCLID_{}".format(i),
|
|
|
@ -90,7 +99,6 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
|
|
|
|
})
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
|
|
def _try_to_compile(self, startup_program, main_program, loss):
|
|
|
|
def _try_to_compile(self, startup_program, main_program, loss):
|
|
|
|
import copy
|
|
|
|
|
|
|
|
dist_strategy = self.user_defined_strategy
|
|
|
|
dist_strategy = self.user_defined_strategy
|
|
|
|
local_build_strategy = paddle.fluid.BuildStrategy()
|
|
|
|
local_build_strategy = paddle.fluid.BuildStrategy()
|
|
|
|
local_build_strategy.enable_sequential_execution = \
|
|
|
|
local_build_strategy.enable_sequential_execution = \
|
|
|
@ -148,13 +156,12 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
|
|
|
|
|
|
|
|
|
|
|
|
sync_allreduce = dist_strategy.sync_nccl_allreduce
|
|
|
|
sync_allreduce = dist_strategy.sync_nccl_allreduce
|
|
|
|
if sync_allreduce:
|
|
|
|
if sync_allreduce:
|
|
|
|
exe_strategy.num_threads = local_build_strategy.nccl_comm_num + 1
|
|
|
|
exe_strategy.num_threads = max(
|
|
|
|
if local_build_strategy.use_hierarchical_allreduce:
|
|
|
|
local_build_strategy.nccl_comm_num + 1,
|
|
|
|
exe_strategy.num_threads = 2 * local_build_strategy.nccl_comm_num + 1
|
|
|
|
exe_strategy.num_threads)
|
|
|
|
if exe_strategy.num_threads > 4:
|
|
|
|
if local_build_strategy.nccl_comm_num > 1:
|
|
|
|
logging.warn(
|
|
|
|
logging.warn(
|
|
|
|
"if you use hierachical_allreduce or "
|
|
|
|
"nccl_comm_num > 1, you may need to set sync_nccl_allreduce=False to ensure that different nccl comms can overlap"
|
|
|
|
"with multi nccl comm, please set distributed_strategy.sync_nccl_allreduce=False"
|
|
|
|
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
sync_batch_norm = local_build_strategy.sync_batch_norm
|
|
|
|
sync_batch_norm = local_build_strategy.sync_batch_norm
|
|
|
@ -167,6 +174,11 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
|
|
|
|
"set num_threads=1, nccl_comm_num=1, hierachical_allreduce=False."
|
|
|
|
"set num_threads=1, nccl_comm_num=1, hierachical_allreduce=False."
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# NOTE. compatible with compiler, otherwise these values will be overwritten by compiler
|
|
|
|
|
|
|
|
main_program._nccl_comm_num = local_build_strategy.nccl_comm_num
|
|
|
|
|
|
|
|
main_program._use_hierarchical_allreduce = local_build_strategy.use_hierarchical_allreduce
|
|
|
|
|
|
|
|
main_program._hierarchical_allreduce_inter_nranks = local_build_strategy.hierarchical_allreduce_inter_nranks
|
|
|
|
|
|
|
|
|
|
|
|
# TODO(guru4elephant): should be an independent optimizer
|
|
|
|
# TODO(guru4elephant): should be an independent optimizer
|
|
|
|
self._setup_nccl_op(startup_program, main_program, local_build_strategy)
|
|
|
|
self._setup_nccl_op(startup_program, main_program, local_build_strategy)
|
|
|
|
|
|
|
|
|
|
|
|