|
|
|
@ -18,6 +18,7 @@ limitations under the License. */
|
|
|
|
|
#include "lod_tensor.h"
|
|
|
|
|
#include "lod_tensor_array.h"
|
|
|
|
|
#include "op_registry.h"
|
|
|
|
|
#include "paddle/fluid/framework/feed_fetch_type.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/concat.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
@ -158,15 +159,8 @@ struct ScaleLossGradOpHandle : public OpHandle {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct FetchedData {
|
|
|
|
|
public:
|
|
|
|
|
std::vector<framework::LoDTensor> tensors_;
|
|
|
|
|
|
|
|
|
|
explicit FetchedData(size_t num_fetched) { tensors_.resize(num_fetched); }
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct FetchOpHandle : public OpHandle {
|
|
|
|
|
std::shared_ptr<FetchedData> data_;
|
|
|
|
|
FeedFetchList *data_;
|
|
|
|
|
size_t offset_;
|
|
|
|
|
std::vector<Scope *> *local_scopes_;
|
|
|
|
|
std::vector<LoDTensor> tensors_;
|
|
|
|
@ -175,15 +169,26 @@ struct FetchOpHandle : public OpHandle {
|
|
|
|
|
for (auto *input_var : inputs_) {
|
|
|
|
|
input_var->pending_ops_.erase(this);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Lazily merge tensors. Will faster code.
|
|
|
|
|
MergeTensors();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Wait(platform::DeviceContext *waited_dev) override {
|
|
|
|
|
PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void WaitAndMergeCPUTensors() const {
|
|
|
|
|
// Wait fetch stream done.
|
|
|
|
|
for (auto &ctx : dev_ctx_) {
|
|
|
|
|
ctx.second->Wait();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<const LoDTensor *> tensors_ptr;
|
|
|
|
|
tensors_ptr.reserve(tensors_.size());
|
|
|
|
|
for (auto &t : tensors_) {
|
|
|
|
|
tensors_ptr.emplace_back(&t);
|
|
|
|
|
}
|
|
|
|
|
data_->at(offset_).MergeLoDTensor(tensors_ptr, platform::CPUPlace());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void RunImpl() override {
|
|
|
|
|
for (auto *input : inputs_) {
|
|
|
|
@ -208,15 +213,6 @@ struct FetchOpHandle : public OpHandle {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
void MergeTensors() const {
|
|
|
|
|
std::vector<const LoDTensor *> tensors_ptr;
|
|
|
|
|
for (auto &t : tensors_) {
|
|
|
|
|
tensors_ptr.emplace_back(&t);
|
|
|
|
|
}
|
|
|
|
|
data_->tensors_[offset_].MergeLoDTensor(tensors_ptr, platform::CPUPlace());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class ParallelExecutorPrivate {
|
|
|
|
@ -325,7 +321,6 @@ struct NCCLAllReduceOpHandle : public OpHandle {
|
|
|
|
|
: member_(member) {}
|
|
|
|
|
|
|
|
|
|
void Wait(platform::DeviceContext *waited_dev) override {
|
|
|
|
|
VLOG(3) << "Wait nccl all reduce op";
|
|
|
|
|
OpHandle::Wait(waited_dev);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -355,6 +350,11 @@ struct NCCLAllReduceOpHandle : public OpHandle {
|
|
|
|
|
|
|
|
|
|
auto &lod_tensor = s->FindVar(var_name)->Get<framework::LoDTensor>();
|
|
|
|
|
void *buffer = const_cast<void *>(lod_tensor.data<void>());
|
|
|
|
|
uintptr_t buf = reinterpret_cast<uintptr_t>(buffer);
|
|
|
|
|
if (buf % sizeof(float) != 0) {
|
|
|
|
|
VLOG(3) << "Buffer is not aligned " << buf;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (dtype == -1) {
|
|
|
|
|
dtype = ToNCCLDataType(lod_tensor.type());
|
|
|
|
|
}
|
|
|
|
@ -680,7 +680,7 @@ void ParallelExecutor::BuildNCCLCommunicator() const {
|
|
|
|
|
void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
|
|
|
|
|
const std::string &fetched_var_name) {
|
|
|
|
|
bool use_event = true;
|
|
|
|
|
auto fetched_data = std::make_shared<FetchedData>(fetch_tensors.size());
|
|
|
|
|
FeedFetchList fetched_data(fetch_tensors.size());
|
|
|
|
|
// Version --> VarHandle
|
|
|
|
|
member_->exception_.reset();
|
|
|
|
|
std::unordered_map<VarHandleBase *, std::atomic<bool>> pending_vars;
|
|
|
|
@ -728,7 +728,7 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
|
|
|
|
|
auto &vars = fetched_vars[var_name];
|
|
|
|
|
fetch_ops.emplace_back();
|
|
|
|
|
FetchOpHandle *op = &fetch_ops.back();
|
|
|
|
|
op->data_ = fetched_data;
|
|
|
|
|
op->data_ = &fetched_data;
|
|
|
|
|
op->offset_ = i;
|
|
|
|
|
op->local_scopes_ = &member_->local_scopes_;
|
|
|
|
|
for (auto &p : member_->places_) {
|
|
|
|
@ -786,9 +786,12 @@ void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
|
|
|
|
|
platform::DeviceContextPool::Instance().Get(p)->Wait();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fetch_ops.clear();
|
|
|
|
|
*member_->global_scope_->Var(fetched_var_name)->GetMutable<LoDTensorArray>() =
|
|
|
|
|
fetched_data->tensors_;
|
|
|
|
|
for (auto &fetch_op : fetch_ops) {
|
|
|
|
|
fetch_op.WaitAndMergeCPUTensors();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
*member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() =
|
|
|
|
|
fetched_data;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ParallelExecutor::RunOp(
|
|
|
|
|