parent
5c333e4143
commit
f28ae6e4b1
@ -0,0 +1,74 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
NCCLAllReduceOpHandle::NCCLAllReduceOpHandle(
|
||||
const std::vector<Scope *> &local_scopes,
|
||||
const std::vector<platform::Place> &places,
|
||||
const platform::NCCLContextMap &ctxs)
|
||||
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(ctxs) {
|
||||
for (auto &p : places_) {
|
||||
this->dev_ctx_[p] = nccl_ctxs_.DevCtx(p);
|
||||
}
|
||||
}
|
||||
|
||||
void NCCLAllReduceOpHandle::RunImpl() {
|
||||
if (inputs_.size() == 1) {
|
||||
return; // No need to all reduce when GPU count = 1;
|
||||
} else {
|
||||
// Wait input done
|
||||
for (auto *in : inputs_) {
|
||||
auto &p = static_cast<VarHandle *>(in)->place_;
|
||||
in->generated_op_->Wait(dev_ctx_[p]);
|
||||
}
|
||||
|
||||
auto &var_name = static_cast<VarHandle *>(this->inputs_[0])->name_;
|
||||
int dtype = -1;
|
||||
size_t numel = 0;
|
||||
|
||||
platform::NCCLGroupGuard guard;
|
||||
|
||||
for (size_t i = 0; i < local_scopes_.size(); ++i) {
|
||||
auto &p = places_[i];
|
||||
auto *s = local_scopes_[i];
|
||||
int dev_id = boost::get<platform::CUDAPlace>(p).device;
|
||||
|
||||
auto &lod_tensor = s->FindVar(var_name)->Get<LoDTensor>();
|
||||
void *buffer = const_cast<void *>(lod_tensor.data<void>());
|
||||
uintptr_t buf = reinterpret_cast<uintptr_t>(buffer);
|
||||
if (buf % sizeof(float) != 0) {
|
||||
VLOG(3) << "Buffer is not aligned " << buf;
|
||||
}
|
||||
|
||||
if (dtype == -1) {
|
||||
dtype = platform::ToNCCLDataType(lod_tensor.type());
|
||||
}
|
||||
|
||||
if (numel == 0) {
|
||||
numel = static_cast<size_t>(lod_tensor.numel());
|
||||
}
|
||||
auto &nccl_ctx = nccl_ctxs_.at(dev_id);
|
||||
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
|
||||
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
|
||||
nccl_ctx.comm_, nccl_ctx.stream()));
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,41 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/details/op_handle_base.h"
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
#include "paddle/fluid/platform/nccl_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
struct NCCLAllReduceOpHandle : public OpHandleBase {
|
||||
const std::vector<Scope *> &local_scopes_;
|
||||
const std::vector<platform::Place> &places_;
|
||||
const platform::NCCLContextMap &nccl_ctxs_;
|
||||
|
||||
NCCLAllReduceOpHandle(const std::vector<Scope *> &local_scopes,
|
||||
const std::vector<platform::Place> &places,
|
||||
const platform::NCCLContextMap &ctxs);
|
||||
|
||||
protected:
|
||||
void RunImpl() override;
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
Loading…
Reference in new issue