Add FetchAsyncOpHandle, and use it in FastThreadedExecutor (#26643)

* optimized transformation form tensor to numpy, test=develop

* Modify fetch op handle, from memcpy Sync to memcpy Async, test=develop

* modify CUDAPinnedPlace to CPUPlace, test=develop

* modify CPUPlace to CUDAPinnedPlace, and set default inplace to false, test=develop

* revert fetch_op_handle, add fetch_async_op_handle, test=develop

* revert fetch_op_handle, add fetch_async_op_handle, test=develop

* fix error msg report, test=develop

* fix bug in cpuplace, test=develop

* fix bug in unmerge and tensorarray modle, test=develop

* fix bug, double copy gpu memory, test=develop

* fix chenweihang¡¯s review advice, test=develop
revert-26856-strategy_example2
wanghuancoder 5 years ago committed by GitHub
parent 5205748481
commit 2d2c31a63a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,6 +3,7 @@ cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(fetch_async_op_handle SRCS fetch_async_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(share_tensor_buffer_functor SRCS share_tensor_buffer_functor.cc DEPS framework_proto scope place operator op_registry)
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
@ -98,7 +99,7 @@ cc_library(scope_buffered_ssa_graph_executor SRCS scope_buffered_ssa_graph_execu
#cc_test(reduce_op_handle_test SRCS reduce_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
# device_context reduce_op_handle )
cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executor.cc
DEPS fetch_op_handle ssa_graph_executor scope simple_threadpool device_context)
DEPS fetch_async_op_handle ssa_graph_executor scope simple_threadpool device_context)
cc_test(fused_broadcast_op_test SRCS fused_broadcast_op_handle_test.cc DEPS fused_broadcast_op_handle)
cc_test(exception_holder_test SRCS exception_holder_test.cc )

@ -18,7 +18,8 @@
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/details/fetch_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/fetch_async_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/profiler.h"
@ -120,6 +121,11 @@ FetchResultType FastThreadedSSAGraphExecutor::Run(
}
// Wait FetchOps.
ClearFetchOp(graph_, &fetch_ops);
for (auto &place : places_) {
fetch_ctxs_.Get(place)->Wait();
}
return fetches;
}
@ -162,8 +168,8 @@ void FastThreadedSSAGraphExecutor::InsertFetchOps(
ir::Node *fetch_node =
graph_->CreateEmptyNode("fetch", ir::Node::Type::kOperation);
auto *op = new FetchOpHandle(fetch_node, fetches, i, &local_scopes_,
&local_exec_scopes_, return_merged);
auto *op = new FetchAsyncOpHandle(fetch_node, fetches, i, &local_scopes_,
&local_exec_scopes_, return_merged);
fetch_ops->emplace_back(op);
for (auto &p : places_) {
@ -174,6 +180,14 @@ void FastThreadedSSAGraphExecutor::InsertFetchOps(
op->AddInput(var);
}
for (auto *var : vars) {
auto *op = var->GeneratedOp();
auto *compute_op = dynamic_cast<details::ComputationOpHandle *>(op);
if (compute_op) {
compute_op->SetLockAndRecordEventFree(false);
}
}
int dep = static_cast<int>(op->NotReadyInputSize());
(*op_deps)[op] = dep;
if (dep == 0) {
@ -261,7 +275,7 @@ void FastThreadedSSAGraphExecutor::PrepareAtomicOpDeps() {
const ir::Graph &FastThreadedSSAGraphExecutor::Graph() const { return *graph_; }
void FastThreadedSSAGraphExecutor::RecordOps(OpHandleBase *op) {
if (strategy_.num_threads_ == 1 && !dynamic_cast<FetchOpHandle *>(op)) {
if (strategy_.num_threads_ == 1 && !dynamic_cast<FetchAsyncOpHandle *>(op)) {
traced_ops_.emplace_back(op);
}
}

File diff suppressed because it is too large Load Diff

@ -0,0 +1,63 @@
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace framework {
namespace details {
struct FetchAsyncOpHandle : public OpHandleBase {
public:
FetchAsyncOpHandle(ir::Node *node, FetchResultType *data, size_t offset,
std::vector<Scope *> *local_scopes,
std::vector<Scope *> *local_exec_scopes,
bool return_merged);
~FetchAsyncOpHandle();
void RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) override;
std::string Name() const override;
bool IsMultiDeviceTransfer() override;
protected:
void RunImpl() override;
std::vector<Scope *> GetLocalScopes() override { return *local_scopes_; }
void FetchMergedLodTensor(
const std::vector<const LoDTensor *> &src_lodtensors,
LoDTensor *dst_lodtensor);
private:
FetchResultType *data_;
size_t offset_;
std::vector<Scope *> *local_scopes_;
std::vector<Scope *> *local_exec_scopes_;
bool return_merged_;
};
} // namespace details
} // namespace framework
} // namespace paddle

@ -36,7 +36,8 @@ FetchOpHandle::FetchOpHandle(ir::Node *node, FetchResultType *data,
FetchOpHandle::~FetchOpHandle() {}
void FetchOpHandle::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error");
PADDLE_THROW(platform::errors::PermissionDenied(
"No nodes need to wait FetchOp. Unexpceted Error."));
}
static void CheckDims(const framework::DDim &tensor_dims,

@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
#include "paddle/fluid/framework/details/fetch_async_op_handle.h"
namespace paddle {
namespace framework {
@ -23,9 +24,11 @@ void ClearFetchOp(ir::Graph* graph, std::vector<OpHandleBase*>* fetch_ops) {
if (fetch_ops->empty()) return;
for (auto& op : *fetch_ops) {
PADDLE_ENFORCE_NOT_NULL(
dynamic_cast<FetchOpHandle*>(op),
"The input ops of ClearFetchOp function should be FetchOpHandle.");
PADDLE_ENFORCE_EQ(dynamic_cast<FetchOpHandle*>(op) != nullptr ||
dynamic_cast<FetchAsyncOpHandle*>(op) != nullptr,
true,
"The input ops of ClearFetchOp function should be "
"FetchOpHandle or FetchAsyncOpHandle.");
for (auto& out_var : op->Node()->outputs) {
graph->RemoveNode(out_var);
}

Loading…
Cancel
Save