|
|
|
@ -186,7 +186,7 @@ void SyncBatchNormFunctor(const framework::ExecutionContext &ctx,
|
|
|
|
|
auto gplace = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
|
|
|
|
|
memory::Copy(platform::CPUPlace(), c_g_st_d, gplace, stats, bytes, 0);
|
|
|
|
|
|
|
|
|
|
#ifndef WIN32
|
|
|
|
|
#ifdef PADDLE_WITH_NCCL
|
|
|
|
|
auto *comm = dev_ctx.nccl_comm();
|
|
|
|
|
if (comm) {
|
|
|
|
|
int dtype = platform::ToNCCLDataType(mean_out->type());
|
|
|
|
@ -460,7 +460,7 @@ void SyncBatchNormGradFunctor(
|
|
|
|
|
dy_d, x_d, saved_mean, N, fsize, C, stats);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifndef WIN32
|
|
|
|
|
#ifdef PADDLE_WITH_NCCL
|
|
|
|
|
auto *comm = dev_ctx.nccl_comm();
|
|
|
|
|
if (comm) {
|
|
|
|
|
int dtype = platform::ToNCCLDataType(scale->type());
|
|
|
|
|