diff --git a/doc/design/optimizer.md b/doc/design/optimizer.md new file mode 100644 index 0000000000..17440fae50 --- /dev/null +++ b/doc/design/optimizer.md @@ -0,0 +1,105 @@ +## Optimizer Design + +### The Problem + +A PaddlePaddle program, or a block, is a sequence of operators operating variables. A training program needs to do three kinds of works: + +1. the forward pass, which computes intermediate results and the cost(s), +1. the backward pass, which derives gradients from intermediate results and costs, and +1. the optimization pass, which update model parameters to optimize the cost(s). + +These works rely on three kinds of operators: + +1. forward operators, +1. gradient operators, and +1. optimization operators. + +It's true that users should be able to create all these operators manually by calling some low-level API, but it would be much more convenient if they could only describe the forward pass and let PaddlePaddle create the backward and optimization operators automatically. + +In this design, we propose a high-level API that automatically derives the optimisation pass and operators from the forward pass. + + +### High-level Python API to describe the training process + +1. User write code to describe the network: + + ```python + images = layer.data("images") + labels = layer.data("labels") + w1 = pd.var("w1") + b1 = pd.var("b1") + hidden = layer.fc(images, w=w1, b=b1) + cost = layer.mse(hidden, labels) + ``` + + The above code snippet will create forward operators in [Block](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/block.md). + + +2. Users create a certain kind of Optimizer with some argument. + + ```python + optimizer = AdagradOptimizer(learing_rate=0.001) + ``` + +3. Users use the optimizer to `minimize` a certain `cost` through updating parameters in parameter_list. + + ```python + opt_op_list = optimizer.minimize(cost, parameter_list=[w1, b1]) + ``` + The above code snippet will create gradient and optimization operators in Block. The return value of `minimize()` is list of optimization operators that will be run by session. + +4. Users use Session/Executor to run this opt_op_list as target to do training. + + ```python + sess.run(target= opt_op_list, ...) + ``` + +#### Optimizer Python interface: + +```python +class Optimizer(object): + """Optimizer Base class. + + """ + + def __init__(self): + pass + + def create_backward_pass(self, loss, parameter_list=None): + """ + create and add gradient Operators in BlockDesc to Compute gradients of `loss` + for parameters in parameter_list + + Args: + loss: an variable generated by cost function. + parameter_list: parameters that need to compute gradient and update to optimize the lost. + + Returns: + list of (parameters, gradients) pair. + """ + return None + + def create_optimization_pass(self, parameters_and_grads): + """Add optimization operators to update gradients to variables. + + Args: + parameters_and_grads: a list of (variable, gradient) pair to update. + + Returns: + optmization_op_list: a list of optimization operator that will update parameter using gradient. + """ + return None + + def minimize(self, loss, parameter_list): + """Add operations to minimize `loss` by updating `parameter_list`. + + This method combines interface `create_backward_pass()` and + `create_optimization_pass()` into one. + """ + params_grads = self.create_backward_pass(loss, parameter_list) + update_ops = self.create_optimization_pass(params_grads) + return update_ops + +``` + +Users can inherit the Optimizer above to create their own Optimizer with some special logic, such as AdagradOptimizer. diff --git a/doc/design/python_api.md b/doc/design/python_api.md index c4665e44fc..56ae1d925a 100644 --- a/doc/design/python_api.md +++ b/doc/design/python_api.md @@ -214,3 +214,7 @@ def fc_layer(input, size, ...): out.writer = op return out ``` + +## Optimizer + +[Optimizer Design Doc](./optimizer.md) diff --git a/doc/design/selected_rows.md b/doc/design/selected_rows.md new file mode 100644 index 0000000000..9e6f3b20cb --- /dev/null +++ b/doc/design/selected_rows.md @@ -0,0 +1,74 @@ +# Design Doc: Selected Rows + +`SelectedRows` is a kind of sparse tensor data type, which is designed to support `embedding` operators. The gradient of embedding table is a sparse tensor. Only a few rows are non-zero values in that tensor. It is straightforward to represent the sparse tensor by the following sparse tensor data structure: + +```cpp +class SelectedRows { + private: + vector rows_; + Tensor value_; + int height_; +}; +``` + +The field `height_` shows the first dimension of `SelectedRows`. The `rows` are the indices of which rows of `SelectedRows` are non-zeros. The `value_` field is an N-dim tensor and shape is `[rows.size() /* NUM_ROWS */, ...]`, which supplies values for each row. The dimension of `SelectedRows` satisfies `[height_] + value_.shape[1:]`. + +Suppose that a SelectedRows-typed variable `x` has many rows, but only two of them have values -- row 73 is `[1, 2]` and row 84 is `[3, 4]`, the `SelectedRows` representation would be: + +``` +x = SelectedRow { + rows = [73, 84], + value = [[1, 2], [3,4]] +} +``` + + +## SelectedRows in Protobuf + +`SelectedRows` is a kind of `Variable`. `VarDesc` in protobuf should describe the `SelectedRows` information. Only the tensor dimension of a `SelectedRows` will be described in compile-time since the `rows_` and `value_` are related to training data. +So we use `TensorDesc` to unify `data_type` and `dims`. A LodTensorDesc contains a `TensorDesc` and `lod_level`. The description of `SelectedRows` is a Tensor description. + +```proto +message TensorDesc { + required DataType data_type = 1; + repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480] +} + +message LodTensorDesc { + required TensorDesc tensor = 1; + optional int lod_level = 2; +} + +message VarDesc { + required string name = 1; + enum VarType { + LOD_TENSOR = 0; + SELECTED_ROWS = 1; + } + required VarType type = 2; + optional LodTensorDesc lod_desc = 3; + optional TensorDesc selected_rows_desc = 4; + optional bool persistable = 5 [ default = false ]; +} +``` + +## InferShape for Selected Rows + +Just like `LoD` information, `InferShape` method will inference output tensor type as well. The operator should decide whether its output is a `SelectedRows` or `Dense` tensor. + +For example, the gradient operator of `TableLookup` will always generate `SelectedRows`. Its `InferShape` method should be like following + +```cpp +void TableLookupGrad::InferShape(context) { + ... + context.SetDataType("Embedding.Grad", kSelectedRows); +} +``` + + +## Sparse Operators + +There are several operators should be written to support `SelectedRows`. They are: + +1. Operators which generates `SelectedRows` gradient. e.g. Gradient of `TableLookupOp`. +2. Optimize operators which support `SelectedRows` gradient. e.g. `SGD` or `AdaGrad` for `SelectedRows`. However, there should be only one `SGD` operator. `OpWithKernel::Run` should select a suitable kernel for both `dense` tensor or `SelectedRows`. diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 1bf80b3e58..148610aa2c 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -42,5 +42,12 @@ add_custom_command(TARGET framework_py_proto POST_BUILD cc_library(backward SRCS backward.cc DEPS net_op) cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context) +cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto backward ${GLOB_OP_LIB}) +if(WITH_GPU) + nv_test(executor_test SRCS executor_test.cc DEPS executor) +else() + cc_test(executor_test SRCS executor_test.cc DEPS executor) +endif() + cc_library(tensor_array SRCS tensor_array.cc DEPS lod_tensor) cc_test(tensor_array_test SRCS tensor_array_test.cc DEPS tensor_array place) diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc new file mode 100644 index 0000000000..886e9ab33e --- /dev/null +++ b/paddle/framework/executor.cc @@ -0,0 +1,165 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/executor.h" + +#include +#include +#include +#include +#include + +#include "paddle/framework/lod_tensor.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/scope.h" + +#include + +namespace paddle { +namespace framework { + +const std::string kFeedOpType = "feed"; +const std::string kFetchOpType = "fetch"; + +Executor::Executor(const std::vector& places) { + PADDLE_ENFORCE_GT(places.size(), 0); + device_contexts_.resize(places.size()); + for (size_t i = 0; i < places.size(); i++) { + if (platform::is_cpu_place(places[i])) { + device_contexts_[i] = new platform::CPUDeviceContext( + boost::get(places[i])); + } else if (platform::is_gpu_place(places[i])) { +#ifdef PADDLE_WITH_CUDA + device_contexts_[i] = new platform::CUDADeviceContext( + boost::get(places[i])); +#else + PADDLE_THROW( + "'GPUPlace' is not supported, Please re-compile with WITH_GPU " + "option"); +#endif + } + } +} + +Executor::~Executor() { + for (auto& device_context : device_contexts_) { + delete device_context; + } +} + +void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) { + // TODO(tonyyang-svail): + // - only runs on the first device (i.e. no interdevice communication) + // - will change to use multiple blocks for RNN op and Cond Op + PADDLE_ENFORCE_GT(pdesc.blocks_size(), block_id); + auto& block = pdesc.blocks(block_id); + auto& device = device_contexts_[0]; + + // Instantiate all the vars in the global scope + for (auto& var : block.vars()) { + scope->NewVar(var.name()); + } + + Scope& local_scope = scope->NewScope(); + + std::vector should_run = Prune(pdesc, block_id); + PADDLE_ENFORCE_EQ(should_run.size(), static_cast(block.ops_size())); + for (size_t i = 0; i < should_run.size(); ++i) { + if (should_run[i]) { + for (auto& var : block.ops(i).outputs()) { + for (auto& argu : var.arguments()) { + if (local_scope.FindVar(argu) == nullptr) { + local_scope.NewVar(argu); + } + } + } + auto op = paddle::framework::OpRegistry::CreateOp(block.ops(i)); + op->Run(local_scope, *device); + } + } + + // TODO(tonyyang-svail): + // - Destroy local_scope +} + +std::vector Prune(const ProgramDesc& pdesc, int block_id) { + // TODO(tonyyang-svail): + // - will change to use multiple blocks for RNN op and Cond Op + + auto& block = pdesc.blocks(block_id); + auto& ops = block.ops(); + + bool expect_feed = true; + for (auto& op_desc : ops) { + PADDLE_ENFORCE(op_desc.type() != kFeedOpType || expect_feed, + "All FeedOps are at the beginning of the ProgramDesc"); + expect_feed = (op_desc.type() == kFeedOpType); + } + + bool expect_fetch = true; + for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) { + auto& op_desc = *op_iter; + PADDLE_ENFORCE(op_desc.type() != kFetchOpType || expect_fetch, + "All FetchOps must at the end of the ProgramDesc"); + expect_fetch = (op_desc.type() == kFetchOpType); + } + + std::set dependent_vars; + std::vector should_run; + for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) { + auto& op_desc = *op_iter; + + bool found_dependent_vars = false; + for (auto& var : op_desc.outputs()) { + for (auto& argu : var.arguments()) { + if (dependent_vars.count(argu) != 0) { + found_dependent_vars = true; + } + } + } + + if (op_desc.type() == kFetchOpType || found_dependent_vars) { + // erase its output to the dependency graph + for (auto& var : op_desc.outputs()) { + for (auto& argu : var.arguments()) { + dependent_vars.erase(argu); + } + } + + // insert its input to the dependency graph + for (auto& var : op_desc.inputs()) { + for (auto& argu : var.arguments()) { + dependent_vars.insert(argu); + } + } + + should_run.push_back(true); + } else { + should_run.push_back(false); + } + } + + // TODO(tonyyang-svail): + // - check this after integration of Init + // PADDLE_ENFORCE(dependent_vars.empty()); + + // since we are traversing the ProgramDesc in reverse order + // we reverse the should_run vector + std::reverse(should_run.begin(), should_run.end()); + + return should_run; +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/executor.h b/paddle/framework/executor.h new file mode 100644 index 0000000000..4e3bc2c0a5 --- /dev/null +++ b/paddle/framework/executor.h @@ -0,0 +1,55 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/framework.pb.h" +#include "paddle/framework/op_info.h" +#include "paddle/framework/scope.h" +#include "paddle/framework/tensor.h" + +namespace paddle { +namespace framework { + +class Executor { + public: + explicit Executor(const std::vector& places); + ~Executor(); + + /* @Brief + * Runtime evaluation of the given ProgramDesc under certain Scope + * + * @param + * ProgramDesc + * Scope + */ + void Run(const ProgramDesc&, Scope*, int); + + private: + std::vector device_contexts_; +}; + +/* @Brief + * Pruning the graph + * + * @param + * ProgramDesc + * + * @return + * vector Same size as ops. Indicates whether an op should be run. + */ +std::vector Prune(const ProgramDesc& pdesc, int block_id); + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/executor_test.cc b/paddle/framework/executor_test.cc new file mode 100644 index 0000000000..137e53d849 --- /dev/null +++ b/paddle/framework/executor_test.cc @@ -0,0 +1,318 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/executor.h" + +#include +#include + +#include "gtest/gtest.h" +#include "paddle/framework/attribute.h" +#include "paddle/framework/backward.h" +#include "paddle/framework/block_desc.h" +#include "paddle/framework/op_desc.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/operator.h" + +USE_OP(elementwise_add); +USE_OP(gaussian_random); +USE_OP(feed); +USE_OP(fetch); +USE_OP(mul); +USE_OP(sum); +USE_OP(squared_l2_distance); +USE_OP(fill_constant); +USE_OP(sgd); + +using namespace paddle::platform; +using namespace paddle::framework; + +void AddOp(const std::string& type, const VariableNameMap& inputs, + const VariableNameMap& outputs, AttributeMap attrs, + paddle::framework::BlockDescBind* block) { + // insert output + for (auto kv : outputs) { + for (auto v : kv.second) { + auto var = block->NewVar(v); + var->SetDataType(paddle::framework::DataType::FP32); + } + } + + // insert op + auto op = block->AppendOp(); + op->SetType(type); + for (auto& kv : inputs) { + op->SetInput(kv.first, kv.second); + } + for (auto& kv : outputs) { + op->SetOutput(kv.first, kv.second); + } + op->SetAttrMap(attrs); +} + +// Tensors in feed value variable will only be in CPUPlace +// So we can memcpy the data from vector to feed_value +template +void SetFeedVariable(const std::vector>& inputs, + const std::vector>& dims) { + Variable* g_feed_value = GetGlobalScope().FindVar("feed_value"); + auto& feed_inputs = + *(g_feed_value->GetMutable>()); + size_t size = inputs.size(); + feed_inputs.resize(size); + for (size_t i = 0; i < size; i++) { + T* dst = feed_inputs[i].mutable_data(make_ddim(dims[i]), CPUPlace()); + memcpy(dst, inputs[i].data(), inputs[i].size() * sizeof(T)); + } +} + +// Tensors in fetch value variable will only be in CPUPlace +// So we can memcpy the data from fetch_value to vector +template +std::vector> GetFetchVariable() { + Variable* g_fetch_value = GetGlobalScope().FindVar("fetch_value"); + auto& fetch_outputs = + *(g_fetch_value->GetMutable>()); + + size_t size = fetch_outputs.size(); + std::vector> result; + result.reserve(size); + for (size_t i = 0; i < size; i++) { + std::vector tmp; + tmp.resize(fetch_outputs[i].numel()); + memcpy(tmp.data(), fetch_outputs[i].data(), + fetch_outputs[i].numel() * sizeof(T)); + result.push_back(tmp); + } + + return result; +} + +class ExecutorTesterRandom : public ::testing::Test { + public: + virtual void SetUp() override { + int input_dim = 3, batch_size = 2, embed_dim = 5; + + auto temp_init_root_block = init_pdesc_.add_blocks(); + temp_init_root_block->set_idx(0); + temp_init_root_block->set_parent_idx(-1); + paddle::framework::ProgramDescBind& init_program = + paddle::framework::ProgramDescBind::Instance(&init_pdesc_); + paddle::framework::BlockDescBind* init_root_block = init_program.Block(0); + + AddOp("gaussian_random", {}, {{"Out", {"w1"}}}, + {{"dims", std::vector{input_dim, embed_dim}}}, init_root_block); + AddOp("gaussian_random", {}, {{"Out", {"w2"}}}, + {{"dims", std::vector{embed_dim, input_dim}}}, init_root_block); + AddOp("fetch", {{"Input", {"w1"}}}, {}, {{"col", 0}}, init_root_block); + AddOp("fetch", {{"Input", {"w2"}}}, {}, {{"col", 1}}, init_root_block); + + // flush + init_program.Proto(); + + // run block + auto temp_root_block = pdesc_.add_blocks(); + temp_root_block->set_idx(0); + temp_root_block->set_parent_idx(-1); + paddle::framework::ProgramDescBind& program = + paddle::framework::ProgramDescBind::Instance(&pdesc_); + paddle::framework::BlockDescBind* root_block = program.Block(0); + + // feed data + inputs_.push_back({1.0, 1.0, 1.0, 1.0, 1.0, 1.0}); + dims_.push_back({batch_size, input_dim}); + AddOp("feed", {}, {{"Out", {"a"}}}, + {{"dims", std::vector{batch_size, input_dim}}, {"col", 0}}, + root_block); + + // forward + AddOp("mul", {{"X", {"a"}}, {"Y", {"w1"}}}, {{"Out", {"b"}}}, {}, + root_block); + AddOp("mul", {{"X", {"b"}}, {"Y", {"w2"}}}, {{"Out", {"a_out"}}}, {}, + root_block); + AddOp("squared_l2_distance", {{"X", {"a"}}, {"Y", {"a_out"}}}, + {{"Out", {"l2_distance"}}, {"sub_result", {"l2_distance_sub"}}}, {}, + root_block); + + // backward + AddOp("fill_constant", {}, {{"Out", {"l2_distance@GRAD"}}}, + {{"shape", std::vector{batch_size, 1}}, {"value", float(1.0)}}, + root_block); + AppendBackward(program, {}); + + // update + AddOp("fill_constant", {}, {{"Out", {"learning_rate"}}}, + {{"shape", std::vector{1}}, {"value", float(0.001)}}, + root_block); + AddOp("sgd", {{"Param", {"w1"}}, + {"LearningRate", {"learning_rate"}}, + {"Grad", {"w1@GRAD"}}}, + {{"ParamOut", {"w1"}}}, {}, root_block); + AddOp("sgd", {{"Param", {"w2"}}, + {"LearningRate", {"learning_rate"}}, + {"Grad", {"w2@GRAD"}}}, + {{"ParamOut", {"w2"}}}, {}, root_block); + + AddOp("fetch", {{"Input", {"w1"}}}, {}, {{"col", 0}}, root_block); + AddOp("fetch", {{"Input", {"w2"}}}, {}, {{"col", 1}}, root_block); + AddOp("fetch", {{"Input", {"l2_distance"}}}, {}, {{"col", 0}}, root_block); + + // flush + program.Proto(); + } + + protected: + ProgramDesc init_pdesc_; + ProgramDesc pdesc_; + std::vector> inputs_; + std::vector> dims_; +}; + +class ExecutorTesterFeedAndFetch : public ::testing::Test { + public: + virtual void SetUp() override { + auto temp_root_block = pdesc_.add_blocks(); + temp_root_block->set_idx(0); + temp_root_block->set_parent_idx(-1); + + // wrap to BlockDescBind + paddle::framework::ProgramDescBind& program = + paddle::framework::ProgramDescBind::Instance(&pdesc_); + paddle::framework::BlockDescBind* root_block = program.Block(0); + + std::vector dim{6}; + + AddOp("feed", {}, {{"Out", {"a"}}}, {{"dims", dim}, {"col", 0}}, + root_block); + AddOp("feed", {}, {{"Out", {"b"}}}, {{"dims", dim}, {"col", 1}}, + root_block); + AddOp("fetch", {{"Input", {"a"}}}, {}, {{"col", 0}}, root_block); + AddOp("fetch", {{"Input", {"b"}}}, {}, {{"col", 1}}, root_block); + + // flush + program.Proto(); + + std::vector vec1 = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0}; + std::vector vec2 = {4.0, 5.0, 6.0, 7.0, 8.0, 9.0}; + inputs_.push_back(vec1); + inputs_.push_back(vec2); + dims_.push_back({static_cast(vec1.size())}); + dims_.push_back({static_cast(vec2.size())}); + } + + protected: + ProgramDesc pdesc_; + std::vector> inputs_; + std::vector> dims_; +}; + +#ifndef PADDLE_WITH_CUDA +TEST_F(ExecutorTesterRandom, CPU) { + std::vector places; + CPUPlace cpu_place; + places.push_back(cpu_place); + + // We have a global Scope and BuddyAllocator, and we must ensure + // global BuddyAllocator is initialized before global Scope. Thus, + // global Scope will deconstruct before BuddyAllocator. Otherwise, + // "pointer being freed was not allocated" error will appear. + paddle::memory::Used(cpu_place); + + std::unique_ptr executor(new Executor(places)); + + executor->Run(init_pdesc_, &GetGlobalScope(), 0); + SetFeedVariable(inputs_, dims_); + executor->Run(pdesc_, &GetGlobalScope(), 0); + std::vector> result = GetFetchVariable(); +} + +TEST_F(ExecutorTesterFeedAndFetch, CPU) { + std::vector places; + CPUPlace cpu_place; + places.push_back(cpu_place); + + // We have a global Scope and BuddyAllocator, and we must ensure + // global BuddyAllocator is initialized before global Scope. Thus, + // global Scope will deconstruct before BuddyAllocator. Otherwise, + // "pointer being freed was not allocated" error will appear. + paddle::memory::Used(cpu_place); + + std::unique_ptr executor(new Executor(places)); + + for (int batch_id = 0; batch_id < 3; batch_id++) { + SetFeedVariable(inputs_, dims_); + executor->Run(pdesc_, &GetGlobalScope(), 0); + std::vector> result = GetFetchVariable(); + PADDLE_ENFORCE_EQ(result.size(), inputs_.size()); + for (size_t i = 0; i < result.size(); ++i) { + PADDLE_ENFORCE_EQ(result[i].size(), inputs_[i].size()); + for (size_t j = 0; j < result[i].size(); ++j) { + PADDLE_ENFORCE_EQ(result[i][j], inputs_[i][j]); + } + } + } +} +#else +TEST_F(ExecutorTesterRandom, GPU) { + std::vector places; + GPUPlace gpu_place(0); + places.push_back(gpu_place); + + // We have a global Scope and BuddyAllocator, and we must ensure + // global BuddyAllocator is initialized before global Scope. Thus, + // global Scope will deconstruct before BuddyAllocator. Otherwise, + // "pointer being freed was not allocated" error will appear. + // If paddle is compiled with GPU, both CPU and GPU BuddyAllocator + // need to be used at first. + paddle::memory::Used(CPUPlace()); + paddle::memory::Used(gpu_place); + + std::unique_ptr executor(new Executor(places)); + + executor->Run(init_pdesc_, &GetGlobalScope(), 0); + for (int batch_id = 0; batch_id < 3; batch_id++) { + SetFeedVariable(inputs_, dims_); + executor->Run(pdesc_, &GetGlobalScope(), 0); + } +} + +TEST_F(ExecutorTesterFeedAndFetch, GPU) { + std::vector places; + GPUPlace gpu_place(0); + places.push_back(gpu_place); + // We have a global Scope and BuddyAllocator, and we must ensure + // global BuddyAllocator is initialized before global Scope. Thus, + // global Scope will deconstruct before BuddyAllocator. Otherwise, + // "pointer being freed was not allocated" error will appear. + // If paddle is compiled with GPU, both CPU and GPU BuddyAllocator + // need to be used at first. + paddle::memory::Used(CPUPlace()); + paddle::memory::Used(gpu_place); + + std::unique_ptr executor(new Executor(places)); + + for (int batch_id = 0; batch_id < 3; batch_id++) { + SetFeedVariable(inputs_, dims_); + executor->Run(pdesc_, &GetGlobalScope(), 0); + std::vector> result = GetFetchVariable(); + PADDLE_ENFORCE_EQ(result.size(), inputs_.size()); + for (size_t i = 0; i < result.size(); ++i) { + PADDLE_ENFORCE_EQ(result[i].size(), inputs_[i].size()); + for (size_t j = 0; j < result[i].size(); ++j) { + PADDLE_ENFORCE_EQ(result[i][j], inputs_[i][j]); + } + } + } +} +#endif diff --git a/paddle/framework/scope.cc b/paddle/framework/scope.cc index 080b4ac621..5821bac928 100644 --- a/paddle/framework/scope.cc +++ b/paddle/framework/scope.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/framework/scope.h" + +#include // for unique_ptr +#include // for call_once #include "paddle/string/printf.h" namespace paddle { @@ -62,5 +65,17 @@ void Scope::DropKids() { kids_.clear(); } +std::once_flag feed_variable_flag; + +framework::Scope& GetGlobalScope() { + static std::unique_ptr g_scope{nullptr}; + std::call_once(feed_variable_flag, [&]() { + g_scope.reset(new framework::Scope()); + g_scope->NewVar("feed_value"); + g_scope->NewVar("fetch_value"); + }); + return *(g_scope.get()); +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/scope.h b/paddle/framework/scope.h index 7047f0d55e..a8cfb107c2 100644 --- a/paddle/framework/scope.h +++ b/paddle/framework/scope.h @@ -73,5 +73,7 @@ class Scope { DISABLE_COPY_AND_ASSIGN(Scope); }; +framework::Scope& GetGlobalScope(); + } // namespace framework } // namespace paddle diff --git a/paddle/operators/activation_op.cc b/paddle/operators/activation_op.cc index a6bb738af3..ced14a8923 100644 --- a/paddle/operators/activation_op.cc +++ b/paddle/operators/activation_op.cc @@ -137,6 +137,24 @@ class TanhShrinkOpMaker : public framework::OpProtoAndCheckerMaker { } }; +template +class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker { + public: + HardShrinkOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of HardShrink operator"); + AddOutput("Y", "Output of HardShrink operator"); + AddComment( + "HardShrink activation operator, " + "hard_shrink(x) = x if x > lambda" + "hard_shrink(x) = x if x < -lambda" + "hard_shrink(x) = 0 otherwise"); + AddAttr("threshold", "The value of threshold for HardShrink") + .SetDefault(static_cast(0.5)); + } +}; + class SqrtOpMaker : public framework::OpProtoAndCheckerMaker { public: SqrtOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) @@ -188,6 +206,17 @@ class SquareOpMaker : public framework::OpProtoAndCheckerMaker { } }; +class SoftplusOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SoftplusOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of Softplus operator"); + AddOutput("Y", "Output of Softplus operator"); + AddComment("Softplus activation operator, softplus(x) = log(1 + exp(x))"); + } +}; + class SoftsignOpMaker : public framework::OpProtoAndCheckerMaker { public: SoftsignOpMaker(framework::OpProto *proto, @@ -333,6 +362,9 @@ REGISTER_OP(log, ops::ActivationOp, ops::LogOpMaker, log_grad, REGISTER_OP(square, ops::ActivationOp, ops::SquareOpMaker, square_grad, ops::ActivationOpGrad); +REGISTER_OP(softplus, ops::ActivationOp, ops::SoftplusOpMaker, softplus_grad, + ops::ActivationOpGrad); + REGISTER_OP(softsign, ops::ActivationOp, ops::SoftsignOpMaker, softsign_grad, ops::ActivationOpGrad); @@ -357,6 +389,9 @@ REGISTER_OP(pow, ops::ActivationOp, ops::PowOpMaker, pow_grad, REGISTER_OP(stanh, ops::ActivationOp, ops::STanhOpMaker, stanh_grad, ops::ActivationOpGrad); +REGISTER_OP(hard_shrink, ops::ActivationOp, ops::HardShrinkOpMaker, + hard_shrink_grad, ops::ActivationOpGrad); + #define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \ REGISTER_OP_CPU_KERNEL( \ act_type, \ diff --git a/paddle/operators/activation_op.h b/paddle/operators/activation_op.h index 70d5a62052..f88c9c48eb 100644 --- a/paddle/operators/activation_op.h +++ b/paddle/operators/activation_op.h @@ -199,6 +199,39 @@ struct TanhShrinkGradFunctor : public BaseActivationFunctor { } }; +// tanhshrink(x) = x - tanh(x) +// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) +template +struct HardShrinkFunctor : public BaseActivationFunctor { + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + template + void operator()(Device d, X x, Y y) const { + auto temp1 = (x < (threshold * -1)).template cast().eval(); + auto temp2 = (x > threshold).template cast().eval(); + y.device(d) = x * (temp1 + temp2); + } +}; + +template +struct HardShrinkGradFunctor : public BaseActivationFunctor { + float threshold; + + typename BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + template + void operator()(Device d, X x, Y y, dY dy, dX dx) const { + auto temp1 = (x < (threshold * -1)).template cast().eval(); + auto temp2 = (x > threshold).template cast().eval(); + dx.device(d) = dy * (temp1 + temp2).template cast(); + } +}; + // softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < lambda; 0 // otherwise template @@ -351,8 +384,6 @@ template struct Relu6Functor : public BaseActivationFunctor { float threshold; - // NOTE: Explicit hides the `BaseActivationFunctor::GetAttrs` - // not polymorphism for speed. typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } @@ -376,6 +407,33 @@ struct Relu6GradFunctor : public BaseActivationFunctor { } }; +// softplus(x) = log(1 + exp(x)) +// When x is a very large positive number, exp(x) may explode to inf, +// Using trick below for numerical stability +// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/ +// Then: softplus(x) = max(x, 0) + log(exp(-max(x, 0)) + exp(x - max(x, 0))) +template +struct SoftplusFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Y y) { + auto temp = x.cwiseMax(static_cast(0)); // temp = max(x, 0) + y.device(d) = temp + (((-temp).exp() + (x - temp).exp()).log()); + } +}; + +// d(softplus(x))/dx = exp(x) / (1 + exp(x)) +// For numerical stability: +// d(softplus(x))/dx = exp(x - max(x, 0)) / (exp(-max(x, 0)) + +// exp(x - max(x, 0))) +template +struct SoftplusGradFunctor : public BaseActivationFunctor { + template + void operator()(Device d, X x, Y y, dY dy, dX dx) { + auto temp = x.cwiseMax(static_cast(0)); // temp = max(x, 0) + dx.device(d) = dy * ((x - temp).exp() / ((-temp).exp() + (x - temp).exp())); + } +}; + // softsign(x) = x / (1 + |x|) template struct SoftsignFunctor : public BaseActivationFunctor { @@ -551,8 +609,10 @@ struct STanhGradFunctor : public BaseActivationFunctor { __macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor); \ __macro(pow, PowFunctor, PowGradFunctor); \ __macro(stanh, STanhFunctor, STanhGradFunctor); \ + __macro(softplus, SoftplusFunctor, SoftplusGradFunctor); \ __macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \ __macro(relu6, Relu6Functor, Relu6GradFunctor); \ __macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \ __macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \ - __macro(elu, ELUFunctor, ELUGradFunctor) + __macro(elu, ELUFunctor, ELUGradFunctor); \ + __macro(hard_shrink, HardShrinkFunctor, HardShrinkGradFunctor) diff --git a/paddle/operators/feed_op.cc b/paddle/operators/feed_op.cc new file mode 100644 index 0000000000..fa325bb282 --- /dev/null +++ b/paddle/operators/feed_op.cc @@ -0,0 +1,59 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/feed_op.h" + +namespace paddle { +namespace operators { + +class FeedOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output should be not null."); + auto& shape = ctx->Attrs().Get>("dims"); + std::vector shape_int64(shape.size(), 0); + std::transform(shape.begin(), shape.end(), shape_int64.begin(), + [](int a) { return static_cast(a); }); + ctx->SetOutputDim("Out", framework::make_ddim(shape_int64)); + // TODO(qijun): need to handle LodTensor later + } + + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return static_cast(Attr("dataType")); + } +}; + +class FeedOpMaker : public framework::OpProtoAndCheckerMaker { + public: + FeedOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddAttr("dataType", "output data type") + .SetDefault(framework::DataType::FP32); + AddAttr("col", "The col in global feed variable").SetDefault(0); + AddAttr>("dims", "The dimension of feed tensor."); + AddOutput("Out", "The output of feed op."); + AddComment(R"DOC(Feed data from global feed variable)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(feed, ops::FeedOp, ops::FeedOpMaker); +REGISTER_OP_CPU_KERNEL(feed, ops::FeedKernel); diff --git a/paddle/operators/feed_op.cu b/paddle/operators/feed_op.cu new file mode 100644 index 0000000000..7b6a2ac91e --- /dev/null +++ b/paddle/operators/feed_op.cu @@ -0,0 +1,18 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/feed_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(feed, ops::FeedKernel); diff --git a/paddle/operators/feed_op.h b/paddle/operators/feed_op.h new file mode 100644 index 0000000000..9d8158299f --- /dev/null +++ b/paddle/operators/feed_op.h @@ -0,0 +1,42 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class FeedKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + framework::Tensor* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + framework::Variable* g_feed_variable = + framework::GetGlobalScope().FindVar("feed_value"); + const auto& tensors = + g_feed_variable->Get>(); + int col = ctx.template Attr("col"); + PADDLE_ENFORCE_GT(tensors.size(), static_cast(col)); + // TODO(qijun): + // check tensors[col].dims() with attribute, + // except the first dimenson. + out->CopyFrom(tensors[col], ctx.GetPlace()); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/fetch_op.cc b/paddle/operators/fetch_op.cc new file mode 100644 index 0000000000..90737c8c55 --- /dev/null +++ b/paddle/operators/fetch_op.cc @@ -0,0 +1,52 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/fetch_op.h" + +namespace paddle { +namespace operators { + +class FetchOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should be not null."); + } + + framework::DataType IndicateDataType( + const framework::ExecutionContext& ctx) const override { + return static_cast(Attr("dataType")); + } +}; + +class FetchOpMaker : public framework::OpProtoAndCheckerMaker { + public: + FetchOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddAttr("dataType", "output data type") + .SetDefault(framework::DataType::FP32); + AddAttr("col", "The col in global fetch variable").SetDefault(0); + AddInput("Input", "The output of fetch op."); + AddComment(R"DOC(Fetch data to global fetch variable)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(fetch, ops::FetchOp, ops::FetchOpMaker); +REGISTER_OP_CPU_KERNEL(fetch, ops::FetchKernel); diff --git a/paddle/operators/fetch_op.cu b/paddle/operators/fetch_op.cu new file mode 100644 index 0000000000..ca39d24c79 --- /dev/null +++ b/paddle/operators/fetch_op.cu @@ -0,0 +1,18 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/fetch_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(fetch, ops::FetchKernel); diff --git a/paddle/operators/fetch_op.h b/paddle/operators/fetch_op.h new file mode 100644 index 0000000000..eb9c3a7b59 --- /dev/null +++ b/paddle/operators/fetch_op.h @@ -0,0 +1,44 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class FetchKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const framework::Tensor* input = ctx.Input("Input"); + framework::Variable* g_fetch_variable = + framework::GetGlobalScope().FindVar("fetch_value"); + auto* tensors = + g_fetch_variable->GetMutable>(); + int col = ctx.template Attr("col"); + if (tensors->size() < static_cast(col + 1)) { + tensors->resize(col + 1); + } + PADDLE_ENFORCE_GT(tensors->size(), static_cast(col)); + (*tensors)[col].Resize(input->dims()); + (*tensors)[col].mutable_data(platform::CPUPlace()); + (*tensors)[col].CopyFrom(*input, platform::CPUPlace()); + // TODO(qijun): need to handle LodTensor later + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 6e2611af7b..2c1bc6d910 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -4,12 +4,15 @@ if(WITH_GPU) nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator) nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator) nv_library(pooling SRCS pooling.cc pooling.cu DEPS operator) + nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context) else() cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator) cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) cc_library(softmax SRCS softmax.cc DEPS operator) cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator) cc_library(pooling SRCS pooling.cc DEPS operator) + cc_library(vol2col SRCS vol2col.cc DEPS device_context) endif() cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor) +cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col tensor) diff --git a/paddle/operators/math/vol2col.cc b/paddle/operators/math/vol2col.cc new file mode 100644 index 0000000000..e9718a0473 --- /dev/null +++ b/paddle/operators/math/vol2col.cc @@ -0,0 +1,155 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/vol2col.h" + +namespace paddle { +namespace operators { +namespace math { + +/* + * vol = [input_channels, input_depth, input_height, input_width] + * col = + * [input_channels, filter_depth, filter_height, filter_width, + * output_depth, output_height, output_width] + */ +template +class Vol2ColFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& vol, framework::Tensor& col, + int stride_depth, int stride_height, int stride_width, + int padding_depth, int padding_height, + int padding_width) const { + PADDLE_ENFORCE(vol.dims().size() == 4); + PADDLE_ENFORCE(col.dims().size() == 7); + + int input_channels = vol.dims()[0]; + int input_depth = vol.dims()[1]; + int input_height = vol.dims()[2]; + int input_width = vol.dims()[3]; + int filter_depth = col.dims()[1]; + int filter_height = col.dims()[2]; + int filter_width = col.dims()[3]; + int output_depth = col.dims()[4]; + int output_height = col.dims()[5]; + int output_width = col.dims()[6]; + int channels_col = + input_channels * filter_depth * filter_height * filter_width; + + const T* vol_data = vol.data(); + T* col_data = col.data(); + + for (int c = 0; c < channels_col; ++c) { + int w_offset = c % filter_width; + int h_offset = (c / filter_width) % filter_height; + int d_offset = (c / filter_width / filter_height) % filter_depth; + int c_in = c / filter_width / filter_height / filter_depth; + for (int d = 0; d < output_depth; ++d) { + int d_pad = d * stride_depth - padding_depth + d_offset; + for (int h = 0; h < output_height; ++h) { + int h_pad = h * stride_height - padding_height + h_offset; + for (int w = 0; w < output_width; ++w) { + int w_pad = w * stride_width - padding_width + w_offset; + + int col_idx = + ((c * output_depth + d) * output_height + h) * output_width + w; + if (h_pad < 0 || h_pad >= input_height || w_pad < 0 || + w_pad >= input_width || d_pad < 0 || d_pad >= input_depth) { + col_data[col_idx] = static_cast(0); + } else { + int vol_idx = + ((c_in * input_depth + d_pad) * input_height + h_pad) * + input_width + + w_pad; + col_data[col_idx] = vol_data[vol_idx]; + } + } + } + } + } + } +}; + +/* + * vol = [input_channels,input_depth, input_height, input_width] + * col = + * [input_channels, filter_depth, filter_height, filter_width, + * output_depth, output_height, output_width] + */ +template +class Col2VolFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& vol, const framework::Tensor& col, + int stride_depth, int stride_height, int stride_width, + int padding_depth, int padding_height, + int padding_width) const { + PADDLE_ENFORCE(vol.dims().size() == 4); + PADDLE_ENFORCE(col.dims().size() == 7); + + int input_channels = vol.dims()[0]; + int input_depth = vol.dims()[1]; + int input_height = vol.dims()[2]; + int input_width = vol.dims()[3]; + int filter_depth = col.dims()[1]; + int filter_height = col.dims()[2]; + int filter_width = col.dims()[3]; + int output_depth = col.dims()[4]; + int output_height = col.dims()[5]; + int output_width = col.dims()[6]; + int channels_col = + input_channels * filter_depth * filter_height * filter_width; + + T* vol_data = vol.data(); + const T* col_data = col.data(); + + for (int c = 0; c < channels_col; ++c) { + int w_offset = c % filter_width; + int h_offset = (c / filter_width) % filter_height; + int d_offset = (c / filter_width / filter_height) % filter_depth; + int cIm = c / filter_width / filter_height / filter_depth; + for (int d = 0; d < output_depth; ++d) { + int d_pad = d * stride_depth - padding_depth + d_offset; + for (int h = 0; h < output_height; ++h) { + int h_pad = h * stride_height - padding_height + h_offset; + for (int w = 0; w < output_width; ++w) { + int w_pad = w * stride_width - padding_width + w_offset; + + if (h_pad >= 0 && h_pad < input_height && w_pad >= 0 && + w_pad < input_width && d_pad >= 0 && d_pad < input_depth) { + int vol_idx = + ((cIm * input_depth + d_pad) * input_height + h_pad) * + input_width + + w_pad; + int col_idx = + ((c * output_depth + d) * output_height + h) * output_width + + w; + vol_data[vol_idx] += col_data[col_idx]; + } + } + } + } + } + } +}; + +template class Vol2ColFunctor; +template class Vol2ColFunctor; +template class Col2VolFunctor; +template class Col2VolFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/vol2col.cu b/paddle/operators/math/vol2col.cu new file mode 100644 index 0000000000..27b11fb237 --- /dev/null +++ b/paddle/operators/math/vol2col.cu @@ -0,0 +1,204 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/vol2col.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +namespace math { + +template +__global__ void vol2col(int num_kernels, const T* data_vol, int depth, + int height, int width, int filter_depth, + int filter_height, int filter_width, int stride_depth, + int stride_height, int stride_width, int padding_depth, + int padding_height, int padding_width, int output_detph, + int output_height, int output_width, T* data_col) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels; + index += blockDim.x * gridDim.x) { + int w_out = index % output_width; + int h_out = (index / output_width) % output_height; + int d_out = (index / output_width / output_height) % output_detph; + int channel_in = index / output_width / output_height / output_detph; + int channel_out = channel_in * filter_depth * filter_height * filter_width; + int w_in = w_out * stride_width - padding_width; + int h_in = h_out * stride_height - padding_height; + int d_in = d_out * stride_depth - padding_depth; + + data_col += ((channel_out * output_detph + d_out) * output_height + h_out) * + output_width + + w_out; + data_vol += ((channel_in * depth + d_in) * height + h_in) * width + w_in; + for (int k = 0; k < filter_depth; ++k) { + for (int i = 0; i < filter_height; ++i) { + for (int j = 0; j < filter_width; ++j) { + int d = d_in + k; + int h = h_in + i; + int w = w_in + j; + *data_col = (d >= 0 && d < depth && h >= 0 && h < height && w >= 0 && + w < width) + ? data_vol[(k * height + i) * width + j] + : 0; + data_col += output_detph * output_height * output_width; + } + } + } + } +} + +/* + * im = [input_channels,intpu_depth, input_height, input_width] + * col = + * [input_channels, filter_depth, filter_height, filter_width, + * output_depth, output_height, output_width] + */ +template +class Vol2ColFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& vol, framework::Tensor& col, + int stride_depth, int stride_height, int stride_width, + int padding_depth, int padding_height, + int padding_width) const { + PADDLE_ENFORCE(vol.dims().size() == 4); + PADDLE_ENFORCE(col.dims().size() == 7); + + int input_channels = vol.dims()[0]; + int input_depth = vol.dims()[1]; + int input_height = vol.dims()[2]; + int input_width = vol.dims()[3]; + int filter_depth = col.dims()[1]; + int filter_height = col.dims()[2]; + int filter_width = col.dims()[3]; + int output_depth = col.dims()[4]; + int output_height = col.dims()[5]; + int output_width = col.dims()[6]; + + int num_outputs = + input_channels * output_depth * output_height * output_width; + + const int threads = 1024; + const int blocks = (num_outputs + 1024 - 1) / 1024; + vol2col<<(context) + .stream()>>>( + num_outputs, vol.data(), input_depth, input_height, input_width, + filter_depth, filter_height, filter_width, stride_depth, stride_height, + stride_width, padding_depth, padding_height, padding_width, + output_depth, output_height, output_width, col.data()); + } +}; + +template +__global__ void col2vol(int num_kernels, const T* data_col, int depth, + int height, int width, int filter_depth, + int filter_height, int filter_width, int stride_depth, + int stride_height, int stride_width, int padding_depth, + int padding_height, int padding_width, int output_detph, + int output_height, int output_width, T* data_vol) { + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels; + index += blockDim.x * gridDim.x) { + T src_val = 0; + int w = index % width + padding_width; + int h = (index / width) % height + padding_height; + int d = (index / width / height) % depth + padding_depth; + int c = index / width / height / depth; + // compute the start and end of the output + int w_col_start = + (w < filter_width) ? 0 : (w - filter_width) / stride_width + 1; + int w_col_end = min(w / stride_width + 1, output_width); + int h_col_start = + (h < filter_height) ? 0 : (h - filter_height) / stride_height + 1; + int h_col_end = min(h / stride_height + 1, output_height); + int d_col_start = + (d < filter_depth) ? 0 : (d - filter_depth) / stride_depth + 1; + int d_col_end = min(d / stride_depth + 1, output_detph); + + int offset = (c * filter_depth * filter_height * filter_width + + d * filter_width * filter_height + h * filter_width + w) * + output_detph * output_height * output_width; + + int coeff_d_col = + (1 - stride_depth * filter_width * filter_height * output_detph) * + output_height * output_width; + int coeff_h_col = + (1 - stride_height * filter_width * output_detph * output_height) * + output_width; + int coeff_w_col = + (1 - stride_width * output_detph * output_height * output_width); + + for (int d_col = d_col_start; d_col < d_col_end; ++d_col) { + for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { + for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { + src_val += data_col[offset + d_col * coeff_d_col + + h_col * coeff_h_col + w_col * coeff_w_col]; + } + } + } + data_vol[index] = src_val; + } +} + +/* + * im = [input_channels, input_depth, input_height, input_width] + * col = + * [input_channels, filter_depth, filter_height, filter_width, + * output_depth, output_height, output_width] + */ +template +class Col2VolFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& vol, const framework::Tensor& col, + int stride_depth, int stride_height, int stride_width, + int padding_depth, int padding_height, + int padding_width) const { + PADDLE_ENFORCE(vol.dims().size() == 4); + PADDLE_ENFORCE(col.dims().size() == 7); + + int input_channels = vol.dims()[0]; + int input_depth = vol.dims()[1]; + int input_height = vol.dims()[2]; + int input_width = vol.dims()[3]; + int filter_depth = col.dims()[1]; + int filter_height = col.dims()[2]; + int filter_width = col.dims()[3]; + int output_depth = col.dims()[4]; + int output_height = col.dims()[5]; + int output_width = col.dims()[6]; + + int num_kernels = input_channels * input_depth * input_height * input_width; + + const int threads = 1024; + const int blocks = (num_kernels + 1024 - 1) / 1024; + + col2vol<<(context) + .stream()>>>( + num_kernels, col.data(), input_depth, input_height, input_width, + filter_depth, filter_height, filter_width, stride_depth, stride_height, + stride_width, padding_depth, padding_height, padding_width, + output_depth, output_height, output_width, vol.data()); + } +}; + +template class Vol2ColFunctor; +template class Vol2ColFunctor; +template class Col2VolFunctor; +template class Col2VolFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/vol2col.h b/paddle/operators/math/vol2col.h new file mode 100644 index 0000000000..f022365a16 --- /dev/null +++ b/paddle/operators/math/vol2col.h @@ -0,0 +1,78 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/tensor.h" +#include "paddle/platform/device_context.h" + +namespace paddle { +namespace operators { +namespace math { +/* + * \brief Converts the feature data of four dimensions(CDHW) into a colData of + * seven dimensions in the Vol2ColFunctor calculation, + * And in the Col2VolFunctor calculation, it is reversed. + * + * \param volData Vol data. + * \param volShape The shape of volData, + * [input_channels, input_depth, input_height, input_width]. + * \param colData Column data. + * \param colShape The shape of colData. + * + * The shape of colData is: + * [input_channels, filter_depth, filter_height, filter_width, output_depth, + * output_height, output_width] + * So, it is easy to reshape into a convolution matrix for convolution + * calculation based on matrix multiplication. + * The shape of convolution matrix is [height, width], where the height is equal + * input_channels * filter_depth * filter_height * filter_width, and the width + * is equal output_depth * output_height * output_width. + * + * Reshape: + * shape of colData shape of convolution matrix + * [input_channels, + * filter_depth, + * filter_height, + * filter_width, ======> [height, width] + * output_depth, + * output_height, + * output_width] + * + * \note The caller needs to ensure that volShape.inputChannels is equal to + * colShape.inputChannels. + */ +template +class Vol2ColFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& vol, framework::Tensor& col, + int stride_depth, int stride_height, int stride_width, + int padding_depth, int padding_height, + int padding_width) const; +}; + +template +class Col2VolFunctor { + public: + void operator()(const platform::DeviceContext& context, + framework::Tensor& vol, const framework::Tensor& col, + int stride_depth, int stride_height, int stride_width, + int padding_depth, int padding_height, + int padding_width) const; +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/vol2col_test.cc b/paddle/operators/math/vol2col_test.cc new file mode 100644 index 0000000000..81225e9a98 --- /dev/null +++ b/paddle/operators/math/vol2col_test.cc @@ -0,0 +1,135 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/vol2col.h" +#include +#include + +template +void testVol2col() { + paddle::framework::Tensor input; + paddle::framework::Tensor input_tmp; + paddle::framework::Tensor output; + paddle::framework::Tensor output_tmp; + + auto* place = new Place(); + paddle::platform::DeviceContext* context; + if (paddle::platform::is_cpu_place(*place)) { + context = + new paddle::platform::CPUDeviceContext(paddle::platform::CPUPlace()); + } else { +#ifdef PADDLE_WITH_CUDA + context = + new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace()); +#else + PADDLE_THROW("no GPU support"); +#endif // PADDLE_WITH_CUDA + } + + /** + * input = [[0, 1, 2, + * 3, 4, 5] + * [6, 7, 8, + * 9, 10, 11]] + * + * output = [0, 1 + * 1, 2 + * 3, 4 + * 4, 5 + * 6, 7 + * 7, 8 + * 9, 10 + * 10, 11] + * + * col2vol = [[0, 2, 2, + * 3, 8, 5] + * [6, 14, 8, + * 9, 20, 11]] + * + */ + int input_depth = 2; + int input_height = 2; + int input_width = 3; + int filter_size = 2; + int stride = 1; + int padding = 0; + int output_depth = (input_depth - filter_size + 2 * padding) / stride + 1; + int output_height = (input_height - filter_size + 2 * padding) / stride + 1; + int output_width = (input_width - filter_size + 2 * padding) / stride + 1; + + // Vol2Col test + float* input_ptr = + input_tmp.mutable_data({1, input_depth, input_height, input_width}, + paddle::platform::CPUPlace()); + float arr[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + memcpy(input_ptr, arr, 12 * sizeof(float)); + + if (paddle::platform::is_cpu_place(*place)) { + input = input_tmp; + } else { + input.CopyFrom(input_tmp, *place); + } + output.mutable_data({1, filter_size, filter_size, filter_size, + output_depth, output_height, output_width}, + *place); + + paddle::operators::math::Vol2ColFunctor vol2col; + vol2col(*context, input, output, stride, stride, stride, padding, padding, + padding); + + float vol_2_col[] = {0, 1, 1, 2, 3, 4, 4, 5, 6, 7, 7, 8, 9, 10, 10, 11}; + float* out_cfo_ptr; + if (paddle::platform::is_cpu_place(*place)) { + out_cfo_ptr = output.data(); + } else { + output_tmp.CopyFrom(output, paddle::platform::CPUPlace()); + out_cfo_ptr = output_tmp.data(); + } + + for (int i = 0; i < 16; ++i) { + EXPECT_EQ(out_cfo_ptr[i], vol_2_col[i]); + } + + // Col2Vol test + float col_2_vol[] = {0, 2, 2, 3, 8, 5, 6, 14, 8, 9, 20, 11}; + memset(input_ptr, 0, 12 * sizeof(float)); + if (paddle::platform::is_cpu_place(*place)) { + input = input_tmp; + } else { + input.CopyFrom(input_tmp, *place); + } + + paddle::operators::math::Col2VolFunctor col2vol; + col2vol(*context, input, output, stride, stride, stride, padding, padding, + padding); + + float* in_ptr; + if (paddle::platform::is_cpu_place(*place)) { + in_ptr = input.data(); + } else { + input_tmp.CopyFrom(input, paddle::platform::CPUPlace()); + in_ptr = input_tmp.data(); + } + + for (int i = 0; i < 12; ++i) { + EXPECT_EQ(in_ptr[i], col_2_vol[i]); + } +} + +TEST(math, vol2col) { + testVol2col(); +#ifdef PADDLE_WITH_CUDA + testVol2col(); +#endif // PADDLE_WITH_CUDA +} diff --git a/paddle/platform/gpu_info.cc b/paddle/platform/gpu_info.cc index 70ad611d5d..0cab5ffc56 100644 --- a/paddle/platform/gpu_info.cc +++ b/paddle/platform/gpu_info.cc @@ -43,6 +43,8 @@ int GetCurrentDeviceId() { } void SetDeviceId(int id) { + // TODO(qijun): find a better way to cache the cuda device count + PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count"); PADDLE_ENFORCE(cudaSetDevice(id), "cudaSetDevice failed in paddle::platform::SetDeviceId"); } diff --git a/python/paddle/v2/framework/tests/test_activation_op.py b/python/paddle/v2/framework/tests/test_activation_op.py index 9157e00f6e..a28c4431e1 100644 --- a/python/paddle/v2/framework/tests/test_activation_op.py +++ b/python/paddle/v2/framework/tests/test_activation_op.py @@ -78,6 +78,26 @@ class TestTanhShrink(OpTest): self.check_grad(['X'], 'Y', max_relative_error=0.008) +class TestHardShrink(OpTest): + def setUp(self): + self.op_type = "hard_shrink" + x = np.random.uniform(-1, 1, [4, 4]).astype("float32") + threshold = 0.5 + + self.inputs = {'X': x} + self.attrs = {'lambda': threshold} + + t = np.copy(x) + t[(t >= -threshold) & (t <= threshold)] = 0 + self.outputs = {'Y': t} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.005) + + class TestSoftShrink(OpTest): def setUp(self): self.op_type = "softshrink" @@ -311,6 +331,21 @@ class TestSTanh(OpTest): self.check_grad(['X'], 'Y', max_relative_error=0.007) +class TestSoftplus(OpTest): + def setUp(self): + self.op_type = "softplus" + self.inputs = { + 'X': np.random.uniform(-1, 1, [11, 17]).astype("float32") + } + self.outputs = {'Y': np.log(1 + np.exp(self.inputs['X']))} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.007) + + class TestSoftsign(OpTest): def setUp(self): self.op_type = "softsign"