Add broadcast operators (#17503)
* This PR adds broadcast for multi-process. And it could be used in dynamic graph to broadcast parameters.fix_ema
parent
2280f185d7
commit
b5f4d5ed0e
@ -0,0 +1,76 @@
|
|||||||
|
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <ostream>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
class BroadcastOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||||
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||||
|
"Input(X) of BroadcastOp should not be null.");
|
||||||
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||||
|
"Output(Output) of ConvOp should not be null.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class BroadcastOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
void Make() {
|
||||||
|
AddInput("X", "(Tensor), tensor to be broadcast.");
|
||||||
|
AddOutput("Out", "(Tensor) the result of broadcast.");
|
||||||
|
AddAttr<bool>(
|
||||||
|
"sync_mode",
|
||||||
|
"(bool) whether to synchronize the CUDA stream after nccl call.")
|
||||||
|
.SetDefault(false);
|
||||||
|
AddAttr<int>("root", "(int).").SetDefault(0).EqualGreaterThan(0);
|
||||||
|
AddComment(R"DOC(
|
||||||
|
***Broadcast Operator***
|
||||||
|
|
||||||
|
Call NCCL Broadcast internally. Note that this op must be used when one
|
||||||
|
thread is managing one GPU device.
|
||||||
|
)DOC");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class BroadcastOpKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
PADDLE_THROW("Broadcast op can run on gpu place only for now.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
namespace plat = paddle::platform;
|
||||||
|
|
||||||
|
REGISTER_OP_WITHOUT_GRADIENT(broadcast, ops::BroadcastOp,
|
||||||
|
ops::BroadcastOpMaker);
|
||||||
|
|
||||||
|
REGISTER_OP_CPU_KERNEL(broadcast, ops::BroadcastOpKernel<float>,
|
||||||
|
ops::BroadcastOpKernel<double>,
|
||||||
|
ops::BroadcastOpKernel<int>,
|
||||||
|
ops::BroadcastOpKernel<int64_t>,
|
||||||
|
ops::BroadcastOpKernel<plat::float16>);
|
@ -0,0 +1,81 @@
|
|||||||
|
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/data_type.h"
|
||||||
|
#include "paddle/fluid/framework/lod_tensor.h"
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
|
||||||
|
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
|
||||||
|
#include "paddle/fluid/platform/nccl_helper.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
namespace plat = paddle::platform;
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class NCCLBroadcastOpKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
||||||
|
"The place of ExecutionContext should be CUDAPlace.");
|
||||||
|
|
||||||
|
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
|
||||||
|
int dev_id = boost::get<platform::CUDAPlace>(ctx.GetPlace()).device;
|
||||||
|
int root_dev_id = ctx.Attr<int>("root");
|
||||||
|
|
||||||
|
auto in = ctx.Input<framework::Tensor>("X");
|
||||||
|
auto out = ctx.Output<framework::Tensor>("Out");
|
||||||
|
PADDLE_ENFORCE(out->IsInitialized(),
|
||||||
|
"Currently, the output of broadcast op must be initialized, "
|
||||||
|
"because this op can only be an In-Place operation.");
|
||||||
|
void* send_recv_buffer = out->mutable_data<T>(ctx.GetPlace());
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
send_recv_buffer, in->data<void>(),
|
||||||
|
"Currently, the broadcast op can only be an In-Place operation.");
|
||||||
|
|
||||||
|
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
||||||
|
auto comm = dev_ctx.nccl_comm();
|
||||||
|
auto stream = dev_ctx.stream();
|
||||||
|
|
||||||
|
PADDLE_ENFORCE(platform::dynload::ncclBcast(
|
||||||
|
send_recv_buffer, static_cast<size_t>(in->numel()),
|
||||||
|
platform::ToNCCLDataType(in->type()), root_dev_id, comm, stream));
|
||||||
|
|
||||||
|
VLOG(3) << "Bcast " << ctx.Inputs("X")[0] << ", (" << in->numel() << ")"
|
||||||
|
<< " From " << root_dev_id << " to " << dev_id;
|
||||||
|
|
||||||
|
if (ctx.Attr<bool>("sync_mode")) {
|
||||||
|
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
PADDLE_THROW("PaddlePaddle should compile with GPU.");
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
REGISTER_OP_CUDA_KERNEL(broadcast, ops::NCCLBroadcastOpKernel<float>,
|
||||||
|
ops::NCCLBroadcastOpKernel<double>,
|
||||||
|
ops::NCCLBroadcastOpKernel<int>,
|
||||||
|
ops::NCCLBroadcastOpKernel<int64_t>,
|
||||||
|
ops::NCCLBroadcastOpKernel<plat::float16>);
|
@ -0,0 +1,43 @@
|
|||||||
|
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except jin compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import os
|
||||||
|
from ..layers import collective
|
||||||
|
|
||||||
|
__parallel_ctx__clz__ = None
|
||||||
|
|
||||||
|
|
||||||
|
def _is_data_parallel_mode():
|
||||||
|
global __parallel_ctx__clz__
|
||||||
|
return __parallel_ctx__clz__ is not None and int(
|
||||||
|
os.getenv("PADDLE_TRAINERS_NUM", "1")) > 1
|
||||||
|
|
||||||
|
|
||||||
|
def _set_parallel_ctx(nccl_parallel_context):
|
||||||
|
global __parallel_ctx__clz__
|
||||||
|
assert __parallel_ctx__clz__ is None, \
|
||||||
|
"ParallelContext can only be initialized once."
|
||||||
|
__parallel_ctx__clz__ = nccl_parallel_context
|
||||||
|
|
||||||
|
|
||||||
|
def _init_parallel_ctx():
|
||||||
|
global __parallel_ctx__clz__
|
||||||
|
assert __parallel_ctx__clz__ is not None, \
|
||||||
|
"ParallelContext should be initialized."
|
||||||
|
__parallel_ctx__clz__.init()
|
||||||
|
|
||||||
|
|
||||||
|
def _broadcast_parameters(parameters):
|
||||||
|
for param in parameters:
|
||||||
|
if param.trainable:
|
||||||
|
collective._broadcast(param, 0, sync_mode=True)
|
Loading…
Reference in new issue