commit
fa21436d0d
@ -0,0 +1,30 @@
|
||||
INCLUDE(ExternalProject)
|
||||
|
||||
SET(THREADPOOL_SOURCE_DIR ${THIRD_PARTY_PATH}/threadpool)
|
||||
SET(THREADPOOL_INCLUDE_DIR ${THREADPOOL_SOURCE_DIR}/src/extern_threadpool)
|
||||
INCLUDE_DIRECTORIES(${THREADPOOL_INCLUDE_DIR})
|
||||
|
||||
ExternalProject_Add(
|
||||
extern_threadpool
|
||||
${EXTERNAL_PROJECT_LOG_ARGS}
|
||||
GIT_REPOSITORY "https://github.com/progschj/ThreadPool.git"
|
||||
GIT_TAG 9a42ec1329f259a5f4881a291db1dcb8f2ad9040
|
||||
PREFIX ${THREADPOOL_SOURCE_DIR}
|
||||
UPDATE_COMMAND ""
|
||||
CONFIGURE_COMMAND ""
|
||||
BUILD_COMMAND ""
|
||||
INSTALL_COMMAND ""
|
||||
TEST_COMMAND ""
|
||||
)
|
||||
|
||||
if (${CMAKE_VERSION} VERSION_LESS "3.3.0")
|
||||
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/threadpool_dummy.c)
|
||||
file(WRITE ${dummyfile} "const char *dummy_threadpool = \"${dummyfile}\";")
|
||||
add_library(simple_threadpool STATIC ${dummyfile})
|
||||
else()
|
||||
add_library(simple_threadpool INTERFACE)
|
||||
endif()
|
||||
|
||||
add_dependencies(simple_threadpool extern_threadpool)
|
||||
|
||||
LIST(APPEND external_project_dependencies simple_threadpool)
|
@ -0,0 +1,83 @@
|
||||
digraph G {
|
||||
subgraph cluster_init {
|
||||
label="Initialization"
|
||||
startup_program [label="startup", shape=box]
|
||||
node_w_g0 [label="W\nGPU0"]
|
||||
startup_program -> node_w_g0 [label="Initialize"]
|
||||
node_w_g1 [label="W\nGPU1"]
|
||||
node_w_g0 -> node_w_g1 [label="broadcast"]
|
||||
}
|
||||
|
||||
subgraph cluster_train {
|
||||
label="forward_backward"
|
||||
|
||||
subgraph cluster_gpu0 {
|
||||
label="GPU0"
|
||||
fc_0 [label="fc\nGPU0", shape=box]
|
||||
hidden_0 [label="hidden\nGPU0"]
|
||||
node_w_g0 -> fc_0
|
||||
fc_0 -> hidden_0
|
||||
loss0 [label="loss\nGPU0"]
|
||||
hidden_0 -> loss0 [label="many ops omitted"]
|
||||
scale_loss_0 [label="scale_loss_gradient\nGPU0", shape=box]
|
||||
loss_g0 [label="loss_grad\nGPU0"]
|
||||
scale_loss_0->loss_g0
|
||||
|
||||
fc_g_0 [label="w_grad\nGPU0", shape=box]
|
||||
loss0 -> fc_g_0
|
||||
loss_g0 -> fc_g_0
|
||||
hidden_0 -> fc_g_0
|
||||
}
|
||||
|
||||
subgraph cluster_gpu1 {
|
||||
label="GPU1"
|
||||
fc_1 [label="fc\nGPU1", shape=box]
|
||||
hidden_1 [label="hidden\nGPU1"]
|
||||
node_w_g1 -> fc_1
|
||||
fc_1 -> hidden_1
|
||||
loss1 [label="loss\nGPU1"]
|
||||
hidden_1 -> loss1 [label="many ops omitted"]
|
||||
scale_loss_1 [label="scale_loss_gradient\nGPU1", shape=box]
|
||||
loss_g1 [label="loss_grad\nGPU1"]
|
||||
scale_loss_1->loss_g1
|
||||
|
||||
fc_g_1 [label="w_grad\nGPU1", shape=box]
|
||||
loss1 -> fc_g_1
|
||||
loss_g1 -> fc_g_1
|
||||
hidden_1 -> fc_g_1
|
||||
}
|
||||
}
|
||||
|
||||
all_reduce_w [label="Merge Gradients(AllReduce)", shape=box]
|
||||
fc_g_0 -> all_reduce_w
|
||||
fc_g_1 -> all_reduce_w
|
||||
|
||||
fc_g_0_merged [label="w_grad\nMerged\nGPU0"]
|
||||
fc_g_1_merged [label="w_grad\nMerged\nGPU1"]
|
||||
all_reduce_w -> fc_g_0_merged
|
||||
all_reduce_w -> fc_g_1_merged
|
||||
|
||||
subgraph cluster_optimization {
|
||||
label="Optimization"
|
||||
subgraph cluster_opt_gpu0 {
|
||||
label="GPU0"
|
||||
sgd_0 [label="SGD Op\nGPU0", shape=box]
|
||||
|
||||
fc_g_0_merged -> sgd_0
|
||||
node_w_g0 -> sgd_0
|
||||
optimized_w_0 [label="Optimized W\nGPU0"]
|
||||
sgd_0 -> optimized_w_0
|
||||
}
|
||||
subgraph cluster_opt_gpu1 {
|
||||
label="GPU1"
|
||||
sgd_1 [label="SGD Op\nGPU1", shape=box]
|
||||
|
||||
fc_g_1_merged -> sgd_1
|
||||
node_w_g1 -> sgd_1
|
||||
optimized_w_1 [label="Optimized W\nGPU0"]
|
||||
sgd_1 -> optimized_w_1
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
After Width: | Height: | Size: 175 KiB |
@ -0,0 +1,104 @@
|
||||
# ParallelExecutor
|
||||
|
||||
## Background
|
||||
|
||||
Neural network models are defined as a `ProgramDesc` in Fluid. The `ProgramDesc` can be executed by an interpreter(i.e. the `executor` concept in Fluid). The instructions or operators in a `Program` will be executed, and the results will be fetched in Python side.
|
||||
|
||||
The executor is a very naive interpreter. It runs operators one by one. We can use `Parallel.Do` to support data parallelism, however, lacking device information in `ProgramDesc`; it is not possible to optimize the performance of `Parallel.Do`.
|
||||
|
||||
We want a `ProgramDesc` can be run on different nodes. It is better not to contain device information in `ProgramDesc`. However, we can write a high-performance interpreter, which can hold an alternative intermediate representation of `ProgramDesc`, to take full usage of Multi-GPUs.
|
||||
|
||||
ParallelExecutor is an interpreter of `ProgramDesc` which will [out-of-order execute](https://en.wikipedia.org/wiki/Out-of-order_execution) `Program` in data parallelism mode and maximise the utility of Multi-GPUs.
|
||||
|
||||
|
||||
## Overview of MultiGPUs logic
|
||||
|
||||
The ParallelExecutor takes the startup program and main program as inputs. The parameters will be initialised on `GPU0` by startup program and will broadcast to multi-GPUs. The main program will be duplicated into multi-GPUs. The gradient will be merged during each iteration, and each device will optimize parameters independently. Since the gradients on each device will be merged before parameter optimization, the parameters will be the same on each device and it does not need to be broadcast the parameters.
|
||||
|
||||

|
||||
|
||||
There are several optimizations for this logic.
|
||||
|
||||
1. We use an alternate representation in ParallelExecutor. It because the device information is critical for performance optimization.
|
||||
2. The execution is out-of-order, i.e., an operator will be executed whenever the inputs of the operator are ready.
|
||||
* GPU is a high-performance device; only one CPU thread cannot fulfil one GPU. So there is a thread pool to execute operators.
|
||||
* Out-of-order also helps transpilers to generate `ProgramDesc`. It is no need to concern about the best order of performance when implementing a transpiler.
|
||||
3. The streams of computation, merge gradients and fetch data are different.
|
||||
|
||||
The performance of `ResNeXt152` on `TitanX` which `batch_size=12` is shown below.
|
||||
|
||||
| Number of GPUs | 1 | 2 | 3 | 4|
|
||||
| --- | --- | --- | --- | --- |
|
||||
| Image/Sec | 17.9906 | 25.771 | 36.911 | 48.8428 |
|
||||
| Speed Up | N/A | 1.43247029 | 2.05168255 | 2.71490667 |
|
||||
|
||||
|
||||
## Static single assignment Graph
|
||||
|
||||
[Static single assignment form](https://en.wikipedia.org/wiki/Static_single_assignment_form)(`SSA` for short) is a common form for compiler optimization. To implement concurrent execution, we uses an `SSA` graph as an intermedia representation of `ProgramDesc`.
|
||||
|
||||
The `Program` is a directed acyclic graph, since a variable can be assigned multiple times. We enforce a variable will be assigned once, by adding version number to varaibles. We parsing the `Program` into a `SSA` graph. Also, ProgramExecutor duplicate `Program` into multi-devices. We also add a device number to varaibles and insert `NCCLAllReduce` into Graph.
|
||||
|
||||
The data structure of `SSA` graph is:
|
||||
|
||||
```c++
|
||||
struct VarHandleBase {
|
||||
OpHandleBase* generated_op_;
|
||||
vector<OpHandleBase*> pending_ops_;
|
||||
|
||||
string name;
|
||||
Place place;
|
||||
size_t version;
|
||||
};
|
||||
|
||||
struct OpHandleBase {
|
||||
vector<OpHandleBase*> inputs_;
|
||||
vector<OpHnadleBase*> outputs_;
|
||||
};
|
||||
|
||||
struct SSAGraph {
|
||||
// vars on each devices.
|
||||
// * the vars in each map in vector is on different device.
|
||||
// * the map is mapping a variable name to variable handles
|
||||
// with different versions
|
||||
vector<std::unordered_map<string, vector<VarHandleBase>>> vars_;
|
||||
|
||||
// All ops
|
||||
vector<OpHandleBase> ops_;
|
||||
};
|
||||
```
|
||||
The variable handles are the wrapper of `Variables`. The operator handles are the wrapper of `OperatorBase`. Some `OpHandle` is not an `OperatorBase`, such as `NCCLAllReduceOpHandle`, because `AllReduceOpHandle` will use new device contexts.
|
||||
|
||||
When the `ProgramDesc` converted into an `SSA` Graph, the [data hazard](https://en.wikipedia.org/wiki/Hazard_(computer_architecture)) problem is also need to be taken care. The dummy variables, which represent the dependency between operators, will be manually inserted into SSA graph to resolve the [data hazard](https://en.wikipedia.org/wiki/Hazard_(computer_architecture)) problem.
|
||||
|
||||
## Execute SSA Graph
|
||||
|
||||
The SSA graph can be out-of-order executed by an approximate [topological sorting](https://en.wikipedia.org/wiki/Topological_sorting) algorithm. The algorithm is
|
||||
|
||||
1. Maintaining a map of an operator and its needed input number.
|
||||
2. If a variable is not generated by an operator, i.e., `var.generated_op == nullptr`, decrease the needed input number of its pending operators.
|
||||
3. If there is an operator which needed input number is decreased to zero, just run this operator.
|
||||
4. After run this operator, just mark the variables are generated and repeat step 2 until all variables are generated.
|
||||
|
||||
Running an operator can be asynchronized. There is a thread pool to execute an `SSA` graph.
|
||||
|
||||
## Synchronize GPU Kernels
|
||||
|
||||
The GPU is a non-blocking device. The different streams need be synchronized when switing streams. In current implementation, the synchronization based on the following algorithm:
|
||||
|
||||
1. `OpHandle` will record `DeviceContext` that it is used.
|
||||
2. In `OpHandle::Run`, if the `DeviceContext` of current operator is different from `DeviceContext` of any input variable, just wait the generate operator of this input variable.
|
||||
|
||||
The `wait` are implemented by two strategies:
|
||||
|
||||
1. Invoke `DeviceContext->Wait()`, It will wait all operators on this device contexts complete.
|
||||
2. Uses `cudaStreamWaitEvent` to sending a event to the stream. It is a non-blocking call. The wait operators will be executed in GPU.
|
||||
|
||||
Generally, the `cudaStreamWaitEvent` will have a better perforamnce. However, `DeviceContext->Wait()` strategy is easier to debug. The strategy can be changed in runtime.
|
||||
|
||||
## What's next?
|
||||
|
||||
* Merging gradient of dense parameters has been done. However, the merging of sparse parameters has not been done.
|
||||
* The CPU version of Parallel Executor has not been implemented. The out-of-order logic will make CPU compuatation faster, too.
|
||||
* A better strategy to merge gradients can be introduced. We can shrink the gradients from `float32` to `int8` or `int4` while merging. It will significantly speed up multi-GPUs training without much loss of precision.
|
||||
* Combine multi-Nodes implementation. By the benifit of out-of-order, sending and recving operator can be an blocking operator, and the transpiler does not need to concern about the best position of operator.
|
@ -0,0 +1,21 @@
|
||||
cc_library(var_handle SRCS var_handle.cc DEPS place)
|
||||
cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context)
|
||||
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
|
||||
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
|
||||
nv_library(nccl_all_reduce_op_handle SRCS nccl_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
|
||||
dynload_cuda)
|
||||
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
|
||||
|
||||
cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
|
||||
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
|
||||
|
||||
if(WITH_GPU)
|
||||
set(multi_devices_graph_builder_deps nccl_all_reduce_op_handle)
|
||||
else()
|
||||
set(multi_devices_graph_builder_deps)
|
||||
endif()
|
||||
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
|
||||
scale_loss_grad_op_handle ${multi_devices_graph_builder_deps})
|
||||
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph)
|
||||
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
|
||||
simple_threadpool device_context)
|
@ -0,0 +1,42 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/computation_op_handle.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
ComputationOpHandle::ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
|
||||
platform::Place place)
|
||||
: op_(framework::OpRegistry::CreateOp(op_desc)),
|
||||
scope_(scope),
|
||||
place_(place) {}
|
||||
|
||||
void ComputationOpHandle::RunImpl() {
|
||||
auto *cur_ctx = dev_ctxes_[place_];
|
||||
for (auto *in : inputs_) {
|
||||
bool need_wait =
|
||||
in->generated_op_ && in->generated_op_->dev_ctxes_[place_] != cur_ctx;
|
||||
if (need_wait) {
|
||||
in->generated_op_->Wait(cur_ctx);
|
||||
}
|
||||
}
|
||||
|
||||
op_->Run(*scope_->FindVar("@TMP_SCOPE@")->Get<Scope *>(), place_);
|
||||
}
|
||||
|
||||
std::string ComputationOpHandle::Name() const { return op_->Type(); }
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,41 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/details/op_handle_base.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
#include "paddle/fluid/platform/device_context.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
struct ComputationOpHandle : public OpHandleBase {
|
||||
std::unique_ptr<OperatorBase> op_;
|
||||
Scope *scope_;
|
||||
platform::Place place_;
|
||||
|
||||
ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
|
||||
platform::Place place);
|
||||
|
||||
std::string Name() const override;
|
||||
|
||||
protected:
|
||||
void RunImpl() override;
|
||||
};
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,79 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/fetch_op_handle.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
FetchOpHandle::FetchOpHandle(FeedFetchList *data, size_t offset,
|
||||
std::vector<Scope *> *local_scopes)
|
||||
: data_(data), offset_(offset), local_scopes_(local_scopes) {}
|
||||
|
||||
FetchOpHandle::~FetchOpHandle() {
|
||||
for (auto *input_var : inputs_) {
|
||||
input_var->pending_ops_.erase(this);
|
||||
}
|
||||
}
|
||||
|
||||
void FetchOpHandle::Wait(platform::DeviceContext *waited_dev) {
|
||||
PADDLE_THROW("Nobody should wait FetchOp. Unexpceted Error");
|
||||
}
|
||||
|
||||
void FetchOpHandle::WaitAndMergeCPUTensors() const {
|
||||
std::vector<const LoDTensor *> tensors_ptr;
|
||||
tensors_ptr.reserve(tensors_.size());
|
||||
for (auto &t : tensors_) {
|
||||
tensors_ptr.emplace_back(&t);
|
||||
}
|
||||
data_->at(offset_).MergeLoDTensor(tensors_ptr, platform::CPUPlace());
|
||||
}
|
||||
|
||||
void FetchOpHandle::RunImpl() {
|
||||
auto cpu_ctx =
|
||||
platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
|
||||
for (auto *input : inputs_) {
|
||||
auto *var = static_cast<VarHandle *>(input);
|
||||
var->generated_op_->Wait(cpu_ctx);
|
||||
}
|
||||
|
||||
tensors_.resize(inputs_.size());
|
||||
auto *var = static_cast<VarHandle *>(inputs_[0]);
|
||||
auto &var_name = var->name_;
|
||||
platform::CPUPlace cpu;
|
||||
auto &scopes = *local_scopes_;
|
||||
|
||||
for (size_t i = 0; i < scopes.size(); ++i) {
|
||||
auto &scope = scopes[i];
|
||||
auto &t = scope->FindVar(var_name)->Get<framework::LoDTensor>();
|
||||
if (platform::is_gpu_place(var->place_)) {
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
TensorCopy(t, cpu, *dev_ctxes_[t.place()], &tensors_[i]);
|
||||
dev_ctxes_[t.place()]->Wait();
|
||||
#endif
|
||||
} else {
|
||||
tensors_[i].ShareDataWith(t);
|
||||
tensors_[i].set_lod(t.lod());
|
||||
}
|
||||
}
|
||||
|
||||
this->WaitAndMergeCPUTensors();
|
||||
}
|
||||
|
||||
std::string FetchOpHandle::Name() const { return "Fetch"; }
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,49 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/details/op_handle_base.h"
|
||||
#include "paddle/fluid/framework/feed_fetch_type.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
#include "paddle/fluid/platform/device_context.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
struct FetchOpHandle : public OpHandleBase {
|
||||
FeedFetchList *data_;
|
||||
size_t offset_;
|
||||
std::vector<Scope *> *local_scopes_;
|
||||
std::vector<LoDTensor> tensors_;
|
||||
|
||||
FetchOpHandle(FeedFetchList *data, size_t offset,
|
||||
std::vector<Scope *> *local_scopes);
|
||||
|
||||
~FetchOpHandle();
|
||||
|
||||
void Wait(platform::DeviceContext *waited_dev) override;
|
||||
|
||||
void WaitAndMergeCPUTensors() const;
|
||||
|
||||
std::string Name() const override;
|
||||
|
||||
protected:
|
||||
void RunImpl() override;
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,174 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
|
||||
#include "paddle/fluid/framework/details/computation_op_handle.h"
|
||||
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
|
||||
#endif
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
|
||||
const std::vector<platform::Place> &places,
|
||||
const std::string &loss_var_name,
|
||||
const std::unordered_set<std::string> ¶ms,
|
||||
const std::vector<Scope *> &local_scopes,
|
||||
platform::NCCLContextMap *nccl_ctxs)
|
||||
: loss_var_name_(loss_var_name),
|
||||
places_(places),
|
||||
local_scopes_(local_scopes),
|
||||
nccl_ctxs_(nccl_ctxs) {
|
||||
#else
|
||||
MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
|
||||
const std::vector<platform::Place> &places,
|
||||
const std::string &loss_var_name,
|
||||
const std::unordered_set<std::string> ¶ms,
|
||||
const std::vector<Scope *> &local_scopes)
|
||||
: loss_var_name_(loss_var_name),
|
||||
places_(places),
|
||||
local_scopes_(local_scopes) {
|
||||
#endif
|
||||
for (auto &p : params) {
|
||||
grad_names_.insert(GradVarName(p));
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
|
||||
const ProgramDesc &program) const {
|
||||
auto graph = new SSAGraph();
|
||||
SSAGraph &result = *graph;
|
||||
result.vars_.resize(places_.size());
|
||||
|
||||
bool is_forwarding = true;
|
||||
for (auto *op : program.Block(0).AllOps()) {
|
||||
bool change_forward = false;
|
||||
if (!is_forwarding) {
|
||||
// FIXME(yy): Do not hard code like this
|
||||
if (op->OutputArgumentNames().size() == 1 &&
|
||||
op->OutputArgumentNames()[0] == GradVarName(loss_var_name_)) {
|
||||
continue; // Drop fill 1. for backward coeff;
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < places_.size(); ++i) {
|
||||
auto &p = places_[i];
|
||||
auto *s = local_scopes_[i];
|
||||
|
||||
result.ops_.emplace_back(new ComputationOpHandle(*op, s, p));
|
||||
auto *op_handle = result.ops_.back().get();
|
||||
op_handle->dev_ctxes_[p] = const_cast<platform::DeviceContext *>(
|
||||
platform::DeviceContextPool::Instance().Get(p));
|
||||
|
||||
auto var_names = op->InputArgumentNames();
|
||||
|
||||
for (auto &each_var_name : var_names) {
|
||||
VarHandle *var =
|
||||
CreateOrGetLatestVarHandle(&result, each_var_name, p, i);
|
||||
op_handle->AddInput(var);
|
||||
}
|
||||
var_names = op->OutputArgumentNames();
|
||||
|
||||
for (auto &each_var_name : var_names) {
|
||||
CreateOpOutput(&result, op_handle, each_var_name, p, i);
|
||||
}
|
||||
|
||||
if (is_forwarding) {
|
||||
if (var_names.size() == 1 && var_names[0] == loss_var_name_) {
|
||||
// Insert ScaleCost OpHandle
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
auto *communication_dev_ctx = nccl_ctxs_->DevCtx(p);
|
||||
#else
|
||||
auto *communication_dev_ctx =
|
||||
platform::DeviceContextPool::Instance().Get(platform::CPUPlace());
|
||||
#endif
|
||||
|
||||
op_handle = new ScaleLossGradOpHandle(local_scopes_.size(), s, p,
|
||||
communication_dev_ctx);
|
||||
result.ops_.emplace_back(op_handle);
|
||||
|
||||
// FIXME: Currently ScaleLossGradOp only use device_count as scale
|
||||
// factor. So it does not depend on any other operators.
|
||||
// VarHandle *loss = GetVarHandle(loss_var_name, place);
|
||||
// loss->pending_ops_.emplace_back(op_handle);
|
||||
// op_handle->inputs_.emplace_back(loss);
|
||||
|
||||
CreateOpOutput(&result, op_handle, GradVarName(loss_var_name_), p, i);
|
||||
change_forward = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (change_forward) {
|
||||
is_forwarding = false;
|
||||
}
|
||||
|
||||
if (!is_forwarding) {
|
||||
auto var_names = op->OutputArgumentNames();
|
||||
for (auto &og : var_names) {
|
||||
if (grad_names_.count(og) != 0) { // is param grad
|
||||
// Insert NCCL AllReduce Op
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
result.ops_.emplace_back(
|
||||
new NCCLAllReduceOpHandle(local_scopes_, places_, *nccl_ctxs_));
|
||||
auto *op_handle = result.ops_.back().get();
|
||||
|
||||
for (size_t i = 0; i < places_.size(); ++i) {
|
||||
auto &p = places_[i];
|
||||
auto &vars = result.vars_[i][og];
|
||||
|
||||
if (vars.empty()) { // This device has no data. continue.
|
||||
continue;
|
||||
}
|
||||
auto *prev_grad = &vars[vars.size() - 1];
|
||||
op_handle->AddInput(prev_grad);
|
||||
|
||||
auto &var = vars[vars.size()];
|
||||
var.place_ = p;
|
||||
var.name_ = og;
|
||||
var.version_ = vars.size() - 1;
|
||||
|
||||
op_handle->AddOutput(&var);
|
||||
}
|
||||
#else
|
||||
PADDLE_ENFORCE("Not implemented");
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
Dependency graph has been constructed. However, there are still data
|
||||
harzaeds need to be handled.
|
||||
*/
|
||||
PolishGraphToSupportDataHazards(&result);
|
||||
|
||||
if (VLOG_IS_ON(10)) {
|
||||
std::ostringstream sout;
|
||||
PrintGraphviz(*graph, sout);
|
||||
VLOG(10) << sout.str();
|
||||
}
|
||||
|
||||
return std::unique_ptr<SSAGraph>(graph);
|
||||
} // namespace details
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,56 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace platform {
|
||||
class NCCLContextMap;
|
||||
}
|
||||
|
||||
namespace framework {
|
||||
class Scope;
|
||||
namespace details {
|
||||
class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
|
||||
public:
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
|
||||
const std::string &loss_var_name,
|
||||
const std::unordered_set<std::string> ¶ms,
|
||||
const std::vector<Scope *> &local_scopes,
|
||||
platform::NCCLContextMap *nccl_ctxs);
|
||||
#else
|
||||
MultiDevSSAGraphBuilder(const std::vector<platform::Place> &places,
|
||||
const std::string &loss_var_name,
|
||||
const std::unordered_set<std::string> ¶ms,
|
||||
const std::vector<Scope *> &local_scopes);
|
||||
#endif
|
||||
|
||||
std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const override;
|
||||
|
||||
private:
|
||||
std::string loss_var_name_;
|
||||
const std::vector<platform::Place> &places_;
|
||||
const std::vector<Scope *> &local_scopes_;
|
||||
std::unordered_set<std::string> grad_names_;
|
||||
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
platform::NCCLContextMap *nccl_ctxs_;
|
||||
#endif
|
||||
};
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,82 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
NCCLAllReduceOpHandle::NCCLAllReduceOpHandle(
|
||||
const std::vector<Scope *> &local_scopes,
|
||||
const std::vector<platform::Place> &places,
|
||||
const platform::NCCLContextMap &ctxs)
|
||||
: local_scopes_(local_scopes), places_(places), nccl_ctxs_(ctxs) {
|
||||
for (auto &p : places_) {
|
||||
this->dev_ctxes_[p] = nccl_ctxs_.DevCtx(p);
|
||||
}
|
||||
}
|
||||
|
||||
void NCCLAllReduceOpHandle::RunImpl() {
|
||||
if (inputs_.size() == 1) {
|
||||
return; // No need to all reduce when GPU count = 1;
|
||||
} else {
|
||||
// Wait input done
|
||||
for (auto *in : inputs_) {
|
||||
auto &p = static_cast<VarHandle *>(in)->place_;
|
||||
in->generated_op_->Wait(dev_ctxes_[p]);
|
||||
}
|
||||
|
||||
auto &var_name = static_cast<VarHandle *>(this->inputs_[0])->name_;
|
||||
int dtype = -1;
|
||||
size_t numel = 0;
|
||||
|
||||
std::vector<std::function<void()>> all_reduce_calls;
|
||||
|
||||
for (size_t i = 0; i < local_scopes_.size(); ++i) {
|
||||
auto &p = places_[i];
|
||||
auto *s = local_scopes_[i];
|
||||
int dev_id = boost::get<platform::CUDAPlace>(p).device;
|
||||
|
||||
auto &lod_tensor = s->FindVar(var_name)->Get<LoDTensor>();
|
||||
void *buffer = const_cast<void *>(lod_tensor.data<void>());
|
||||
|
||||
if (dtype == -1) {
|
||||
dtype = platform::ToNCCLDataType(lod_tensor.type());
|
||||
}
|
||||
|
||||
if (numel == 0) {
|
||||
numel = static_cast<size_t>(lod_tensor.numel());
|
||||
}
|
||||
|
||||
auto &nccl_ctx = nccl_ctxs_.at(dev_id);
|
||||
auto stream = nccl_ctx.stream();
|
||||
auto comm = nccl_ctx.comm_;
|
||||
all_reduce_calls.emplace_back([=] {
|
||||
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
|
||||
buffer, buffer, numel, static_cast<ncclDataType_t>(dtype), ncclSum,
|
||||
comm, stream));
|
||||
});
|
||||
}
|
||||
|
||||
platform::NCCLGroupGuard guard;
|
||||
for (auto &call : all_reduce_calls) {
|
||||
call();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::string NCCLAllReduceOpHandle::Name() const { return "NCCL AllReduce"; }
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,43 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/details/op_handle_base.h"
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
#include "paddle/fluid/platform/nccl_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
struct NCCLAllReduceOpHandle : public OpHandleBase {
|
||||
const std::vector<Scope *> &local_scopes_;
|
||||
const std::vector<platform::Place> &places_;
|
||||
const platform::NCCLContextMap &nccl_ctxs_;
|
||||
|
||||
NCCLAllReduceOpHandle(const std::vector<Scope *> &local_scopes,
|
||||
const std::vector<platform::Place> &places,
|
||||
const platform::NCCLContextMap &ctxs);
|
||||
|
||||
std::string Name() const override;
|
||||
|
||||
protected:
|
||||
void RunImpl() override;
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,102 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/op_handle_base.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
std::string OpHandleBase::DebugString() const {
|
||||
std::stringstream ss;
|
||||
ss << "(";
|
||||
for (auto *var : inputs_) {
|
||||
ss << var->DebugString() << ", ";
|
||||
}
|
||||
ss << ") --> (";
|
||||
for (auto *var : outputs_) {
|
||||
ss << var->DebugString() << ", ";
|
||||
}
|
||||
ss << ")\n";
|
||||
return ss.str();
|
||||
}
|
||||
|
||||
OpHandleBase::~OpHandleBase() {
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
for (auto &ev : events_) {
|
||||
PADDLE_ENFORCE(cudaEventDestroy(ev.second));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void OpHandleBase::Run(bool use_event) {
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
if (events_.empty() && use_event) {
|
||||
for (auto &p : dev_ctxes_) {
|
||||
int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
|
||||
PADDLE_ENFORCE(cudaSetDevice(dev_id));
|
||||
PADDLE_ENFORCE(
|
||||
cudaEventCreateWithFlags(&events_[dev_id], cudaEventDisableTiming));
|
||||
}
|
||||
}
|
||||
#else
|
||||
PADDLE_ENFORCE(!use_event);
|
||||
#endif
|
||||
|
||||
RunImpl();
|
||||
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
if (use_event) {
|
||||
for (auto &p : dev_ctxes_) {
|
||||
int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
|
||||
auto stream =
|
||||
static_cast<platform::CUDADeviceContext *>(p.second)->stream();
|
||||
PADDLE_ENFORCE(cudaEventRecord(events_.at(dev_id), stream));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void OpHandleBase::Wait(platform::DeviceContext *waited_dev) {
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
if (platform::is_cpu_place(waited_dev->GetPlace()) || events_.empty()) {
|
||||
for (auto &dev_ctx : dev_ctxes_) {
|
||||
dev_ctx.second->Wait();
|
||||
}
|
||||
} else {
|
||||
auto stream =
|
||||
static_cast<platform::CUDADeviceContext *>(waited_dev)->stream();
|
||||
for (auto &ev : events_) {
|
||||
PADDLE_ENFORCE(cudaStreamWaitEvent(stream, ev.second, 0));
|
||||
}
|
||||
}
|
||||
#else
|
||||
for (auto &dev_ctx : dev_ctxes_) {
|
||||
dev_ctx.second->Wait();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void OpHandleBase::AddInput(VarHandleBase *in) {
|
||||
this->inputs_.emplace_back(in);
|
||||
in->pending_ops_.insert(this);
|
||||
}
|
||||
|
||||
void OpHandleBase::AddOutput(VarHandleBase *out) {
|
||||
outputs_.emplace_back(out);
|
||||
out->generated_op_ = this;
|
||||
}
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,62 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/details/var_handle.h"
|
||||
#include "paddle/fluid/platform/device_context.h"
|
||||
#include "paddle/fluid/platform/macros.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
class OpHandleBase {
|
||||
private:
|
||||
DISABLE_COPY_AND_ASSIGN(OpHandleBase);
|
||||
|
||||
public:
|
||||
std::vector<VarHandleBase *> inputs_;
|
||||
std::vector<VarHandleBase *> outputs_;
|
||||
std::unordered_map<platform::Place, platform::DeviceContext *,
|
||||
platform::PlaceHash>
|
||||
dev_ctxes_;
|
||||
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
std::unordered_map<int, cudaEvent_t> events_;
|
||||
#endif
|
||||
|
||||
OpHandleBase() {}
|
||||
|
||||
std::string DebugString() const;
|
||||
|
||||
virtual std::string Name() const = 0;
|
||||
|
||||
virtual ~OpHandleBase();
|
||||
|
||||
void Run(bool use_event);
|
||||
|
||||
virtual void Wait(platform::DeviceContext *waited_dev);
|
||||
|
||||
void AddInput(VarHandleBase *in);
|
||||
|
||||
void AddOutput(VarHandleBase *out);
|
||||
|
||||
protected:
|
||||
virtual void RunImpl() = 0;
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,52 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
ScaleLossGradOpHandle::ScaleLossGradOpHandle(size_t num_dev, Scope *scope,
|
||||
platform::Place place,
|
||||
platform::DeviceContext *dev_ctx)
|
||||
: coeff_(static_cast<float>(1.0 / num_dev)), scope_(scope), place_(place) {
|
||||
dev_ctxes_[place_] = dev_ctx;
|
||||
}
|
||||
|
||||
ScaleLossGradOpHandle::~ScaleLossGradOpHandle() {}
|
||||
|
||||
void ScaleLossGradOpHandle::RunImpl() {
|
||||
std::string var_name = static_cast<VarHandle *>(this->outputs_[0])->name_;
|
||||
|
||||
float *tmp =
|
||||
scope_->FindVar(var_name)->GetMutable<LoDTensor>()->mutable_data<float>(
|
||||
make_ddim({1}), place_);
|
||||
|
||||
if (platform::is_cpu_place(place_)) {
|
||||
*tmp = coeff_;
|
||||
} else {
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
auto stream =
|
||||
static_cast<platform::CUDADeviceContext *>(this->dev_ctxes_[place_])
|
||||
->stream();
|
||||
memory::Copy(boost::get<platform::CUDAPlace>(place_), tmp,
|
||||
platform::CPUPlace(), &coeff_, sizeof(float), stream);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
std::string ScaleLossGradOpHandle::Name() const { return "Scale LossGrad"; }
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,43 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/details/op_handle_base.h"
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
#include "paddle/fluid/framework/scope.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
struct ScaleLossGradOpHandle : public OpHandleBase {
|
||||
float coeff_;
|
||||
Scope *scope_;
|
||||
platform::Place place_;
|
||||
|
||||
ScaleLossGradOpHandle(size_t num_dev, Scope *scope, platform::Place place,
|
||||
platform::DeviceContext *context);
|
||||
|
||||
~ScaleLossGradOpHandle() final;
|
||||
|
||||
std::string Name() const override;
|
||||
|
||||
protected:
|
||||
void RunImpl() override;
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,15 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/ssa_graph.h"
|
@ -0,0 +1,35 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/details/op_handle_base.h"
|
||||
#include "paddle/fluid/framework/details/var_handle.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
struct SSAGraph {
|
||||
std::vector<std::unordered_map<std::string, std::map<int, VarHandle>>> vars_;
|
||||
// aux variables to represent dependency. Useful to resolve data hazard.
|
||||
std::unordered_set<std::unique_ptr<VarHandleBase>> dep_vars_;
|
||||
std::vector<std::unique_ptr<OpHandleBase>> ops_;
|
||||
};
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,141 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
void SSAGraphBuilder::PolishGraphToSupportDataHazards(SSAGraph *graph) {
|
||||
for (auto &var_map : graph->vars_) {
|
||||
for (auto &name_pair : var_map) {
|
||||
if (name_pair.second.size() <= 1) {
|
||||
continue;
|
||||
}
|
||||
auto it_new = name_pair.second.rbegin();
|
||||
auto it_old = name_pair.second.rbegin();
|
||||
++it_old;
|
||||
for (; it_old != name_pair.second.rend(); it_new = it_old, ++it_old) {
|
||||
auto *write_op = it_new->second.generated_op_;
|
||||
auto &read_ops = it_old->second.pending_ops_;
|
||||
|
||||
for (auto *read_op : read_ops) {
|
||||
// Manually add a dependency var from read_op to write_op;
|
||||
if (read_op == write_op) {
|
||||
// Read Write is the same op.
|
||||
continue;
|
||||
}
|
||||
|
||||
auto *dep_var = new DummyVarHandle();
|
||||
read_op->AddOutput(dep_var);
|
||||
write_op->AddInput(dep_var);
|
||||
graph->dep_vars_.emplace(dep_var);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
VarHandle *SSAGraphBuilder::CreateOrGetLatestVarHandle(
|
||||
SSAGraph *graph, const std::string &each_var_name,
|
||||
const platform::Place &place, size_t place_offset) {
|
||||
auto &var_holders = graph->vars_[place_offset];
|
||||
auto &var_holder = var_holders[each_var_name];
|
||||
VarHandle *var = nullptr;
|
||||
if (var_holder.empty()) {
|
||||
auto &init_var = var_holder[0];
|
||||
init_var.place_ = place;
|
||||
init_var.name_ = each_var_name;
|
||||
init_var.generated_op_ = nullptr;
|
||||
init_var.version_ = 0;
|
||||
var = &init_var;
|
||||
} else {
|
||||
var = &var_holder.rbegin()->second;
|
||||
}
|
||||
return var;
|
||||
}
|
||||
|
||||
void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
|
||||
const std::string &each_var_name,
|
||||
const platform::Place &place,
|
||||
size_t place_offset) {
|
||||
auto &vars = graph->vars_[place_offset][each_var_name];
|
||||
size_t version = vars.size();
|
||||
auto &var = vars[version];
|
||||
var.version_ = version;
|
||||
var.name_ = each_var_name;
|
||||
var.place_ = place;
|
||||
op_handle->AddOutput(&var);
|
||||
}
|
||||
|
||||
template <typename Callback>
|
||||
void IterAllVar(const SSAGraph &graph, Callback callback) {
|
||||
for (auto &each : graph.vars_) {
|
||||
for (auto &pair1 : each) {
|
||||
for (auto &pair2 : pair1.second) {
|
||||
callback(pair2.second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto &var : graph.dep_vars_) {
|
||||
callback(*var);
|
||||
}
|
||||
}
|
||||
|
||||
void SSAGraphBuilder::PrintGraphviz(const SSAGraph &graph, std::ostream &sout) {
|
||||
size_t var_id = 0;
|
||||
std::unordered_map<const VarHandleBase *, size_t> vars;
|
||||
|
||||
sout << "digraph G {\n";
|
||||
|
||||
IterAllVar(graph, [&](const VarHandleBase &var) {
|
||||
auto *var_ptr = &var;
|
||||
auto *var_handle_ptr = dynamic_cast<const VarHandle *>(var_ptr);
|
||||
auto *dummy_ptr = dynamic_cast<const DummyVarHandle *>(var_ptr);
|
||||
|
||||
size_t cur_var_id = var_id++;
|
||||
vars[var_ptr] = cur_var_id;
|
||||
|
||||
if (var_handle_ptr) {
|
||||
sout << "var_" << cur_var_id << " [label=\"" << var_handle_ptr->name_
|
||||
<< "\\n"
|
||||
<< var_handle_ptr->place_ << "\\n"
|
||||
<< var_handle_ptr->version_ << "\"]" << std::endl;
|
||||
} else if (dummy_ptr) {
|
||||
sout << "var_" << cur_var_id << " [label=\"dummy\"]" << std::endl;
|
||||
}
|
||||
});
|
||||
|
||||
size_t op_id = 0;
|
||||
for (auto &op : graph.ops_) {
|
||||
std::string op_name = "op_" + std::to_string(op_id++);
|
||||
sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]"
|
||||
<< std::endl;
|
||||
for (auto in : op->inputs_) {
|
||||
std::string var_name = "var_" + std::to_string(vars[in]);
|
||||
sout << var_name << " -> " << op_name << std::endl;
|
||||
}
|
||||
|
||||
for (auto out : op->outputs_) {
|
||||
std::string var_name = "var_" + std::to_string(vars[out]);
|
||||
sout << op_name << " -> " << var_name << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
sout << "}\n";
|
||||
}
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,59 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/details/ssa_graph.h"
|
||||
#include "paddle/fluid/framework/program_desc.h"
|
||||
#include "paddle/fluid/platform/place.h"
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
class SSAGraphBuilder {
|
||||
public:
|
||||
SSAGraphBuilder() {}
|
||||
virtual ~SSAGraphBuilder() {}
|
||||
virtual std::unique_ptr<SSAGraph> Build(const ProgramDesc &program) const = 0;
|
||||
|
||||
DISABLE_COPY_AND_ASSIGN(SSAGraphBuilder);
|
||||
|
||||
protected:
|
||||
/**
|
||||
* We only handle write after read(WAR), since it should not have a write
|
||||
* after write in program. If there are write after write operators, we need
|
||||
* prune them.
|
||||
*
|
||||
* https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Write_after_read_(WAR)
|
||||
*/
|
||||
static void PolishGraphToSupportDataHazards(SSAGraph *graph);
|
||||
|
||||
static VarHandle *CreateOrGetLatestVarHandle(SSAGraph *graph,
|
||||
const std::string &each_var_name,
|
||||
const platform::Place &place,
|
||||
size_t place_offset);
|
||||
|
||||
static void CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
|
||||
const std::string &each_var_name,
|
||||
const platform::Place &place, size_t place_offset);
|
||||
|
||||
static void PrintGraphviz(const SSAGraph &graph, std::ostream &sout);
|
||||
};
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,28 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
SSAGraphExecutor::SSAGraphExecutor(std::unique_ptr<SSAGraph> &&graph)
|
||||
: graph_(std::move(graph)) {}
|
||||
|
||||
SSAGraphExecutor::~SSAGraphExecutor() {}
|
||||
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,41 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include "paddle/fluid/framework/details/ssa_graph.h"
|
||||
#include "paddle/fluid/framework/feed_fetch_type.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace details {
|
||||
|
||||
class SSAGraphExecutor {
|
||||
DISABLE_COPY_AND_ASSIGN(SSAGraphExecutor);
|
||||
|
||||
public:
|
||||
// Steal graph inside
|
||||
explicit SSAGraphExecutor(std::unique_ptr<SSAGraph> &&graph);
|
||||
|
||||
virtual ~SSAGraphExecutor();
|
||||
|
||||
virtual FeedFetchList Run(const std::vector<std::string> &fetch_tensors) = 0;
|
||||
|
||||
protected:
|
||||
std::unique_ptr<SSAGraph> graph_;
|
||||
};
|
||||
} // namespace details
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue