Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fix_transpose_doc

del_some_in_makelist
wanghaoshuang 7 years ago
commit 641df37393

@ -16,8 +16,6 @@ cmake_minimum_required(VERSION 3.0)
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake") set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
set(PADDLE_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}) set(PADDLE_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
set(PADDLE_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}) set(PADDLE_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
SET(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG")
SET(CMAKE_C_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG")
include(system) include(system)
@ -201,6 +199,10 @@ if(WITH_GOLANG)
endif(WITH_GOLANG) endif(WITH_GOLANG)
set(PADDLE_PYTHON_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/python/build") set(PADDLE_PYTHON_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/python/build")
SET(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG")
SET(CMAKE_C_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG")
add_subdirectory(paddle) add_subdirectory(paddle)
if(WITH_PYTHON) if(WITH_PYTHON)
add_subdirectory(python) add_subdirectory(python)

@ -22,6 +22,7 @@ On each machine, we will test and compare the performance of training on single
#### Training #### Training
Test on batch size 64, 128, 256 on Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz Test on batch size 64, 128, 256 on Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz
Pay attetion that the speed below includes forward, backward and parameter update time. So we can not directly compare the data with the benchmark of caffe `time` [command](https://github.com/PaddlePaddle/Paddle/blob/develop/benchmark/caffe/image/run.sh#L9), which only contain forward and backward. The updating time of parameter would become very heavy when the weight size are large, especially on alexnet.
Input image size - 3 * 224 * 224, Time: images/second Input image size - 3 * 224 * 224, Time: images/second
@ -55,6 +56,16 @@ Input image size - 3 * 224 * 224, Time: images/second
<img src="figs/googlenet-cpu-train.png" width="500"> <img src="figs/googlenet-cpu-train.png" width="500">
- Alexnet
| BatchSize | 64 | 128 | 256 |
|--------------|--------| ------ | -------|
| OpenBLAS | 2.13 | 2.45 | 2.68 |
| MKLML | 66.37 | 105.60 | 144.04 |
| MKL-DNN | 399.00 | 498.94 | 626.53 |
chart TBD
#### Inference #### Inference
Test on batch size 1, 2, 4, 8, 16 on Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz Test on batch size 1, 2, 4, 8, 16 on Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz
- VGG-19 - VGG-19

@ -0,0 +1,149 @@
# Design Doc: Add MKLDNN Kernel in Fluid Operator
## Principles
First of all, we should follow some basical principles like:
1. [How to write a new operator](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/dev/new_op_en.md). We are trying to add a new kind of kernel into operators, so basically we should follow this doc.
2. [Supporting new Device/Library](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/support_new_device.md). Since MKLDNN is a new library to fluid, we should add `MKLDNNDeviceContext` and maybe `mkldnn_helper.h`, just like [cudnn_helper.h](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/platform/cudnn_helper.h).
3. [Switch Kernel](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/switch_kernel.md). Another important point is that we should ensure the data synchronization between different kernel types, which is this [topic](https://github.com/PaddlePaddle/Paddle/issues/6549). So basically we should override `GetExpectedKernelType` and `trans` functions to support switching kernels.
4. [The Keys of Operator Kernel Type](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/operator_kernel_type.md). Kernel Type is a pivotal conception which can record the `Place`, `Library`, `DataType` and `Layout`.
## Sulution
In general, there are four parts we should follow to run a MKL-DNN primitive.
- Create a primitive descriptor that describe this operator
- Create a primitive itself by primitive descriptor and the engine
- Create all memory buffers that primitive needed
- Launch a stream to execute the primitive created
More details can refer to [here](http://01org.github.io/mkl-dnn).
It's better to avoid reinitialization of primitives and memory handles in the first three stages in every iteration. \
So we plan to create a map to record all the `primitive` and `memory`, which should not take too much memories as discussed [here](https://github.com/PaddlePaddle/Paddle/issues/6822).
It's assumed that following three conditions should be satisfied.
1. there is a unique key for each operator instance. May be the actual name of `Output Tensor`.
2. the `Input Tensor` inside `Compute` function is the one after converted.
3. we can get the phase(eg. `is_test`) inside `Compute` function, otherwise we need to expose this attribue to user.
### Compute
The algorithm of `Compute` would be described as follow, let's take conv like an example.
```c++
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), "It must use CPUPlace.");
PADDLE_ENFORCE(platform::is_mkldnn_library(ctx.GetLibrary()), "It must use MKLDNN Library.");
auto& dev_ctx = ctx.template device_context<platform::MKLDNNDeviceContext>();
// find primitive by unique key from mkldnn context
// the op_key should be a unique name of this op instance
auto& p = dev_ctx.findPrimitive(op_key + "_fwd");
// assuming the input tensor inside this compute function is the one after converted
// this point should be guarantee by another mechanism
auto& i = dev_ctx.findMemory(op_key + "_input");
if (p == nullptr || i == nullptr || inputSizeChanged(p, i)) {
auto fwd_primitive_desc = createPrimitiveDesc(ctx);
auto* input = ctx.Input<Tensor>("Input");
auto* filter = ctx.Input<Tensor>("Filter");
auto* output = ctx.Output<Tensor>("Output");
shared_ptr<mkldnn::memory> in(new mkldnn::memory(fwd_primitive_desc->src_primitive_desc(), input->data<T>()));
shared_ptr<mkldnn::memory> wgt(new mkldnn::memory(fwd_primitive_desc->weights_primitive_desc(), filter->data<T>()));
shared_ptr<mkldnn::memory> out(new mkldnn::memory(fwd_primitive_desc->dst_primitive_desc(), output->mutable_data<T>(ctx.GetPlace())));
shared_ptr<mkldnn::conv_fwd> fwd_primitive(new mkldnn::conv_fwd(*fwd_primitive_desc, *in, *wgt, *out));
dev_ctx.addMemory(op_key+"_input", in);
dev_ctx.addMemory(op_key+"_output", out);
dev_ctx.addMemory(op_key+"_filer", wgt);
dev_ctx.addPrimitive(op_key+"_fwd", fwd_primitive);
dev_ctx.addPrimitiveDesc(op_key+"_fwd_PD", fwd_primitive_desc);
}
p = dev_ctx.findPrimitive(op_key + "_fwd");
PADDLE_ENFORCE(p, "Should have forward Primitive");
PADDLE_ENFORCE(dev_ctx.findMemory(op_unique_key+"_input"), "Should have input memory");
PADDLE_ENFORCE(dev_ctx.findMemory(op_unique_key+"_output"), "Should have output memory");
PADDLE_ENFORCE(dev_ctx.findMemory(op_unique_key+"_filter"), "Should have filter memory");
PADDLE_ENFORCE(dev_ctx.findPrimitiveDesc(op_unique_key+"_fwd_PD"), "Should have forward PrimitiveDesc");
dev_ctx.submit(p);
dev_ctx.execute(); // the convert primitive should have already contained.
```
The `createPrimitiveDesc` returns the primitive descripotor of this operator, would be like this:
```c++
auto* input = ctx.Input<Tensor>("Input");
auto* filter = ctx.Input<Tensor>("Filter");
auto* output = ctx.Output<Tensor>("Output");
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
int groups = ctx.Attr<int>("groups");
algorithm algo = static_cast<algorithm>(ctx.Attr<int>("convolution_algorithm_option"));
prop_kind pk = ctx.Attr<bool>("is_test") ? prop_kind::forward_inference : prop_kind::forward_training;
auto fwd_desc = mkldnn::conv_fwd::desc(/* all the setting above*/);
shared_ptr<mkldnn::conv_fwd::primitive_desc> fwd_primitive_desc(new mkldnn::conv_fwd::primitive_desc(fwd_desc, ctx.getEngine()));
return fwd_primitive_desc;
}
```
### MKLDNNDeviceContext
`MKLDNNDeviceContext`, which is very straightforward, should contain some base information like: `stream`, `engine` and the map needed.
### mkldnn_helper
Some functions would be put in `paddle/platform/mkldnn_helper.h`.
- create MKLDNN memories
- create MKLDNN primitives
- error check function
- etc
### Kernel Switch
We should `reorder` the different Layout from other device or to other device. `GetExpectedKernelType` and `trans` functions can help us to implement it.
`GetExpectedKernelType` should get the context, and this operator can return the best `KernelType`.
`trans` would be like this:
```c++
void trans(inputs, ctx) override {
if (NoNeedTrans()) {
return;
}
// find reorder primitive by op_key from context
auto& dev_ctx = ctx.template device_context<platform::MKLDNNDeviceContext>();
auto& p = dev_ctx.findPrimitive(op_key + "_reorder_input");
auto& i = dev_ctx.findMemory(op_key + "_src_input");
if (p == nullptr || i == nullptr || changeSized(i, input)) {
auto prim = createPrimitiveDesc(ctx);
auto src = createMemory(memoryDesc(input->dims(), actual_layout), input->data);
auto newbuffer = paddle::memory::Alloc(ctx.GetPlace(), input->size_in_bytes());
auto dst = createMemory(p->expected_desc(), newbuffer->data);
auto reorder_primitive(new mkldnn::reorder(src, dst));
dev_ctx.addMemory(op_key+"_src_input", src);
dev_ctx.addMemory(op_key+"_input", dst);
dev_ctx.addPrimitive(op_key+"_reorder_input", reorder_primitive);
}
p = dev_ctx.findPrimitive(op_key + "_reorder_input");
PADDLE_ENFORCE(p, "Should have Reorder Primitive");
dev_ctx.submit(p);
if (! this->isMKLDNNKernel()) {
// execute immediately only if this is not mkldnn kernel function.
// otherwise, it can be executed with the operator primitive in Compute
dev_ctx.stream();
}
// after submit, the input tensor in ExecutionContext should be changed as the converted one
// there should be another mechanism to ensure this
}
```
### Unit Test
All the functions should be tested corresponding.
TBD

@ -25,13 +25,14 @@ There are mainly three parts that we have to consider while integrating a new de
### Place and DeviceContext ### Place and DeviceContext
Please remind that device and computing library are not one-to-one corresponding. A device can have a lot of computing libraries and a computing library can also support several devices.
#### Place #### Place
Fluid uses class [Place](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/platform/place.h#L55) to represent different devices and computing libraries. There are inheritance relationships between different kinds of `Place`. Fluid uses class [Place](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/platform/place.h#L55) to represent the device memory where data is located. If we add another device, we have to add corresponding `DevicePlace`.
``` ```
| CPUPlace --> MKLDNNPlace | CPUPlace
Place --| CUDAPlace --> CUDNNPlace Place --| CUDAPlace
| FPGAPlace | FPGAPlace
``` ```
@ -43,7 +44,7 @@ typedef boost::variant<CUDAPlace, CPUPlace, FPGAPlace> Place;
#### DeviceContext #### DeviceContext
Fluid uses class [DeviceContext](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/platform/device_context.h#L30) to manage the resources in different hardwares, such as CUDA stream in `CDUADeviceContext`. There are also inheritance relationships between different kinds of `DeviceContext`. Fluid uses class [DeviceContext](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/platform/device_context.h#L30) to manage the resources in different libraries, such as CUDA stream in `CDUADeviceContext`. There are also inheritance relationships between different kinds of `DeviceContext`.
``` ```
@ -106,7 +107,7 @@ template <typename Place>
size_t Used(Place place); size_t Used(Place place);
``` ```
To implementing these interfaces, we have to implement MemoryAllocator for different Devices To implement these interfaces, we have to implement MemoryAllocator for different Devices.
#### Tensor #### Tensor
@ -243,6 +244,7 @@ REGISTER_OP_CUDA_KERNEL(
Generally, we will impelement OpKernel for all Device/Library of an Operator. We can easily train a Convolutional Neural Network in GPU. However, some OpKernel is not sutibale on a specific Device. For example, crf operator can only run on CPU, whereas most other operators can run at GPU. To achieve high performance in such circumstance, we have to switch between different Device/Library. Generally, we will impelement OpKernel for all Device/Library of an Operator. We can easily train a Convolutional Neural Network in GPU. However, some OpKernel is not sutibale on a specific Device. For example, crf operator can only run on CPU, whereas most other operators can run at GPU. To achieve high performance in such circumstance, we have to switch between different Device/Library.
We will discuss how to implement an efficient OpKernel switch policy. For more details, please refer to following docs:
- TBD - operator kernel type [doc](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/operator_kernel_type.md)
- switch kernel [doc](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/switch_kernel.md)

@ -109,3 +109,31 @@ PaddlePaddle使用avx SIMD指令提高cpu执行效率因此错误的使用二
解决办法是: 解决办法是:
* 卸载PaddlePaddle包 :code:`pip uninstall paddle`, 清理掉老旧的PaddlePaddle安装包使得单元测试有一个干净的环境。如果PaddlePaddle包已经在python的site-packages里面单元测试会引用site-packages里面的python包而不是源码目录里 :code:`/python` 目录下的python包。同时即便设置 :code:`PYTHONPATH`:code:`/python` 也没用因为python的搜索路径是优先已经安装的python包。 * 卸载PaddlePaddle包 :code:`pip uninstall paddle`, 清理掉老旧的PaddlePaddle安装包使得单元测试有一个干净的环境。如果PaddlePaddle包已经在python的site-packages里面单元测试会引用site-packages里面的python包而不是源码目录里 :code:`/python` 目录下的python包。同时即便设置 :code:`PYTHONPATH`:code:`/python` 也没用因为python的搜索路径是优先已经安装的python包。
8. 下载MKLML库失败
------------------
.. code-block:: bash
make[2]: *** [third_party/mklml/src/extern_mklml-stamp/extern_mklml-download] 错误 4
make[1]: *** [CMakeFiles/extern_mklml.dir/all] 错误 2
make[1]: *** 正在等待未完成的任务....
原因网速或SSL链接原因导致MKLML库下载不成功。
解决办法是:手动下载并安装,具体步骤如下。
.. code-block:: bash
// 1. 进入对应的目录
cd build/third_party/mklml/src/extern_mklml
// 2. 查看包的大小, 正常情况下是75M如果小于75M即下载失败
du -sh mklml_lnx_2018.0.1.20171007.tgz
// 3. 手动下载且解压缩并手动生成download成功标签
wget --no-check-certificate https://github.com/01org/mkl-dnn/releases/download/v0.11/mklml_lnx_2018.0.1.20171007.tgz -c -O mklml_lnx_2018.0.1.20171007.tgz
tar zxf mklml_lnx_2018.0.1.20171007.tgz
touch ../extern_mklml-stamp/extern_mklml-download
// 4. 接着编译即可

@ -59,5 +59,9 @@ cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
cc_library(selected_rows SRCS selected_rows.cc DEPS tensor) cc_library(selected_rows SRCS selected_rows.cc DEPS tensor)
cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows) cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows)
cc_library(threadpool SRCS threadpool.cc)
cc_test(threadpool_test SRCS threadpool_test.cc DEPS threadpool)
cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece) cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece)
cc_test(init_test SRCS init_test.cc DEPS init) cc_test(init_test SRCS init_test.cc DEPS init)
cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context)

@ -13,11 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/platform/enforce.h"
#include <iostream>
#include "paddle/platform/enforce.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
enum DataLayout { enum class DataLayout {
kNHWC = 0, kNHWC = 0,
kNCHW = 1, kNCHW = 1,
kAnyLayout = 2, kAnyLayout = 2,
@ -33,5 +37,23 @@ inline DataLayout StringToDataLayout(const std::string& str) {
} }
} }
inline std::string DataLayoutToString(const DataLayout& data_layout) {
switch (data_layout) {
case DataLayout::kNHWC:
return "NHWC";
case DataLayout::kNCHW:
return "NCHW";
case DataLayout::kAnyLayout:
return "ANY_LAYOUT";
default:
PADDLE_THROW("unknown DataLayou %d", data_layout);
}
}
inline std::ostream& operator<<(std::ostream& out, DataLayout l) {
out << DataLayoutToString(l);
return out;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle

@ -54,7 +54,7 @@ bool InitDevices(const std::vector<std::string> &devices) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
auto pos = string::RFind(p, ':', string::Piece::npos); auto pos = string::RFind(p, ':', string::Piece::npos);
auto number = device.substr(pos + 1); auto number = device.substr(pos + 1);
places.emplace_back(platform::GPUPlace(std::stoi(number))); places.emplace_back(platform::CUDAPlace(std::stoi(number)));
#else #else
LOG(WARNING) LOG(WARNING)
<< "'GPU' is not supported, Please re-compile with WITH_GPU option"; << "'GPU' is not supported, Please re-compile with WITH_GPU option";

@ -20,7 +20,25 @@ namespace framework {
// For more details about the design of LibraryType, Please refer to // For more details about the design of LibraryType, Please refer to
// https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/operator_kernel_type.md#library // https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/operator_kernel_type.md#library
enum LibraryType { kPlain = 0, kMKLDNN = 1, kCUDNN = 2 }; enum class LibraryType { kPlain = 0, kMKLDNN = 1, kCUDNN = 2 };
inline std::string LibraryTypeToString(const LibraryType& library_type) {
switch (library_type) {
case LibraryType::kPlain:
return "PLAIN";
case LibraryType::kMKLDNN:
return "MKLDNN";
case LibraryType::kCUDNN:
return "CUDNN";
default:
PADDLE_THROW("unknown LibraryType %d", library_type);
}
}
inline std::ostream& operator<<(std::ostream& out, LibraryType l) {
out << LibraryTypeToString(l);
return out;
}
} // namespace } // namespace
} // framework } // framework

@ -224,7 +224,7 @@ void SerializeToStream(std::ostream &os, const LoDTensor &tensor,
while (size != 0) { while (size != 0) {
size_t size_to_write = std::min(kBufSize, static_cast<size_t>(size)); size_t size_to_write = std::min(kBufSize, static_cast<size_t>(size));
memory::Copy(cpu, buf.get(), memory::Copy(cpu, buf.get(),
boost::get<platform::GPUPlace>(tensor.place()), boost::get<platform::CUDAPlace>(tensor.place()),
reinterpret_cast<const void *>(data), size_to_write, reinterpret_cast<const void *>(data), size_to_write,
gpu_dev_ctx.stream()); gpu_dev_ctx.stream());
gpu_dev_ctx.Wait(); gpu_dev_ctx.Wait();

@ -27,7 +27,7 @@ __global__ void test(size_t* a, int size) {
TEST(LoDTensor, LoDInGPU) { TEST(LoDTensor, LoDInGPU) {
paddle::framework::LoDTensor lod_tensor; paddle::framework::LoDTensor lod_tensor;
paddle::platform::GPUPlace place(0); paddle::platform::CUDAPlace place(0);
paddle::framework::LoD src_lod; paddle::framework::LoD src_lod;
src_lod.push_back(std::vector<size_t>{0, 2, 4, 6, 8, 10, 12, 14}); src_lod.push_back(std::vector<size_t>{0, 2, 4, 6, 8, 10, 12, 14});

@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/framework/data_layout.h" #include "paddle/framework/data_layout.h"
#include "paddle/framework/data_type.h" #include "paddle/framework/data_type.h"
#include "paddle/framework/library_type.h" #include "paddle/framework/library_type.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h" #include "paddle/platform/place.h"
namespace paddle { namespace paddle {
@ -39,6 +40,7 @@ struct OpKernelType {
// place, data_type, library_type kinds less than 2^8 // place, data_type, library_type kinds less than 2^8
constexpr static int LEFT_SHIFT = 8; constexpr static int LEFT_SHIFT = 8;
proto::DataType data_type_; proto::DataType data_type_;
DataLayout data_layout_; DataLayout data_layout_;
platform::Place place_; platform::Place place_;
@ -68,5 +70,13 @@ struct OpKernelType {
} }
}; };
inline std::ostream& operator<<(std::ostream& os,
const OpKernelType& kernel_key) {
os << "data_type[" << kernel_key.data_type_ << "]:data_layout["
<< kernel_key.data_layout_ << "]:place[" << kernel_key.place_
<< "]:library_type[" << kernel_key.library_type_ << "]";
return os;
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle

@ -0,0 +1,51 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/op_kernel_type.h"
#include <gtest/gtest.h>
#include <iostream>
TEST(OpKernelType, ToString) {
using OpKernelType = paddle::framework::OpKernelType;
using DataType = paddle::framework::proto::DataType;
using CPUPlace = paddle::platform::CPUPlace;
using DataLayout = paddle::framework::DataLayout;
using LibraryType = paddle::framework::LibraryType;
OpKernelType op_kernel_type(DataType::FP32, CPUPlace(), DataLayout::kNCHW,
LibraryType::kCUDNN);
std::ostringstream stream;
stream << op_kernel_type;
ASSERT_EQ(
stream.str(),
"data_type[5]:data_layout[NCHW]:place[CPUPlace]:library_type[CUDNN]");
}
TEST(OpKernelType, Hash) {
using OpKernelType = paddle::framework::OpKernelType;
using DataType = paddle::framework::proto::DataType;
using CPUPlace = paddle::platform::CPUPlace;
using CUDAPlace = paddle::platform::CUDAPlace;
using DataLayout = paddle::framework::DataLayout;
using LibraryType = paddle::framework::LibraryType;
OpKernelType op_kernel_type_1(DataType::FP32, CPUPlace(), DataLayout::kNCHW,
LibraryType::kCUDNN);
OpKernelType op_kernel_type_2(DataType::FP32, CUDAPlace(0), DataLayout::kNCHW,
LibraryType::kCUDNN);
OpKernelType::Hash hasher;
ASSERT_NE(hasher(op_kernel_type_1), hasher(op_kernel_type_2));
}

@ -188,7 +188,7 @@ class OpKernelRegistrar : public Registrar {
} }
#define REGISTER_OP_CUDA_KERNEL(op_type, ...) \ #define REGISTER_OP_CUDA_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, CUDA, ::paddle::platform::GPUPlace, __VA_ARGS__) REGISTER_OP_KERNEL(op_type, CUDA, ::paddle::platform::CUDAPlace, __VA_ARGS__)
#define REGISTER_OP_CPU_KERNEL(op_type, ...) \ #define REGISTER_OP_CPU_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__) REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)

@ -242,13 +242,6 @@ std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
return res; return res;
} }
std::ostream& operator<<(std::ostream& os, const OpKernelType& kernel_key) {
os << "data_type[" << kernel_key.data_type_ << "]:data_layout["
<< kernel_key.data_layout_ << "]:place[" << kernel_key.place_
<< "]:library_type[" << kernel_key.library_type_ << "]";
return os;
}
bool OpSupportGPU(const std::string& op_type) { bool OpSupportGPU(const std::string& op_type) {
auto& all_kernels = OperatorWithKernel::AllOpKernels(); auto& all_kernels = OperatorWithKernel::AllOpKernels();
auto it = all_kernels.find(op_type); auto it = all_kernels.find(op_type);
@ -409,19 +402,28 @@ void OperatorWithKernel::Run(const Scope& scope,
OpKernelMap& kernels = kernels_iter->second; OpKernelMap& kernels = kernels_iter->second;
ExecutionContext ctx(*this, scope, *dev_ctx); ExecutionContext ctx(*this, scope, *dev_ctx);
auto kernel_key = GetKernelType(ctx); auto actual_kernel_key = GetActualKernelType(ctx);
auto kernel_iter = kernels.find(kernel_key); auto expected_kernel_key = GetExpectedKernelType(actual_kernel_key);
auto kernel_iter = kernels.find(expected_kernel_key);
if (kernel_iter == kernels.end()) { if (kernel_iter == kernels.end()) {
PADDLE_THROW("The operator %s does not support %s", type_, kernel_key); PADDLE_THROW("The operator %s does not support %s", type_,
expected_kernel_key);
} }
kernel_iter->second->Compute(ctx); kernel_iter->second->Compute(ctx);
} }
OpKernelType OperatorWithKernel::GetKernelType(
OpKernelType OperatorWithKernel::GetActualKernelType(
const ExecutionContext& ctx) const { const ExecutionContext& ctx) const {
return OpKernelType(IndicateDataType(ctx), ctx.GetPlace()); return OpKernelType(IndicateDataType(ctx), ctx.GetPlace());
} }
OpKernelType OperatorWithKernel::GetExpectedKernelType(
const OpKernelType& actual_kernel_type) const {
return actual_kernel_type;
}
proto::DataType OperatorWithKernel::IndicateDataType( proto::DataType OperatorWithKernel::IndicateDataType(
const ExecutionContext& ctx) const { const ExecutionContext& ctx) const {
auto& scope = ctx.scope(); auto& scope = ctx.scope();

@ -52,6 +52,11 @@ constexpr char kGradVarSuffix[] = "@GRAD";
/// Variables with this suffix are supposed to be filled up with zeros. /// Variables with this suffix are supposed to be filled up with zeros.
constexpr char kZeroVarSuffix[] = "@ZERO"; constexpr char kZeroVarSuffix[] = "@ZERO";
// define some kernel hint
const std::string kUseCPU = "use_cpu";
const std::string kUseCUDNN = "use_cudnn";
const std::string kUseMKLDNN = "use_mkldnn";
inline std::string GradVarName(const std::string& var_name) { inline std::string GradVarName(const std::string& var_name) {
return var_name + kGradVarSuffix; return var_name + kGradVarSuffix;
} }
@ -373,7 +378,9 @@ class OperatorWithKernel : public OperatorBase {
} }
protected: protected:
virtual OpKernelType GetKernelType(const ExecutionContext& ctx) const; virtual OpKernelType GetActualKernelType(const ExecutionContext& ctx) const;
virtual OpKernelType GetExpectedKernelType(
const OpKernelType& actual_kernel_type) const;
private: private:
// indicate kernel DataType by input data. Defaultly all input data must be // indicate kernel DataType by input data. Defaultly all input data must be
@ -381,8 +388,6 @@ class OperatorWithKernel : public OperatorBase {
proto::DataType IndicateDataType(const ExecutionContext& ctx) const; proto::DataType IndicateDataType(const ExecutionContext& ctx) const;
}; };
std::ostream& operator<<(std::ostream& os, const OpKernelType& kernel_key);
extern bool OpSupportGPU(const std::string& op_type); extern bool OpSupportGPU(const std::string& op_type);
} // namespace framework } // namespace framework

@ -114,7 +114,7 @@ class OpWithKernelTest : public OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override {} void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetKernelType(const ExecutionContext& ctx) const override { OpKernelType GetActualKernelType(const ExecutionContext& ctx) const override {
return OpKernelType(proto::DataType::FP32, ctx.GetPlace()); return OpKernelType(proto::DataType::FP32, ctx.GetPlace());
} }
}; };

@ -20,12 +20,12 @@ limitations under the License. */
#include <typeindex> #include <typeindex>
#include <vector> #include <vector>
#include "paddle/framework/data_layout.h"
#include "paddle/framework/ddim.h" #include "paddle/framework/ddim.h"
#include "paddle/memory/memory.h" #include "paddle/memory/memory.h"
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h" #include "paddle/platform/enforce.h"
#include "paddle/platform/place.h" #include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor"
namespace paddle { namespace paddle {
@ -115,6 +115,10 @@ class Tensor {
inline void check_memory_size() const; inline void check_memory_size() const;
inline DataLayout layout() const { return layout_; }
inline void set_layout(const DataLayout layout) { layout_ = layout; }
private: private:
friend class LoDTensor; friend class LoDTensor;
@ -173,6 +177,19 @@ class Tensor {
DDim dims_; DDim dims_;
/**
* @brief the layout of memory block, default is NCHW.
*
* @note the memory allocation order, describe how weight/data is stored
* For example, in 4-D Tensor(rank=4), there are three commonly
* used layout. They are
* NCHW, NHWC, CHWN.
* N,C,H,W for respectively the batch size, the number of
* feature maps, the height.
*/
DataLayout layout_ = DataLayout::kNHWC;
/** /**
* @brief A PlaceHolder may be shared by more than one tensor. * @brief A PlaceHolder may be shared by more than one tensor.
* *

@ -71,7 +71,7 @@ private:
``` ```
```c++ ```c++
typedef boost::variant<GpuPlace, CpuPlace> Place; typedef boost::variant<CUDAPlace, CpuPlace> Place;
typedef boost::variant<Dim<1>, Dim<2>, Dim<3>, Dim<4>, Dim<5>, typedef boost::variant<Dim<1>, Dim<2>, Dim<3>, Dim<4>, Dim<5>,
Dim<6>, Dim<7>, Dim<8>, Dim<9>> DDimVar; Dim<6>, Dim<7>, Dim<8>, Dim<9>> DDimVar;
typedef boost::variant< typedef boost::variant<

@ -125,11 +125,11 @@ inline void* Tensor::mutable_data(platform::Place place, std::type_index type) {
boost::get<platform::CPUPlace>(place), size, type)); boost::get<platform::CPUPlace>(place), size, type));
} else if (platform::is_gpu_place(place)) { } else if (platform::is_gpu_place(place)) {
#ifndef PADDLE_WITH_CUDA #ifndef PADDLE_WITH_CUDA
PADDLE_THROW("'GPUPlace' is not supported in CPU only device."); PADDLE_THROW("'CUDAPlace' is not supported in CPU only device.");
} }
#else #else
holder_.reset(new PlaceholderImpl<platform::GPUPlace>( holder_.reset(new PlaceholderImpl<platform::CUDAPlace>(
boost::get<platform::GPUPlace>(place), size, type)); boost::get<platform::CUDAPlace>(place), size, type));
} }
#endif #endif
offset_ = 0; offset_ = 0;
@ -165,6 +165,7 @@ inline Tensor Tensor::Slice(int begin_idx, int end_idx) const {
size_t base = numel() / dims_[0]; size_t base = numel() / dims_[0];
Tensor dst; Tensor dst;
dst.holder_ = holder_; dst.holder_ = holder_;
dst.set_layout(layout_);
DDim dst_dims = dims_; DDim dst_dims = dims_;
dst_dims[0] = end_idx - begin_idx; dst_dims[0] = end_idx - begin_idx;
dst.Resize(dst_dims); dst.Resize(dst_dims);

@ -80,20 +80,20 @@ TEST(Tensor, MutableData) {
float* p1 = nullptr; float* p1 = nullptr;
float* p2 = nullptr; float* p2 = nullptr;
// initialization // initialization
p1 = src_tensor.mutable_data<float>(make_ddim({1, 2, 3}), GPUPlace()); p1 = src_tensor.mutable_data<float>(make_ddim({1, 2, 3}), CUDAPlace());
EXPECT_NE(p1, nullptr); EXPECT_NE(p1, nullptr);
// set src_tensor a new dim with large size // set src_tensor a new dim with large size
// momery is supposed to be re-allocated // momery is supposed to be re-allocated
p2 = src_tensor.mutable_data<float>(make_ddim({3, 4}), GPUPlace()); p2 = src_tensor.mutable_data<float>(make_ddim({3, 4}), CUDAPlace());
EXPECT_NE(p2, nullptr); EXPECT_NE(p2, nullptr);
EXPECT_NE(p1, p2); EXPECT_NE(p1, p2);
// set src_tensor a new dim with same size // set src_tensor a new dim with same size
// momery block is supposed to be unchanged // momery block is supposed to be unchanged
p1 = src_tensor.mutable_data<float>(make_ddim({2, 2, 3}), GPUPlace()); p1 = src_tensor.mutable_data<float>(make_ddim({2, 2, 3}), CUDAPlace());
EXPECT_EQ(p1, p2); EXPECT_EQ(p1, p2);
// set src_tensor a new dim with smaller size // set src_tensor a new dim with smaller size
// momery block is supposed to be unchanged // momery block is supposed to be unchanged
p2 = src_tensor.mutable_data<float>(make_ddim({2, 2}), GPUPlace()); p2 = src_tensor.mutable_data<float>(make_ddim({2, 2}), CUDAPlace());
EXPECT_EQ(p1, p2); EXPECT_EQ(p1, p2);
} }
#endif #endif
@ -130,7 +130,7 @@ TEST(Tensor, ShareDataWith) {
{ {
Tensor src_tensor; Tensor src_tensor;
Tensor dst_tensor; Tensor dst_tensor;
src_tensor.mutable_data<int>(make_ddim({2, 3, 4}), GPUPlace()); src_tensor.mutable_data<int>(make_ddim({2, 3, 4}), CUDAPlace());
dst_tensor.ShareDataWith(src_tensor); dst_tensor.ShareDataWith(src_tensor);
ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>()); ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>());
} }
@ -166,7 +166,7 @@ TEST(Tensor, Slice) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
{ {
Tensor src_tensor; Tensor src_tensor;
src_tensor.mutable_data<double>(make_ddim({6, 9}), GPUPlace()); src_tensor.mutable_data<double>(make_ddim({6, 9}), CUDAPlace());
Tensor slice_tensor = src_tensor.Slice(2, 6); Tensor slice_tensor = src_tensor.Slice(2, 6);
DDim slice_dims = slice_tensor.dims(); DDim slice_dims = slice_tensor.dims();
ASSERT_EQ(arity(slice_dims), 2); ASSERT_EQ(arity(slice_dims), 2);
@ -176,11 +176,11 @@ TEST(Tensor, Slice) {
uintptr_t src_data_address = uintptr_t src_data_address =
reinterpret_cast<uintptr_t>(src_tensor.data<double>()); reinterpret_cast<uintptr_t>(src_tensor.data<double>());
uintptr_t src_mutable_data_address = reinterpret_cast<uintptr_t>( uintptr_t src_mutable_data_address = reinterpret_cast<uintptr_t>(
src_tensor.mutable_data<double>(src_tensor.dims(), GPUPlace())); src_tensor.mutable_data<double>(src_tensor.dims(), CUDAPlace()));
uintptr_t slice_data_address = uintptr_t slice_data_address =
reinterpret_cast<uintptr_t>(slice_tensor.data<double>()); reinterpret_cast<uintptr_t>(slice_tensor.data<double>());
uintptr_t slice_mutable_data_address = reinterpret_cast<uintptr_t>( uintptr_t slice_mutable_data_address = reinterpret_cast<uintptr_t>(
slice_tensor.mutable_data<double>(slice_tensor.dims(), GPUPlace())); slice_tensor.mutable_data<double>(slice_tensor.dims(), CUDAPlace()));
EXPECT_EQ(src_data_address, src_mutable_data_address); EXPECT_EQ(src_data_address, src_mutable_data_address);
EXPECT_EQ(slice_data_address, slice_mutable_data_address); EXPECT_EQ(slice_data_address, slice_mutable_data_address);
EXPECT_EQ(src_data_address + 9 * 2 * sizeof(double), slice_data_address); EXPECT_EQ(src_data_address + 9 * 2 * sizeof(double), slice_data_address);
@ -200,3 +200,12 @@ TEST(Tensor, ReshapeToMatrix) {
ASSERT_EQ(res.dims()[0], 2 * 3); ASSERT_EQ(res.dims()[0], 2 * 3);
ASSERT_EQ(res.dims()[1], 4 * 9); ASSERT_EQ(res.dims()[1], 4 * 9);
} }
TEST(Tensor, Layout) {
using namespace paddle::framework;
using namespace paddle::platform;
Tensor src;
ASSERT_EQ(src.layout(), DataLayout::kNHWC);
src.set_layout(DataLayout::kAnyLayout);
ASSERT_EQ(src.layout(), DataLayout::kAnyLayout);
}

@ -33,6 +33,7 @@ inline void CopyFrom(const Tensor& src, const platform::Place& dst_place,
src.check_memory_size(); src.check_memory_size();
dst->Resize(src.dims()); dst->Resize(src.dims());
dst->set_layout(src.layout());
auto src_place = src.place(); auto src_place = src.place();
auto src_ptr = src.data<void>(); auto src_ptr = src.data<void>();
@ -47,11 +48,11 @@ inline void CopyFrom(const Tensor& src, const platform::Place& dst_place,
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
else if (platform::is_gpu_place(src_place) && // NOLINT else if (platform::is_gpu_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) { platform::is_cpu_place(dst_place)) {
auto src_gpu_place = boost::get<platform::GPUPlace>(src_place); auto src_gpu_place = boost::get<platform::CUDAPlace>(src_place);
auto dst_cpu_place = boost::get<platform::CPUPlace>(dst_place); auto dst_cpu_place = boost::get<platform::CPUPlace>(dst_place);
auto ctx_place = ctx.GetPlace(); auto ctx_place = ctx.GetPlace();
PADDLE_ENFORCE(platform::is_gpu_place(ctx_place)); PADDLE_ENFORCE(platform::is_gpu_place(ctx_place));
auto ctx_gpu_place = boost::get<platform::GPUPlace>(ctx_place); auto ctx_gpu_place = boost::get<platform::CUDAPlace>(ctx_place);
PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place); PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place);
memory::Copy( memory::Copy(
dst_cpu_place, dst_ptr, src_gpu_place, src_ptr, size, dst_cpu_place, dst_ptr, src_gpu_place, src_ptr, size,
@ -59,21 +60,21 @@ inline void CopyFrom(const Tensor& src, const platform::Place& dst_place,
} else if (platform::is_cpu_place(src_place) && } else if (platform::is_cpu_place(src_place) &&
platform::is_gpu_place(dst_place)) { platform::is_gpu_place(dst_place)) {
auto src_cpu_place = boost::get<platform::CPUPlace>(src_place); auto src_cpu_place = boost::get<platform::CPUPlace>(src_place);
auto dst_gpu_place = boost::get<platform::GPUPlace>(dst_place); auto dst_gpu_place = boost::get<platform::CUDAPlace>(dst_place);
auto ctx_place = ctx.GetPlace(); auto ctx_place = ctx.GetPlace();
PADDLE_ENFORCE(platform::is_gpu_place(ctx_place)); PADDLE_ENFORCE(platform::is_gpu_place(ctx_place));
auto ctx_gpu_place = boost::get<platform::GPUPlace>(ctx_place); auto ctx_gpu_place = boost::get<platform::CUDAPlace>(ctx_place);
PADDLE_ENFORCE_EQ(dst_gpu_place, ctx_gpu_place); PADDLE_ENFORCE_EQ(dst_gpu_place, ctx_gpu_place);
memory::Copy( memory::Copy(
dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size, dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()); reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
} else if (platform::is_gpu_place(src_place) && } else if (platform::is_gpu_place(src_place) &&
platform::is_gpu_place(dst_place)) { platform::is_gpu_place(dst_place)) {
auto src_gpu_place = boost::get<platform::GPUPlace>(src_place); auto src_gpu_place = boost::get<platform::CUDAPlace>(src_place);
auto dst_gpu_place = boost::get<platform::GPUPlace>(dst_place); auto dst_gpu_place = boost::get<platform::CUDAPlace>(dst_place);
auto ctx_place = ctx.GetPlace(); auto ctx_place = ctx.GetPlace();
PADDLE_ENFORCE(platform::is_gpu_place(ctx_place)); PADDLE_ENFORCE(platform::is_gpu_place(ctx_place));
auto ctx_gpu_place = boost::get<platform::GPUPlace>(ctx_place); auto ctx_gpu_place = boost::get<platform::CUDAPlace>(ctx_place);
PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place); PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place);
memory::Copy( memory::Copy(
dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size,
@ -82,6 +83,29 @@ inline void CopyFrom(const Tensor& src, const platform::Place& dst_place,
#endif #endif
} }
/**
* @brief CopyFrom support CPU <-> CPU
*/
inline void CopyFrom(const Tensor& src, const platform::Place& dst_place,
Tensor* dst) {
src.check_memory_size();
dst->Resize(src.dims());
dst->set_layout(src.layout());
auto src_place = src.place();
auto src_ptr = src.data<void>();
auto dst_ptr = dst->mutable_data(dst_place, src.type());
auto size = src.numel() * SizeOfType(src.type());
PADDLE_ENFORCE(platform::is_cpu_place(src_place) &&
platform::is_cpu_place(dst_place));
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
boost::get<platform::CPUPlace>(src_place), src_ptr, size);
}
/** /**
* @brief Copy the content of an external vector to a tensor. * @brief Copy the content of an external vector to a tensor.
* *
@ -108,13 +132,28 @@ inline void CopyFromVector(const std::vector<T>& src,
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
else if (platform::is_gpu_place(dst_place)) { // NOLINT else if (platform::is_gpu_place(dst_place)) { // NOLINT
memory::Copy( memory::Copy(
boost::get<platform::GPUPlace>(dst_place), dst_ptr, src_place, src_ptr, boost::get<platform::CUDAPlace>(dst_place), dst_ptr, src_place, src_ptr,
size, size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()); reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
} }
#endif #endif
} }
/**
* @brief CopyFromVector CPU vector -> CPU Tensor
*/
template <typename T>
inline void CopyFromVector(const std::vector<T>& src, Tensor* dst) {
platform::CPUPlace dst_place = platform::CPUPlace();
auto src_ptr = static_cast<const void*>(src.data());
platform::CPUPlace src_place;
dst->Resize({static_cast<int64_t>(src.size())});
auto dst_ptr = static_cast<void*>(dst->mutable_data<T>(dst_place));
auto size = src.size() * sizeof(T);
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
}
/** /**
* @brief Copy the content of a tensor to a vector * @brief Copy the content of a tensor to a vector
* *
@ -141,12 +180,30 @@ inline void CopyToVector(const Tensor& src, const platform::DeviceContext& ctx,
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
else if (platform::is_gpu_place(src.place())) { // NOLINT else if (platform::is_gpu_place(src.place())) { // NOLINT
memory::Copy( memory::Copy(
dst_place, dst_ptr, boost::get<platform::GPUPlace>(src.place()), dst_place, dst_ptr, boost::get<platform::CUDAPlace>(src.place()),
src_ptr, size, src_ptr, size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()); reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
} }
#endif #endif
} }
/**
* @brief CopyToVector CPUTensor <-> CPU Vector
*/
template <typename T>
inline void CopyToVector(const Tensor& src, std::vector<T>* dst) {
auto src_ptr = static_cast<const void*>(src.data<T>());
auto size = src.numel() * sizeof(T);
platform::CPUPlace dst_place;
dst->resize(src.numel());
auto dst_ptr = static_cast<void*>(dst->data());
PADDLE_ENFORCE(platform::is_cpu_place(src.place()));
memory::Copy(dst_place, dst_ptr, boost::get<platform::CPUPlace>(src.place()),
src_ptr, size);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle

@ -17,6 +17,7 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
TEST(CopyFrom, Tensor) { TEST(CopyFrom, Tensor) {
Tensor src_tensor; Tensor src_tensor;
Tensor dst_tensor; Tensor dst_tensor;
@ -27,9 +28,10 @@ TEST(CopyFrom, Tensor) {
int arr[9] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; int arr[9] = {1, 2, 3, 4, 5, 6, 7, 8, 9};
memcpy(src_ptr, arr, 9 * sizeof(int)); memcpy(src_ptr, arr, 9 * sizeof(int));
src_tensor.set_layout(DataLayout::kAnyLayout);
auto cpu_place = new platform::CPUPlace(); auto cpu_place = new platform::CPUPlace();
CopyFrom(src_tensor, *cpu_place, cpu_ctx, &dst_tensor); CopyFrom(src_tensor, *cpu_place, &dst_tensor);
const int* dst_ptr = dst_tensor.data<int>(); const int* dst_ptr = dst_tensor.data<int>();
ASSERT_NE(src_ptr, dst_ptr); ASSERT_NE(src_ptr, dst_ptr);
@ -37,14 +39,18 @@ TEST(CopyFrom, Tensor) {
EXPECT_EQ(src_ptr[i], dst_ptr[i]); EXPECT_EQ(src_ptr[i], dst_ptr[i]);
} }
EXPECT_TRUE(dst_tensor.layout() == src_tensor.layout());
Tensor slice_tensor = src_tensor.Slice(1, 2); Tensor slice_tensor = src_tensor.Slice(1, 2);
CopyFrom(slice_tensor, *cpu_place, cpu_ctx, &dst_tensor); CopyFrom(slice_tensor, *cpu_place, &dst_tensor);
const int* slice_ptr = slice_tensor.data<int>(); const int* slice_ptr = slice_tensor.data<int>();
dst_ptr = dst_tensor.data<int>(); dst_ptr = dst_tensor.data<int>();
ASSERT_NE(dst_ptr, slice_ptr); ASSERT_NE(dst_ptr, slice_ptr);
for (size_t i = 0; i < 3; ++i) { for (size_t i = 0; i < 3; ++i) {
EXPECT_EQ(dst_ptr[i], slice_ptr[i]); EXPECT_EQ(dst_ptr[i], slice_ptr[i]);
} }
EXPECT_TRUE(dst_tensor.layout() == src_tensor.layout());
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
{ {
Tensor src_tensor; Tensor src_tensor;
@ -58,7 +64,7 @@ TEST(CopyFrom, Tensor) {
memcpy(src_ptr, arr, 9 * sizeof(int)); memcpy(src_ptr, arr, 9 * sizeof(int));
// CPU Tensor to GPU Tensor // CPU Tensor to GPU Tensor
auto gpu_place = new platform::GPUPlace(0); auto gpu_place = new platform::CUDAPlace(0);
platform::CUDADeviceContext gpu_ctx(*gpu_place); platform::CUDADeviceContext gpu_ctx(*gpu_place);
CopyFrom(src_tensor, *gpu_place, gpu_ctx, &gpu_tensor); CopyFrom(src_tensor, *gpu_place, gpu_ctx, &gpu_tensor);
@ -90,6 +96,8 @@ TEST(CopyFrom, Tensor) {
for (size_t i = 0; i < 3; ++i) { for (size_t i = 0; i < 3; ++i) {
EXPECT_EQ(dst_ptr[i], slice_ptr[i]); EXPECT_EQ(dst_ptr[i], slice_ptr[i]);
} }
EXPECT_TRUE(dst_tensor.layout() == src_tensor.layout());
} }
#endif #endif
} }
@ -104,8 +112,7 @@ TEST(CopyFromVector, Tensor) {
// Copy to CPU Tensor // Copy to CPU Tensor
cpu_tensor.Resize(make_ddim({3, 3})); cpu_tensor.Resize(make_ddim({3, 3}));
auto cpu_place = new paddle::platform::CPUPlace(); auto cpu_place = new paddle::platform::CPUPlace();
CPUDeviceContext cpu_ctx(*cpu_place); CopyFromVector<int>(src_vec, &cpu_tensor);
CopyFromVector<int>(src_vec, cpu_ctx, &cpu_tensor);
// Compare Tensors // Compare Tensors
const int* cpu_ptr = cpu_tensor.data<int>(); const int* cpu_ptr = cpu_tensor.data<int>();
@ -117,7 +124,7 @@ TEST(CopyFromVector, Tensor) {
src_vec.erase(src_vec.begin(), src_vec.begin() + 5); src_vec.erase(src_vec.begin(), src_vec.begin() + 5);
cpu_tensor.Resize(make_ddim({2, 2})); cpu_tensor.Resize(make_ddim({2, 2}));
CopyFromVector<int>(src_vec, cpu_ctx, &cpu_tensor); CopyFromVector<int>(src_vec, &cpu_tensor);
cpu_ptr = cpu_tensor.data<int>(); cpu_ptr = cpu_tensor.data<int>();
src_ptr = src_vec.data(); src_ptr = src_vec.data();
ASSERT_NE(src_ptr, cpu_ptr); ASSERT_NE(src_ptr, cpu_ptr);
@ -143,7 +150,7 @@ TEST(CopyFromVector, Tensor) {
// Copy to GPUTensor // Copy to GPUTensor
gpu_tensor.Resize(make_ddim({3, 3})); gpu_tensor.Resize(make_ddim({3, 3}));
auto gpu_place = new paddle::platform::GPUPlace(); auto gpu_place = new paddle::platform::CUDAPlace();
CUDADeviceContext gpu_ctx(*gpu_place); CUDADeviceContext gpu_ctx(*gpu_place);
CopyFromVector<int>(src_vec, gpu_ctx, &gpu_tensor); CopyFromVector<int>(src_vec, gpu_ctx, &gpu_tensor);
// Copy from GPU to CPU tensor for comparison // Copy from GPU to CPU tensor for comparison
@ -198,9 +205,8 @@ TEST(CopyToVector, Tensor) {
} }
CPUPlace place; CPUPlace place;
CPUDeviceContext cpu_ctx(place);
std::vector<int> dst; std::vector<int> dst;
CopyToVector<int>(src, cpu_ctx, &dst); CopyToVector<int>(src, &dst);
for (int i = 0; i < 3 * 3; ++i) { for (int i = 0; i < 3 * 3; ++i) {
EXPECT_EQ(src_ptr[i], dst[i]); EXPECT_EQ(src_ptr[i], dst[i]);
@ -210,7 +216,7 @@ TEST(CopyToVector, Tensor) {
{ {
std::vector<int> src_vec = {1, 2, 3, 4, 5, 6, 7, 8, 9}; std::vector<int> src_vec = {1, 2, 3, 4, 5, 6, 7, 8, 9};
Tensor gpu_tensor; Tensor gpu_tensor;
GPUPlace place; CUDAPlace place;
CUDADeviceContext gpu_ctx(place); CUDADeviceContext gpu_ctx(place);
CopyFromVector<int>(src_vec, gpu_ctx, &gpu_tensor); CopyFromVector<int>(src_vec, gpu_ctx, &gpu_tensor);

@ -0,0 +1,24 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/threadpool.h"
namespace paddle {
namespace framework {
std::unique_ptr<ThreadPool> ThreadPool::threadpool(nullptr);
std::once_flag ThreadPool::init_flag;
} // namespace framework
} // namespace paddle

@ -0,0 +1,156 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <condition_variable>
#include <functional>
#include <mutex>
#include <queue>
#include <thread>
#include "paddle/platform/enforce.h"
namespace paddle {
namespace framework {
typedef std::function<void()> Task;
class ThreadPool {
public:
/**
* @brief Get a instance of threadpool, the thread number will
* be specified as the number of hardware thread contexts
*/
static ThreadPool* GetInstance() {
std::call_once(init_flag, &ThreadPool::Init);
return threadpool.get();
}
~ThreadPool() {
{
// notify all threads to stop running
running_ = false;
scheduled_.notify_all();
}
for (auto& t : threads_) {
t->join();
t.reset(nullptr);
}
}
int GetNumThreads() const { return num_threads_; }
int GetAvailable() {
std::unique_lock<std::mutex> lock(mutex_);
return available_;
}
/**
* @brief Push a function to the queue, and will be scheduled and
* executed if a thread is available.
* @param[in] Task will be pushed to the task queue.
*/
void Run(const Task& fn) {
std::unique_lock<std::mutex> lock(mutex_);
tasks_.push(fn);
lock.unlock();
scheduled_.notify_one();
}
/**
* @brief Wait until all the tasks are completed.
*/
void Wait() {
std::unique_lock<std::mutex> lock(mutex_);
completed_.wait(lock, [=] { return Done() == true; });
}
private:
DISABLE_COPY_AND_ASSIGN(ThreadPool);
explicit ThreadPool(int num_threads)
: num_threads_(num_threads), available_(num_threads), running_(true) {
threads_.resize(num_threads);
for (auto& thread : threads_) {
// TODO(Yancey1989): binding the thread on the specify CPU number
thread.reset(new std::thread(std::bind(&ThreadPool::TaskLoop, this)));
}
}
/**
* @brief If the task queue is empty and avaialbe
* is equal to the number of threads, means that
* all tasks are completed.
*
* Note: this function is not thread-safe.
*
* @return true if all tasks are completed.
*/
bool Done() { return tasks_.empty() && available_ == num_threads_; }
void TaskLoop() {
while (running_) {
std::unique_lock<std::mutex> lock(mutex_);
scheduled_.wait(lock, [=] { return !tasks_.empty() || !running_; });
if (!running_) {
break;
}
// pop a task from the task queue
auto task = tasks_.front();
tasks_.pop();
--available_;
lock.unlock();
// run the task
task();
{
std::unique_lock<std::mutex> lock(mutex_);
++available_;
if (Done()) {
completed_.notify_all();
}
}
}
}
static void Init() {
if (threadpool.get() == nullptr) {
// TODO(Yancey1989): specify the max threads number
int num_threads = std::thread::hardware_concurrency();
PADDLE_ENFORCE_GT(num_threads, 0);
threadpool.reset(new ThreadPool(num_threads));
}
}
private:
static std::unique_ptr<ThreadPool> threadpool;
static std::once_flag init_flag;
int num_threads_;
int available_;
bool running_;
std::queue<Task> tasks_;
std::vector<std::unique_ptr<std::thread>> threads_;
std::mutex mutex_;
std::condition_variable scheduled_;
std::condition_variable completed_;
};
} // namespace framework
} // namespace paddle

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save