diff --git a/CMakeLists.txt b/CMakeLists.txt
index b309ff37e5..5df83499d5 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -16,8 +16,6 @@ cmake_minimum_required(VERSION 3.0)
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
set(PADDLE_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
set(PADDLE_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
-SET(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG")
-SET(CMAKE_C_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG")
include(system)
@@ -201,6 +199,10 @@ if(WITH_GOLANG)
endif(WITH_GOLANG)
set(PADDLE_PYTHON_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/python/build")
+
+SET(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG")
+SET(CMAKE_C_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG")
+
add_subdirectory(paddle)
if(WITH_PYTHON)
add_subdirectory(python)
diff --git a/benchmark/IntelOptimizedPaddle.md b/benchmark/IntelOptimizedPaddle.md
index 8ee7fd28c5..6cc9598947 100644
--- a/benchmark/IntelOptimizedPaddle.md
+++ b/benchmark/IntelOptimizedPaddle.md
@@ -22,6 +22,7 @@ On each machine, we will test and compare the performance of training on single
#### Training
Test on batch size 64, 128, 256 on Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz
+Pay attetion that the speed below includes forward, backward and parameter update time. So we can not directly compare the data with the benchmark of caffe `time` [command](https://github.com/PaddlePaddle/Paddle/blob/develop/benchmark/caffe/image/run.sh#L9), which only contain forward and backward. The updating time of parameter would become very heavy when the weight size are large, especially on alexnet.
Input image size - 3 * 224 * 224, Time: images/second
@@ -55,6 +56,16 @@ Input image size - 3 * 224 * 224, Time: images/second
+- Alexnet
+
+| BatchSize | 64 | 128 | 256 |
+|--------------|--------| ------ | -------|
+| OpenBLAS | 2.13 | 2.45 | 2.68 |
+| MKLML | 66.37 | 105.60 | 144.04 |
+| MKL-DNN | 399.00 | 498.94 | 626.53 |
+
+chart TBD
+
#### Inference
Test on batch size 1, 2, 4, 8, 16 on Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz
- VGG-19
diff --git a/doc/design/mkl/mkldnn_fluid.md b/doc/design/mkl/mkldnn_fluid.md
new file mode 100644
index 0000000000..bef126f3f0
--- /dev/null
+++ b/doc/design/mkl/mkldnn_fluid.md
@@ -0,0 +1,149 @@
+# Design Doc: Add MKLDNN Kernel in Fluid Operator
+
+## Principles
+
+First of all, we should follow some basical principles like:
+1. [How to write a new operator](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/dev/new_op_en.md). We are trying to add a new kind of kernel into operators, so basically we should follow this doc.
+2. [Supporting new Device/Library](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/support_new_device.md). Since MKLDNN is a new library to fluid, we should add `MKLDNNDeviceContext` and maybe `mkldnn_helper.h`, just like [cudnn_helper.h](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/platform/cudnn_helper.h).
+3. [Switch Kernel](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/switch_kernel.md). Another important point is that we should ensure the data synchronization between different kernel types, which is this [topic](https://github.com/PaddlePaddle/Paddle/issues/6549). So basically we should override `GetExpectedKernelType` and `trans` functions to support switching kernels.
+4. [The Keys of Operator Kernel Type](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/operator_kernel_type.md). Kernel Type is a pivotal conception which can record the `Place`, `Library`, `DataType` and `Layout`.
+
+## Sulution
+
+In general, there are four parts we should follow to run a MKL-DNN primitive.
+- Create a primitive descriptor that describe this operator
+- Create a primitive itself by primitive descriptor and the engine
+- Create all memory buffers that primitive needed
+- Launch a stream to execute the primitive created
+More details can refer to [here](http://01org.github.io/mkl-dnn).
+
+It's better to avoid reinitialization of primitives and memory handles in the first three stages in every iteration. \
+So we plan to create a map to record all the `primitive` and `memory`, which should not take too much memories as discussed [here](https://github.com/PaddlePaddle/Paddle/issues/6822).
+
+It's assumed that following three conditions should be satisfied.
+1. there is a unique key for each operator instance. May be the actual name of `Output Tensor`.
+2. the `Input Tensor` inside `Compute` function is the one after converted.
+3. we can get the phase(eg. `is_test`) inside `Compute` function, otherwise we need to expose this attribue to user.
+
+### Compute
+The algorithm of `Compute` would be described as follow, let's take conv like an example.
+
+```c++
+
+ PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), "It must use CPUPlace.");
+ PADDLE_ENFORCE(platform::is_mkldnn_library(ctx.GetLibrary()), "It must use MKLDNN Library.");
+
+ auto& dev_ctx = ctx.template device_context();
+
+ // find primitive by unique key from mkldnn context
+ // the op_key should be a unique name of this op instance
+ auto& p = dev_ctx.findPrimitive(op_key + "_fwd");
+
+ // assuming the input tensor inside this compute function is the one after converted
+ // this point should be guarantee by another mechanism
+ auto& i = dev_ctx.findMemory(op_key + "_input");
+
+ if (p == nullptr || i == nullptr || inputSizeChanged(p, i)) {
+ auto fwd_primitive_desc = createPrimitiveDesc(ctx);
+ auto* input = ctx.Input("Input");
+ auto* filter = ctx.Input("Filter");
+ auto* output = ctx.Output("Output");
+ shared_ptr in(new mkldnn::memory(fwd_primitive_desc->src_primitive_desc(), input->data()));
+ shared_ptr wgt(new mkldnn::memory(fwd_primitive_desc->weights_primitive_desc(), filter->data()));
+ shared_ptr out(new mkldnn::memory(fwd_primitive_desc->dst_primitive_desc(), output->mutable_data(ctx.GetPlace())));
+ shared_ptr fwd_primitive(new mkldnn::conv_fwd(*fwd_primitive_desc, *in, *wgt, *out));
+
+ dev_ctx.addMemory(op_key+"_input", in);
+ dev_ctx.addMemory(op_key+"_output", out);
+ dev_ctx.addMemory(op_key+"_filer", wgt);
+ dev_ctx.addPrimitive(op_key+"_fwd", fwd_primitive);
+ dev_ctx.addPrimitiveDesc(op_key+"_fwd_PD", fwd_primitive_desc);
+ }
+
+ p = dev_ctx.findPrimitive(op_key + "_fwd");
+
+ PADDLE_ENFORCE(p, "Should have forward Primitive");
+ PADDLE_ENFORCE(dev_ctx.findMemory(op_unique_key+"_input"), "Should have input memory");
+ PADDLE_ENFORCE(dev_ctx.findMemory(op_unique_key+"_output"), "Should have output memory");
+ PADDLE_ENFORCE(dev_ctx.findMemory(op_unique_key+"_filter"), "Should have filter memory");
+ PADDLE_ENFORCE(dev_ctx.findPrimitiveDesc(op_unique_key+"_fwd_PD"), "Should have forward PrimitiveDesc");
+ dev_ctx.submit(p);
+ dev_ctx.execute(); // the convert primitive should have already contained.
+
+```
+
+The `createPrimitiveDesc` returns the primitive descripotor of this operator, would be like this:
+```c++
+ auto* input = ctx.Input("Input");
+ auto* filter = ctx.Input("Filter");
+ auto* output = ctx.Output("Output");
+ std::vector strides = ctx.Attr>("strides");
+ std::vector paddings = ctx.Attr>("paddings");
+ std::vector dilations = ctx.Attr>("dilations");
+ int groups = ctx.Attr("groups");
+ algorithm algo = static_cast(ctx.Attr("convolution_algorithm_option"));
+ prop_kind pk = ctx.Attr("is_test") ? prop_kind::forward_inference : prop_kind::forward_training;
+
+ auto fwd_desc = mkldnn::conv_fwd::desc(/* all the setting above*/);
+ shared_ptr fwd_primitive_desc(new mkldnn::conv_fwd::primitive_desc(fwd_desc, ctx.getEngine()));
+
+ return fwd_primitive_desc;
+ }
+```
+
+### MKLDNNDeviceContext
+`MKLDNNDeviceContext`, which is very straightforward, should contain some base information like: `stream`, `engine` and the map needed.
+
+
+### mkldnn_helper
+Some functions would be put in `paddle/platform/mkldnn_helper.h`.
+- create MKLDNN memories
+- create MKLDNN primitives
+- error check function
+- etc
+
+
+### Kernel Switch
+We should `reorder` the different Layout from other device or to other device. `GetExpectedKernelType` and `trans` functions can help us to implement it.
+
+`GetExpectedKernelType` should get the context, and this operator can return the best `KernelType`.
+`trans` would be like this:
+
+```c++
+void trans(inputs, ctx) override {
+ if (NoNeedTrans()) {
+ return;
+ }
+ // find reorder primitive by op_key from context
+ auto& dev_ctx = ctx.template device_context();
+ auto& p = dev_ctx.findPrimitive(op_key + "_reorder_input");
+ auto& i = dev_ctx.findMemory(op_key + "_src_input");
+
+ if (p == nullptr || i == nullptr || changeSized(i, input)) {
+ auto prim = createPrimitiveDesc(ctx);
+ auto src = createMemory(memoryDesc(input->dims(), actual_layout), input->data);
+ auto newbuffer = paddle::memory::Alloc(ctx.GetPlace(), input->size_in_bytes());
+ auto dst = createMemory(p->expected_desc(), newbuffer->data);
+ auto reorder_primitive(new mkldnn::reorder(src, dst));
+
+ dev_ctx.addMemory(op_key+"_src_input", src);
+ dev_ctx.addMemory(op_key+"_input", dst);
+ dev_ctx.addPrimitive(op_key+"_reorder_input", reorder_primitive);
+ }
+
+ p = dev_ctx.findPrimitive(op_key + "_reorder_input");
+ PADDLE_ENFORCE(p, "Should have Reorder Primitive");
+ dev_ctx.submit(p);
+ if (! this->isMKLDNNKernel()) {
+ // execute immediately only if this is not mkldnn kernel function.
+ // otherwise, it can be executed with the operator primitive in Compute
+ dev_ctx.stream();
+ }
+ // after submit, the input tensor in ExecutionContext should be changed as the converted one
+ // there should be another mechanism to ensure this
+}
+```
+
+### Unit Test
+All the functions should be tested corresponding.
+TBD
diff --git a/doc/design/support_new_device.md b/doc/design/support_new_device.md
index fd23dc211a..f54b2b3694 100644
--- a/doc/design/support_new_device.md
+++ b/doc/design/support_new_device.md
@@ -25,13 +25,14 @@ There are mainly three parts that we have to consider while integrating a new de
### Place and DeviceContext
+Please remind that device and computing library are not one-to-one corresponding. A device can have a lot of computing libraries and a computing library can also support several devices.
#### Place
-Fluid uses class [Place](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/platform/place.h#L55) to represent different devices and computing libraries. There are inheritance relationships between different kinds of `Place`.
+Fluid uses class [Place](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/platform/place.h#L55) to represent the device memory where data is located. If we add another device, we have to add corresponding `DevicePlace`.
```
- | CPUPlace --> MKLDNNPlace
-Place --| CUDAPlace --> CUDNNPlace
+ | CPUPlace
+Place --| CUDAPlace
| FPGAPlace
```
@@ -43,7 +44,7 @@ typedef boost::variant Place;
#### DeviceContext
-Fluid uses class [DeviceContext](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/platform/device_context.h#L30) to manage the resources in different hardwares, such as CUDA stream in `CDUADeviceContext`. There are also inheritance relationships between different kinds of `DeviceContext`.
+Fluid uses class [DeviceContext](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/platform/device_context.h#L30) to manage the resources in different libraries, such as CUDA stream in `CDUADeviceContext`. There are also inheritance relationships between different kinds of `DeviceContext`.
```
@@ -106,7 +107,7 @@ template
size_t Used(Place place);
```
-To implementing these interfaces, we have to implement MemoryAllocator for different Devices
+To implement these interfaces, we have to implement MemoryAllocator for different Devices.
#### Tensor
@@ -243,6 +244,7 @@ REGISTER_OP_CUDA_KERNEL(
Generally, we will impelement OpKernel for all Device/Library of an Operator. We can easily train a Convolutional Neural Network in GPU. However, some OpKernel is not sutibale on a specific Device. For example, crf operator can only run on CPU, whereas most other operators can run at GPU. To achieve high performance in such circumstance, we have to switch between different Device/Library.
-We will discuss how to implement an efficient OpKernel switch policy.
+For more details, please refer to following docs:
-- TBD
+- operator kernel type [doc](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/operator_kernel_type.md)
+- switch kernel [doc](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/switch_kernel.md)
diff --git a/doc/faq/build_and_install/index_cn.rst b/doc/faq/build_and_install/index_cn.rst
index a2bdeead78..ed8a0c7e87 100644
--- a/doc/faq/build_and_install/index_cn.rst
+++ b/doc/faq/build_and_install/index_cn.rst
@@ -109,3 +109,31 @@ PaddlePaddle使用avx SIMD指令提高cpu执行效率,因此错误的使用二
解决办法是:
* 卸载PaddlePaddle包 :code:`pip uninstall paddle`, 清理掉老旧的PaddlePaddle安装包,使得单元测试有一个干净的环境。如果PaddlePaddle包已经在python的site-packages里面,单元测试会引用site-packages里面的python包,而不是源码目录里 :code:`/python` 目录下的python包。同时,即便设置 :code:`PYTHONPATH` 到 :code:`/python` 也没用,因为python的搜索路径是优先已经安装的python包。
+
+8. 下载MKLML库失败
+------------------
+
+.. code-block:: bash
+
+ make[2]: *** [third_party/mklml/src/extern_mklml-stamp/extern_mklml-download] 错误 4
+ make[1]: *** [CMakeFiles/extern_mklml.dir/all] 错误 2
+ make[1]: *** 正在等待未完成的任务....
+
+原因:网速或SSL链接原因,导致MKLML库下载不成功。
+
+解决办法是:手动下载并安装,具体步骤如下。
+
+.. code-block:: bash
+
+ // 1. 进入对应的目录
+ cd build/third_party/mklml/src/extern_mklml
+
+ // 2. 查看包的大小, 正常情况下是75M,如果小于75M,即下载失败:
+ du -sh mklml_lnx_2018.0.1.20171007.tgz
+
+ // 3. 手动下载且解压缩,并手动生成download成功标签:
+ wget --no-check-certificate https://github.com/01org/mkl-dnn/releases/download/v0.11/mklml_lnx_2018.0.1.20171007.tgz -c -O mklml_lnx_2018.0.1.20171007.tgz
+ tar zxf mklml_lnx_2018.0.1.20171007.tgz
+ touch ../extern_mklml-stamp/extern_mklml-download
+
+ // 4. 接着编译即可
diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt
index be9c01fb04..c2a57a95ee 100644
--- a/paddle/framework/CMakeLists.txt
+++ b/paddle/framework/CMakeLists.txt
@@ -59,5 +59,9 @@ cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
cc_library(selected_rows SRCS selected_rows.cc DEPS tensor)
cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows)
+cc_library(threadpool SRCS threadpool.cc)
+cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool)
cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece)
cc_test(init_test SRCS init_test.cc DEPS init)
+
+cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context)
diff --git a/paddle/framework/data_layout.h b/paddle/framework/data_layout.h
index 7429de7ee3..4a8669c3a4 100644
--- a/paddle/framework/data_layout.h
+++ b/paddle/framework/data_layout.h
@@ -13,11 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
+#include "paddle/platform/enforce.h"
+
+#include
+#include "paddle/platform/enforce.h"
namespace paddle {
namespace framework {
-enum DataLayout {
+enum class DataLayout {
kNHWC = 0,
kNCHW = 1,
kAnyLayout = 2,
@@ -33,5 +37,23 @@ inline DataLayout StringToDataLayout(const std::string& str) {
}
}
+inline std::string DataLayoutToString(const DataLayout& data_layout) {
+ switch (data_layout) {
+ case DataLayout::kNHWC:
+ return "NHWC";
+ case DataLayout::kNCHW:
+ return "NCHW";
+ case DataLayout::kAnyLayout:
+ return "ANY_LAYOUT";
+ default:
+ PADDLE_THROW("unknown DataLayou %d", data_layout);
+ }
+}
+
+inline std::ostream& operator<<(std::ostream& out, DataLayout l) {
+ out << DataLayoutToString(l);
+ return out;
+}
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/init.cc b/paddle/framework/init.cc
index 4deb4fa903..3ff2da3446 100644
--- a/paddle/framework/init.cc
+++ b/paddle/framework/init.cc
@@ -54,7 +54,7 @@ bool InitDevices(const std::vector &devices) {
#ifdef PADDLE_WITH_CUDA
auto pos = string::RFind(p, ':', string::Piece::npos);
auto number = device.substr(pos + 1);
- places.emplace_back(platform::GPUPlace(std::stoi(number)));
+ places.emplace_back(platform::CUDAPlace(std::stoi(number)));
#else
LOG(WARNING)
<< "'GPU' is not supported, Please re-compile with WITH_GPU option";
diff --git a/paddle/framework/library_type.h b/paddle/framework/library_type.h
index 49b273656b..6baae6c2bb 100644
--- a/paddle/framework/library_type.h
+++ b/paddle/framework/library_type.h
@@ -20,7 +20,25 @@ namespace framework {
// For more details about the design of LibraryType, Please refer to
// https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/operator_kernel_type.md#library
-enum LibraryType { kPlain = 0, kMKLDNN = 1, kCUDNN = 2 };
+enum class LibraryType { kPlain = 0, kMKLDNN = 1, kCUDNN = 2 };
+
+inline std::string LibraryTypeToString(const LibraryType& library_type) {
+ switch (library_type) {
+ case LibraryType::kPlain:
+ return "PLAIN";
+ case LibraryType::kMKLDNN:
+ return "MKLDNN";
+ case LibraryType::kCUDNN:
+ return "CUDNN";
+ default:
+ PADDLE_THROW("unknown LibraryType %d", library_type);
+ }
+}
+
+inline std::ostream& operator<<(std::ostream& out, LibraryType l) {
+ out << LibraryTypeToString(l);
+ return out;
+}
} // namespace
} // framework
diff --git a/paddle/framework/lod_tensor.cc b/paddle/framework/lod_tensor.cc
index 465f8c62b5..d766d3c416 100644
--- a/paddle/framework/lod_tensor.cc
+++ b/paddle/framework/lod_tensor.cc
@@ -224,7 +224,7 @@ void SerializeToStream(std::ostream &os, const LoDTensor &tensor,
while (size != 0) {
size_t size_to_write = std::min(kBufSize, static_cast(size));
memory::Copy(cpu, buf.get(),
- boost::get(tensor.place()),
+ boost::get(tensor.place()),
reinterpret_cast(data), size_to_write,
gpu_dev_ctx.stream());
gpu_dev_ctx.Wait();
diff --git a/paddle/framework/lod_tensor_test.cu b/paddle/framework/lod_tensor_test.cu
index 5b90fbfca7..e8508ad265 100644
--- a/paddle/framework/lod_tensor_test.cu
+++ b/paddle/framework/lod_tensor_test.cu
@@ -27,7 +27,7 @@ __global__ void test(size_t* a, int size) {
TEST(LoDTensor, LoDInGPU) {
paddle::framework::LoDTensor lod_tensor;
- paddle::platform::GPUPlace place(0);
+ paddle::platform::CUDAPlace place(0);
paddle::framework::LoD src_lod;
src_lod.push_back(std::vector{0, 2, 4, 6, 8, 10, 12, 14});
diff --git a/paddle/framework/op_kernel_type.h b/paddle/framework/op_kernel_type.h
index a1dea0d9d8..97b542e345 100644
--- a/paddle/framework/op_kernel_type.h
+++ b/paddle/framework/op_kernel_type.h
@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/framework/data_layout.h"
#include "paddle/framework/data_type.h"
#include "paddle/framework/library_type.h"
+#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
namespace paddle {
@@ -39,6 +40,7 @@ struct OpKernelType {
// place, data_type, library_type kinds less than 2^8
constexpr static int LEFT_SHIFT = 8;
+
proto::DataType data_type_;
DataLayout data_layout_;
platform::Place place_;
@@ -68,5 +70,13 @@ struct OpKernelType {
}
};
+inline std::ostream& operator<<(std::ostream& os,
+ const OpKernelType& kernel_key) {
+ os << "data_type[" << kernel_key.data_type_ << "]:data_layout["
+ << kernel_key.data_layout_ << "]:place[" << kernel_key.place_
+ << "]:library_type[" << kernel_key.library_type_ << "]";
+ return os;
+}
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/op_kernel_type_test.cc b/paddle/framework/op_kernel_type_test.cc
new file mode 100644
index 0000000000..8753d7cc37
--- /dev/null
+++ b/paddle/framework/op_kernel_type_test.cc
@@ -0,0 +1,51 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/framework/op_kernel_type.h"
+#include
+#include
+
+TEST(OpKernelType, ToString) {
+ using OpKernelType = paddle::framework::OpKernelType;
+ using DataType = paddle::framework::proto::DataType;
+ using CPUPlace = paddle::platform::CPUPlace;
+ using DataLayout = paddle::framework::DataLayout;
+ using LibraryType = paddle::framework::LibraryType;
+
+ OpKernelType op_kernel_type(DataType::FP32, CPUPlace(), DataLayout::kNCHW,
+ LibraryType::kCUDNN);
+
+ std::ostringstream stream;
+ stream << op_kernel_type;
+ ASSERT_EQ(
+ stream.str(),
+ "data_type[5]:data_layout[NCHW]:place[CPUPlace]:library_type[CUDNN]");
+}
+
+TEST(OpKernelType, Hash) {
+ using OpKernelType = paddle::framework::OpKernelType;
+ using DataType = paddle::framework::proto::DataType;
+ using CPUPlace = paddle::platform::CPUPlace;
+ using CUDAPlace = paddle::platform::CUDAPlace;
+ using DataLayout = paddle::framework::DataLayout;
+ using LibraryType = paddle::framework::LibraryType;
+
+ OpKernelType op_kernel_type_1(DataType::FP32, CPUPlace(), DataLayout::kNCHW,
+ LibraryType::kCUDNN);
+ OpKernelType op_kernel_type_2(DataType::FP32, CUDAPlace(0), DataLayout::kNCHW,
+ LibraryType::kCUDNN);
+
+ OpKernelType::Hash hasher;
+ ASSERT_NE(hasher(op_kernel_type_1), hasher(op_kernel_type_2));
+}
\ No newline at end of file
diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h
index 244c117465..9bb2a3b5c2 100644
--- a/paddle/framework/op_registry.h
+++ b/paddle/framework/op_registry.h
@@ -188,7 +188,7 @@ class OpKernelRegistrar : public Registrar {
}
#define REGISTER_OP_CUDA_KERNEL(op_type, ...) \
- REGISTER_OP_KERNEL(op_type, CUDA, ::paddle::platform::GPUPlace, __VA_ARGS__)
+ REGISTER_OP_KERNEL(op_type, CUDA, ::paddle::platform::CUDAPlace, __VA_ARGS__)
#define REGISTER_OP_CPU_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc
index 06184f6ba9..66840a2e03 100644
--- a/paddle/framework/operator.cc
+++ b/paddle/framework/operator.cc
@@ -242,13 +242,6 @@ std::vector ExecutionContext::MultiOutput(
return res;
}
-std::ostream& operator<<(std::ostream& os, const OpKernelType& kernel_key) {
- os << "data_type[" << kernel_key.data_type_ << "]:data_layout["
- << kernel_key.data_layout_ << "]:place[" << kernel_key.place_
- << "]:library_type[" << kernel_key.library_type_ << "]";
- return os;
-}
-
bool OpSupportGPU(const std::string& op_type) {
auto& all_kernels = OperatorWithKernel::AllOpKernels();
auto it = all_kernels.find(op_type);
@@ -409,19 +402,28 @@ void OperatorWithKernel::Run(const Scope& scope,
OpKernelMap& kernels = kernels_iter->second;
ExecutionContext ctx(*this, scope, *dev_ctx);
- auto kernel_key = GetKernelType(ctx);
- auto kernel_iter = kernels.find(kernel_key);
+ auto actual_kernel_key = GetActualKernelType(ctx);
+ auto expected_kernel_key = GetExpectedKernelType(actual_kernel_key);
+ auto kernel_iter = kernels.find(expected_kernel_key);
if (kernel_iter == kernels.end()) {
- PADDLE_THROW("The operator %s does not support %s", type_, kernel_key);
+ PADDLE_THROW("The operator %s does not support %s", type_,
+ expected_kernel_key);
}
kernel_iter->second->Compute(ctx);
}
-OpKernelType OperatorWithKernel::GetKernelType(
+
+OpKernelType OperatorWithKernel::GetActualKernelType(
const ExecutionContext& ctx) const {
return OpKernelType(IndicateDataType(ctx), ctx.GetPlace());
}
+
+OpKernelType OperatorWithKernel::GetExpectedKernelType(
+ const OpKernelType& actual_kernel_type) const {
+ return actual_kernel_type;
+}
+
proto::DataType OperatorWithKernel::IndicateDataType(
const ExecutionContext& ctx) const {
auto& scope = ctx.scope();
diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h
index aba34c5bcb..55eed57e66 100644
--- a/paddle/framework/operator.h
+++ b/paddle/framework/operator.h
@@ -52,6 +52,11 @@ constexpr char kGradVarSuffix[] = "@GRAD";
/// Variables with this suffix are supposed to be filled up with zeros.
constexpr char kZeroVarSuffix[] = "@ZERO";
+// define some kernel hint
+const std::string kUseCPU = "use_cpu";
+const std::string kUseCUDNN = "use_cudnn";
+const std::string kUseMKLDNN = "use_mkldnn";
+
inline std::string GradVarName(const std::string& var_name) {
return var_name + kGradVarSuffix;
}
@@ -373,7 +378,9 @@ class OperatorWithKernel : public OperatorBase {
}
protected:
- virtual OpKernelType GetKernelType(const ExecutionContext& ctx) const;
+ virtual OpKernelType GetActualKernelType(const ExecutionContext& ctx) const;
+ virtual OpKernelType GetExpectedKernelType(
+ const OpKernelType& actual_kernel_type) const;
private:
// indicate kernel DataType by input data. Defaultly all input data must be
@@ -381,8 +388,6 @@ class OperatorWithKernel : public OperatorBase {
proto::DataType IndicateDataType(const ExecutionContext& ctx) const;
};
-std::ostream& operator<<(std::ostream& os, const OpKernelType& kernel_key);
-
extern bool OpSupportGPU(const std::string& op_type);
} // namespace framework
diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc
index fbca45b59d..4d38a7ada9 100644
--- a/paddle/framework/operator_test.cc
+++ b/paddle/framework/operator_test.cc
@@ -114,7 +114,7 @@ class OpWithKernelTest : public OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {}
- OpKernelType GetKernelType(const ExecutionContext& ctx) const override {
+ OpKernelType GetActualKernelType(const ExecutionContext& ctx) const override {
return OpKernelType(proto::DataType::FP32, ctx.GetPlace());
}
};
diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h
index 6a0c5133c9..b9f6884f7c 100644
--- a/paddle/framework/tensor.h
+++ b/paddle/framework/tensor.h
@@ -20,12 +20,12 @@ limitations under the License. */
#include
#include
+#include "paddle/framework/data_layout.h"
#include "paddle/framework/ddim.h"
#include "paddle/memory/memory.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"
-#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle {
@@ -115,6 +115,10 @@ class Tensor {
inline void check_memory_size() const;
+ inline DataLayout layout() const { return layout_; }
+
+ inline void set_layout(const DataLayout layout) { layout_ = layout; }
+
private:
friend class LoDTensor;
@@ -173,6 +177,19 @@ class Tensor {
DDim dims_;
+ /**
+ * @brief the layout of memory block, default is NCHW.
+ *
+ * @note the memory allocation order, describe how weight/data is stored
+ * For example, in 4-D Tensor(rank=4), there are three commonly
+ * used layout. They are
+ * NCHW, NHWC, CHWN.
+ * N,C,H,W for respectively the batch size, the number of
+ * feature maps, the height.
+ */
+
+ DataLayout layout_ = DataLayout::kNHWC;
+
/**
* @brief A PlaceHolder may be shared by more than one tensor.
*
diff --git a/paddle/framework/tensor.md b/paddle/framework/tensor.md
index 7a80816d8e..0a27ac9bb6 100644
--- a/paddle/framework/tensor.md
+++ b/paddle/framework/tensor.md
@@ -71,7 +71,7 @@ private:
```
```c++
-typedef boost::variant Place;
+typedef boost::variant Place;
typedef boost::variant, Dim<2>, Dim<3>, Dim<4>, Dim<5>,
Dim<6>, Dim<7>, Dim<8>, Dim<9>> DDimVar;
typedef boost::variant<
diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h
index aba1f9f093..6c6f298edc 100644
--- a/paddle/framework/tensor_impl.h
+++ b/paddle/framework/tensor_impl.h
@@ -125,11 +125,11 @@ inline void* Tensor::mutable_data(platform::Place place, std::type_index type) {
boost::get(place), size, type));
} else if (platform::is_gpu_place(place)) {
#ifndef PADDLE_WITH_CUDA
- PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
+ PADDLE_THROW("'CUDAPlace' is not supported in CPU only device.");
}
#else
- holder_.reset(new PlaceholderImpl(
- boost::get(place), size, type));
+ holder_.reset(new PlaceholderImpl(
+ boost::get(place), size, type));
}
#endif
offset_ = 0;
@@ -165,6 +165,7 @@ inline Tensor Tensor::Slice(int begin_idx, int end_idx) const {
size_t base = numel() / dims_[0];
Tensor dst;
dst.holder_ = holder_;
+ dst.set_layout(layout_);
DDim dst_dims = dims_;
dst_dims[0] = end_idx - begin_idx;
dst.Resize(dst_dims);
diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc
index ceca64365a..ca76a9fcb9 100644
--- a/paddle/framework/tensor_test.cc
+++ b/paddle/framework/tensor_test.cc
@@ -80,20 +80,20 @@ TEST(Tensor, MutableData) {
float* p1 = nullptr;
float* p2 = nullptr;
// initialization
- p1 = src_tensor.mutable_data(make_ddim({1, 2, 3}), GPUPlace());
+ p1 = src_tensor.mutable_data(make_ddim({1, 2, 3}), CUDAPlace());
EXPECT_NE(p1, nullptr);
// set src_tensor a new dim with large size
// momery is supposed to be re-allocated
- p2 = src_tensor.mutable_data(make_ddim({3, 4}), GPUPlace());
+ p2 = src_tensor.mutable_data(make_ddim({3, 4}), CUDAPlace());
EXPECT_NE(p2, nullptr);
EXPECT_NE(p1, p2);
// set src_tensor a new dim with same size
// momery block is supposed to be unchanged
- p1 = src_tensor.mutable_data(make_ddim({2, 2, 3}), GPUPlace());
+ p1 = src_tensor.mutable_data(make_ddim({2, 2, 3}), CUDAPlace());
EXPECT_EQ(p1, p2);
// set src_tensor a new dim with smaller size
// momery block is supposed to be unchanged
- p2 = src_tensor.mutable_data(make_ddim({2, 2}), GPUPlace());
+ p2 = src_tensor.mutable_data(make_ddim({2, 2}), CUDAPlace());
EXPECT_EQ(p1, p2);
}
#endif
@@ -130,7 +130,7 @@ TEST(Tensor, ShareDataWith) {
{
Tensor src_tensor;
Tensor dst_tensor;
- src_tensor.mutable_data(make_ddim({2, 3, 4}), GPUPlace());
+ src_tensor.mutable_data(make_ddim({2, 3, 4}), CUDAPlace());
dst_tensor.ShareDataWith(src_tensor);
ASSERT_EQ(src_tensor.data(), dst_tensor.data());
}
@@ -166,7 +166,7 @@ TEST(Tensor, Slice) {
#ifdef PADDLE_WITH_CUDA
{
Tensor src_tensor;
- src_tensor.mutable_data(make_ddim({6, 9}), GPUPlace());
+ src_tensor.mutable_data(make_ddim({6, 9}), CUDAPlace());
Tensor slice_tensor = src_tensor.Slice(2, 6);
DDim slice_dims = slice_tensor.dims();
ASSERT_EQ(arity(slice_dims), 2);
@@ -176,11 +176,11 @@ TEST(Tensor, Slice) {
uintptr_t src_data_address =
reinterpret_cast(src_tensor.data());
uintptr_t src_mutable_data_address = reinterpret_cast(
- src_tensor.mutable_data(src_tensor.dims(), GPUPlace()));
+ src_tensor.mutable_data(src_tensor.dims(), CUDAPlace()));
uintptr_t slice_data_address =
reinterpret_cast(slice_tensor.data());
uintptr_t slice_mutable_data_address = reinterpret_cast(
- slice_tensor.mutable_data(slice_tensor.dims(), GPUPlace()));
+ slice_tensor.mutable_data(slice_tensor.dims(), CUDAPlace()));
EXPECT_EQ(src_data_address, src_mutable_data_address);
EXPECT_EQ(slice_data_address, slice_mutable_data_address);
EXPECT_EQ(src_data_address + 9 * 2 * sizeof(double), slice_data_address);
@@ -200,3 +200,12 @@ TEST(Tensor, ReshapeToMatrix) {
ASSERT_EQ(res.dims()[0], 2 * 3);
ASSERT_EQ(res.dims()[1], 4 * 9);
}
+
+TEST(Tensor, Layout) {
+ using namespace paddle::framework;
+ using namespace paddle::platform;
+ Tensor src;
+ ASSERT_EQ(src.layout(), DataLayout::kNHWC);
+ src.set_layout(DataLayout::kAnyLayout);
+ ASSERT_EQ(src.layout(), DataLayout::kAnyLayout);
+}
diff --git a/paddle/framework/tensor_util.h b/paddle/framework/tensor_util.h
index 4e34b90d57..692f5f1af7 100644
--- a/paddle/framework/tensor_util.h
+++ b/paddle/framework/tensor_util.h
@@ -33,6 +33,7 @@ inline void CopyFrom(const Tensor& src, const platform::Place& dst_place,
src.check_memory_size();
dst->Resize(src.dims());
+ dst->set_layout(src.layout());
auto src_place = src.place();
auto src_ptr = src.data();
@@ -47,11 +48,11 @@ inline void CopyFrom(const Tensor& src, const platform::Place& dst_place,
#ifdef PADDLE_WITH_CUDA
else if (platform::is_gpu_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) {
- auto src_gpu_place = boost::get(src_place);
+ auto src_gpu_place = boost::get(src_place);
auto dst_cpu_place = boost::get(dst_place);
auto ctx_place = ctx.GetPlace();
PADDLE_ENFORCE(platform::is_gpu_place(ctx_place));
- auto ctx_gpu_place = boost::get(ctx_place);
+ auto ctx_gpu_place = boost::get(ctx_place);
PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place);
memory::Copy(
dst_cpu_place, dst_ptr, src_gpu_place, src_ptr, size,
@@ -59,21 +60,21 @@ inline void CopyFrom(const Tensor& src, const platform::Place& dst_place,
} else if (platform::is_cpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
auto src_cpu_place = boost::get(src_place);
- auto dst_gpu_place = boost::get(dst_place);
+ auto dst_gpu_place = boost::get(dst_place);
auto ctx_place = ctx.GetPlace();
PADDLE_ENFORCE(platform::is_gpu_place(ctx_place));
- auto ctx_gpu_place = boost::get(ctx_place);
+ auto ctx_gpu_place = boost::get(ctx_place);
PADDLE_ENFORCE_EQ(dst_gpu_place, ctx_gpu_place);
memory::Copy(
dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size,
reinterpret_cast(ctx).stream());
} else if (platform::is_gpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
- auto src_gpu_place = boost::get(src_place);
- auto dst_gpu_place = boost::get(dst_place);
+ auto src_gpu_place = boost::get(src_place);
+ auto dst_gpu_place = boost::get(dst_place);
auto ctx_place = ctx.GetPlace();
PADDLE_ENFORCE(platform::is_gpu_place(ctx_place));
- auto ctx_gpu_place = boost::get(ctx_place);
+ auto ctx_gpu_place = boost::get(ctx_place);
PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place);
memory::Copy(
dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size,
@@ -82,6 +83,29 @@ inline void CopyFrom(const Tensor& src, const platform::Place& dst_place,
#endif
}
+/**
+ * @brief CopyFrom support CPU <-> CPU
+ */
+inline void CopyFrom(const Tensor& src, const platform::Place& dst_place,
+ Tensor* dst) {
+ src.check_memory_size();
+ dst->Resize(src.dims());
+ dst->set_layout(src.layout());
+
+ auto src_place = src.place();
+ auto src_ptr = src.data();
+
+ auto dst_ptr = dst->mutable_data(dst_place, src.type());
+
+ auto size = src.numel() * SizeOfType(src.type());
+
+ PADDLE_ENFORCE(platform::is_cpu_place(src_place) &&
+ platform::is_cpu_place(dst_place));
+
+ memory::Copy(boost::get(dst_place), dst_ptr,
+ boost::get(src_place), src_ptr, size);
+}
+
/**
* @brief Copy the content of an external vector to a tensor.
*
@@ -108,13 +132,28 @@ inline void CopyFromVector(const std::vector& src,
#ifdef PADDLE_WITH_CUDA
else if (platform::is_gpu_place(dst_place)) { // NOLINT
memory::Copy(
- boost::get(dst_place), dst_ptr, src_place, src_ptr,
+ boost::get(dst_place), dst_ptr, src_place, src_ptr,
size,
reinterpret_cast(ctx).stream());
}
#endif
}
+/**
+ * @brief CopyFromVector CPU vector -> CPU Tensor
+ */
+template
+inline void CopyFromVector(const std::vector& src, Tensor* dst) {
+ platform::CPUPlace dst_place = platform::CPUPlace();
+ auto src_ptr = static_cast(src.data());
+ platform::CPUPlace src_place;
+ dst->Resize({static_cast(src.size())});
+ auto dst_ptr = static_cast(dst->mutable_data(dst_place));
+ auto size = src.size() * sizeof(T);
+
+ memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
+}
+
/**
* @brief Copy the content of a tensor to a vector
*
@@ -141,12 +180,30 @@ inline void CopyToVector(const Tensor& src, const platform::DeviceContext& ctx,
#ifdef PADDLE_WITH_CUDA
else if (platform::is_gpu_place(src.place())) { // NOLINT
memory::Copy(
- dst_place, dst_ptr, boost::get(src.place()),
+ dst_place, dst_ptr, boost::get(src.place()),
src_ptr, size,
reinterpret_cast(ctx).stream());
}
#endif
}
+/**
+ * @brief CopyToVector CPUTensor <-> CPU Vector
+ */
+template
+inline void CopyToVector(const Tensor& src, std::vector* dst) {
+ auto src_ptr = static_cast(src.data());
+ auto size = src.numel() * sizeof(T);
+
+ platform::CPUPlace dst_place;
+ dst->resize(src.numel());
+ auto dst_ptr = static_cast(dst->data());
+
+ PADDLE_ENFORCE(platform::is_cpu_place(src.place()));
+
+ memory::Copy(dst_place, dst_ptr, boost::get(src.place()),
+ src_ptr, size);
+}
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/tensor_util_test.cc b/paddle/framework/tensor_util_test.cc
index 03a70de182..f388c19f28 100644
--- a/paddle/framework/tensor_util_test.cc
+++ b/paddle/framework/tensor_util_test.cc
@@ -17,6 +17,7 @@
namespace paddle {
namespace framework {
+
TEST(CopyFrom, Tensor) {
Tensor src_tensor;
Tensor dst_tensor;
@@ -27,9 +28,10 @@ TEST(CopyFrom, Tensor) {
int arr[9] = {1, 2, 3, 4, 5, 6, 7, 8, 9};
memcpy(src_ptr, arr, 9 * sizeof(int));
+ src_tensor.set_layout(DataLayout::kAnyLayout);
auto cpu_place = new platform::CPUPlace();
- CopyFrom(src_tensor, *cpu_place, cpu_ctx, &dst_tensor);
+ CopyFrom(src_tensor, *cpu_place, &dst_tensor);
const int* dst_ptr = dst_tensor.data();
ASSERT_NE(src_ptr, dst_ptr);
@@ -37,14 +39,18 @@ TEST(CopyFrom, Tensor) {
EXPECT_EQ(src_ptr[i], dst_ptr[i]);
}
+ EXPECT_TRUE(dst_tensor.layout() == src_tensor.layout());
+
Tensor slice_tensor = src_tensor.Slice(1, 2);
- CopyFrom(slice_tensor, *cpu_place, cpu_ctx, &dst_tensor);
+ CopyFrom(slice_tensor, *cpu_place, &dst_tensor);
const int* slice_ptr = slice_tensor.data();
dst_ptr = dst_tensor.data();
ASSERT_NE(dst_ptr, slice_ptr);
for (size_t i = 0; i < 3; ++i) {
EXPECT_EQ(dst_ptr[i], slice_ptr[i]);
}
+ EXPECT_TRUE(dst_tensor.layout() == src_tensor.layout());
+
#ifdef PADDLE_WITH_CUDA
{
Tensor src_tensor;
@@ -58,7 +64,7 @@ TEST(CopyFrom, Tensor) {
memcpy(src_ptr, arr, 9 * sizeof(int));
// CPU Tensor to GPU Tensor
- auto gpu_place = new platform::GPUPlace(0);
+ auto gpu_place = new platform::CUDAPlace(0);
platform::CUDADeviceContext gpu_ctx(*gpu_place);
CopyFrom(src_tensor, *gpu_place, gpu_ctx, &gpu_tensor);
@@ -90,6 +96,8 @@ TEST(CopyFrom, Tensor) {
for (size_t i = 0; i < 3; ++i) {
EXPECT_EQ(dst_ptr[i], slice_ptr[i]);
}
+
+ EXPECT_TRUE(dst_tensor.layout() == src_tensor.layout());
}
#endif
}
@@ -104,8 +112,7 @@ TEST(CopyFromVector, Tensor) {
// Copy to CPU Tensor
cpu_tensor.Resize(make_ddim({3, 3}));
auto cpu_place = new paddle::platform::CPUPlace();
- CPUDeviceContext cpu_ctx(*cpu_place);
- CopyFromVector(src_vec, cpu_ctx, &cpu_tensor);
+ CopyFromVector(src_vec, &cpu_tensor);
// Compare Tensors
const int* cpu_ptr = cpu_tensor.data();
@@ -117,7 +124,7 @@ TEST(CopyFromVector, Tensor) {
src_vec.erase(src_vec.begin(), src_vec.begin() + 5);
cpu_tensor.Resize(make_ddim({2, 2}));
- CopyFromVector(src_vec, cpu_ctx, &cpu_tensor);
+ CopyFromVector(src_vec, &cpu_tensor);
cpu_ptr = cpu_tensor.data();
src_ptr = src_vec.data();
ASSERT_NE(src_ptr, cpu_ptr);
@@ -143,7 +150,7 @@ TEST(CopyFromVector, Tensor) {
// Copy to GPUTensor
gpu_tensor.Resize(make_ddim({3, 3}));
- auto gpu_place = new paddle::platform::GPUPlace();
+ auto gpu_place = new paddle::platform::CUDAPlace();
CUDADeviceContext gpu_ctx(*gpu_place);
CopyFromVector(src_vec, gpu_ctx, &gpu_tensor);
// Copy from GPU to CPU tensor for comparison
@@ -198,9 +205,8 @@ TEST(CopyToVector, Tensor) {
}
CPUPlace place;
- CPUDeviceContext cpu_ctx(place);
std::vector dst;
- CopyToVector(src, cpu_ctx, &dst);
+ CopyToVector(src, &dst);
for (int i = 0; i < 3 * 3; ++i) {
EXPECT_EQ(src_ptr[i], dst[i]);
@@ -210,7 +216,7 @@ TEST(CopyToVector, Tensor) {
{
std::vector src_vec = {1, 2, 3, 4, 5, 6, 7, 8, 9};
Tensor gpu_tensor;
- GPUPlace place;
+ CUDAPlace place;
CUDADeviceContext gpu_ctx(place);
CopyFromVector(src_vec, gpu_ctx, &gpu_tensor);
diff --git a/paddle/framework/threadpool.cc b/paddle/framework/threadpool.cc
new file mode 100644
index 0000000000..2b9be0646c
--- /dev/null
+++ b/paddle/framework/threadpool.cc
@@ -0,0 +1,24 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. */
+
+#include "paddle/framework/threadpool.h"
+
+namespace paddle {
+namespace framework {
+
+std::unique_ptr ThreadPool::threadpool(nullptr);
+std::once_flag ThreadPool::init_flag;
+
+} // namespace framework
+} // namespace paddle
diff --git a/paddle/framework/threadpool.h b/paddle/framework/threadpool.h
new file mode 100644
index 0000000000..5f6b2d458f
--- /dev/null
+++ b/paddle/framework/threadpool.h
@@ -0,0 +1,156 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+
+#include
+#include
+#include
+#include
+#include
+
+#include "paddle/platform/enforce.h"
+
+namespace paddle {
+namespace framework {
+
+typedef std::function Task;
+
+class ThreadPool {
+ public:
+ /**
+ * @brief Get a instance of threadpool, the thread number will
+ * be specified as the number of hardware thread contexts
+ */
+ static ThreadPool* GetInstance() {
+ std::call_once(init_flag, &ThreadPool::Init);
+ return threadpool.get();
+ }
+
+ ~ThreadPool() {
+ {
+ // notify all threads to stop running
+ running_ = false;
+ scheduled_.notify_all();
+ }
+
+ for (auto& t : threads_) {
+ t->join();
+ t.reset(nullptr);
+ }
+ }
+
+ int GetNumThreads() const { return num_threads_; }
+
+ int GetAvailable() {
+ std::unique_lock lock(mutex_);
+ return available_;
+ }
+
+ /**
+ * @brief Push a function to the queue, and will be scheduled and
+ * executed if a thread is available.
+ * @param[in] Task will be pushed to the task queue.
+ */
+ void Run(const Task& fn) {
+ std::unique_lock lock(mutex_);
+ tasks_.push(fn);
+ lock.unlock();
+ scheduled_.notify_one();
+ }
+
+ /**
+ * @brief Wait until all the tasks are completed.
+ */
+ void Wait() {
+ std::unique_lock lock(mutex_);
+ completed_.wait(lock, [=] { return Done() == true; });
+ }
+
+ private:
+ DISABLE_COPY_AND_ASSIGN(ThreadPool);
+
+ explicit ThreadPool(int num_threads)
+ : num_threads_(num_threads), available_(num_threads), running_(true) {
+ threads_.resize(num_threads);
+ for (auto& thread : threads_) {
+ // TODO(Yancey1989): binding the thread on the specify CPU number
+ thread.reset(new std::thread(std::bind(&ThreadPool::TaskLoop, this)));
+ }
+ }
+
+ /**
+ * @brief If the task queue is empty and avaialbe
+ * is equal to the number of threads, means that
+ * all tasks are completed.
+ *
+ * Note: this function is not thread-safe.
+ *
+ * @return true if all tasks are completed.
+ */
+ bool Done() { return tasks_.empty() && available_ == num_threads_; }
+
+ void TaskLoop() {
+ while (running_) {
+ std::unique_lock lock(mutex_);
+ scheduled_.wait(lock, [=] { return !tasks_.empty() || !running_; });
+
+ if (!running_) {
+ break;
+ }
+ // pop a task from the task queue
+ auto task = tasks_.front();
+ tasks_.pop();
+
+ --available_;
+ lock.unlock();
+
+ // run the task
+ task();
+
+ {
+ std::unique_lock lock(mutex_);
+ ++available_;
+ if (Done()) {
+ completed_.notify_all();
+ }
+ }
+ }
+ }
+
+ static void Init() {
+ if (threadpool.get() == nullptr) {
+ // TODO(Yancey1989): specify the max threads number
+ int num_threads = std::thread::hardware_concurrency();
+ PADDLE_ENFORCE_GT(num_threads, 0);
+ threadpool.reset(new ThreadPool(num_threads));
+ }
+ }
+
+ private:
+ static std::unique_ptr threadpool;
+ static std::once_flag init_flag;
+
+ int num_threads_;
+ int available_;
+ bool running_;
+ std::queue tasks_;
+ std::vector> threads_;
+ std::mutex mutex_;
+ std::condition_variable scheduled_;
+ std::condition_variable completed_;
+};
+
+} // namespace framework
+} // namespace paddle
diff --git a/paddle/framework/threadpool_test.cc b/paddle/framework/threadpool_test.cc
new file mode 100644
index 0000000000..012d92a5ed
--- /dev/null
+++ b/paddle/framework/threadpool_test.cc
@@ -0,0 +1,56 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include
+#include
+
+#include "threadpool.h"
+
+namespace framework = paddle::framework;
+
+void do_sum(framework::ThreadPool* pool, std::atomic& sum, int cnt) {
+ for (int i = 0; i < cnt; ++i) {
+ pool->Run([&sum]() { sum.fetch_add(1); });
+ }
+}
+
+TEST(ThreadPool, ConcurrentInit) {
+ framework::ThreadPool* pool;
+ int concurrent_cnt = 50;
+ std::vector threads;
+ for (int i = 0; i < concurrent_cnt; ++i) {
+ std::thread t([&pool]() { pool = framework::ThreadPool::GetInstance(); });
+ threads.push_back(std::move(t));
+ }
+ for (auto& t : threads) {
+ t.join();
+ }
+}
+
+TEST(ThreadPool, ConcurrentStart) {
+ framework::ThreadPool* pool = framework::ThreadPool::GetInstance();
+ std::atomic sum(0);
+ std::vector threads;
+ int concurrent_cnt = 50;
+ // sum = (n * (n + 1)) / 2
+ for (int i = 1; i <= concurrent_cnt; ++i) {
+ std::thread t(do_sum, pool, std::ref(sum), i);
+ threads.push_back(std::move(t));
+ }
+ for (auto& t : threads) {
+ t.join();
+ }
+ pool->Wait();
+ EXPECT_EQ(sum, ((concurrent_cnt + 1) * concurrent_cnt) / 2);
+}
diff --git a/paddle/memory/README.md b/paddle/memory/README.md
index 6cb003c50b..7cf61d089b 100644
--- a/paddle/memory/README.md
+++ b/paddle/memory/README.md
@@ -12,13 +12,13 @@ p = memory::Alloc(platform::CPUPlace(), 4*1024);
To allocate 4KB memory on the 3rd GPU:
```cpp
-p = memory::Alloc(platform::GPUPlace(2), 4*1024);
+p = memory::Alloc(platform::CUDAPlace(2), 4*1024);
```
To free memory and check the so-far used amount of memory on a place:
```cpp
-auto pl = platform::GPUPlace(0);
+auto pl = platform::CUDAPlace(0);
p = memory::Alloc(pl, 4*1024);
cout << memory::Used(pl);
memory::Free(pl, p);
@@ -36,7 +36,7 @@ template size_t Used(Place);
} // namespace memory
```
-These function templates have specializations on either `platform::CPUPlace` or `platform::GPUPlace`:
+These function templates have specializations on either `platform::CPUPlace` or `platform::CUDAPlace`:
```cpp
template<>
@@ -49,7 +49,7 @@ and
```cpp
template<>
-void Alloc(GPUPlace p, size_t size) {
+void Alloc(CUDAPlace p, size_t size) {
return GetGPUBuddyAllocator(p.id)->Alloc(size);
}
```
@@ -122,7 +122,7 @@ There are two implementations of `Context`:
1. [`CPUContext`](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/context.h#L105), whose [`New` method](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/context.h#L131) calls [`g_cpu_allocator.get()->New(size_t)`](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/context.cc#L15) to allocate the memory.
-1. [`CUDAContext`](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/context_gpu.h#L99), which has a data member [`int gpu_id_`](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/context_gpu.h#L202). This looks very similar to class `majel::GPUPlace`, who also has an `int id_` data member. `CUDAContext::New(size_t)` calls [`g_cub_allocator->DeviceAllocate(&ptr, nbytes)`](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/context_gpu.cu#L355) to allocate the memory.
+1. [`CUDAContext`](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/context_gpu.h#L99), which has a data member [`int gpu_id_`](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/context_gpu.h#L202). This looks very similar to class `majel::CUDAPlace`, who also has an `int id_` data member. `CUDAContext::New(size_t)` calls [`g_cub_allocator->DeviceAllocate(&ptr, nbytes)`](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/context_gpu.cu#L355) to allocate the memory.
### Majel
diff --git a/paddle/memory/memcpy.cc b/paddle/memory/memcpy.cc
index 5c629dc3d2..b46141aafd 100644
--- a/paddle/memory/memcpy.cc
+++ b/paddle/memory/memcpy.cc
@@ -28,31 +28,25 @@ void Copy(platform::CPUPlace, void* dst,
#ifdef PADDLE_WITH_CUDA
template <>
-void Copy(platform::CPUPlace dst_place,
- void* dst,
- platform::GPUPlace src_place,
- const void* src, size_t num,
- cudaStream_t stream) {
+void Copy(
+ platform::CPUPlace dst_place, void* dst, platform::CUDAPlace src_place,
+ const void* src, size_t num, cudaStream_t stream) {
platform::SetDeviceId(src_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
}
template <>
-void Copy(platform::GPUPlace dst_place,
- void* dst,
- platform::CPUPlace src_place,
- const void* src, size_t num,
- cudaStream_t stream) {
+void Copy(
+ platform::CUDAPlace dst_place, void* dst, platform::CPUPlace src_place,
+ const void* src, size_t num, cudaStream_t stream) {
platform::SetDeviceId(dst_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
}
template <>
-void Copy(platform::GPUPlace dst_place,
- void* dst,
- platform::GPUPlace src_place,
- const void* src, size_t num,
- cudaStream_t stream) {
+void Copy(
+ platform::CUDAPlace dst_place, void* dst, platform::CUDAPlace src_place,
+ const void* src, size_t num, cudaStream_t stream) {
if (dst_place == src_place) {
platform::SetDeviceId(src_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream);
diff --git a/paddle/memory/memory.cc b/paddle/memory/memory.cc
index 9cafdfda75..c4bb6baee7 100644
--- a/paddle/memory/memory.cc
+++ b/paddle/memory/memory.cc
@@ -83,12 +83,12 @@ BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) {
}
template <>
-size_t Used(platform::GPUPlace place) {
+size_t Used(platform::CUDAPlace place) {
return GetGPUBuddyAllocator(place.device)->Used();
}
template <>
-void* Alloc(platform::GPUPlace place, size_t size) {
+void* Alloc(platform::CUDAPlace place, size_t size) {
auto* buddy_allocator = GetGPUBuddyAllocator(place.device);
auto* ptr = buddy_allocator->Alloc(size);
if (ptr == nullptr) {
@@ -101,14 +101,14 @@ void* Alloc(platform::GPUPlace place, size_t size) {
LOG(WARNING) << "total " << total;
LOG(WARNING) << "GpuMinChunkSize " << platform::GpuMinChunkSize();
LOG(WARNING) << "GpuMaxChunkSize " << platform::GpuMaxChunkSize();
- LOG(WARNING) << "GPU memory used: " << Used(place);
+ LOG(WARNING) << "GPU memory used: " << Used(place);
platform::SetDeviceId(cur_dev);
}
return ptr;
}
template <>
-void Free(platform::GPUPlace place, void* p) {
+void Free(platform::CUDAPlace place, void* p) {
GetGPUBuddyAllocator(place.device)->Free(p);
}
diff --git a/paddle/memory/memory_test.cc b/paddle/memory/memory_test.cc
index 2444931e26..f476bf7126 100644
--- a/paddle/memory/memory_test.cc
+++ b/paddle/memory/memory_test.cc
@@ -82,7 +82,7 @@ TEST(BuddyAllocator, CPUMultAlloc) {
#ifdef PADDLE_WITH_CUDA
-size_t align(size_t size, paddle::platform::GPUPlace place) {
+size_t align(size_t size, paddle::platform::CUDAPlace place) {
size += sizeof(paddle::memory::detail::Metadata);
size_t alignment = paddle::platform::GpuMinChunkSize();
size_t remaining = size % alignment;
@@ -94,7 +94,7 @@ TEST(BuddyAllocator, GPUAllocation) {
EXPECT_EQ(p, nullptr);
- paddle::platform::GPUPlace gpu(0);
+ paddle::platform::CUDAPlace gpu(0);
p = paddle::memory::Alloc(gpu, 4096);
EXPECT_NE(p, nullptr);
@@ -103,7 +103,7 @@ TEST(BuddyAllocator, GPUAllocation) {
}
TEST(BuddyAllocator, GPUMultAlloc) {
- paddle::platform::GPUPlace gpu;
+ paddle::platform::CUDAPlace gpu;
std::unordered_map ps;
diff --git a/paddle/operators/accuracy_op.cc b/paddle/operators/accuracy_op.cc
index b8ed93f4eb..d7baa6e905 100644
--- a/paddle/operators/accuracy_op.cc
+++ b/paddle/operators/accuracy_op.cc
@@ -53,7 +53,7 @@ class AccuracyOp : public framework::OperatorWithKernel {
}
protected:
- framework::OpKernelType GetKernelType(
+ framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input("Out")->type()),
diff --git a/paddle/operators/accuracy_op.cu b/paddle/operators/accuracy_op.cu
index dd51aad105..0aadd5af41 100644
--- a/paddle/operators/accuracy_op.cu
+++ b/paddle/operators/accuracy_op.cu
@@ -56,7 +56,7 @@ class AccuracyOpCUDAKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
- "It must use GPUPlace.");
+ "It must use CUDAPlace.");
auto* inference = ctx.Input("Out");
auto* indices = ctx.Input("Indices");
auto* label = ctx.Input("Label");
diff --git a/paddle/operators/adam_op.h b/paddle/operators/adam_op.h
index 45157842a6..c4e2c8bb88 100644
--- a/paddle/operators/adam_op.h
+++ b/paddle/operators/adam_op.h
@@ -13,59 +13,113 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
-#include "paddle/framework/eigen.h"
+#include // for sqrt in CPU and CUDA
#include "paddle/framework/op_registry.h"
+#include "paddle/operators/detail/safe_ref.h"
+#include "paddle/platform/for_range.h"
namespace paddle {
namespace operators {
+template
+struct AdamFunctor {
+ T beta1_;
+ T beta2_;
+ T epsilon_;
+
+ const T* beta1_pow_;
+ const T* beta2_pow_;
+ const T* moment1_;
+ T* moment1_out_;
+ const T* moment2_;
+ T* moment2_out_;
+ const T* lr_;
+ const T* grad_;
+ const T* param_;
+ T* param_out_;
+
+ AdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
+ const T* beta2_pow, const T* mom1, T* mom1_out, const T* mom2,
+ T* mom2_out, const T* lr, const T* grad, const T* param,
+ T* param_out)
+ : beta1_(beta1),
+ beta2_(beta2),
+ epsilon_(epsilon),
+ beta1_pow_(beta1_pow),
+ beta2_pow_(beta2_pow),
+ moment1_(mom1),
+ moment1_out_(mom1_out),
+ moment2_(mom2),
+ moment2_out_(mom2_out),
+ lr_(lr),
+ grad_(grad),
+ param_(param),
+ param_out_(param_out) {}
+
+ inline HOSTDEVICE void operator()(size_t i) const {
+ // Merge all memory access together.
+ T g = grad_[i];
+ T mom1 = moment1_[i];
+ T mom2 = moment2_[i];
+ T lr = *lr_;
+ T beta1_pow = *beta1_pow_;
+ T beta2_pow = *beta2_pow_;
+ T p = param_[i];
+
+ // Calculation
+ lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
+ mom1 = beta1_ * mom1 + (1 - beta1_) * g;
+ mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
+ p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
+
+ // Write back to global memory
+ moment1_out_[i] = mom1;
+ moment2_out_[i] = mom2;
+ param_out_[i] = p;
+ }
+};
+
template
class AdamOpKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
- auto param_out_tensor = ctx.Output("ParamOut");
- auto moment1_out_tensor = ctx.Output("Moment1Out");
- auto moment2_out_tensor = ctx.Output("Moment2Out");
-
- param_out_tensor->mutable_data(ctx.GetPlace());
- moment1_out_tensor->mutable_data(ctx.GetPlace());
- moment2_out_tensor->mutable_data(ctx.GetPlace());
+ using paddle::framework::LoDTensor;
+ using paddle::operators::detail::Ref;
T beta1 = static_cast(ctx.Attr("beta1"));
T beta2 = static_cast(ctx.Attr("beta2"));
T epsilon = static_cast(ctx.Attr("epsilon"));
+ auto& param = Ref(ctx.Input("Param"), "Must set Param");
+ auto& grad = Ref(ctx.Input("Grad"), "Must set Grad");
+ auto& mom1 = Ref(ctx.Input("Moment1"), "Must set Moment1");
+ auto& mom2 = Ref(ctx.Input("Moment2"), "Must set Moment2");
+ auto& lr =
+ Ref(ctx.Input("LearningRate"), "Must set LearningRate");
+
+ auto& beta1_pow =
+ Ref(ctx.Input("Beta1Pow"), "Must set Beta1Pow");
+ auto& beta2_pow =
+ Ref(ctx.Input("Beta2Pow"), "Must set Beta2Pow");
+
+ auto& param_out =
+ Ref(ctx.Output("ParamOut"), "Must set ParamOut");
+ auto& mom1_out =
+ Ref(ctx.Output("Moment1Out"), "Must set Moment1Out");
+ auto& mom2_out =
+ Ref(ctx.Output("Moment2Out"), "Must set Moment1Out");
- auto param = framework::EigenVector::Flatten(
- *ctx.Input("Param"));
- auto grad = framework::EigenVector::Flatten(
- *ctx.Input("Grad"));
- auto moment1 = framework::EigenVector::Flatten(
- *ctx.Input("Moment1"));
- auto moment2 = framework::EigenVector::Flatten(
- *ctx.Input("Moment2"));
- auto lr = framework::EigenVector::Flatten(
- *ctx.Input("LearningRate"));
- auto beta1_pow = framework::EigenVector::Flatten(
- *ctx.Input("Beta1Pow"));
- auto beta2_pow = framework::EigenVector::Flatten(
- *ctx.Input("Beta2Pow"));
- auto param_out = framework::EigenVector::Flatten(*param_out_tensor);
- auto moment1_out = framework::EigenVector::Flatten(*moment1_out_tensor);
- auto moment2_out = framework::EigenVector::Flatten(*moment2_out_tensor);
- auto* place = ctx.template device_context().eigen_device();
-
- moment1_out.device(*place) = beta1 * moment1 + (1 - beta1) * grad;
- moment2_out.device(*place) = beta2 * moment2 + (1 - beta2) * grad.square();
-
- // All of these are tensors of 1 element
- auto lr_t = lr * (1 - beta2_pow).sqrt() / (1 - beta1_pow);
- // Eigen does not support automatic broadcast
- // Get dimensions of moment vector to broadcast lr_t
- Eigen::DSizes m_dsize(moment1_out_tensor->numel());
- param_out.device(*place) =
- param -
- lr_t.broadcast(m_dsize) *
- (moment1_out / (moment2_out.sqrt() + epsilon));
+ AdamFunctor functor(beta1, beta2, epsilon, beta1_pow.template data(),
+ beta2_pow.template data(),
+ mom1.template data(),
+ mom1_out.template mutable_data(ctx.GetPlace()),
+ mom2.template data(),
+ mom2_out.template mutable_data(ctx.GetPlace()),
+ lr.template data(), grad.template data(),
+ param.template data(),
+ param_out.template mutable_data(ctx.GetPlace()));
+ platform::ForRange for_range(
+ static_cast(ctx.device_context()), param.numel());
+ for_range(functor);
}
};
diff --git a/paddle/operators/auc_op.cc b/paddle/operators/auc_op.cc
index 811c487089..c16bc11931 100644
--- a/paddle/operators/auc_op.cc
+++ b/paddle/operators/auc_op.cc
@@ -39,7 +39,7 @@ class AucOp : public framework::OperatorWithKernel {
}
protected:
- framework::OpKernelType GetKernelType(
+ framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input("Out")->type()),
diff --git a/paddle/operators/batch_norm_op.cc b/paddle/operators/batch_norm_op.cc
index 1c14acbe11..49cb0fa4d9 100644
--- a/paddle/operators/batch_norm_op.cc
+++ b/paddle/operators/batch_norm_op.cc
@@ -304,7 +304,7 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
}
protected:
- framework::OpKernelType GetKernelType(
+ framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
const auto *var = ctx.InputVar(framework::GradVarName("Y"));
if (var == nullptr) {
diff --git a/paddle/operators/batch_norm_op.cu.cc b/paddle/operators/batch_norm_op.cu.cc
index 55d0736a4c..3d17725ab4 100644
--- a/paddle/operators/batch_norm_op.cu.cc
+++ b/paddle/operators/batch_norm_op.cu.cc
@@ -53,7 +53,7 @@ class BatchNormKernel
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
- "It must use GPUPlace.");
+ "It must use CUDAPlace.");
double epsilon = static_cast(ctx.Attr("epsilon"));
const float momentum = ctx.Attr("momentum");
const bool is_test = ctx.Attr("is_test");
@@ -179,7 +179,7 @@ class BatchNormGradKernel
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
- "It must use GPUPlace.");
+ "It must use CUDAPlace.");
double epsilon = static_cast(ctx.Attr("epsilon"));
const std::string data_layout_str = ctx.Attr("data_layout");
const DataLayout data_layout =
diff --git a/paddle/operators/chunk_eval_op.cc b/paddle/operators/chunk_eval_op.cc
index f1f274a7af..a040404266 100644
--- a/paddle/operators/chunk_eval_op.cc
+++ b/paddle/operators/chunk_eval_op.cc
@@ -55,7 +55,7 @@ class ChunkEvalOp : public framework::OperatorWithKernel {
}
protected:
- framework::OpKernelType GetKernelType(
+ framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(framework::proto::DataType::FP32,
ctx.device_context());
diff --git a/paddle/operators/compare_op.cc b/paddle/operators/compare_op.cc
index 1148172f3a..10bf3d4bbc 100644
--- a/paddle/operators/compare_op.cc
+++ b/paddle/operators/compare_op.cc
@@ -66,9 +66,9 @@ class CompareOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
- framework::OpKernelType GetKernelType(
+ framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
- framework::OpKernelType kt = OperatorWithKernel::GetKernelType(ctx);
+ framework::OpKernelType kt = OperatorWithKernel::GetActualKernelType(ctx);
// CompareOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input("X")->place();
return kt;
diff --git a/paddle/operators/conv_cudnn_op.cu.cc b/paddle/operators/conv_cudnn_op.cu.cc
index 3da0a9001a..79e020b755 100644
--- a/paddle/operators/conv_cudnn_op.cu.cc
+++ b/paddle/operators/conv_cudnn_op.cu.cc
@@ -36,7 +36,7 @@ class CudnnConvOpKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
- "It must use GPUPlace.");
+ "It must use CUDAPlace.");
auto* input = ctx.Input("Input");
auto* filter = ctx.Input("Filter");
auto* output = ctx.Output("Output");
@@ -130,7 +130,7 @@ class CudnnConvOpKernel : public framework::OpKernel {
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_output_desc, algo, &workspace_size_in_bytes));
// Allocate on GPU memory
- platform::GPUPlace gpu = boost::get(ctx.GetPlace());
+ platform::CUDAPlace gpu = boost::get(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// ------------------- cudnn conv forward ---------------------
T alpha = 1.0f, beta = 0.0f;
@@ -151,7 +151,7 @@ class CudnnConvGradOpKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
- "It must use GPUPlace.");
+ "It must use CUDAPlace.");
auto input = ctx.Input("Input");
auto filter = ctx.Input("Filter");
auto output_grad = ctx.Input(framework::GradVarName("Output"));
@@ -277,7 +277,7 @@ class CudnnConvGradOpKernel : public framework::OpKernel