[Kunlun]PR3: add xpu executor, multi xpu card train function optimization (#30317)

revert-31068-fix_conv3d_windows
liuyuhui 5 years ago committed by GitHub
parent 8489d4f76f
commit 843dc3cdbd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -265,7 +265,7 @@ target_link_libraries(executor while_op_helper executor_gc_helper recurrent_op_h
cc_library(parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor async_ssa_graph_executor
graph build_strategy collective_helper
graph build_strategy bind_threaded_ssa_graph_executor collective_helper
fast_threaded_ssa_graph_executor variable_helper)
cc_library(executor_cache SRCS executor_cache.cc DEPS executor)

@ -101,6 +101,8 @@ cc_library(scope_buffered_monitor SRCS scope_buffered_monitor.cc DEPS scope prof
cc_library(scope_buffered_ssa_graph_executor SRCS scope_buffered_ssa_graph_executor.cc DEPS ssa_graph_executor scope_buffered_monitor)
#cc_test(reduce_op_handle_test SRCS reduce_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
# device_context reduce_op_handle )
cc_library(bind_threaded_ssa_graph_executor SRCS bind_threaded_ssa_graph_executor.cc
DEPS fetch_op_handle gflags ssa_graph_executor scope simple_threadpool device_context)
cc_library(fast_threaded_ssa_graph_executor SRCS fast_threaded_ssa_graph_executor.cc
DEPS fetch_async_op_handle ssa_graph_executor scope simple_threadpool device_context)
cc_test(fused_broadcast_op_test SRCS fused_broadcast_op_handle_test.cc DEPS fused_broadcast_op_handle)

@ -0,0 +1,107 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <ThreadPool.h>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/details/exception_holder.h"
#include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
#if defined(PADDLE_WITH_XPU)
namespace paddle {
namespace framework {
class Scope;
namespace details {
struct RunningItem {
std::atomic<int> dep_num;
OpHandleBase *op;
};
class OpHandleBase;
class BindThreadedSSAGraphExecutor : public SSAGraphExecutor {
public:
BindThreadedSSAGraphExecutor(const ExecutionStrategy &strategy,
const std::vector<Scope *> &local_scopes,
const std::vector<Scope *> &local_exec_scopes,
const std::vector<platform::Place> &places,
ir::Graph *graph);
// FeedFetchList Run(const std::vector<std::string> &fetch_tensors) override;
// Run a SSAGraph by a thread pool
// Use topological sort algorithm
FetchResultType Run(const std::vector<std::string> &fetch_tensors,
bool return_merged) override;
const ir::Graph &Graph() const override;
private:
FetchResultType RunMainStream(const std::vector<std::string> &fetch_tensors,
bool return_merged);
// Note(zcd): the ThreadPool should be placed last so that ThreadPool should
// be destroyed first.
ExecutionStrategy strategy_;
std::vector<Scope *> local_scopes_;
std::vector<Scope *> local_exec_scopes_;
std::vector<platform::Place> places_;
ir::Graph *graph_;
std::unordered_map<OpHandleBase *, int> op_deps_;
std::unordered_map<int, int> place_to_index_;
std::vector<OpHandleBase *> bootstrap_ops_;
std::unique_ptr<int[]> stream_op_count_;
std::future<
std::unique_ptr<std::unordered_map<OpHandleBase *, struct RunningItem>>>
atomic_op_deps_;
ExceptionHolder exception_;
std::vector<std::unique_ptr<::ThreadPool>> pool_;
::ThreadPool prepare_pool_;
::ThreadPool multi_device_op_pool_;
void RunOpAsyncMainStream(
OpHandleBase *op,
std::unordered_map<OpHandleBase *, struct RunningItem> *op_deps,
std::shared_ptr<BlockingQueue<OpHandleBase *>> ready_ops, int index);
void RunMultiDeviceOpAsync(
OpHandleBase *op,
std::unordered_map<OpHandleBase *, struct RunningItem> *op_deps,
std::shared_ptr<BlockingQueue<OpHandleBase *>> ready_ops);
void PrepareAtomicOpDeps();
int get_pool_thread_index(int device_id);
inline void ExecutionFinal(std::vector<OpHandleBase *> *fetch_ops);
void InsertFetchOps(
const std::vector<std::string> &fetch_tensors, FetchResultType *fetches,
std::unordered_map<std::string, std::vector<VarHandleBase *>>
*fetched_vars,
std::unordered_map<OpHandleBase *, struct RunningItem> *op_deps,
std::vector<OpHandleBase *> *fetch_ops,
std::vector<OpHandleBase *> *ready_fetch_ops, bool return_merged);
};
} // namespace details
} // namespace framework
} // namespace paddle
#endif

@ -215,13 +215,6 @@ void OpHandleBase::WaitInputVarGenerated(bool wait_for_feed) {
#else
PADDLE_THROW(
platform::errors::PreconditionNotMet("Not compiled with CUDA."));
#endif
} else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
dev_ctxes_.at(place)->Wait();
#else
PADDLE_THROW(
platform::errors::PreconditionNotMet("Not compiled with XPU."));
#endif
}
// There are nothing to do when the place is CPUPlace.
@ -271,19 +264,6 @@ void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) {
#else
PADDLE_THROW(
platform::errors::PreconditionNotMet("Not compiled with CUDA."));
#endif
} else if (platform::is_xpu_place(in_var_handle->place())) {
#ifdef PADDLE_WITH_XPU
PADDLE_ENFORCE_EQ(
platform::is_same_place(place, in_var_handle->place()), true,
platform::errors::InvalidArgument(
"The place of output(%s) is not consistent with the "
"place of current op(%s).",
in_var_handle->Name(), Name()));
dev_ctxes_.at(place)->Wait();
#else
PADDLE_THROW(
platform::errors::PreconditionNotMet("Not compiled with XPU."));
#endif
}
// There are nothing to do when the place is CPUPlace.

@ -22,6 +22,7 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/details/async_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/op_handle_base.h"
@ -933,10 +934,23 @@ ParallelExecutor::ParallelExecutor(const std::vector<platform::Place> &places,
exec_strategy, member_->local_scopes_, member_->local_exec_scopes_,
member_->places_, graph));
} else {
VLOG(3) << "use FastThreadedSSAGraphExecutor";
member_->executor_.reset(new details::FastThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, member_->local_exec_scopes_,
member_->places_, graph));
if (member_->use_device_ == p::kXPU) {
#if defined(PADDLE_WITH_XPU)
VLOG(3) << "use BindThreadedSSAGraphExecutor";
member_->executor_.reset(new details::BindThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_,
member_->local_exec_scopes_, member_->places_, graph));
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't use XPU device since it's not compiled with XPU,"
"Please recompile or reinstall Paddle with XPU support."));
#endif
} else {
VLOG(3) << "use FastThreadedSSAGraphExecutor";
member_->executor_.reset(new details::FastThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_,
member_->local_exec_scopes_, member_->places_, graph));
}
}
final_graphs.emplace_back(graph);
}

@ -211,7 +211,7 @@ void XPUDeviceContext::Wait() const {
"XPU API return wrong value[%d], please check whether "
"Baidu Kunlun Card is properly installed.",
ret));
xpu_wait();
xpu_wait(context_->xpu_stream);
}
Place XPUDeviceContext::GetPlace() const { return place_; }

Loading…
Cancel
Save