From 8dc382e4ee53a9da7f63c42809ebf787b9f8ccc8 Mon Sep 17 00:00:00 2001 From: wangmeng28 Date: Tue, 26 Sep 2017 15:35:54 +0800 Subject: [PATCH 01/13] Check whether param name is manually set when input is a sequence in fc layer --- python/paddle/trainer_config_helpers/layers.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 74025d2a7b..fffb44152e 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -1044,6 +1044,8 @@ def fc_layer(input, if isinstance(param_attr, collections.Sequence): assert len(input) == len(param_attr) else: + if "parameter_name" in param_attr.attr and len(input) > 1: + logger.fatal("You should set the parameter name for each of the input item.") param_attr = [copy.deepcopy(param_attr) for _ in range(len(input))] assert isinstance(input, collections.Sequence) @@ -4863,6 +4865,8 @@ def selective_fc_layer(input, if isinstance(param_attr, collections.Sequence): assert len(input) == len(param_attr) else: + if "parameter_name" in param_attr.attr and len(input) > 1: + logger.fatal("You should set the parameter name for each of the input item.") param_attr = [copy.deepcopy(param_attr) for _ in range(len(input))] assert isinstance(input, collections.Sequence) @@ -6473,7 +6477,7 @@ def switch_order_layer(input, act=None, layer_attr=None): """ - This layer switch dimension order of image input. + This layer switch dimension order of image input. From order "batchSize, channels, height, width" to order "batchSize, height, width, channels". From a378db3c373b318a1312d1503f019ca3ac15e3a8 Mon Sep 17 00:00:00 2001 From: wangmeng28 Date: Tue, 26 Sep 2017 16:05:08 +0800 Subject: [PATCH 02/13] fix style issue --- python/paddle/trainer_config_helpers/layers.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index fffb44152e..aebdcc134b 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -1045,7 +1045,9 @@ def fc_layer(input, assert len(input) == len(param_attr) else: if "parameter_name" in param_attr.attr and len(input) > 1: - logger.fatal("You should set the parameter name for each of the input item.") + logger.fatal( + "You should set the parameter name for each of the input item." + ) param_attr = [copy.deepcopy(param_attr) for _ in range(len(input))] assert isinstance(input, collections.Sequence) @@ -4866,7 +4868,9 @@ def selective_fc_layer(input, assert len(input) == len(param_attr) else: if "parameter_name" in param_attr.attr and len(input) > 1: - logger.fatal("You should set the parameter name for each of the input item.") + logger.fatal( + "You should set the parameter name for each of the input item." + ) param_attr = [copy.deepcopy(param_attr) for _ in range(len(input))] assert isinstance(input, collections.Sequence) From 2c5d4c6d200c478f9660593cdff67bad10c56402 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Mon, 30 Oct 2017 16:19:58 +0800 Subject: [PATCH 03/13] Clean code and update doc. --- paddle/operators/lstm_op.cc | 10 +++++----- paddle/operators/lstm_op.h | 14 +------------- python/paddle/v2/framework/tests/test_lstm_op.py | 8 +++++--- 3 files changed, 11 insertions(+), 21 deletions(-) diff --git a/paddle/operators/lstm_op.cc b/paddle/operators/lstm_op.cc index 10b60e3de6..94342d9407 100644 --- a/paddle/operators/lstm_op.cc +++ b/paddle/operators/lstm_op.cc @@ -126,11 +126,11 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { " - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.") .AsDispensable(); AddOutput("Hidden", - "(LoDTensor) the hidden state lod tensor of LSTM operator. " - "The shape and lod is the same with the `Input`."); + "(LoDTensor) the hidden state of LSTM operator. " + "The shape is (T x D), and lod is the same with the `Input`."); AddOutput("Cell", - "(LoDTensor) the cell state lod tensor of LSTM operator. " - "The shape and lod is the same with the `Input`."); + "(LoDTensor) the cell state of LSTM operator. " + "The shape is (T x D), and lod is the same with the `Input`."); AddOutput("BatchGate", "(LoDTensor) This LoDTensor contains input gate, forget gate " "and output gate after the nonlinear computation. This " @@ -141,7 +141,7 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { "in the raw input.") .AsIntermediate(); AddOutput("BatchCellPreAct", - "(LoDTensor) This LoDTensor is get in the forward and used " + "(LoDTensor) This LoDTensor is got in the forward and used " "in the backward.") .AsIntermediate(); AddAttr("usePeepholes", diff --git a/paddle/operators/lstm_op.h b/paddle/operators/lstm_op.h index d147b84aef..af088b80b4 100644 --- a/paddle/operators/lstm_op.h +++ b/paddle/operators/lstm_op.h @@ -155,7 +155,6 @@ class LSTMGradKernel : public framework::OpKernel { auto* batch_cell_pre_act = ctx.Input("BatchCellPreAct"); auto* hidden_g = ctx.Input(framework::GradVarName("Hidden")); - // auto* cell_g = ctx.Input(framework::GradVarName("Cell")); auto* in_g = ctx.Output(framework::GradVarName("Input")); auto* weight_g = ctx.Output(framework::GradVarName("Weight")); @@ -251,7 +250,7 @@ class LSTMGradKernel : public framework::OpKernel { lstm_grad.gateGrad = gate_g.data(); lstm_grad.outputGrad = out_g.data(); - if (n != 0) { + if (n) { int bstart_pre = static_cast(batch_starts[n - 1]); Tensor cell_pre = batch_cell.Slice(bstart_pre, bstart); Tensor cell_pre_g = batch_cell_g.Slice(bstart_pre, bstart); @@ -292,17 +291,6 @@ class LSTMGradKernel : public framework::OpKernel { } if (bias && bias_g) { /* backward bias */ - // Following Eigen computation failed for double type on GPU device. - // bias_g->mutable_data(ctx.GetPlace()); - // Tensor bias_mat; - // bias_mat.ShareDataWith(*bias_g); - // bias_mat.Resize({1, 4 * frame_size}); - - // auto bias_g_e = EigenVector::Flatten(bias_mat); - // auto gate_g_e = EigenMatrix::From(batch_gate_g); - // Eigen::array dims{{0}}; - // bias_g_e.device(ctx.GetEigenDevice()) = gate_g_e.sum(dims); - int m = static_cast(batch_gate_g.dims()[0]); int n = static_cast(batch_gate_g.dims()[1]); diff --git a/python/paddle/v2/framework/tests/test_lstm_op.py b/python/paddle/v2/framework/tests/test_lstm_op.py index f308ba82fa..fe7f9783e4 100644 --- a/python/paddle/v2/framework/tests/test_lstm_op.py +++ b/python/paddle/v2/framework/tests/test_lstm_op.py @@ -161,9 +161,11 @@ class TestLstmOp(OpTest): #TODO(qingqing) add more unit testing case def test_check_grad(self): - # TODO(qingqing) remove folowing two lines after the check_grad is refined. - self.outputs['BatchGate'] = None - self.outputs['BatchCellPreAct'] = None + # TODO(qingqing) remove folowing lines after the check_grad is refined. + N = len(self.lod[0]) - 1 + self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') + self.outputs['BatchCellPreAct'] = np.zeros( + (N, self.D)).astype('float64') self.check_grad( ['Input', 'Weight', 'Bias'], ['Hidden'], max_relative_error=0.02) From 1c8a0c4bd466aa2accbc6fa257142dbe76a01f6d Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Tue, 31 Oct 2017 17:26:52 +0800 Subject: [PATCH 04/13] Refine activation function pointer for LSTM operator. --- paddle/framework/CMakeLists.txt | 3 +- paddle/operators/math/detail/CMakeLists.txt | 4 +- .../math/detail/activation_functions.h | 170 ++++++++++++++++ .../{hl_avx_functions.cc => avx_functions.cc} | 22 +- .../math/detail/hl_activation_functions.h | 188 ------------------ .../operators/math/detail/hl_avx_functions.h | 32 --- .../operators/math/detail/hl_cpu_functions.cc | 89 --------- paddle/operators/math/detail/hl_functions.h | 71 ------- .../operators/math/detail/hl_gpu_functions.h | 93 --------- .../operators/math/detail/lstm_cpu_kernel.h | 28 ++- .../operators/math/detail/lstm_gpu_kernel.h | 30 ++- paddle/operators/math/detail/lstm_kernel.h | 135 +++++-------- .../paddle/v2/framework/tests/test_lstm_op.py | 4 +- 13 files changed, 279 insertions(+), 590 deletions(-) create mode 100644 paddle/operators/math/detail/activation_functions.h rename paddle/operators/math/detail/{hl_avx_functions.cc => avx_functions.cc} (84%) delete mode 100644 paddle/operators/math/detail/hl_activation_functions.h delete mode 100644 paddle/operators/math/detail/hl_avx_functions.h delete mode 100644 paddle/operators/math/detail/hl_cpu_functions.cc delete mode 100644 paddle/operators/math/detail/hl_functions.h delete mode 100644 paddle/operators/math/detail/hl_gpu_functions.h diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index f4fef055da..2be21e825a 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -20,7 +20,8 @@ cc_test(scope_test SRCS scope_test.cc DEPS scope) cc_library(attribute SRCS attribute.cc DEPS framework_proto) -cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc) +cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc +device_context) cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute) cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker) cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto) diff --git a/paddle/operators/math/detail/CMakeLists.txt b/paddle/operators/math/detail/CMakeLists.txt index 49cf228de2..92eac9d362 100644 --- a/paddle/operators/math/detail/CMakeLists.txt +++ b/paddle/operators/math/detail/CMakeLists.txt @@ -1,5 +1,3 @@ if(WITH_AVX) - cc_library(activation_functions SRCS hl_cpu_functions.cc hl_avx_functions.cc) -else() - cc_library(activation_functions SRCS hl_cpu_functions.cc) + cc_library(activation_functions SRCS avx_functions.cc) endif() diff --git a/paddle/operators/math/detail/activation_functions.h b/paddle/operators/math/detail/activation_functions.h new file mode 100644 index 0000000000..8a186a51d6 --- /dev/null +++ b/paddle/operators/math/detail/activation_functions.h @@ -0,0 +1,170 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/platform/hostdevice.h" + +#ifdef __AVX__ +#include +#endif + +namespace paddle { +namespace operators { +namespace math { +namespace detail { + +#define SIGMOID_THRESHOLD_MIN -40.0 +#define SIGMOID_THRESHOLD_MAX 13.0 +#define EXP_MAX_INPUT 40.0 + +namespace forward { + +template +DEVICE T linear(const T a) { + return a; +} + +template +DEVICE T relu(const T a) { + return a > static_cast(0.0) ? a : static_cast(0.0); +} + +template +DEVICE T sigmoid(const T a) { + const T min = SIGMOID_THRESHOLD_MIN; + const T max = SIGMOID_THRESHOLD_MAX; + T tmp = (a < min) ? min : ((a > max) ? max : a); + return static_cast(1.0) / (static_cast(1.0) + exp(-tmp)); +} + +template +DEVICE T tanh(const T a) { + T tmp = -2.0 * a; + tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; + return (2.0 / (1.0 + exp(tmp))) - 1.0; +} + +} // namespace forward + +namespace backward { + +template +DEVICE T linear(const T a, const T b) { + return a; +} + +template +DEVICE T relu(const T a, const T b) { + return a * (b > 0.0 ? 1.0 : 0.0); +} + +template +DEVICE T sigmoid(const T a, const T b) { + return a * b * (1.0 - b); +} + +template +DEVICE T tanh(const T a, const T b) { + return a * (1.0 - b * b); +} + +} // namespace backward + +template +struct Active { + typedef T (*Act)(T); + typedef T (*ActGrad)(T, T); +}; + +static DEVICE Active::Act kActFloat[] = { + &forward::sigmoid, &forward::relu, &forward::tanh, + &forward::linear}; + +static DEVICE Active::ActGrad kActGradFloat[] = { + &backward::sigmoid, &backward::relu, &backward::tanh, + &backward::linear}; + +static DEVICE Active::Act kActDouble[] = { + &forward::sigmoid, &forward::relu, &forward::tanh, + &forward::linear}; + +static DEVICE Active::ActGrad kActGradDouble[] = { + &backward::sigmoid, &backward::relu, + &backward::tanh, &backward::linear}; + +namespace forward { +inline DEVICE float activation(float a, int index) { + return kActFloat[index](a); +} + +inline DEVICE double activation(double a, int index) { + return kActDouble[index](a); +} + +} // namespace forward + +namespace backward { +inline DEVICE float activation(float a, float b, int index) { + return kActGradFloat[index](a, b); +} + +inline DEVICE double activation(double a, double b, int index) { + return kActGradDouble[index](a, b); +} +} // namespace backward + +#ifdef __AVX__ +namespace forward { +namespace avx { +__m256 relu(const __m256 a); +__m256 sigmoid(const __m256 a); +__m256 tanh(const __m256 a); +__m256 linear(const __m256 a); +} // namespace avx +} // namespace forward + +namespace backward { +namespace avx { +__m256 relu(const __m256 a, const __m256 b); +__m256 sigmoid(const __m256 a, const __m256 b); +__m256 tanh(const __m256 a, const __m256 b); +__m256 linear(const __m256 a, const __m256 b); +} // namespace avx +} // namespace backward + +static Active<__m256>::Act kActAvx[] = { + &forward::avx::sigmoid, &forward::avx::relu, &forward::avx::tanh, + &forward::avx::linear}; + +static Active<__m256>::ActGrad kActGradAvx[] = { + &backward::avx::sigmoid, &backward::avx::relu, &backward::avx::tanh, + &backward::avx::linear}; + +namespace forward { +inline __m256 activation(__m256 a, int index) { return kActAvx[index](a); } +} // namespace forward + +namespace backward { +inline __m256 activation(__m256 a, __m256 b, int index) { + return kActGradAvx[index](a, b); +} +} // namespace backward + +#endif + +} // namespace detail +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/detail/hl_avx_functions.cc b/paddle/operators/math/detail/avx_functions.cc similarity index 84% rename from paddle/operators/math/detail/hl_avx_functions.cc rename to paddle/operators/math/detail/avx_functions.cc index 415bac5d93..b8f014d30e 100644 --- a/paddle/operators/math/detail/hl_avx_functions.cc +++ b/paddle/operators/math/detail/avx_functions.cc @@ -13,14 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. */ #include -#include "hl_functions.h" +#include "paddle/operators/math/detail/activation_functions.h" // TODO(qingqing) refine this dependence #include "paddle/cuda/src/avx_mathfun.h" -namespace hppl { +namespace paddle { +namespace operators { +namespace math { +namespace detail { __m256 exp(__m256 a) { return exp256_ps(a); } +namespace forward { +namespace avx { __m256 relu(const __m256 a) { __m256 tmp = _mm256_set1_ps(0.0f); return _mm256_max_ps(a, tmp); @@ -50,6 +55,11 @@ __m256 tanh(const __m256 a) { __m256 linear(const __m256 a) { return a; } +} // namespace avx +} // namespace forward + +namespace backward { +namespace avx { __m256 relu(const __m256 a, const __m256 b) { return _mm256_mul_ps( a, _mm256_and_ps(_mm256_cmp_ps(b, _mm256_set1_ps(0.0f), _CMP_GT_OS), @@ -67,4 +77,10 @@ __m256 tanh(const __m256 a, const __m256 b) { } __m256 linear(const __m256 a, const __m256 b) { return a; } -} // namespace hppl +} // namespace avx +} // namespace backward + +} // namespace detail +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/detail/hl_activation_functions.h b/paddle/operators/math/detail/hl_activation_functions.h deleted file mode 100644 index 9d7d9914f0..0000000000 --- a/paddle/operators/math/detail/hl_activation_functions.h +++ /dev/null @@ -1,188 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifndef HL_ACTIVATION_FUNCTIONS_H_ -#define HL_ACTIVATION_FUNCTIONS_H_ - -#include "hl_functions.h" -#include "paddle/operators/math/lstm_compute.h" - -/** - * Active functions: sigmoid, relu, tanh and linear. - */ -#define FLOAT_ACTIVE_FUNCTION \ - { \ - hppl::typef::sigmoid, hppl::typef::relu, hppl::typef::tanh, \ - hppl::typef::linear \ - } - -#define DOUBLE_ACTIVE_FUNCTION \ - { \ - hppl::typed::sigmoid, hppl::typed::relu, hppl::typed::tanh, \ - hppl::typed::linear \ - } - -#define AVX_ACTIVE_FUNCTION \ - { hppl::sigmoid, hppl::relu, hppl::tanh, hppl::linear } - -namespace hppl { - -using activation_mode_t = paddle::operators::math::activation_mode_t; - -/** - * Hppl supports sigmoid, relu, tanh, linear active functions - * for neural networks' forward and backward activation. - */ -template -class Active { - public: - typedef T (*forward)(T); - typedef T (*backward)(T, T); -}; - -template -struct ForwardActType; - -template <> -struct ForwardActType { - using type = Active::forward; -}; - -template <> -struct ForwardActType { - using type = Active::forward; -}; - -template -struct BackwardActType; - -template <> -struct BackwardActType { - using type = Active::backward; -}; - -template <> -struct BackwardActType { - using type = Active::backward; -}; - -#ifdef __NVCC__ -namespace gpu { -static __device__ Active::forward forward[] = FLOAT_ACTIVE_FUNCTION; -static __device__ Active::backward backward[] = FLOAT_ACTIVE_FUNCTION; - -static __device__ Active::forward forward_d[] = DOUBLE_ACTIVE_FUNCTION; -static __device__ Active::backward backward_d[] = - DOUBLE_ACTIVE_FUNCTION; - -template -struct ForwardAct { - __device__ typename ForwardActType::type operator()( - activation_mode_t type); -}; - -template <> -struct ForwardAct { - __device__ ForwardActType::type operator()(activation_mode_t type) { - return forward[type]; - } -}; - -template <> -struct ForwardAct { - __device__ ForwardActType::type operator()(activation_mode_t type) { - return forward_d[type]; - } -}; - -template -struct BackwardAct { - __device__ typename BackwardActType::type operator()( - activation_mode_t type); -}; - -template <> -struct BackwardAct { - __device__ BackwardActType::type operator()(activation_mode_t type) { - return backward[type]; - } -}; - -template <> -struct BackwardAct { - __device__ BackwardActType::type operator()(activation_mode_t type) { - return backward_d[type]; - } -}; - -} // namespace gpu -#else -namespace cpu { -static Active::forward forward[] = FLOAT_ACTIVE_FUNCTION; -static Active::backward backward[] = FLOAT_ACTIVE_FUNCTION; - -static Active::forward forward_d[] = DOUBLE_ACTIVE_FUNCTION; -static Active::backward backward_d[] = DOUBLE_ACTIVE_FUNCTION; - -template -struct ForwardAct { - typename ForwardActType::type operator()(activation_mode_t type); -}; - -template <> -struct ForwardAct { - ForwardActType::type operator()(activation_mode_t type) { - return forward[type]; - } -}; - -template <> -struct ForwardAct { - ForwardActType::type operator()(activation_mode_t type) { - return forward_d[type]; - } -}; - -template -struct BackwardAct { - typename BackwardActType::type operator()(activation_mode_t type); -}; - -template <> -struct BackwardAct { - BackwardActType::type operator()(activation_mode_t type) { - return backward[type]; - } -}; - -template <> -struct BackwardAct { - BackwardActType::type operator()(activation_mode_t type) { - return backward_d[type]; - } -}; - -} // namespace cpu - -#ifdef __AVX__ -namespace avx { -static Active<__m256>::forward forward[] = AVX_ACTIVE_FUNCTION; -static Active<__m256>::backward backward[] = AVX_ACTIVE_FUNCTION; -} // namespace avx -#endif -#endif - -} // namespace hppl - -#endif // HL_ACTIVATION_FUNCTIONS_H_ diff --git a/paddle/operators/math/detail/hl_avx_functions.h b/paddle/operators/math/detail/hl_avx_functions.h deleted file mode 100644 index 35f4eabb4c..0000000000 --- a/paddle/operators/math/detail/hl_avx_functions.h +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifndef HL_AVX_FUNCTIONS_H_ -#define HL_AVX_FUNCTIONS_H_ - -#include - -namespace hppl { -__m256 relu(const __m256 a); -__m256 sigmoid(const __m256 a); -__m256 tanh(const __m256 a); -__m256 linear(const __m256 a); - -__m256 relu(const __m256 a, const __m256 b); -__m256 sigmoid(const __m256 a, const __m256 b); -__m256 tanh(const __m256 a, const __m256 b); -__m256 linear(const __m256 a, const __m256 b); -} // namespace hppl - -#endif // HL_AVX_FUNCTIONS_H_ diff --git a/paddle/operators/math/detail/hl_cpu_functions.cc b/paddle/operators/math/detail/hl_cpu_functions.cc deleted file mode 100644 index 21ec78f962..0000000000 --- a/paddle/operators/math/detail/hl_cpu_functions.cc +++ /dev/null @@ -1,89 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include "hl_functions.h" - -namespace hppl { -namespace typef { - -float relu(const float a) { - return a > static_cast(0.0) ? a : static_cast(0.0); -} - -float sigmoid(const float a) { - const float min = SIGMOID_THRESHOLD_MIN; - const float max = SIGMOID_THRESHOLD_MAX; - float tmp = (a < min) ? min : ((a > max) ? max : a); - return static_cast(1.0) / (static_cast(1.0) + exp(-tmp)); -} - -float tanh(const float a) { - float tmp = -2.0 * a; - tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; - return (2.0 / (1.0 + exp(tmp))) - 1.0; -} - -float linear(const float a) { return a; } - -float relu(const float a, const float b) { return a * (b > 0.0 ? 1.0 : 0.0); } - -float sigmoid(const float a, const float b) { - return a * b * (static_cast(1) - b); -} - -float tanh(const float a, const float b) { - return a * (static_cast(1) - b * b); -} - -float linear(const float a, const float b) { return a; } - -} // namespace typef - -namespace typed { -double relu(const double a) { - return a > static_cast(0.0) ? a : static_cast(0.0); -} - -double sigmoid(const double a) { - const double min = SIGMOID_THRESHOLD_MIN; - const double max = SIGMOID_THRESHOLD_MAX; - double tmp = (a < min) ? min : ((a > max) ? max : a); - return static_cast(1.0) / (static_cast(1.0) + exp(-tmp)); -} - -double tanh(const double a) { - double tmp = -2.0 * a; - tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; - return (2.0 / (1.0 + exp(tmp))) - 1.0; -} - -double linear(const double a) { return a; } - -double relu(const double a, const double b) { - return a * (b > 0.0 ? 1.0 : 0.0); -} - -double sigmoid(const double a, const double b) { - return a * b * (static_cast(1) - b); -} - -double tanh(const double a, const double b) { - return a * (static_cast(1) - b * b); -} - -double linear(const double a, const double b) { return a; } - -} // namespace typed -} // namespace hppl diff --git a/paddle/operators/math/detail/hl_functions.h b/paddle/operators/math/detail/hl_functions.h deleted file mode 100644 index 3e2f0c9ee6..0000000000 --- a/paddle/operators/math/detail/hl_functions.h +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifndef HL_FUNCTIONS_H_ -#define HL_FUNCTIONS_H_ - -/** - * sigmoid threshold maximum - */ -#define SIGMOID_THRESHOLD_MIN -40.0 - -/** - * sigmoid threshold minimum - */ -#define SIGMOID_THRESHOLD_MAX 13.0 - -/** - * The maximum input value for exp, used to avoid overflow problem. - * currently only used for tanh function. - */ -#define EXP_MAX_INPUT 40.0 - -#ifndef __NVCC__ -namespace hppl { -namespace typef { -float relu(const float a); -float sigmoid(const float a); -float tanh(const float a); -float linear(const float a); - -float relu(const float a, const float b); -float sigmoid(const float a, const float b); -float tanh(const float a, const float b); -float linear(const float a, const float b); - -} // namespace typef - -namespace typed { -double relu(const double a); -double sigmoid(const double a); -double tanh(const double a); -double linear(const double a); - -double relu(const double a, const double b); -double sigmoid(const double a, const double b); -double tanh(const double a, const double b); -double linear(const double a, const double b); -} // namespace typed - -} // namespace hppl - -#ifdef __AVX__ -#include "hl_avx_functions.h" -#endif - -#else -#include "hl_gpu_functions.h" -#endif - -#endif // HL_FUNCTIONS_H_ diff --git a/paddle/operators/math/detail/hl_gpu_functions.h b/paddle/operators/math/detail/hl_gpu_functions.h deleted file mode 100644 index 72f2204e7b..0000000000 --- a/paddle/operators/math/detail/hl_gpu_functions.h +++ /dev/null @@ -1,93 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifndef HL_GPU_FUNCTIONS_CUH_ -#define HL_GPU_FUNCTIONS_CUH_ - -#include "hl_base.h" - -namespace hppl { -namespace typef { - -__device__ static float relu(const float a) { return a > 0.0f ? a : 0.0f; } - -__device__ static float sigmoid(const float a) { - const float min = SIGMOID_THRESHOLD_MIN; - const float max = SIGMOID_THRESHOLD_MAX; - float tmp = (a < min) ? min : ((a > max) ? max : a); - return __fdividef(1.0f, 1.0f + __expf(-tmp)); -} - -__device__ static float tanh(const float a) { - float tmp = -2.0 * a; - tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; - return __fdividef(2.0f, (1.0f + __expf(-2.0f * tmp))) - 1.0f; -} - -__device__ static float linear(const float a) { return a; } - -__device__ static float relu(const float a, const float b) { - return a * (b > 0.0f ? 1.0f : 0.0f); -} - -__device__ static float sigmoid(const float a, const float b) { - return a * b * (1.0f - b); -} - -__device__ static float tanh(const float a, const float b) { - return a * (1.0f - b * b); -} - -__device__ static float linear(const float a, const float b) { return a; } - -} // namespace typef - -namespace typed { - -__device__ static double relu(const double a) { return a > 0.0 ? a : 0.0; } - -__device__ static double sigmoid(const double a) { - const double min = SIGMOID_THRESHOLD_MIN; - const double max = SIGMOID_THRESHOLD_MAX; - double tmp = (a < min) ? min : ((a > max) ? max : a); - return 1.0 / (1.0 + exp(-tmp)); -} - -__device__ static double tanh(const double a) { - double tmp = -2.0 * a; - tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; - return (2.0 / (1.0 + exp(-2.0 * a))) - 1.0; -} - -__device__ static double linear(const double a) { return a; } - -__device__ static double relu(const double a, const double b) { - return a * (b > 0.0 ? 1.0 : 0.0); -} - -__device__ static double sigmoid(const double a, const double b) { - return a * b * (1 - b); -} - -__device__ static double tanh(const double a, const double b) { - return a * (1.0 - b * b); -} - -__device__ static double linear(const double a, const double b) { return a; } - -} // namespace typef - -} // namespace hppl - -#endif // HL_GPU_FUNCTIONS_CUH_ diff --git a/paddle/operators/math/detail/lstm_cpu_kernel.h b/paddle/operators/math/detail/lstm_cpu_kernel.h index d0ed55ea16..f5b0dd85c9 100644 --- a/paddle/operators/math/detail/lstm_cpu_kernel.h +++ b/paddle/operators/math/detail/lstm_cpu_kernel.h @@ -14,7 +14,7 @@ limitations under the License. */ #pragma once #include -#include "paddle/operators/math/detail/hl_activation_functions.h" +#include "paddle/operators/math/detail/activation_functions.h" #include "paddle/operators/math/lstm_compute.h" namespace paddle { @@ -26,7 +26,10 @@ namespace detail { template void naive_lstm_forward_one_sequence(Op op, LstmMetaValue value, - int frameSize) { + int frameSize, + activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { T rValueIn; T rValueIg; T rValueFg; @@ -58,7 +61,7 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue value, } op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, - rOut, rCheckI, rCheckF, rCheckO); + rOut, rCheckI, rCheckF, rCheckO, active_node, active_gate, active_state); valueIn[i] = rValueIn; valueIg[i] = rValueIg; @@ -72,7 +75,10 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue value, template void naive_lstm_backward_one_sequence(Op op, LstmMetaValue value, - LstmMetaGrad grad, int frameSize) { + LstmMetaGrad grad, int frameSize, + activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { T rValueIn; T rValueIg; T rValueFg; @@ -122,7 +128,7 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue value, op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, - rCheckOGrad); + rCheckOGrad, active_node, active_gate, active_state); gradIn[i] = rGradIn; gradIg[i] = rGradIg; @@ -176,8 +182,7 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue value, int frameSize, } op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, - rOut, rCheckI, rCheckF, rCheckO, hppl::avx::forward[active_node], - hppl::avx::forward[active_gate], hppl::avx::forward[active_state]); + rOut, rCheckI, rCheckF, rCheckO, active_node, active_gate, active_state); valueIn[i] = rValueIn; valueIg[i] = rValueIg; @@ -246,8 +251,7 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue value, op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, - rCheckOGrad, hppl::avx::backward[active_node], - hppl::avx::backward[active_gate], hppl::avx::backward[active_state]); + rCheckOGrad, active_node, active_gate, active_state); gradIn[i] = rGradIn; gradIg[i] = rGradIg; @@ -274,7 +278,8 @@ void cpu_lstm_forward(Op op, LstmMetaValue value, int frameSize, avx_lstm_forward_one_sequence(op, value, frameSize, active_node, active_gate, active_state); } else { - naive_lstm_forward_one_sequence(op, value, frameSize); + naive_lstm_forward_one_sequence(op, value, frameSize, active_node, + active_gate, active_state); } } @@ -287,7 +292,8 @@ void cpu_lstm_backward(Op op, LstmMetaValue value, LstmMetaGrad grad, avx_lstm_backward_one_sequence(op, value, grad, frameSize, active_node, active_gate, active_state); } else { - naive_lstm_backward_one_sequence(op, value, grad, frameSize); + naive_lstm_backward_one_sequence(op, value, grad, frameSize, active_node, + active_gate, active_state); } } diff --git a/paddle/operators/math/detail/lstm_gpu_kernel.h b/paddle/operators/math/detail/lstm_gpu_kernel.h index c06f164f84..d3e5e381a5 100644 --- a/paddle/operators/math/detail/lstm_gpu_kernel.h +++ b/paddle/operators/math/detail/lstm_gpu_kernel.h @@ -13,13 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include -#include "paddle/operators/math/detail/hl_activation_functions.h" +#include "paddle/operators/math/detail/activation_functions.h" #include "paddle/operators/math/lstm_compute.h" #include "paddle/platform/cuda_helper.h" #include "paddle/platform/device_context.h" -#include +#include namespace paddle { namespace operators { @@ -32,7 +31,9 @@ namespace detail { */ template __global__ void KeLstmForward(Op op, LstmMetaValue value, int frameSize, - int batchSize) { + int batchSize, activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; if (frameIdx >= frameSize) return; @@ -69,7 +70,7 @@ __global__ void KeLstmForward(Op op, LstmMetaValue value, int frameSize, } op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv, - rOut, rCheckI, rCheckF, rCheckO); + rOut, rCheckI, rCheckF, rCheckO, active_node, active_gate, active_state); value.gateValue[frameIdx] = rValueIn; value.gateValue[frameIdx + frameSize] = rValueIg; @@ -88,7 +89,9 @@ __global__ void KeLstmForward(Op op, LstmMetaValue value, int frameSize, template __global__ void KeLstmBackward(Op op, LstmMetaValue value, LstmMetaGrad grad, int frameSize, - int batchSize) { + int batchSize, activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; if (frameIdx >= frameSize) return; @@ -141,7 +144,8 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue value, op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad, - rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad); + rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad, + active_node, active_gate, active_state); grad.gateGrad[frameIdx] = rGradIn; grad.gateGrad[frameIdx + frameSize] = rGradIg; @@ -197,11 +201,13 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op, if (batchSize == 1) { KeLstmForward<<>>( - op, value, frameSize, batchSize); + op, value, frameSize, batchSize, active_node, active_gate, + active_state); } else { KeLstmForward<<>>( - op, value, frameSize, batchSize); + op, value, frameSize, batchSize, active_node, active_gate, + active_state); } } @@ -230,11 +236,13 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op, if (batchSize == 1) { KeLstmBackward<<>>( - op, value, grad, frameSize, batchSize); + op, value, grad, frameSize, batchSize, active_node, active_gate, + active_state); } else { KeLstmBackward<<>>( - op, value, grad, frameSize, batchSize); + op, value, grad, frameSize, batchSize, active_node, active_gate, + active_state); } } diff --git a/paddle/operators/math/detail/lstm_kernel.h b/paddle/operators/math/detail/lstm_kernel.h index 461039a4d5..9daaf91981 100644 --- a/paddle/operators/math/detail/lstm_kernel.h +++ b/paddle/operators/math/detail/lstm_kernel.h @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/math/detail/hl_activation_functions.h" +#include "paddle/operators/math/detail/activation_functions.h" #include "paddle/platform/hostdevice.h" #include @@ -24,45 +24,22 @@ namespace detail { namespace forward { -template -DEVICE inline T sigmoid(const T a) { - const T min = SIGMOID_THRESHOLD_MIN; - const T max = SIGMOID_THRESHOLD_MAX; - T tmp = (a < min) ? min : ((a > max) ? max : a); - return static_cast(1.0) / (static_cast(1.0) + exp(-tmp)); -} - -template -DEVICE inline T tanh(const T a) { - T tmp = -2.0 * a; - tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; - return (2.0 / (1.0 + exp(tmp))) - 1.0; -} - template class lstm { public: HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg, T &prevState, T &state, T &stateAtv, T &output, - T &checkI, T &checkF, T &checkO) { -#if 0 - // TODO(qingqing) support to activation speficed by users - valueIn = actInput(valueIn); - valueIg = actGate(valueIg + prevState * checkI); - valueFg = actGate(valueFg + prevState * checkF); - state = valueIn * valueIg + prevState * valueFg; - valueOg = actGate(valueOg + state * checkO); - stateAtv = actState(state); - output = valueOg * stateAtv; -#else - valueIn = tanh(valueIn); - valueIg = sigmoid(valueIg + prevState * checkI); - valueFg = sigmoid(valueFg + prevState * checkF); + T &checkI, T &checkF, T &checkO, + activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { + valueIn = activation(valueIn, active_node); + valueIg = activation(valueIg + prevState * checkI, active_gate); + valueFg = activation(valueFg + prevState * checkF, active_gate); state = valueIn * valueIg + prevState * valueFg; - valueOg = sigmoid(valueOg + state * checkO); - stateAtv = tanh(state); + valueOg = activation(valueOg + state * checkO, active_gate); + stateAtv = activation(state, active_state); output = valueOg * stateAtv; -#endif } #ifndef __NVCC__ #ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default @@ -75,16 +52,19 @@ class lstm { __m256 &valueOg, __m256 &prevState, __m256 &state, __m256 &stateAtv, __m256 &output, __m256 &checkI, __m256 &checkF, __m256 &checkO, - hppl::Active<__m256>::forward actInput, - hppl::Active<__m256>::forward actGate, - hppl::Active<__m256>::forward actState) { - valueIn = actInput(valueIn); - valueIg = actGate(_mm256_add_ps(valueIg, _mm256_mul_ps(prevState, checkI))); - valueFg = actGate(_mm256_add_ps(valueFg, _mm256_mul_ps(prevState, checkF))); + activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { + valueIn = activation(valueIn, active_node); + valueIg = activation( + _mm256_add_ps(valueIg, _mm256_mul_ps(prevState, checkI)), active_gate); + valueFg = activation( + _mm256_add_ps(valueFg, _mm256_mul_ps(prevState, checkF)), active_gate); state = _mm256_add_ps(_mm256_mul_ps(valueIn, valueIg), _mm256_mul_ps(prevState, valueFg)); - valueOg = actGate(_mm256_add_ps(valueOg, _mm256_mul_ps(state, checkO))); - stateAtv = actState(state); + valueOg = activation(_mm256_add_ps(valueOg, _mm256_mul_ps(state, checkO)), + active_gate); + stateAtv = activation(state, active_state); output = _mm256_mul_ps(valueOg, stateAtv); } #endif @@ -95,16 +75,6 @@ class lstm { namespace backward { -template -DEVICE inline T sigmoid(const T a, const T b) { - return a * b * (1.0 - b); -} - -template -DEVICE inline T tanh(const T a, const T b) { - return a * (1.0 - b * b); -} - template class lstm { public: @@ -113,29 +83,20 @@ class lstm { T &prevState, T &prevStateGrad, T &state, T &stateGrad, T &stateAtv, T &outputGrad, T &checkI, T &checkF, T &checkO, T &checkIGrad, - T &checkFGrad, T &checkOGrad) { -#if 0 - // TODO(qingqing) support to activation speficed by users - gradOg = actGate(outputGrad * stateAtv, valueOg); - stateGrad += actState(outputGrad * valueOg, stateAtv) + gradOg * checkO; - gradIn = actInput(stateGrad * valueIg, valueIn); - gradIg = actGate(stateGrad * valueIn, valueIg); - gradFg = actGate(stateGrad * prevState, valueFg); + T &checkFGrad, T &checkOGrad, + activation_mode_t active_node, + activation_mode_t active_gate, + activation_mode_t active_state) { + gradOg = activation(outputGrad * stateAtv, valueOg, active_gate); + stateGrad += activation(outputGrad * valueOg, stateAtv, active_state) + + gradOg * checkO; + gradIn = activation(stateGrad * valueIg, valueIn, active_node); + gradIg = activation(stateGrad * valueIn, valueIg, active_gate); + gradFg = activation(stateGrad * prevState, valueFg, active_gate); prevStateGrad = gradIg * checkI + gradFg * checkF + stateGrad * valueFg; checkIGrad = gradIg * prevState; checkFGrad = gradFg * prevState; checkOGrad = gradOg * state; -#else - gradOg = sigmoid(outputGrad * stateAtv, valueOg); - stateGrad += tanh(outputGrad * valueOg, stateAtv) + gradOg * checkO; - gradIn = tanh(stateGrad * valueIg, valueIn); - gradIg = sigmoid(stateGrad * valueIn, valueIg); - gradFg = sigmoid(stateGrad * prevState, valueFg); - prevStateGrad = gradIg * checkI + gradFg * checkF + stateGrad * valueFg; - checkIGrad = gradIg * prevState; - checkFGrad = gradFg * prevState; - checkOGrad = gradOg * state; -#endif } #ifndef __NVCC__ #ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default @@ -143,24 +104,26 @@ class lstm { #else // Only float support AVX optimization static const bool avx = std::is_same::value; - HOSTDEVICE void operator()(__m256 &valueIn, __m256 &valueIg, __m256 &valueFg, - __m256 &valueOg, __m256 &gradIn, __m256 &gradIg, - __m256 &gradFg, __m256 &gradOg, __m256 &prevState, - __m256 &prevStateGrad, __m256 &state, - __m256 &stateGrad, __m256 &stateAtv, - __m256 &outputGrad, __m256 &checkI, __m256 &checkF, - __m256 &checkO, __m256 &checkIGrad, - __m256 &checkFGrad, __m256 &checkOGrad, - hppl::Active<__m256>::backward actInput, - hppl::Active<__m256>::backward actGate, - hppl::Active<__m256>::backward actState) { - gradOg = actGate(_mm256_mul_ps(outputGrad, stateAtv), valueOg); + HOSTDEVICE void operator()( + __m256 &valueIn, __m256 &valueIg, __m256 &valueFg, __m256 &valueOg, + __m256 &gradIn, __m256 &gradIg, __m256 &gradFg, __m256 &gradOg, + __m256 &prevState, __m256 &prevStateGrad, __m256 &state, + __m256 &stateGrad, __m256 &stateAtv, __m256 &outputGrad, __m256 &checkI, + __m256 &checkF, __m256 &checkO, __m256 &checkIGrad, __m256 &checkFGrad, + __m256 &checkOGrad, activation_mode_t active_node, + activation_mode_t active_gate, activation_mode_t active_state) { + gradOg = + activation(_mm256_mul_ps(outputGrad, stateAtv), valueOg, active_gate); stateGrad = _mm256_add_ps( - actState(_mm256_mul_ps(outputGrad, valueOg), stateAtv), stateGrad); + activation(_mm256_mul_ps(outputGrad, valueOg), stateAtv, active_state), + stateGrad); stateGrad = _mm256_add_ps(_mm256_mul_ps(gradOg, checkO), stateGrad); - gradIn = actInput(_mm256_mul_ps(stateGrad, valueIg), valueIn); - gradIg = actGate(_mm256_mul_ps(stateGrad, valueIn), valueIg); - gradFg = actGate(_mm256_mul_ps(stateGrad, prevState), valueFg); + gradIn = + activation(_mm256_mul_ps(stateGrad, valueIg), valueIn, active_node); + gradIg = + activation(_mm256_mul_ps(stateGrad, valueIn), valueIg, active_gate); + gradFg = + activation(_mm256_mul_ps(stateGrad, prevState), valueFg, active_gate); prevStateGrad = _mm256_add_ps(_mm256_mul_ps(gradIg, checkI), _mm256_mul_ps(gradFg, checkF)); prevStateGrad = diff --git a/python/paddle/v2/framework/tests/test_lstm_op.py b/python/paddle/v2/framework/tests/test_lstm_op.py index fe7f9783e4..ff75160083 100644 --- a/python/paddle/v2/framework/tests/test_lstm_op.py +++ b/python/paddle/v2/framework/tests/test_lstm_op.py @@ -157,7 +157,7 @@ class TestLstmOp(OpTest): } def test_check_output(self): - self.check_output() + self.check_output(atol=1e-8) #TODO(qingqing) add more unit testing case def test_check_grad(self): @@ -167,7 +167,7 @@ class TestLstmOp(OpTest): self.outputs['BatchCellPreAct'] = np.zeros( (N, self.D)).astype('float64') self.check_grad( - ['Input', 'Weight', 'Bias'], ['Hidden'], max_relative_error=0.02) + ['Input', 'Weight', 'Bias'], ['Hidden'], max_relative_error=5e-4) class TestLstmOpHasNoInitial(TestLstmOp): From 1f53a72f10c9d4781932d7d4a842a9993106a8d3 Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Thu, 2 Nov 2017 00:21:04 +0800 Subject: [PATCH 05/13] Reduce the threads number in the LSTM backward kernel to fix the error occurred in GPU GTX 1080. --- paddle/operators/math/detail/lstm_gpu_kernel.h | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/paddle/operators/math/detail/lstm_gpu_kernel.h b/paddle/operators/math/detail/lstm_gpu_kernel.h index d3e5e381a5..e07655eaac 100644 --- a/paddle/operators/math/detail/lstm_gpu_kernel.h +++ b/paddle/operators/math/detail/lstm_gpu_kernel.h @@ -227,7 +227,7 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op, grid = dim3(frameBlocks, 1); } else { /* framePerBlock = 32 batchPerBlock = 32 */ - threads = dim3(32, 32); + threads = dim3(32, 16); grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32); } @@ -244,6 +244,11 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op, op, value, grad, frameSize, batchSize, active_node, active_gate, active_state); } + + cudaStreamSynchronize(stream); + // TODO(qingqing): Add cuda error check for each kernel. + cudaError_t err = cudaGetLastError(); + PADDLE_ENFORCE_EQ(err, cudaGetErrorString(err)); } } // namespace detail From 5a4cdbb3dfb2de82ed6864d38a4381c52d4dba4c Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Thu, 2 Nov 2017 00:30:12 +0800 Subject: [PATCH 06/13] Fix check bug. --- paddle/operators/math/detail/lstm_gpu_kernel.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/operators/math/detail/lstm_gpu_kernel.h b/paddle/operators/math/detail/lstm_gpu_kernel.h index e07655eaac..1781460c35 100644 --- a/paddle/operators/math/detail/lstm_gpu_kernel.h +++ b/paddle/operators/math/detail/lstm_gpu_kernel.h @@ -248,7 +248,7 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op, cudaStreamSynchronize(stream); // TODO(qingqing): Add cuda error check for each kernel. cudaError_t err = cudaGetLastError(); - PADDLE_ENFORCE_EQ(err, cudaGetErrorString(err)); + PADDLE_ENFORCE(err, cudaGetErrorString(err)); } } // namespace detail From 29a9f9b5ea3689ec67bed5c2f39c4a33e4743b2e Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Thu, 2 Nov 2017 12:14:05 +0800 Subject: [PATCH 07/13] Refine code format and fix threads number. --- .../math/detail/activation_functions.h | 56 +++++++++---------- paddle/operators/math/detail/avx_functions.cc | 22 ++++---- .../operators/math/detail/lstm_gpu_kernel.h | 4 +- 3 files changed, 41 insertions(+), 41 deletions(-) diff --git a/paddle/operators/math/detail/activation_functions.h b/paddle/operators/math/detail/activation_functions.h index 8a186a51d6..a20c35d1d9 100644 --- a/paddle/operators/math/detail/activation_functions.h +++ b/paddle/operators/math/detail/activation_functions.h @@ -32,17 +32,17 @@ namespace detail { namespace forward { template -DEVICE T linear(const T a) { +DEVICE T Identity(const T a) { return a; } template -DEVICE T relu(const T a) { +DEVICE T Relu(const T a) { return a > static_cast(0.0) ? a : static_cast(0.0); } template -DEVICE T sigmoid(const T a) { +DEVICE T Sigmoid(const T a) { const T min = SIGMOID_THRESHOLD_MIN; const T max = SIGMOID_THRESHOLD_MAX; T tmp = (a < min) ? min : ((a > max) ? max : a); @@ -50,7 +50,7 @@ DEVICE T sigmoid(const T a) { } template -DEVICE T tanh(const T a) { +DEVICE T Tanh(const T a) { T tmp = -2.0 * a; tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp; return (2.0 / (1.0 + exp(tmp))) - 1.0; @@ -61,22 +61,22 @@ DEVICE T tanh(const T a) { namespace backward { template -DEVICE T linear(const T a, const T b) { +DEVICE T Identity(const T a, const T b) { return a; } template -DEVICE T relu(const T a, const T b) { +DEVICE T Relu(const T a, const T b) { return a * (b > 0.0 ? 1.0 : 0.0); } template -DEVICE T sigmoid(const T a, const T b) { +DEVICE T Sigmoid(const T a, const T b) { return a * b * (1.0 - b); } template -DEVICE T tanh(const T a, const T b) { +DEVICE T Tanh(const T a, const T b) { return a * (1.0 - b * b); } @@ -89,20 +89,20 @@ struct Active { }; static DEVICE Active::Act kActFloat[] = { - &forward::sigmoid, &forward::relu, &forward::tanh, - &forward::linear}; + &forward::Sigmoid, &forward::Relu, &forward::Tanh, + &forward::Identity}; static DEVICE Active::ActGrad kActGradFloat[] = { - &backward::sigmoid, &backward::relu, &backward::tanh, - &backward::linear}; + &backward::Sigmoid, &backward::Relu, &backward::Tanh, + &backward::Identity}; static DEVICE Active::Act kActDouble[] = { - &forward::sigmoid, &forward::relu, &forward::tanh, - &forward::linear}; + &forward::Sigmoid, &forward::Relu, &forward::Tanh, + &forward::Identity}; static DEVICE Active::ActGrad kActGradDouble[] = { - &backward::sigmoid, &backward::relu, - &backward::tanh, &backward::linear}; + &backward::Sigmoid, &backward::Relu, + &backward::Tanh, &backward::Identity}; namespace forward { inline DEVICE float activation(float a, int index) { @@ -128,29 +128,29 @@ inline DEVICE double activation(double a, double b, int index) { #ifdef __AVX__ namespace forward { namespace avx { -__m256 relu(const __m256 a); -__m256 sigmoid(const __m256 a); -__m256 tanh(const __m256 a); -__m256 linear(const __m256 a); +__m256 Relu(const __m256 a); +__m256 Sigmoid(const __m256 a); +__m256 Tanh(const __m256 a); +__m256 Identity(const __m256 a); } // namespace avx } // namespace forward namespace backward { namespace avx { -__m256 relu(const __m256 a, const __m256 b); -__m256 sigmoid(const __m256 a, const __m256 b); -__m256 tanh(const __m256 a, const __m256 b); -__m256 linear(const __m256 a, const __m256 b); +__m256 Relu(const __m256 a, const __m256 b); +__m256 Sigmoid(const __m256 a, const __m256 b); +__m256 Tanh(const __m256 a, const __m256 b); +__m256 Identity(const __m256 a, const __m256 b); } // namespace avx } // namespace backward static Active<__m256>::Act kActAvx[] = { - &forward::avx::sigmoid, &forward::avx::relu, &forward::avx::tanh, - &forward::avx::linear}; + &forward::avx::Sigmoid, &forward::avx::Relu, &forward::avx::Tanh, + &forward::avx::Identity}; static Active<__m256>::ActGrad kActGradAvx[] = { - &backward::avx::sigmoid, &backward::avx::relu, &backward::avx::tanh, - &backward::avx::linear}; + &backward::avx::Sigmoid, &backward::avx::Relu, &backward::avx::Tanh, + &backward::avx::Identity}; namespace forward { inline __m256 activation(__m256 a, int index) { return kActAvx[index](a); } diff --git a/paddle/operators/math/detail/avx_functions.cc b/paddle/operators/math/detail/avx_functions.cc index b8f014d30e..6d9df654a4 100644 --- a/paddle/operators/math/detail/avx_functions.cc +++ b/paddle/operators/math/detail/avx_functions.cc @@ -22,61 +22,61 @@ namespace operators { namespace math { namespace detail { -__m256 exp(__m256 a) { return exp256_ps(a); } +__m256 Exp(__m256 a) { return exp256_ps(a); } namespace forward { namespace avx { -__m256 relu(const __m256 a) { +__m256 Relu(const __m256 a) { __m256 tmp = _mm256_set1_ps(0.0f); return _mm256_max_ps(a, tmp); } -__m256 sigmoid(const __m256 a) { +__m256 Sigmoid(const __m256 a) { __m256 max = _mm256_set1_ps(SIGMOID_THRESHOLD_MAX); __m256 min = _mm256_set1_ps(SIGMOID_THRESHOLD_MIN); __m256 tmp = _mm256_max_ps(a, min); tmp = _mm256_min_ps(tmp, max); tmp = _mm256_sub_ps(_mm256_set1_ps(0.0f), tmp); - tmp = exp(tmp); + tmp = Exp(tmp); tmp = _mm256_add_ps(_mm256_set1_ps(1.0f), tmp); tmp = _mm256_div_ps(_mm256_set1_ps(1.0f), tmp); return tmp; } -__m256 tanh(const __m256 a) { +__m256 Tanh(const __m256 a) { __m256 max = _mm256_set1_ps(EXP_MAX_INPUT); __m256 tmp = _mm256_mul_ps(_mm256_set1_ps(-2.0f), a); tmp = _mm256_min_ps(tmp, max); - tmp = exp(tmp); + tmp = Exp(tmp); return _mm256_sub_ps(_mm256_div_ps(_mm256_set1_ps(2.0f), _mm256_add_ps(_mm256_set1_ps(1.0f), tmp)), _mm256_set1_ps(1.0f)); } -__m256 linear(const __m256 a) { return a; } +__m256 Identity(const __m256 a) { return a; } } // namespace avx } // namespace forward namespace backward { namespace avx { -__m256 relu(const __m256 a, const __m256 b) { +__m256 Relu(const __m256 a, const __m256 b) { return _mm256_mul_ps( a, _mm256_and_ps(_mm256_cmp_ps(b, _mm256_set1_ps(0.0f), _CMP_GT_OS), _mm256_set1_ps(1.0f))); } -__m256 sigmoid(const __m256 a, const __m256 b) { +__m256 Sigmoid(const __m256 a, const __m256 b) { return _mm256_mul_ps(_mm256_mul_ps(a, b), _mm256_sub_ps(_mm256_set1_ps(1.0f), b)); } -__m256 tanh(const __m256 a, const __m256 b) { +__m256 Tanh(const __m256 a, const __m256 b) { return _mm256_mul_ps( a, _mm256_sub_ps(_mm256_set1_ps(1.0f), _mm256_mul_ps(b, b))); } -__m256 linear(const __m256 a, const __m256 b) { return a; } +__m256 Identity(const __m256 a, const __m256 b) { return a; } } // namespace avx } // namespace backward diff --git a/paddle/operators/math/detail/lstm_gpu_kernel.h b/paddle/operators/math/detail/lstm_gpu_kernel.h index 1781460c35..41a54a359d 100644 --- a/paddle/operators/math/detail/lstm_gpu_kernel.h +++ b/paddle/operators/math/detail/lstm_gpu_kernel.h @@ -226,9 +226,9 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op, threads = dim3(framePerBlock, 1); grid = dim3(frameBlocks, 1); } else { - /* framePerBlock = 32 batchPerBlock = 32 */ + /* framePerBlock = 32 batchPerBlock = 16 */ threads = dim3(32, 16); - grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32); + grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 16 - 1) / 16); } auto stream = From 0d79e9732d5215a1f68080c97675af839b5a2470 Mon Sep 17 00:00:00 2001 From: wangmeng28 Date: Thu, 2 Nov 2017 20:29:11 +0800 Subject: [PATCH 08/13] Refine the log message in fc layer --- python/paddle/trainer_config_helpers/layers.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index aebdcc134b..11809a7e98 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -1046,8 +1046,11 @@ def fc_layer(input, else: if "parameter_name" in param_attr.attr and len(input) > 1: logger.fatal( - "You should set the parameter name for each of the input item." - ) + "When the name field of param_attr is manually specified " + "and the input is a list, the param_attr should also be a " + "list with each item being the param_attr for each input " + "item. If only one named param_attr is provided, all the " + "input items would share this parameter.") param_attr = [copy.deepcopy(param_attr) for _ in range(len(input))] assert isinstance(input, collections.Sequence) @@ -4869,8 +4872,11 @@ def selective_fc_layer(input, else: if "parameter_name" in param_attr.attr and len(input) > 1: logger.fatal( - "You should set the parameter name for each of the input item." - ) + "When the name field of param_attr is manually specified " + "and the input is a list, the param_attr should also be a " + "list with each item being the param_attr for each input " + "item. If only one named param_attr is provided, all the " + "input items would share this parameter.") param_attr = [copy.deepcopy(param_attr) for _ in range(len(input))] assert isinstance(input, collections.Sequence) From 81c7dbc5446f861489d70fece73d33418c5eab66 Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Thu, 2 Nov 2017 10:36:56 -0700 Subject: [PATCH 09/13] design doc for float16 --- doc/design/float16.md | 46 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 doc/design/float16.md diff --git a/doc/design/float16.md b/doc/design/float16.md new file mode 100644 index 0000000000..07f0d66e44 --- /dev/null +++ b/doc/design/float16.md @@ -0,0 +1,46 @@ +# Design Doc: float16 + +## Why float16 +Half precision (float16) is a binary floating-point format that occupies 16 bits / 2 bytes in memory. float16 is half the size of traditional 32-bit single precision format (float) and has lower precision and smaller range. + +When high precision computation is not required, using float16 data type could potentially + +- reduce storage space, memory bandwidth, and power usages; +- increase the chance of data fitting into a smaller cache of lower latency; +- provide arithmetic speed up if supported by hardware. + +A brief survey of float16 support on different hardwares can be found [here](https://github.com/PaddlePaddle/Paddle/issues/4853). A brief survey of existing float16 implementations can be found [here](https://github.com/Xreki/Xreki.github.io/blob/master/multi_data_types_in_dl_framework/ppt/float16_and_quantized_type.md). + +There are various natively supported float16 implementations on different hardwares/linear algebra libraries including half on cuda, __fp16/float16_t on ARM processor, and Eigen::half on Eigen. + +The goal of float16 is to serve as a key for the executor to find and run the correct version of operator kernel compute method specialized for float16. It should be compatible with half on cuda, __fp16 on ARM, and Eigen::half on Eigen to make writing customized float16 kernels easier. + +## Implementation +The float16 class holds a 2-byte uint16_t data internally. +``` +struct float16 { + uint16_t x; +}; +``` + +float16 supports the following features: + - constructors / assignment operators that take input from primitive data types including bool, integers of various length, float, and double. + - constructors / assignment operators that take input from half on cuda, __fp16 on ARM, and Eigen::half on Eigen. + - conversion operators to primitive data types and half precision data types on cuda, ARM and Eigen. + - overloaded arithmetic operators (e.g., +, -, *, /) for cuda, arm, and non-arm cpu, respectively. These operators will take advantage of the cuda and ARM intrinsics on the corresponding hardware. + +To support the above features, two fundamental conversion functions are provided: +``` +float16 float_to_half_rn(float f); // convert to half precision in round-to-nearest-even mode +float half_to_float(float16 h); +``` +which provides one-to-one conversion between float32 and float16. These twos functions will do different conversion routines based on the current hardware. CUDA/ARM instrinsics will be used when the corresonding hardware is available. When the hardware falls back to non-ARM cpu, software emulation will be performed to do the conversion. + +## To do +After float16 class is available, some of the future items are below: + +- Update pybind/tensor_py.h to bind c++ float16 with numpy float16. + +- Modify `IndicateDataType()` method in `framework/operator.h` to make it compatible with float16. + +- Create a type-casting operator that can convert the data type in tensor between float16 and other types. From 81ba077e7b29642ec5a4e847384c4694364a732f Mon Sep 17 00:00:00 2001 From: Kexin Zhao Date: Thu, 2 Nov 2017 10:44:23 -0700 Subject: [PATCH 10/13] small fix --- doc/design/float16.md | 46 ------ paddle/operators/activation_op.cc | 238 ++++++++++++++++++++++-------- paddle/operators/activation_op.h | 2 +- 3 files changed, 174 insertions(+), 112 deletions(-) delete mode 100644 doc/design/float16.md diff --git a/doc/design/float16.md b/doc/design/float16.md deleted file mode 100644 index 07f0d66e44..0000000000 --- a/doc/design/float16.md +++ /dev/null @@ -1,46 +0,0 @@ -# Design Doc: float16 - -## Why float16 -Half precision (float16) is a binary floating-point format that occupies 16 bits / 2 bytes in memory. float16 is half the size of traditional 32-bit single precision format (float) and has lower precision and smaller range. - -When high precision computation is not required, using float16 data type could potentially - -- reduce storage space, memory bandwidth, and power usages; -- increase the chance of data fitting into a smaller cache of lower latency; -- provide arithmetic speed up if supported by hardware. - -A brief survey of float16 support on different hardwares can be found [here](https://github.com/PaddlePaddle/Paddle/issues/4853). A brief survey of existing float16 implementations can be found [here](https://github.com/Xreki/Xreki.github.io/blob/master/multi_data_types_in_dl_framework/ppt/float16_and_quantized_type.md). - -There are various natively supported float16 implementations on different hardwares/linear algebra libraries including half on cuda, __fp16/float16_t on ARM processor, and Eigen::half on Eigen. - -The goal of float16 is to serve as a key for the executor to find and run the correct version of operator kernel compute method specialized for float16. It should be compatible with half on cuda, __fp16 on ARM, and Eigen::half on Eigen to make writing customized float16 kernels easier. - -## Implementation -The float16 class holds a 2-byte uint16_t data internally. -``` -struct float16 { - uint16_t x; -}; -``` - -float16 supports the following features: - - constructors / assignment operators that take input from primitive data types including bool, integers of various length, float, and double. - - constructors / assignment operators that take input from half on cuda, __fp16 on ARM, and Eigen::half on Eigen. - - conversion operators to primitive data types and half precision data types on cuda, ARM and Eigen. - - overloaded arithmetic operators (e.g., +, -, *, /) for cuda, arm, and non-arm cpu, respectively. These operators will take advantage of the cuda and ARM intrinsics on the corresponding hardware. - -To support the above features, two fundamental conversion functions are provided: -``` -float16 float_to_half_rn(float f); // convert to half precision in round-to-nearest-even mode -float half_to_float(float16 h); -``` -which provides one-to-one conversion between float32 and float16. These twos functions will do different conversion routines based on the current hardware. CUDA/ARM instrinsics will be used when the corresonding hardware is available. When the hardware falls back to non-ARM cpu, software emulation will be performed to do the conversion. - -## To do -After float16 class is available, some of the future items are below: - -- Update pybind/tensor_py.h to bind c++ float16 with numpy float16. - -- Modify `IndicateDataType()` method in `framework/operator.h` to make it compatible with float16. - -- Create a type-casting operator that can convert the data type in tensor between float16 and other types. diff --git a/paddle/operators/activation_op.cc b/paddle/operators/activation_op.cc index 90f1535fcd..483f988897 100644 --- a/paddle/operators/activation_op.cc +++ b/paddle/operators/activation_op.cc @@ -43,7 +43,12 @@ class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Sigmoid operator"); AddOutput("Y", "Output of Sigmoid operator"); - AddComment("Sigmoid activation operator, sigmoid = 1 / (1 + exp(-x))"); + AddComment(R"DOC( +Sigmoid activation operator. + +$y = 1 / (1 + e^{-x})$ + +)DOC"); } }; @@ -54,8 +59,12 @@ class LogSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of LogSigmoid operator"); AddOutput("Y", "Output of LogSigmoid operator"); - AddComment( - "Logsigmoid activation operator, logsigmoid = log (1 / (1 + exp(-x)))"); + AddComment(R"DOC( +Logsigmoid activation operator. + +$y = \log(1 / (1 + e^{-x}))$ + +)DOC"); } }; @@ -65,7 +74,12 @@ class ExpOpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Exp operator"); AddOutput("Y", "Output of Exp operator"); - AddComment("Exp activation operator, exp(x) = e^x"); + AddComment(R"DOC( +Exp activation operator. + +$y = e^x$ + +)DOC"); } }; @@ -75,7 +89,12 @@ class ReluOpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Relu operator"); AddOutput("Y", "Output of Relu operator"); - AddComment("Relu activation operator, relu(x) = max(x, 0)"); + AddComment(R"DOC( +Relu activation operator. + +$y = \max(x, 0)$ + +)DOC"); } }; @@ -87,11 +106,14 @@ class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of LeakyRelu operator"); AddOutput("Y", "Output of LeakyRelu operator"); - AddComment( - "LeakyRelu activation operator, " - "leaky_relu = max(x, alpha * x)"); AddAttr("alpha", "The small negative slope") .SetDefault(static_cast(0.02f)); + AddComment(R"DOC( +LeakyRelu activation operator. + +$y = \max(x, \alpha * x)$ + +)DOC"); } }; @@ -103,12 +125,20 @@ class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Softshrink operator"); AddOutput("Y", "Output of Softshrink operator"); - AddComment( - "Softshrink activation operator, " - "softshrink = x - lambda, if x > lambda;" - " x + lambda, if x < lambda; 0 otherwise"); AddAttr("lambda", "non-negative offset") .SetDefault(static_cast(0.5f)); + AddComment(R"DOC( +Softshrink activation operator. + +$$ +y = \begin{cases} + x - \lambda, \text{if } x > \lambda \\ + x + \lambda, \text{if } x < -\lambda \\ + 0, \text{otherwise} + \end{cases} +$$ + +)DOC"); } }; @@ -118,9 +148,12 @@ class TanhOpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Tanh operator"); AddOutput("Y", "Output of Tanh operator"); - AddComment( - "Tanh activation operator, tanh = (exp(x) - exp(-x)) / (exp(x) + " - "exp(-x))"); + AddComment(R"DOC( +Tanh activation operator. + +$$y = \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$ + +)DOC"); } }; @@ -131,7 +164,12 @@ class TanhShrinkOpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of TanhShrink operator"); AddOutput("Y", "Output of TanhShrink operator"); - AddComment("TanhShrink activation operator, tanhshrink(x) = x - tanh(x)"); + AddComment(R"DOC( +TanhShrink activation operator. + +$$y = x - \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$ + +)DOC"); } }; @@ -143,13 +181,20 @@ class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of HardShrink operator"); AddOutput("Y", "Output of HardShrink operator"); - AddComment( - "HardShrink activation operator, " - "hard_shrink(x) = x if x > lambda" - "hard_shrink(x) = x if x < -lambda" - "hard_shrink(x) = 0 otherwise"); AddAttr("threshold", "The value of threshold for HardShrink") .SetDefault(static_cast(0.5)); + AddComment(R"DOC( +HardShrink activation operator. + +$$ +y = \begin{cases} + x, \text{if } x > \lambda \\ + x, \text{if } x < -\lambda \\ + 0, \text{otherwise} + \end{cases} +$$ + +)DOC"); } }; @@ -159,7 +204,12 @@ class SqrtOpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Sqrt operator"); AddOutput("Y", "Output of Sqrt operator"); - AddComment("Sqrt activation operator, sqrt(x) = x^(1/2)"); + AddComment(R"DOC( +Sqrt activation operator. + +$y = \sqrt{x}$ + +)DOC"); } }; @@ -169,7 +219,12 @@ class AbsOpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Abs operator"); AddOutput("Y", "Output of Abs operator"); - AddComment("Abs activation operator, abs(x) = |x|"); + AddComment(R"DOC( +Abs activation operator. + +$y = |x|$ + +)DOC"); } }; @@ -180,7 +235,12 @@ class ReciprocalOpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Reciprocal operator"); AddOutput("Y", "Output of Reciprocal operator"); - AddComment("Reciprocal activation operator, reciprocal(x) = 1 / x"); + AddComment(R"DOC( +Reciprocal activation operator. + +$$y = \frac{1}{x}$$ + +)DOC"); } }; @@ -190,7 +250,14 @@ class LogOpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Log operator"); AddOutput("Y", "Output of Log operator"); - AddComment("Log activation operator, log(x) = natural logarithm of x"); + AddComment(R"DOC( +Log activation operator. + +$y = \ln(x)$ + +Natural logarithm of x. + +)DOC"); } }; @@ -200,7 +267,12 @@ class SquareOpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Square operator"); AddOutput("Y", "Output of Square operator"); - AddComment("Square activation operator, square(x) = x^2"); + AddComment(R"DOC( +Square activation operator. + +$y = x^2$ + +)DOC"); } }; @@ -211,7 +283,12 @@ class SoftplusOpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Softplus operator"); AddOutput("Y", "Output of Softplus operator"); - AddComment("Softplus activation operator, softplus(x) = log(1 + exp(x))"); + AddComment(R"DOC( +Softplus activation operator. + +$y = \ln(1 + e^{x})$ + +)DOC"); } }; @@ -222,7 +299,12 @@ class SoftsignOpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Softsign operator"); AddOutput("Y", "Output of Softsign operator"); - AddComment("Softsign activation operator, softsign(x) = x / (1 + |x|)"); + AddComment(R"DOC( +Softsign activation operator. + +$$y = \frac{x}{1 + |x|}$$ + +)DOC"); } }; @@ -233,11 +315,16 @@ class BReluOpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of BRelu operator"); AddOutput("Y", "Output of BRelu operator"); - AddComment("BRelu activation operator, brelu = max(min(x, t_min), t_max)"); AddAttr("t_min", "The min marginal value of BRelu") .SetDefault(static_cast(0)); AddAttr("t_max", "The max marginal value of BRelu") .SetDefault(static_cast(24)); + AddComment(R"DOC( +BRelu activation operator. + +$y = \max(\min(x, t_{min}), t_{max})$ + +)DOC"); } }; @@ -249,11 +336,14 @@ class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of SoftRelu operator"); AddOutput("Y", "Output of SoftRelu operator"); - AddComment( - "SoftRelu activation operator, soft_relu = log(1 + exp(max(min(x, " - "threshold), threshold)))"); AddAttr("threshold", "The threshold value of SoftRelu") .SetDefault(static_cast(40)); + AddComment(R"DOC( +SoftRelu activation operator. + +$y = \ln(1 + \exp(\max(\min(x, threshold), threshold))$ + +)DOC"); } }; @@ -262,19 +352,19 @@ class ELUOpMaker : public framework::OpProtoAndCheckerMaker { public: ELUOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", - "(Tensor) The input of ELU operator, it shouldn't be empty. Input " - "is flattened and treated as a 1D array."); - AddOutput("Y", - "(Tensor) The output of ELU operator. It has the same shape as " - "the input."); - AddAttr( - "alpha", "(float, default 1.0) Alpha value in the elu formulation.") - .SetDefault(static_cast(1.)); + AddInput("X", "Input of ELU operator"); + AddOutput("Y", "Output of ELU operator"); + AddAttr("alpha", "The alpha value of ELU") + .SetDefault(static_cast(1.0f)); AddComment(R"DOC( - ELU activation operator. It applies this element-wise computation on - the input: f(x) = max(0, x) + min(0, alpha * (exp(x) - 1)). - Check .. _Link: https://arxiv.org/abs/1511.07289 for more details.)DOC"); +ELU activation operator. + +Applies the following element-wise computation on the input according to +https://arxiv.org/abs/1511.07289. + +$y = \max(0, x) + \min(0, \alpha * (e^x - 1))$ + +)DOC"); } }; @@ -285,9 +375,14 @@ class Relu6OpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Relu6 operator"); AddOutput("Y", "Output of Relu6 operator"); - AddComment("Relu6 activation operator, relu6 = min(max(0, x), 6)"); AddAttr("threshold", "The threshold value of Relu6") .SetDefault(static_cast(6)); + AddComment(R"DOC( +Relu6 activation operator. + +$y = \min(\max(0, x), 6)$ + +)DOC"); } }; @@ -298,9 +393,14 @@ class PowOpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Pow operator"); AddOutput("Y", "Output of Pow operator"); - AddComment("Pow activation operator, pow(x, factor) = x^factor"); AddAttr("factor", "The exponential factor of Pow") .SetDefault(static_cast(1)); + AddComment(R"DOC( +Pow activation operator. + +$y = x^{factor}$ + +)DOC"); } }; @@ -311,11 +411,16 @@ class STanhOpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of STanh operator"); AddOutput("Y", "Output of STanh operator"); - AddComment("STanh activation operator, stanh = b * tanh(a * x)"); AddAttr("scale_a", "The scale parameter of a for the input") .SetDefault(static_cast(2 / 3)); AddAttr("scale_b", "The scale parameter of b for the input") .SetDefault(static_cast(1.7159)); + AddComment(R"DOC( +STanh activation operator. + +$$y = b * \frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$ + +)DOC"); } }; @@ -327,12 +432,19 @@ class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of ThresholdedRelu operator"); AddOutput("Y", "Output of ThresholdedRelu operator"); - AddComment( - "ThresholdedRelu activation operator, " - "thresholded_relu = x for x > threshold, " - "thresholded_relu = 0 otherwise."); AddAttr("threshold", "The threshold location of activation") .SetDefault(static_cast(1.0)); + AddComment(R"DOC( +ThresholdedRelu activation operator. + +$$ +y = \begin{cases} + x, \text{if } x > threshold \\ + 0, \text{otherwise} + \end{cases} +$$ + +)DOC"); } }; @@ -344,27 +456,23 @@ class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of HardSigmoid operator"); AddOutput("Y", "Output of HardSigmoid operator"); + AddAttr("slope", "Slope for linear approximation of sigmoid") + .SetDefault(static_cast(0.2)); + AddAttr("offset", "Offset for linear approximation of sigmoid") + .SetDefault(static_cast(0.5)); AddComment(R"DOC( -Hard Sigmoid activation operator. +HardSigmoid activation operator. -Segment-wise linear approximation of sigmoid[1]. -This is much faster than sigmoid. +Segment-wise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391), +which is much faster than sigmoid. -hard_sigmoid = max(0, min(1, slope * x + shift)) +$y = \max(0, \min(1, slope * x + shift))$ The slope should be positive. The offset can be either positive or negative. -The default slope and shift are set from [1]. +The default slope and shift are set according to the above reference. It is recommended to use the defaults for this activation. -References: - [1] Noisy Activation Functions - (https://arxiv.org/abs/1603.00391) - - )DOC"); - AddAttr("slope", "Slope for linear approximation of sigmoid") - .SetDefault(static_cast(0.2)); - AddAttr("offset", "Offset for linear approximation of sigmoid") - .SetDefault(static_cast(0.5)); +)DOC"); } }; diff --git a/paddle/operators/activation_op.h b/paddle/operators/activation_op.h index ddd966e26c..ceb4b4e40b 100644 --- a/paddle/operators/activation_op.h +++ b/paddle/operators/activation_op.h @@ -232,7 +232,7 @@ struct HardShrinkGradFunctor : public BaseActivationFunctor { } }; -// softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < lambda; 0 +// softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < -lambda; 0 // otherwise template struct SoftShrinkFunctor : public BaseActivationFunctor { From 1796a2ab55324eda53db0f98381edf2e7c5a9354 Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Thu, 2 Nov 2017 20:11:11 -0700 Subject: [PATCH 11/13] Android build document in English (#5029) * Add English version of Android cross-compiling document * Add English version of Android cross-compiling document * Follow comments from Yi-qun and Kavya --- .../cross_compiling_for_android.md | 153 ++++++++++++++++++ .../cross_compiling_for_android_cn.md | 34 ++-- 2 files changed, 170 insertions(+), 17 deletions(-) create mode 100644 doc/howto/cross_compiling/cross_compiling_for_android.md diff --git a/doc/howto/cross_compiling/cross_compiling_for_android.md b/doc/howto/cross_compiling/cross_compiling_for_android.md new file mode 100644 index 0000000000..161863e5c0 --- /dev/null +++ b/doc/howto/cross_compiling/cross_compiling_for_android.md @@ -0,0 +1,153 @@ +# Build PaddlePaddle for Android + +There are two approaches to build PaddlePaddle for Android: using Docker and on Linux without Docker. + +## Cross-Compiling Using Docker + +Docker-based cross-compiling is the recommended approach because Docker runs on all major operating systems, including Linux, Mac OS X, and Windows. + +### Build the Docker Image + +The following steps pack all the tools that we need to build PaddlePaddle into a Docker image. + +```bash +$ git clone https://github.com/PaddlePaddle/Paddle.git +$ cd Paddle +$ docker build -t paddle:dev-android . -f Dockerfile.android +``` + +### Build the Inference Library + +We can run the Docker image we just created to build the inference library of PaddlePaddle for Android using the command below: + +```bash +$ docker run -it --rm -v $PWD:/paddle -e "ANDROID_ABI=armeabi-v7a" -e "ANDROID_API=21" paddle:dev-android +``` + +The Docker image accepts two arguments `ANDROID_ABI` and `ANDROID_API`: + +| Argument | Optional Values | Default | +|-----------------|-------------------------|---------| +|`ANDROID_ABI` |`armeabi-v7a, arm64-v8a` | `armeabi-v7a` | +|`ANDROID_API` |`>= 21` | `21` | + +The ARM-64 architecture (`arm64-v8a`) requires at least level 21 of Android API. + +The default entry-point of the Docker image, [`paddle/scripts/docker/build_android.sh`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/scripts/docker/build_android.sh) generates the [Android cross-compiling standalone toolchain](https://developer.android.com/ndk/guides/standalone_toolchain.html) based on the argument: `ANDROID_ABI` or `ANDROID_API`. For information about other configuration arguments, please continue reading. + +The above command generates and outputs the inference library in `$PWD/install_android` and puts third-party libraries in `$PWD/install_android/third_party`. + +## Cross-Compiling on Linux + +The Linux-base approach to cross-compile is to run steps in `Dockerfile.android` manually on a Linux x64 computer. + +### Setup the Environment + +To build for Android's, we need [Android NDK]( +https://developer.android.com/ndk/downloads/index.html): + +```bash +wget -q https://dl.google.com/android/repository/android-ndk-r14b-linux-x86_64.zip +unzip -q android-ndk-r14b-linux-x86_64.zip +``` + +Android NDK includes everything we need to build the [*standalone toolchain*](https://developer.android.com/ndk/guides/standalone_toolchain.html), which in then used to build PaddlePaddle for Android. (We plan to remove the intermediate stage of building the standalone toolchain in the near future.) + +- To build the standalone toolchain for `armeabi-v7a` and Android API level 21: + + ```bash + your/path/to/android-ndk-r14b-linux-x86_64/build/tools/make-standalone-toolchain.sh \ + --arch=arm --platform=android-21 --install-dir=your/path/to/arm_standalone_toolchain + ``` + + The generated standalone toolchain will be in `your/path/to/arm_standalone_toolchain`. + +- To build the standalone toolchain for `arm64-v8a` and Android API level 21: + + ```bash + your/path/to/android-ndk-r14b-linux-x86_64/build/tools/make-standalone-toolchain.sh \ + --arch=arm64 --platform=android-21 --install-dir=your/path/to/arm64_standalone_toolchain + ``` + + The generated standalone toolchain will be in `your/path/to/arm64_standalone_toolchain`. + +**Please be aware that the minimum level of Android API required by PaddlePaddle is 21.** + +### Cross-Compiling Arguments + +CMake supports [choosing the toolchain](https://cmake.org/cmake/help/v3.0/manual/cmake-toolchains.7.html#cross-compiling). PaddlePaddle provides [`android.cmake`](https://github.com/PaddlePaddle/Paddle/blob/develop/cmake/cross_compiling/android.cmake), which configures the Android cross-compiling toolchain for CMake. `android.cmake` is not required for CMake >= 3.7, which support Android cross-compiling. PaddlePaddle detects the CMake version, for those newer than 3.7, it uses [the official version](https://cmake.org/cmake/help/v3.7/manual/cmake-toolchains.7.html#cross-compiling). + +Some other CMake arguments you need to know: + +- `CMAKE_SYSTEM_NAME` must be `Android`. This tells PaddlePaddle's CMake system to cross-compile third-party dependencies. This also changes some other CMake arguments like `WITH_GPU=OFF`, `WITH_AVX=OFF`, `WITH_PYTHON=OFF`, and `WITH_RDMA=OFF`. +- `WITH_C_API` must be `ON`, to build the C-based inference library for Android. +- `WITH_SWIG_PY` must be `OFF` because the Android platform doesn't support SWIG-based API. + +Some Android-specific arguments: + +- `ANDROID_STANDALONE_TOOLCHAIN`: the absolute path of the Android standalone toolchain, or the path relative to the CMake build directory. PaddlePaddle's CMake extensions would derive the cross-compiler, sysroot and Android API level from this argument. +- `ANDROID_TOOLCHAIN`: could be `gcc` or `clang`. The default value is `clang`. + - For CMake >= 3.7, it should anyway be `clang`. For older versions, it could be `gcc`. + - Android's official `clang` requires `glibc` >= 2.15. +- `ANDROID_ABI`: could be `armeabi-v7a` or `arm64-v8a`. The default value is `armeabi-v7a`. +- `ANDROID_NATIVE_API_LEVEL`: could be derived from the value of `ANDROID_STANDALONE_TOOLCHAIN`. +- `ANROID_ARM_MODE`: + - could be `ON` or `OFF`, and defaults to `ON`, when `ANDROID_ABI=armeabi-v7a`; + - no need to specify when `ANDROID_ABI=arm64-v8a`. +- `ANDROID_ARM_NEON`: indicates if to use NEON instructions. + - could be `ON` or `OFF`, and defaults to `ON`, when `ANDROID_ABI=armeabi-v7a`; + - no need to specify when `ANDROID_ABI=arm64-v8a`. + +Other useful arguments: + +- `USE_EIGEN_FOR_BLAS`: indicates if using Eigen. Could be `ON` or `OFF`, defaults to `OFF`. +- `HOST_C/CXX_COMPILER`: specifies the host compiler, which is used to build the host-specific protoc and target-specific OpenBLAS. It defaults to the value of the environment variable `CC`, or `cc`. + +Some frequent configurations for your reference: + +```bash +cmake -DCMAKE_SYSTEM_NAME=Android \ + -DANDROID_STANDALONE_TOOLCHAIN=your/path/to/arm_standalone_toolchain \ + -DANDROID_ABI=armeabi-v7a \ + -DANDROID_ARM_NEON=ON \ + -DANDROID_ARM_MODE=ON \ + -DUSE_EIGEN_FOR_BLAS=ON \ + -DCMAKE_INSTALL_PREFIX=your/path/to/install \ + -DWITH_C_API=ON \ + -DWITH_SWIG_PY=OFF \ + .. +``` + +``` +cmake -DCMAKE_SYSTEM_NAME=Android \ + -DANDROID_STANDALONE_TOOLCHAIN=your/path/to/arm64_standalone_toolchain \ + -DANDROID_ABI=arm64-v8a \ + -DUSE_EIGEN_FOR_BLAS=OFF \ + -DCMAKE_INSTALL_PREFIX=your/path/to/install \ + -DWITH_C_API=ON \ + -DWITH_SWIG_PY=OFF \ + .. +``` + + +There are some other arguments you might want to configure. + +- `CMAKE_BUILD_TYPE=MinSizeRel` minimizes the size of library. +- `CMAKE_BUILD_TYPE-Release` optimizes the runtime performance. + +Our own tip for performance optimization to use clang and Eigen or OpenBLAS: +- `CMAKE_BUILD_TYPE=Release` +- `ANDROID_TOOLCHAIN=clang` +- `USE_EIGEN_BLAS=ON` for `armeabi-v7a`, or `USE_EIGEN_FOR_BLAS=OFF` for `arm64-v8a`. + +### Build and Install + +After running `cmake`, we can run `make; make install` to build and install. + +Before building, you might want to remove the `third_party` and `build` directories including pre-built libraries for other architectures. + +After building,in the directory `CMAKE_INSTALL_PREFIX`, you will find three sub-directories: + +- `include`: the header file of the inference library, +- `lib`: the inference library built for various Android ABIs, +- `third_party`: dependent third-party libraries built for Android. diff --git a/doc/howto/cross_compiling/cross_compiling_for_android_cn.md b/doc/howto/cross_compiling/cross_compiling_for_android_cn.md index 1fc58c37cc..58e4dd9c3f 100644 --- a/doc/howto/cross_compiling/cross_compiling_for_android_cn.md +++ b/doc/howto/cross_compiling/cross_compiling_for_android_cn.md @@ -1,7 +1,7 @@ # 构建Android平台上的PaddlePaddle库 用户可通过如下两种方式,交叉编译Android平台上适用的PaddlePaddle库: -- 基于Docker容器的编译方式 +- 基于Docker容器的编译方式 - 基于Linux交叉编译环境的编译方式 ## 基于Docker容器的编译方式 @@ -26,14 +26,14 @@ Android的Docker开发镜像向用户提供两个可配置的参数: |`ANDROID_API` |`>= 21` | `21` | - 编译`armeabi-v7a`,`Android API 21`的PaddlePaddle库 -```bash -$ docker run -it --rm -v $PWD:/paddle -e "ANDROID_ABI=armeabi-v7a" -e "ANDROID_API=21" username/paddle-android:dev -``` + ```bash + $ docker run -it --rm -v $PWD:/paddle -e "ANDROID_ABI=armeabi-v7a" -e "ANDROID_API=21" username/paddle-android:dev + ``` -- 编译`arm64-v8a`,`Android API 21`的PaddlePaddle库 -```bash -$ docker run -it --rm -v $PWD:/paddle -e "ANDROID_ABI=arm64-v8a" -e "ANDROID_API=21" username/paddle-android:dev -``` +- 编译`arm64-v8a`,`Android API 21`的PaddlePaddle库 + ```bash + $ docker run -it --rm -v $PWD:/paddle -e "ANDROID_ABI=arm64-v8a" -e "ANDROID_API=21" username/paddle-android:dev + ``` 执行上述`docker run`命令时,容器默认执行[paddle/scripts/docker/build_android.sh](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/scripts/docker/build_android.sh)脚本。该脚本中记录了交叉编译Android版PaddlePaddle库常用的CMake配置,并且会根据`ANDROID_ABI`和`ANDROID_API`自动构建独立工具链、进行编译和安装。由于arm64架构要求Android API不小于21。因此当`ANDROID_ABI=arm64-v8a`,`ANDROID_API<21`时,Docker容器中将默认使用`Android API 21`的编译工具链。用户可以参考下文**配置交叉编译参数**章节,根据个人的需求修改定制Docker容器所执行的脚本。编译安装结束之后,PaddlePaddle的C-API库将被安装到`$PWD/install_android`目录,所依赖的第三方库同时也被安装到`$PWD/install_android/third_party`目录。 @@ -82,16 +82,16 @@ CMake系统对交叉编译提供了支持[cmake-toolchains](https://cmake.org/cm Android平台可选配置参数: - `ANDROID_STANDALONE_TOOLCHAIN`,独立工具链所在的绝对路径,或者相对于构建目录的相对路径。PaddlePaddle的CMake系统将根据该值自动推导和设置需要使用的交叉编译器、sysroot、以及Android API级别;否则,用户需要在cmake时手动设置这些值。无默认值。 -- `ANDROID_TOOLCHAIN`,目标工具链。可设置`gcc/clang`,默认值为`clang`。 - - CMake 3.7以上,将会始终使用`clang`工具链;CMake 3.7以下,可设置`ANDROID_TOOLCHAIN=gcc`以使用`gcc`工具链。 +- `ANDROID_TOOLCHAIN`,目标工具链。可设置`gcc/clang`,默认值为`clang`。 + - CMake 3.7以上,将会始终使用`clang`工具链;CMake 3.7以下,可设置`ANDROID_TOOLCHAIN=gcc`以使用`gcc`工具链。 - Android官方提供的`clang`编译器要求系统支持`GLIBC 2.15`以上。 - `ANDROID_ABI`,目标架构ABI。目前支持`armeabi-v7a`和`arm64-v8a`,默认值为`armeabi-v7a`。 - `ANDROID_NATIVE_API_LEVEL`,工具链的Android API级别。若没有显式设置,PaddlePaddle将根据`ANDROID_STANDALONE_TOOLCHAIN`的值自动推导得到。 -- `ANROID_ARM_MODE`,是否使用ARM模式。 - - `ANDROID_ABI=armeabi-v7a`时,可设置`ON/OFF`,默认值为`ON`; +- `ANROID_ARM_MODE`,是否使用ARM模式。 + - `ANDROID_ABI=armeabi-v7a`时,可设置`ON/OFF`,默认值为`ON`; - `ANDROID_ABI=arm64-v8a`时,不需要设置。 -- `ANDROID_ARM_NEON`,是否使用NEON指令。 - - `ANDROID_ABI=armeabi-v7a`时,可设置`ON/OFF`,默认值为`ON`; +- `ANDROID_ARM_NEON`,是否使用NEON指令。 + - `ANDROID_ABI=armeabi-v7a`时,可设置`ON/OFF`,默认值为`ON`; - `ANDROID_ABI=arm64-v8a`时,不需要设置。 其他配置参数: @@ -119,7 +119,7 @@ cmake -DCMAKE_SYSTEM_NAME=Android \ -DANDROID_STANDALONE_TOOLCHAIN=your/path/to/arm64_standalone_toolchain \ -DANDROID_ABI=arm64-v8a \ -DUSE_EIGEN_FOR_BLAS=OFF \ - -DCMAKE_INSTALL_PREFIX=your/path/to/install \ + -DCMAKE_INSTALL_PREFIX=your/path/to/install \ -DWITH_C_API=ON \ -DWITH_SWIG_PY=OFF \ .. @@ -128,8 +128,8 @@ cmake -DCMAKE_SYSTEM_NAME=Android \ 用户还可根据自己的需求设置其他编译参数。比如希望最小化生成的库的大小,可以设置`CMAKE_BUILD_TYPE`为`MinSizeRel`;若希望最快的执行速度,则可设置`CMAKE_BUILD_TYPE`为`Release`。亦可以通过手动设置`CMAKE_C/CXX_FLAGS_MINSIZEREL/RELEASE`来影响PaddlePaddle的编译过程。 **性能TIPS**,为了达到最快的计算速度,在CMake参数配置上,有以下建议: -- 设置`CMAKE_BUILD_TYPE`为`Release` -- 使用`clang`编译工具链 +- 设置`CMAKE_BUILD_TYPE`为`Release` +- 使用`clang`编译工具链 - `armeabi-v7a`时,设置`USE_EIGEN_BLAS=ON`,使用Eigen进行矩阵计算;`arm64-v8a`时,设置`USE_EIGEN_FOR_BLAS=OFF`,使用OpenBLAS进行矩阵计算 ### 编译和安装 From 1ed5ae7a14b9a740cb0f0e892cd6c12de5a293e4 Mon Sep 17 00:00:00 2001 From: Yi Wang Date: Thu, 2 Nov 2017 20:26:54 -0700 Subject: [PATCH 12/13] Fix comparing between signed and unsigned values (#5328) --- paddle/framework/executor.cc | 2 +- paddle/gserver/evaluators/Evaluator.cpp | 2 +- paddle/operators/seq_expand_op.h | 3 ++- paddle/optimizer/parameter_optimizer_test.cpp | 2 +- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc index f8d32de5df..52fefe4ea3 100644 --- a/paddle/framework/executor.cc +++ b/paddle/framework/executor.cc @@ -83,7 +83,7 @@ void Executor::Run(const ProgramDescBind& pdesc, Scope* scope, int block_id, // TODO(tonyyang-svail): // - only runs on the first device (i.e. no interdevice communication) // - will change to use multiple blocks for RNN op and Cond Op - PADDLE_ENFORCE_LT(block_id, pdesc.Size()); + PADDLE_ENFORCE_LT(static_cast(block_id), pdesc.Size()); auto& block = pdesc.Block(block_id); auto& device = device_contexts_[0]; diff --git a/paddle/gserver/evaluators/Evaluator.cpp b/paddle/gserver/evaluators/Evaluator.cpp index 87cb2d2808..8e66b1f0db 100644 --- a/paddle/gserver/evaluators/Evaluator.cpp +++ b/paddle/gserver/evaluators/Evaluator.cpp @@ -407,7 +407,7 @@ real AucEvaluator::evalImp(std::vector& arguments) { // Copy label from value to a vector. if (nullptr == label && nullptr != labelval) { // label width is 1 - CHECK_EQ(1, labelval->getWidth()); + CHECK_EQ(1U, labelval->getWidth()); VectorPtr vec = Vector::create(labelval->getData(), insNum, output->useGpu()); label = vec->castToInt(); diff --git a/paddle/operators/seq_expand_op.h b/paddle/operators/seq_expand_op.h index 8703105385..4ef0d02cf8 100644 --- a/paddle/operators/seq_expand_op.h +++ b/paddle/operators/seq_expand_op.h @@ -32,7 +32,8 @@ class SeqExpandKernel : public framework::OpKernel { const T* x_data = x->data(); auto x_dims = x->dims(); auto* y = context.Input("Y"); - PADDLE_ENFORCE_EQ(x_dims[0], y->lod().back().size() - 1, + PADDLE_ENFORCE_EQ(static_cast(x_dims[0]), + y->lod().back().size() - 1, "The size of last lod level in Input(Y)" "must be equal to dims[0] of Input(X)."); out->set_lod(y->lod()); diff --git a/paddle/optimizer/parameter_optimizer_test.cpp b/paddle/optimizer/parameter_optimizer_test.cpp index c88fa11748..c99b2254ac 100644 --- a/paddle/optimizer/parameter_optimizer_test.cpp +++ b/paddle/optimizer/parameter_optimizer_test.cpp @@ -85,7 +85,7 @@ public: for (size_t i = 0; i < opts_.size(); ++i) { int s = 0; float* newp = (float*)opts_[i]->get_weight(&s); - EXPECT_EQ(s, kSize); + EXPECT_EQ(static_cast(s), kSize); for (size_t j = 0; j < kSize; ++j) { EXPECT_EQ(newp[j], (*p)[j]); } From 86a3260f97d292fe014b965abe73d464efc8aa02 Mon Sep 17 00:00:00 2001 From: ranqiu Date: Fri, 3 Nov 2017 13:04:49 +0800 Subject: [PATCH 13/13] Update faq --- doc/faq/parameter/index_cn.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/faq/parameter/index_cn.rst b/doc/faq/parameter/index_cn.rst index c721b62318..6fa0c64413 100644 --- a/doc/faq/parameter/index_cn.rst +++ b/doc/faq/parameter/index_cn.rst @@ -75,7 +75,7 @@ PaddlePaddle目前支持8种learning_rate_schedule,这8种learning_rate_schedu optimizer = paddle.optimizer.Adam( learning_rate=1e-3, - learning_rate_schedule="manual", + learning_rate_schedule="pass_manual", learning_rate_args="1:1.0,2:0.9,3:0.8",) 在该示例中,当已训练pass数小于等于1时,学习率为 :code:`1e-3 * 1.0`;当已训练pass数大于1小于等于2时,学习率为 :code:`1e-3 * 0.9`;当已训练pass数大于2时,学习率为 :code:`1e-3 * 0.8`。