|
|
|
@ -12,8 +12,10 @@
|
|
|
|
|
// See the License for the specific language governing permissions and
|
|
|
|
|
// limitations under the License.
|
|
|
|
|
#include "paddle/fluid/framework/details/sparse_all_reduce_op_handle.h"
|
|
|
|
|
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <utility>
|
|
|
|
|
|
|
|
|
|
#include "dgc/dgc.h"
|
|
|
|
|
#include "paddle/fluid/framework/details/container_cast.h"
|
|
|
|
|
#include "paddle/fluid/framework/details/reduce_and_gather.h"
|
|
|
|
@ -38,18 +40,23 @@ SparseAllReduceOpHandle::SparseAllReduceOpHandle(
|
|
|
|
|
is_encoded_(is_encoded),
|
|
|
|
|
nranks_(nranks) {
|
|
|
|
|
// TODO(gongwb) :polish them!
|
|
|
|
|
PADDLE_ENFORCE_EQ(is_encoded, true);
|
|
|
|
|
PADDLE_ENFORCE_EQ(is_encoded, true, platform::errors::InvalidArgument(
|
|
|
|
|
"The argument is_encoded is false."));
|
|
|
|
|
VLOG(1) << "Use dgc allreduce mode"
|
|
|
|
|
<< ", nranks:" << nranks_;
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GT(local_scopes_.size(), 0);
|
|
|
|
|
PADDLE_ENFORCE_GT(local_scopes_.size(), 0,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"The number of local scope should be > 0, but got %zu.",
|
|
|
|
|
local_scopes_.size()));
|
|
|
|
|
auto nranks_name = g_dgc_nranks;
|
|
|
|
|
for (size_t i = 0; i < local_scopes_.size(); ++i) {
|
|
|
|
|
auto *local_scope = local_scopes_[i];
|
|
|
|
|
auto nranks_var = local_scope->FindVar(nranks_name);
|
|
|
|
|
if (nranks_var == nullptr) {
|
|
|
|
|
PADDLE_THROW("not find nranks_var:%s", nranks_name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
nranks_var, platform::errors::NotFound(
|
|
|
|
|
"Variable %s is not found in scope.", nranks_name));
|
|
|
|
|
|
|
|
|
|
float *dgc_nranks = nranks_var->GetMutable<LoDTensor>()->data<float>();
|
|
|
|
|
*dgc_nranks = nranks;
|
|
|
|
@ -64,10 +71,18 @@ void SparseAllReduceOpHandle::RunImplEncoded() {
|
|
|
|
|
auto out_var_handles = DynamicCast<VarHandle>(this->Outputs());
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
in_var_handles.size(), places_.size(),
|
|
|
|
|
"The NoDummyInputSize should be equal to the number of places.");
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"The number of input variables should be equal to the number of "
|
|
|
|
|
"places, but got the number of input variables is %zu and the the "
|
|
|
|
|
"number of places is %zu.",
|
|
|
|
|
in_var_handles.size(), places_.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
in_var_handles.size(), out_var_handles.size(),
|
|
|
|
|
"The NoDummyInputSize and NoDummyOutputSize should be equal.");
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"The number of input variables should be equal to the number of "
|
|
|
|
|
"output variables, but got the number of input variables is %zu and "
|
|
|
|
|
"the the number of output variables is %zu.",
|
|
|
|
|
in_var_handles.size(), out_var_handles.size()));
|
|
|
|
|
|
|
|
|
|
std::vector<const LoDTensor *> ins;
|
|
|
|
|
std::vector<LoDTensor *> gathers;
|
|
|
|
@ -80,14 +95,17 @@ void SparseAllReduceOpHandle::RunImplEncoded() {
|
|
|
|
|
|
|
|
|
|
auto encode_var_name = original_name + g_dgc_encoded;
|
|
|
|
|
auto *in_var = local_scope->FindVar(encode_var_name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(in_var, "%s should not be null", encode_var_name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
in_var, platform::errors::NotFound("Variable %s is not found in scope.",
|
|
|
|
|
encode_var_name));
|
|
|
|
|
auto &in = in_var->Get<LoDTensor>();
|
|
|
|
|
ins.emplace_back(&in);
|
|
|
|
|
|
|
|
|
|
auto gather_var_name = original_name + g_dgc_gather;
|
|
|
|
|
auto *gather_var = local_scope->FindVar(gather_var_name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(gather_var, "%s should not be null",
|
|
|
|
|
gather_var_name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
gather_var, platform::errors::NotFound(
|
|
|
|
|
"Variable %s is not found in scope.", gather_var));
|
|
|
|
|
auto *gather = gather_var->GetMutable<LoDTensor>();
|
|
|
|
|
gathers.emplace_back(gather);
|
|
|
|
|
|
|
|
|
@ -100,14 +118,26 @@ void SparseAllReduceOpHandle::RunImplEncoded() {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(platform::is_gpu_place(ins[0]->place()));
|
|
|
|
|
PADDLE_ENFORCE(platform::is_gpu_place(outs[0]->place()));
|
|
|
|
|
PADDLE_ENFORCE(nccl_ctxs_, "nccl_ctxs should not be nullptr.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
platform::is_gpu_place(ins[0]->place()), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The place of input variable should be CUDAPlace, but got %s.",
|
|
|
|
|
ins[0]->place()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
platform::is_gpu_place(outs[0]->place()), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The place of input variable should be CUDAPlace, but got %s.",
|
|
|
|
|
outs[0]->place()));
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(nccl_ctxs_, platform::errors::PreconditionNotMet(
|
|
|
|
|
"The nccl contexts are NULL."));
|
|
|
|
|
|
|
|
|
|
int dtype = -1;
|
|
|
|
|
size_t in_numel = 0;
|
|
|
|
|
size_t out_numel = 0;
|
|
|
|
|
PADDLE_ENFORCE(nranks_ > 1);
|
|
|
|
|
PADDLE_ENFORCE_GT(
|
|
|
|
|
nranks_, 1,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"The number of ranks should be > 1, but got %d.", nranks_));
|
|
|
|
|
std::vector<std::function<void()>> all_gather_calls;
|
|
|
|
|
std::vector<std::function<void()>> sparse_reduce_calls;
|
|
|
|
|
|
|
|
|
@ -123,8 +153,16 @@ void SparseAllReduceOpHandle::RunImplEncoded() {
|
|
|
|
|
|
|
|
|
|
dtype = (dtype == -1) ? platform::ToNCCLDataType(in.type()) : dtype;
|
|
|
|
|
in_numel = (in_numel == 0) ? static_cast<size_t>(in.numel()) : in_numel;
|
|
|
|
|
PADDLE_ENFORCE(in_numel % 2 == 0);
|
|
|
|
|
PADDLE_ENFORCE(in_numel / 2 == static_cast<size_t>(k));
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_numel % 2, 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The number of elements of input variable should be "
|
|
|
|
|
"even, but got %zu.",
|
|
|
|
|
in_numel));
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_numel / 2, static_cast<size_t>(k),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The number of elements of input variable should be "
|
|
|
|
|
"even, but got %zu.",
|
|
|
|
|
in_numel));
|
|
|
|
|
out_numel = (out_numel == 0) ? static_cast<size_t>(out.numel()) : out_numel;
|
|
|
|
|
|
|
|
|
|
int dev_id = BOOST_GET_CONST(platform::CUDAPlace, place).device;
|
|
|
|
@ -154,7 +192,8 @@ void SparseAllReduceOpHandle::RunImplEncoded() {
|
|
|
|
|
PADDLE_ENFORCE_EQ(paddle::communication::dgc::sparseReduce(
|
|
|
|
|
gather_buff, k, out_tensor_buf,
|
|
|
|
|
static_cast<int>(out_numel), nranks_, stream),
|
|
|
|
|
true);
|
|
|
|
|
true, platform::errors::Unavailable(
|
|
|
|
|
"Calling sparseReduce() failed."));
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -187,11 +226,16 @@ void SparseAllReduceOpHandle::SparseAllReduceFunc(
|
|
|
|
|
int SparseAllReduceOpHandle::GetKValue(const std::string &grad_name) {
|
|
|
|
|
auto original_name = paddle::framework::GradOriginalVarName(grad_name);
|
|
|
|
|
auto var_name = original_name + g_dgc_k;
|
|
|
|
|
PADDLE_ENFORCE(local_scopes_.size() > 0);
|
|
|
|
|
PADDLE_ENFORCE_GT(local_scopes_.size(), 0,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"The number of local scope should be > 0, but got %zu.",
|
|
|
|
|
local_scopes_.size()));
|
|
|
|
|
|
|
|
|
|
auto *scope = local_exec_scopes_[0];
|
|
|
|
|
auto var = scope->FindVar(var_name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
var, platform::errors::NotFound("Variable %s is not found in scope.",
|
|
|
|
|
var_name));
|
|
|
|
|
auto tensor = var->Get<LoDTensor>().data<float>();
|
|
|
|
|
return *tensor;
|
|
|
|
|
}
|
|
|
|
@ -202,15 +246,22 @@ bool SparseAllReduceOpHandle::IsEncoded() {
|
|
|
|
|
}
|
|
|
|
|
auto counter_name = g_dgc_counter_name;
|
|
|
|
|
auto step_name = g_dgc_rampup_begin_step;
|
|
|
|
|
PADDLE_ENFORCE(local_scopes_.size() > 0);
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GT(local_scopes_.size(), 0,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"The number of local scope should be > 0, but got %zu.",
|
|
|
|
|
local_scopes_.size()));
|
|
|
|
|
|
|
|
|
|
auto *local_scope = local_exec_scopes_[0];
|
|
|
|
|
auto count_var = local_scope->FindVar(counter_name);
|
|
|
|
|
auto step_var = local_scope->FindVar(step_name);
|
|
|
|
|
if (count_var == nullptr || step_var == nullptr) {
|
|
|
|
|
PADDLE_THROW("not find count_var:%s or step_var:%s", counter_name,
|
|
|
|
|
step_var);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
count_var, platform::errors::NotFound(
|
|
|
|
|
"Variable %s is not found in scope.", counter_name));
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
step_var, platform::errors::NotFound("Variable %s is not found in scope.",
|
|
|
|
|
step_var));
|
|
|
|
|
|
|
|
|
|
float count = *count_var->Get<LoDTensor>().data<float>();
|
|
|
|
|
float step = *step_var->Get<LoDTensor>().data<float>();
|
|
|
|
|