From 3c49e7b1e4b7b9f8f67fa4b12b05cf648808a40c Mon Sep 17 00:00:00 2001 From: qijun Date: Wed, 13 Sep 2017 14:17:51 +0800 Subject: [PATCH 01/13] move EigenDeviceConverter to device_context.h --- paddle/framework/operator.cc | 4 ++-- paddle/framework/operator.h | 19 ++----------------- paddle/operators/math/activation.h | 20 ++++++++++++++++++++ paddle/platform/device_context.cc | 7 ++++--- paddle/platform/device_context.h | 19 ++++++++++++++++++- paddle/platform/device_context_test.cc | 2 +- 6 files changed, 47 insertions(+), 24 deletions(-) create mode 100644 paddle/operators/math/activation.h diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index e1e122091f..25c545d3f9 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -22,14 +22,14 @@ namespace framework { template <> Eigen::DefaultDevice& ExecutionContext::GetEigenDevice< platform::CPUPlace, Eigen::DefaultDevice>() const { - return *device_context_->get_eigen_device(); + return *device_context_->get_eigen_device(); } #ifndef PADDLE_ONLY_CPU template <> Eigen::GpuDevice& ExecutionContext::GetEigenDevice() const { - return *device_context_->get_eigen_device(); + return *device_context_->get_eigen_device(); } #endif diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 4600b06009..bfa2190557 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -331,21 +331,6 @@ class InferShapeContext { const Scope& scope_; }; -template -struct EigenDeviceConverter; - -template <> -struct EigenDeviceConverter { - using EigenDeviceType = Eigen::DefaultDevice; -}; - -#ifndef PADDLE_ONLY_CPU -template <> -struct EigenDeviceConverter { - using EigenDeviceType = Eigen::GpuDevice; -}; -#endif - class ExecutionContext : public InferShapeContext { public: ExecutionContext(const OperatorBase& op, const Scope& scope, @@ -353,8 +338,8 @@ class ExecutionContext : public InferShapeContext { : InferShapeContext(op, scope), device_context_(device_context) {} template ::EigenDeviceType> + typename DeviceType = typename platform::EigenDeviceConverter< + PlaceType>::EigenDeviceType> DeviceType& GetEigenDevice() const; platform::Place GetPlace() const { return device_context_->GetPlace(); } diff --git a/paddle/operators/math/activation.h b/paddle/operators/math/activation.h new file mode 100644 index 0000000000..b6af478d82 --- /dev/null +++ b/paddle/operators/math/activation.h @@ -0,0 +1,20 @@ +#include "paddle/framework/eigen.h" +#include "paddle/framework/tensor.h" + +namespace paddle { +namespace operators { +namespace math { + +template +struct sigmoid { + void operator()(const platform::DeviceContext& deice_context, + const framework::Tensor& input, framework::Tensor* output) { + auto x = framework::EigenVector::Flatten(*output); + auto y = framework::EigenVector::Flatten(input); + auto* place = device_context.get_eigen_device(); + y.device(*place) = 1. / (1. + (-x).exp()); + } +}; +} +} +} diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index ad212c5b2c..cf5c3eec81 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -16,8 +16,8 @@ namespace paddle { namespace platform { template <> -Eigen::DefaultDevice* DeviceContext::get_eigen_device() - const { +Eigen::DefaultDevice* +DeviceContext::get_eigen_device() const { return reinterpret_cast(this)->eigen_device(); } @@ -91,7 +91,8 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface { }; template <> -Eigen::GpuDevice* DeviceContext::get_eigen_device() const { +Eigen::GpuDevice* DeviceContext::get_eigen_device() + const { return reinterpret_cast(this)->eigen_device(); } diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 11528e1194..a46ba4c703 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -27,12 +27,29 @@ limitations under the License. */ namespace paddle { namespace platform { +template +struct EigenDeviceConverter; + +template <> +struct EigenDeviceConverter { + using EigenDeviceType = Eigen::DefaultDevice; +}; + +#ifndef PADDLE_ONLY_CPU +template <> +struct EigenDeviceConverter { + using EigenDeviceType = Eigen::GpuDevice; +}; +#endif + class DeviceContext { public: virtual ~DeviceContext() {} virtual Place GetPlace() const = 0; - template + template ::EigenDeviceType> DeviceType* get_eigen_device() const; }; diff --git a/paddle/platform/device_context_test.cc b/paddle/platform/device_context_test.cc index 5883a55272..d71e0aae58 100644 --- a/paddle/platform/device_context_test.cc +++ b/paddle/platform/device_context_test.cc @@ -24,7 +24,7 @@ TEST(Device, Init) { for (int i = 0; i < count; i++) { DeviceContext* device_context = new CUDADeviceContext(GPUPlace(i)); Eigen::GpuDevice* gpu_device = - device_context->template get_eigen_device(); + device_context->template get_eigen_device(); ASSERT_NE(nullptr, gpu_device); delete device_context; } From d736fc0e00108384853a996aef9d51dbe81f1564 Mon Sep 17 00:00:00 2001 From: qijun Date: Wed, 13 Sep 2017 17:33:36 +0800 Subject: [PATCH 02/13] add activation macro --- paddle/framework/operator.h | 6 +- paddle/operators/activation_op.cc | 115 ++++++++++++++++++ .../{sigmoid_op.cu => activation_op.cu} | 11 +- paddle/operators/activation_op.h | 71 +++++++++++ paddle/operators/math/activation.h | 20 --- paddle/operators/math/activation_functor.h | 96 +++++++++++++++ paddle/operators/sigmoid_op.cc | 61 ---------- paddle/operators/sigmoid_op.h | 62 ---------- paddle/pybind/pybind.cc | 4 +- .../paddle/v2/framework/tests/test_exp_op.py | 22 ++++ .../paddle/v2/framework/tests/test_relu_op.py | 22 ++++ 11 files changed, 342 insertions(+), 148 deletions(-) create mode 100644 paddle/operators/activation_op.cc rename paddle/operators/{sigmoid_op.cu => activation_op.cu} (66%) create mode 100644 paddle/operators/activation_op.h delete mode 100644 paddle/operators/math/activation.h create mode 100644 paddle/operators/math/activation_functor.h delete mode 100644 paddle/operators/sigmoid_op.cc delete mode 100644 paddle/operators/sigmoid_op.h create mode 100644 python/paddle/v2/framework/tests/test_exp_op.py create mode 100644 python/paddle/v2/framework/tests/test_relu_op.py diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index bfa2190557..0970797e02 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -139,9 +139,9 @@ class OperatorBase { // Macro for define a clone method. // If you are writing an kernel operator, `Clone` will be defined when you // register it. i.e. `Clone` method is not needed to define by yourself. -#define DEFINE_OP_CLONE_METHOD(cls) \ - std::unique_ptr Clone() const final { \ - return std::unique_ptr(new cls(*this)); \ +#define DEFINE_OP_CLONE_METHOD(cls) \ + std::unique_ptr<::paddle::framework::OperatorBase> Clone() const final { \ + return std::unique_ptr<::paddle::framework::OperatorBase>(new cls(*this)); \ } // Macro for define a default constructor for Operator. diff --git a/paddle/operators/activation_op.cc b/paddle/operators/activation_op.cc new file mode 100644 index 0000000000..d2c2378fef --- /dev/null +++ b/paddle/operators/activation_op.cc @@ -0,0 +1,115 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/activation_op.h" + +#define FILL_ACTIVATION_OP \ + public: \ + using framework::OperatorWithKernel::OperatorWithKernel; \ + \ + protected: \ + void InferShape(const framework::InferShapeContext &ctx) const override { \ + ctx.Output("Y")->Resize( \ + ctx.Input("X")->dims()); \ + } + +#define FILL_ACTIVATION_GRAD_OP \ + public: \ + using framework::OperatorWithKernel::OperatorWithKernel; \ + \ + protected: \ + void InferShape(const framework::InferShapeContext &ctx) const override { \ + ctx.Output(framework::GradVarName("X")) \ + ->Resize(ctx.Input("Y")->dims()); \ + } + +namespace paddle { +namespace operators { + +class SigmoidOp : public framework::OperatorWithKernel { + FILL_ACTIVATION_OP +}; + +class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SigmoidOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of Sigmoid operator"); + AddOutput("Y", "Output of Sigmoid operator"); + AddComment("Sigmoid activation operator"); + } +}; + +class SigmoidOpGrad : public framework::OperatorWithKernel { + FILL_ACTIVATION_GRAD_OP +}; + +class ExpOp : public framework::OperatorWithKernel { + FILL_ACTIVATION_OP +}; + +class ExpOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ExpOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of Exp operator"); + AddOutput("Y", "Output of Exp operator"); + AddComment("Exp activation operator"); + } +}; + +class ExpOpGrad : public framework::OperatorWithKernel { + FILL_ACTIVATION_GRAD_OP +}; + +class ReluOp : public framework::OperatorWithKernel { + FILL_ACTIVATION_OP +}; + +class ReluOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ReluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of Relu operator"); + AddOutput("Y", "Output of Relu operator"); + AddComment("Relu activation operator"); + } +}; + +class ReluOpGrad : public framework::OperatorWithKernel { + FILL_ACTIVATION_GRAD_OP +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker, sigmoid_grad, + ops::SigmoidOpGrad); +REGISTER_OP_CPU_KERNEL(sigmoid, + ops::SigmoidKernel); +REGISTER_OP_CPU_KERNEL( + sigmoid_grad, ops::SigmoidGradKernel); + +REGISTER_OP(exp, ops::ExpOp, ops::ExpOpMaker, exp_grad, ops::ExpOpGrad); +REGISTER_OP_CPU_KERNEL(exp, ops::ExpKernel); +REGISTER_OP_CPU_KERNEL(exp_grad, + ops::ExpGradKernel); + +REGISTER_OP(relu, ops::ReluOp, ops::ReluOpMaker, relu_grad, ops::ReluOpGrad); +REGISTER_OP_CPU_KERNEL(relu, + ops::ReluKernel); +REGISTER_OP_CPU_KERNEL(relu_grad, + ops::ReluGradKernel); diff --git a/paddle/operators/sigmoid_op.cu b/paddle/operators/activation_op.cu similarity index 66% rename from paddle/operators/sigmoid_op.cu rename to paddle/operators/activation_op.cu index 1a50dfe14a..55d9f52124 100644 --- a/paddle/operators/sigmoid_op.cu +++ b/paddle/operators/activation_op.cu @@ -13,7 +13,7 @@ limitations under the License. */ #define EIGEN_USE_GPU -#include "paddle/operators/sigmoid_op.h" +#include "paddle/operators/activation_op.h" namespace ops = paddle::operators; @@ -21,3 +21,12 @@ REGISTER_OP_GPU_KERNEL(sigmoid, ops::SigmoidKernel); REGISTER_OP_GPU_KERNEL( sigmoid_grad, ops::SigmoidGradKernel); + +REGISTER_OP_GPU_KERNEL(exp, ops::ExpKernel); +REGISTER_OP_GPU_KERNEL(exp_grad, + ops::ExpGradKernel); + +REGISTER_OP_GPU_KERNEL(relu, + ops::ReluKernel); +REGISTER_OP_GPU_KERNEL(relu_grad, + ops::ReluGradKernel); diff --git a/paddle/operators/activation_op.h b/paddle/operators/activation_op.h new file mode 100644 index 0000000000..9e4101805e --- /dev/null +++ b/paddle/operators/activation_op.h @@ -0,0 +1,71 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/activation_functor.h" + +#define ACTIVATION_KERNEL_NAME(ACTIVATION_NAME) ACTIVATION_NAME##Kernel + +#define DEFINE_ACTIVATION_KERNEL(ACTIVATION_NAME) \ + template \ + class ACTIVATION_KERNEL_NAME(ACTIVATION_NAME) : public framework::OpKernel { \ + public: \ + void Compute(const framework::ExecutionContext& context) const override { \ + auto* X = context.Input("X"); \ + auto* Y = context.Output("Y"); \ + Y->mutable_data(context.GetPlace()); \ + math::ACTIVATION_NAME functor; \ + auto* device_context = context.device_context(); \ + functor(*device_context, *X, Y); \ + } \ + }; + +#define DEFINE_ACTIVATION_GRAD_KERNEL(ACTIVATION_GRAD_NAME) \ + template \ + class ACTIVATION_KERNEL_NAME(ACTIVATION_GRAD_NAME) \ + : public framework::OpKernel { \ + public: \ + void Compute(const framework::ExecutionContext& context) const override { \ + auto* X = context.Input("X"); \ + auto* Y = context.Input("Y"); \ + auto* dY = \ + context.Input(framework::GradVarName("Y")); \ + auto* dX = \ + context.Output(framework::GradVarName("X")); \ + dX->mutable_data(context.GetPlace()); \ + math::ACTIVATION_GRAD_NAME functor; \ + auto* device_context = context.device_context(); \ + functor(*device_context, *X, *Y, *dY, dX); \ + } \ + }; + +namespace paddle { +namespace operators { + +DEFINE_ACTIVATION_KERNEL(Sigmoid); + +DEFINE_ACTIVATION_GRAD_KERNEL(SigmoidGrad); + +DEFINE_ACTIVATION_KERNEL(Exp); + +DEFINE_ACTIVATION_GRAD_KERNEL(ExpGrad); + +DEFINE_ACTIVATION_KERNEL(Relu); + +DEFINE_ACTIVATION_GRAD_KERNEL(ReluGrad); + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/activation.h b/paddle/operators/math/activation.h deleted file mode 100644 index b6af478d82..0000000000 --- a/paddle/operators/math/activation.h +++ /dev/null @@ -1,20 +0,0 @@ -#include "paddle/framework/eigen.h" -#include "paddle/framework/tensor.h" - -namespace paddle { -namespace operators { -namespace math { - -template -struct sigmoid { - void operator()(const platform::DeviceContext& deice_context, - const framework::Tensor& input, framework::Tensor* output) { - auto x = framework::EigenVector::Flatten(*output); - auto y = framework::EigenVector::Flatten(input); - auto* place = device_context.get_eigen_device(); - y.device(*place) = 1. / (1. + (-x).exp()); - } -}; -} -} -} diff --git a/paddle/operators/math/activation_functor.h b/paddle/operators/math/activation_functor.h new file mode 100644 index 0000000000..7e15607f46 --- /dev/null +++ b/paddle/operators/math/activation_functor.h @@ -0,0 +1,96 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/tensor.h" + +namespace paddle { +namespace operators { +namespace math { + +template +struct Sigmoid { + void operator()(const platform::DeviceContext& device_context, + const framework::Tensor& X, framework::Tensor* Y) { + auto x = framework::EigenVector::Flatten(X); + auto y = framework::EigenVector::Flatten(*Y); + auto* place = device_context.template get_eigen_device(); + y.device(*place) = 1. / (1. + (-x).exp()); + } +}; + +template +struct SigmoidGrad { + void operator()(const platform::DeviceContext& device_context, + const framework::Tensor& X, const framework::Tensor& Y, + const framework::Tensor& dY, framework::Tensor* dX) { + auto dx = framework::EigenVector::Flatten(*dX); + auto y = framework::EigenVector::Flatten(Y); + auto dy = framework::EigenVector::Flatten(dY); + auto* place = device_context.template get_eigen_device(); + dx.device(*place) = dy * y * (1. - y); + } +}; + +template +struct Exp { + void operator()(const platform::DeviceContext& device_context, + const framework::Tensor& input, framework::Tensor* output) { + auto x = framework::EigenVector::Flatten(input); + auto y = framework::EigenVector::Flatten(*output); + auto* place = device_context.template get_eigen_device(); + y.device(*place) = x.exp(); + } +}; + +template +struct ExpGrad { + void operator()(const platform::DeviceContext& device_context, + const framework::Tensor& X, const framework::Tensor& Y, + const framework::Tensor& dY, framework::Tensor* dX) { + auto dx = framework::EigenVector::Flatten(*dX); + auto dy = framework::EigenVector::Flatten(dY); + auto* place = device_context.template get_eigen_device(); + dx.device(*place) = dy.exp(); + } +}; + +template +struct Relu { + void operator()(const platform::DeviceContext& device_context, + const framework::Tensor& input, framework::Tensor* output) { + auto x = framework::EigenVector::Flatten(input); + auto y = framework::EigenVector::Flatten(*output); + auto* place = device_context.template get_eigen_device(); + y.device(*place) = x.cwiseMax(static_cast(0)); + } +}; + +template +struct ReluGrad { + void operator()(const platform::DeviceContext& device_context, + const framework::Tensor& X, const framework::Tensor& Y, + const framework::Tensor& dY, framework::Tensor* dX) { + auto dx = framework::EigenVector::Flatten(*dX); + auto dy = framework::EigenVector::Flatten(dY); + auto x = framework::EigenVector::Flatten(X); + auto* place = device_context.template get_eigen_device(); + dx.device(*place) = dy * (x > static_cast(0)).template cast(); + } +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/sigmoid_op.cc b/paddle/operators/sigmoid_op.cc deleted file mode 100644 index 761c6de8d4..0000000000 --- a/paddle/operators/sigmoid_op.cc +++ /dev/null @@ -1,61 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include "paddle/operators/sigmoid_op.h" - -namespace paddle { -namespace operators { - -class SigmoidOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - ctx.Output("Y")->Resize(ctx.Input("X")->dims()); - } -}; - -class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker { - public: - SigmoidOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "sigmoid input"); - AddOutput("Y", "sigmoid output"); - AddComment("Sigmoid function"); - } -}; - -class SigmoidOpGrad : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - ctx.Output(framework::GradVarName("X")) - ->Resize(ctx.Input("Y")->dims()); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker, sigmoid_grad, - ops::SigmoidOpGrad); -REGISTER_OP_CPU_KERNEL(sigmoid, - ops::SigmoidKernel); -REGISTER_OP_CPU_KERNEL( - sigmoid_grad, ops::SigmoidGradKernel); diff --git a/paddle/operators/sigmoid_op.h b/paddle/operators/sigmoid_op.h deleted file mode 100644 index b01a9b3f23..0000000000 --- a/paddle/operators/sigmoid_op.h +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#pragma once -#include "paddle/framework/eigen.h" -#include "paddle/framework/op_registry.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -template -using EigenVector = framework::EigenVector; - -template -class SigmoidKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto input = context.Input("X"); - auto output = context.Output("Y"); - output->mutable_data(context.GetPlace()); - - // The clipping is used in Paddle's raw implenmention - auto X = EigenVector::Flatten(*input); - auto Y = EigenVector::Flatten(*output); - auto place = context.GetEigenDevice(); - - Y.device(place) = 1. / (1. + (-X).exp()); - } -}; - -template -class SigmoidGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto Y_t = context.Input("Y"); - auto dY_t = context.Input(framework::GradVarName("Y")); - auto dX_t = context.Output(framework::GradVarName("X")); - - dX_t->mutable_data(context.GetPlace()); - - auto dX = EigenVector::Flatten(*dX_t); - auto Y = EigenVector::Flatten(*Y_t); - auto dY = EigenVector::Flatten(*dY_t); - dX.device(context.GetEigenDevice()) = dY * Y * (1. - Y); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 16a2368aae..bd964c5d07 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -36,7 +36,6 @@ USE_OP(onehot_cross_entropy); USE_OP(sgd); USE_OP(mul); USE_OP(mean); -USE_OP(sigmoid); USE_OP(softmax); USE_OP(rowwise_add); USE_OP(fill_zeros_like); @@ -55,6 +54,9 @@ USE_OP(top_k); USE_OP(squared_l2_distance); USE_OP(sum); USE_OP(reshape); +USE_OP(sigmoid); +USE_OP(exp); +USE_OP(relu); namespace paddle { namespace framework { diff --git a/python/paddle/v2/framework/tests/test_exp_op.py b/python/paddle/v2/framework/tests/test_exp_op.py new file mode 100644 index 0000000000..5a004f6fe2 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_exp_op.py @@ -0,0 +1,22 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestExp(OpTest): + def setUp(self): + self.op_type = "exp" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") + } + self.outputs = {'Y': np.exp(self.inputs['X'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Y", max_relative_error=0.007) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_relu_op.py b/python/paddle/v2/framework/tests/test_relu_op.py new file mode 100644 index 0000000000..07b7113d79 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_relu_op.py @@ -0,0 +1,22 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestExp(OpTest): + def setUp(self): + self.op_type = "exp" + self.inputs = { + 'X': np.random.uniform(-1, 1, [11, 17]).astype("float32") + } + self.outputs = {'Y': np.maximum(self.inputs['X'], 0)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Y", max_relative_error=0.007) + + +if __name__ == '__main__': + unittest.main() From b50a50761760d124aa4a38c81599a1069bc6fbf0 Mon Sep 17 00:00:00 2001 From: qijun Date: Wed, 13 Sep 2017 17:45:11 +0800 Subject: [PATCH 03/13] add activation operator python test --- paddle/operators/math/activation_functor.h | 4 ++-- python/paddle/v2/framework/tests/test_relu_op.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/operators/math/activation_functor.h b/paddle/operators/math/activation_functor.h index 7e15607f46..1e9bdd142e 100644 --- a/paddle/operators/math/activation_functor.h +++ b/paddle/operators/math/activation_functor.h @@ -61,9 +61,9 @@ struct ExpGrad { const framework::Tensor& X, const framework::Tensor& Y, const framework::Tensor& dY, framework::Tensor* dX) { auto dx = framework::EigenVector::Flatten(*dX); - auto dy = framework::EigenVector::Flatten(dY); + auto y = framework::EigenVector::Flatten(Y); auto* place = device_context.template get_eigen_device(); - dx.device(*place) = dy.exp(); + dx.device(*place) = y; } }; diff --git a/python/paddle/v2/framework/tests/test_relu_op.py b/python/paddle/v2/framework/tests/test_relu_op.py index 07b7113d79..58a0872db4 100644 --- a/python/paddle/v2/framework/tests/test_relu_op.py +++ b/python/paddle/v2/framework/tests/test_relu_op.py @@ -3,9 +3,9 @@ import numpy as np from op_test import OpTest -class TestExp(OpTest): +class TestRelu(OpTest): def setUp(self): - self.op_type = "exp" + self.op_type = "relu" self.inputs = { 'X': np.random.uniform(-1, 1, [11, 17]).astype("float32") } From 4e173527c1650ed86df714392e53801a498b0078 Mon Sep 17 00:00:00 2001 From: qijun Date: Wed, 13 Sep 2017 17:57:41 +0800 Subject: [PATCH 04/13] fix op python tests --- python/paddle/v2/framework/tests/test_exp_op.py | 4 ++-- python/paddle/v2/framework/tests/test_relu_op.py | 8 +++----- python/paddle/v2/framework/tests/test_sigmoid_op.py | 4 ++-- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/python/paddle/v2/framework/tests/test_exp_op.py b/python/paddle/v2/framework/tests/test_exp_op.py index 5a004f6fe2..0ec41e56a0 100644 --- a/python/paddle/v2/framework/tests/test_exp_op.py +++ b/python/paddle/v2/framework/tests/test_exp_op.py @@ -15,8 +15,8 @@ class TestExp(OpTest): self.check_output() def test_check_grad(self): - self.check_grad(["X"], "Y", max_relative_error=0.007) + self.check_grad(['X'], 'Y', max_relative_error=0.007) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/python/paddle/v2/framework/tests/test_relu_op.py b/python/paddle/v2/framework/tests/test_relu_op.py index 58a0872db4..c9af0c2ba7 100644 --- a/python/paddle/v2/framework/tests/test_relu_op.py +++ b/python/paddle/v2/framework/tests/test_relu_op.py @@ -6,17 +6,15 @@ from op_test import OpTest class TestRelu(OpTest): def setUp(self): self.op_type = "relu" - self.inputs = { - 'X': np.random.uniform(-1, 1, [11, 17]).astype("float32") - } + self.inputs = {'X': np.random.uniform(-1, 1, [4, 4]).astype("float32")} self.outputs = {'Y': np.maximum(self.inputs['X'], 0)} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(["X"], "Y", max_relative_error=0.007) + self.check_grad(['X'], 'Y', max_relative_error=0.007) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/python/paddle/v2/framework/tests/test_sigmoid_op.py b/python/paddle/v2/framework/tests/test_sigmoid_op.py index 2316e49eff..cf05e934d5 100644 --- a/python/paddle/v2/framework/tests/test_sigmoid_op.py +++ b/python/paddle/v2/framework/tests/test_sigmoid_op.py @@ -15,8 +15,8 @@ class TestSigmoid(OpTest): self.check_output() def test_check_grad(self): - self.check_grad(["X"], "Y", max_relative_error=0.007) + self.check_grad(['X'], 'Y', max_relative_error=0.007) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() From c18ebc3022961f404265a80400fcc29d216b4534 Mon Sep 17 00:00:00 2001 From: qijun Date: Thu, 14 Sep 2017 07:10:43 +0800 Subject: [PATCH 05/13] remove macros --- paddle/operators/activation_op.cc | 134 ++++++++++++++---------- paddle/operators/activation_op.h | 162 ++++++++++++++++++++++-------- paddle/pybind/pybind.cc | 2 +- 3 files changed, 203 insertions(+), 95 deletions(-) diff --git a/paddle/operators/activation_op.cc b/paddle/operators/activation_op.cc index d2c2378fef..e713b5a211 100644 --- a/paddle/operators/activation_op.cc +++ b/paddle/operators/activation_op.cc @@ -14,33 +14,55 @@ #include "paddle/operators/activation_op.h" -#define FILL_ACTIVATION_OP \ - public: \ - using framework::OperatorWithKernel::OperatorWithKernel; \ - \ - protected: \ - void InferShape(const framework::InferShapeContext &ctx) const override { \ - ctx.Output("Y")->Resize( \ - ctx.Input("X")->dims()); \ - } - -#define FILL_ACTIVATION_GRAD_OP \ - public: \ - using framework::OperatorWithKernel::OperatorWithKernel; \ - \ - protected: \ - void InferShape(const framework::InferShapeContext &ctx) const override { \ - ctx.Output(framework::GradVarName("X")) \ - ->Resize(ctx.Input("Y")->dims()); \ - } +// #define FILL_ACTIVATION_OP \ +// public: \ +// using framework::OperatorWithKernel::OperatorWithKernel; \ +// \ +// protected: \ +// void InferShape(const framework::InferShapeContext &ctx) const override { \ +// ctx.Output("Y")->Resize( \ +// ctx.Input("X")->dims()); \ +// } + +// #define FILL_ACTIVATION_GRAD_OP \ +// public: \ +// using framework::OperatorWithKernel::OperatorWithKernel; \ +// \ +// protected: \ +// void InferShape(const framework::InferShapeContext &ctx) const override { \ +// ctx.Output(framework::GradVarName("X")) \ +// ->Resize(ctx.Input("Y")->dims()); \ +// } namespace paddle { namespace operators { -class SigmoidOp : public framework::OperatorWithKernel { - FILL_ACTIVATION_OP +class ActivationOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + ctx.Output("Y")->Resize( + ctx.Input("X")->dims()); + } }; +class ActivationOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + ctx.Output(framework::GradVarName("X")) + ->Resize(ctx.Input("Y")->dims()); + } +}; + +// class SigmoidOp : public framework::OperatorWithKernel { +// FILL_ACTIVATION_OP +// }; + class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker { public: SigmoidOpMaker(framework::OpProto *proto, @@ -52,13 +74,13 @@ class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker { } }; -class SigmoidOpGrad : public framework::OperatorWithKernel { - FILL_ACTIVATION_GRAD_OP -}; +// class SigmoidOpGrad : public framework::OperatorWithKernel { +// FILL_ACTIVATION_GRAD_OP +// }; -class ExpOp : public framework::OperatorWithKernel { - FILL_ACTIVATION_OP -}; +// class ExpOp : public framework::OperatorWithKernel { +// FILL_ACTIVATION_OP +// }; class ExpOpMaker : public framework::OpProtoAndCheckerMaker { public: @@ -70,13 +92,13 @@ class ExpOpMaker : public framework::OpProtoAndCheckerMaker { } }; -class ExpOpGrad : public framework::OperatorWithKernel { - FILL_ACTIVATION_GRAD_OP -}; +// class ExpOpGrad : public framework::OperatorWithKernel { +// FILL_ACTIVATION_GRAD_OP +// }; -class ReluOp : public framework::OperatorWithKernel { - FILL_ACTIVATION_OP -}; +// class ReluOp : public framework::OperatorWithKernel { +// FILL_ACTIVATION_OP +// }; class ReluOpMaker : public framework::OpProtoAndCheckerMaker { public: @@ -88,28 +110,36 @@ class ReluOpMaker : public framework::OpProtoAndCheckerMaker { } }; -class ReluOpGrad : public framework::OperatorWithKernel { - FILL_ACTIVATION_GRAD_OP -}; +// class ReluOpGrad : public framework::OperatorWithKernel { +// FILL_ACTIVATION_GRAD_OP +// }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker, sigmoid_grad, - ops::SigmoidOpGrad); -REGISTER_OP_CPU_KERNEL(sigmoid, - ops::SigmoidKernel); +REGISTER_OP(sigmoid, ops::ActivationOp, ops::SigmoidOpMaker, sigmoid_grad, + ops::ActivationOpGrad); +REGISTER_OP_CPU_KERNEL( + sigmoid, + ops::ActivationKernel); +REGISTER_OP_CPU_KERNEL(sigmoid_grad, + ops::ActivationGradKernel); + +REGISTER_OP(exp, ops::ActivationOp, ops::ExpOpMaker, exp_grad, + ops::ActivationOpGrad); +REGISTER_OP_CPU_KERNEL( + exp, ops::ActivationKernel); REGISTER_OP_CPU_KERNEL( - sigmoid_grad, ops::SigmoidGradKernel); - -REGISTER_OP(exp, ops::ExpOp, ops::ExpOpMaker, exp_grad, ops::ExpOpGrad); -REGISTER_OP_CPU_KERNEL(exp, ops::ExpKernel); -REGISTER_OP_CPU_KERNEL(exp_grad, - ops::ExpGradKernel); - -REGISTER_OP(relu, ops::ReluOp, ops::ReluOpMaker, relu_grad, ops::ReluOpGrad); -REGISTER_OP_CPU_KERNEL(relu, - ops::ReluKernel); -REGISTER_OP_CPU_KERNEL(relu_grad, - ops::ReluGradKernel); + exp_grad, + ops::ActivationGradKernel); + +// REGISTER_OP(relu, ops::ActivationOp, ops::ReluOpMaker, relu_grad, +// ops::ActivationOpGrad); +// REGISTER_OP_CPU_KERNEL(relu, +// ops::ReluKernel); +// REGISTER_OP_CPU_KERNEL(relu_grad, +// ops::ReluGradKernel); diff --git a/paddle/operators/activation_op.h b/paddle/operators/activation_op.h index 9e4101805e..7d5c5bb26f 100644 --- a/paddle/operators/activation_op.h +++ b/paddle/operators/activation_op.h @@ -15,57 +15,135 @@ #pragma once #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" -#include "paddle/operators/math/activation_functor.h" - -#define ACTIVATION_KERNEL_NAME(ACTIVATION_NAME) ACTIVATION_NAME##Kernel - -#define DEFINE_ACTIVATION_KERNEL(ACTIVATION_NAME) \ - template \ - class ACTIVATION_KERNEL_NAME(ACTIVATION_NAME) : public framework::OpKernel { \ - public: \ - void Compute(const framework::ExecutionContext& context) const override { \ - auto* X = context.Input("X"); \ - auto* Y = context.Output("Y"); \ - Y->mutable_data(context.GetPlace()); \ - math::ACTIVATION_NAME functor; \ - auto* device_context = context.device_context(); \ - functor(*device_context, *X, Y); \ - } \ - }; - -#define DEFINE_ACTIVATION_GRAD_KERNEL(ACTIVATION_GRAD_NAME) \ - template \ - class ACTIVATION_KERNEL_NAME(ACTIVATION_GRAD_NAME) \ - : public framework::OpKernel { \ - public: \ - void Compute(const framework::ExecutionContext& context) const override { \ - auto* X = context.Input("X"); \ - auto* Y = context.Input("Y"); \ - auto* dY = \ - context.Input(framework::GradVarName("Y")); \ - auto* dX = \ - context.Output(framework::GradVarName("X")); \ - dX->mutable_data(context.GetPlace()); \ - math::ACTIVATION_GRAD_NAME functor; \ - auto* device_context = context.device_context(); \ - functor(*device_context, *X, *Y, *dY, dX); \ - } \ - }; +// #include "paddle/operators/math/activation_functor.h" + +// #define ACTIVATION_KERNEL_NAME(ACTIVATION_NAME) ACTIVATION_NAME##Kernel + +// #define DEFINE_ACTIVATION_KERNEL(ACTIVATION_NAME) \ +// template \ +// class ACTIVATION_KERNEL_NAME(ACTIVATION_NAME) : public framework::OpKernel { \ +// public: \ +// void Compute(const framework::ExecutionContext& context) const override { \ +// auto* X = context.Input("X"); \ +// auto* Y = context.Output("Y"); \ +// Y->mutable_data(context.GetPlace()); \ +// math::ACTIVATION_NAME functor; \ +// auto* device_context = context.device_context(); \ +// functor(*device_context, *X, Y); \ +// } \ +// }; + +// #define DEFINE_ACTIVATION_GRAD_KERNEL(ACTIVATION_GRAD_NAME) \ +// template \ +// class ACTIVATION_KERNEL_NAME(ACTIVATION_GRAD_NAME) \ +// : public framework::OpKernel { \ +// public: \ +// void Compute(const framework::ExecutionContext& context) const override { \ +// auto* X = context.Input("X"); \ +// auto* Y = context.Input("Y"); \ +// auto* dY = \ +// context.Input(framework::GradVarName("Y")); \ +// auto* dX = \ +// context.Output(framework::GradVarName("X")); \ +// dX->mutable_data(context.GetPlace()); \ +// math::ACTIVATION_GRAD_NAME functor; \ +// auto* device_context = context.device_context(); \ +// functor(*device_context, *X, *Y, *dY, dX); \ +// } \ +// }; namespace paddle { namespace operators { -DEFINE_ACTIVATION_KERNEL(Sigmoid); +template +class ActivationKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* X = context.Input("X"); + auto* Y = context.Output("Y"); + Y->mutable_data(context.GetPlace()); + + auto x = framework::EigenVector::Flatten(*X); + auto y = framework::EigenVector::Flatten(*Y); + auto place = context.GetEigenDevice(); + Functor functor; + functor(place, x, y); + } +}; + +template +class ActivationGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* X = context.Input("X"); + auto* Y = context.Input("Y"); + auto* dY = context.Input(framework::GradVarName("Y")); + auto* dX = context.Output(framework::GradVarName("X")); + dX->mutable_data(context.GetPlace()); + + auto dy = framework::EigenVector::Flatten(*dY); + auto x = framework::EigenVector::Flatten(*X); + auto y = framework::EigenVector::Flatten(*Y); + auto dx = framework::EigenVector::Flatten(*dX); + auto place = context.GetEigenDevice(); + Functor functor; + functor(place, x, y, dy, dx); + } +}; + +struct Sigmoid { + template + void operator()(Device d, X x, Y y) { + y.device(d) = 1. / (1. + (-x).exp()); + } +}; + +struct SigmoidGrad { + template + void operator()(Device d, X x, Y y, dY dy, dX dx) { + dx.device(d) = dy * y * (1. - y); + } +}; + +struct Exp { + template + void operator()(Device d, X x, Y y) { + y.device(d) = x.exp(); + } +}; + +struct ExpGrad { + template + void operator()(Device d, X x, Y y, dY dy, dX dx) { + dx.device(d) = y; + } +}; + +// template +// struct Relu { +// void operator()(Device d, X x, Y y) { +// y.device(d) = x.cwiseMax(static_cast(0)); +// } +// }; + +// template +// struct ReluGrad { +// void operator()(Device d, X x, Y y, dY dy, dX dx) { +// dx.device(d) = dy * (x > static_cast(0)).template cast(); +// } +// }; + +// DEFINE_ACTIVATION_KERNEL(Sigmoid); -DEFINE_ACTIVATION_GRAD_KERNEL(SigmoidGrad); +// DEFINE_ACTIVATION_GRAD_KERNEL(SigmoidGrad); -DEFINE_ACTIVATION_KERNEL(Exp); +// DEFINE_ACTIVATION_KERNEL(Exp); -DEFINE_ACTIVATION_GRAD_KERNEL(ExpGrad); +// DEFINE_ACTIVATION_GRAD_KERNEL(ExpGrad); -DEFINE_ACTIVATION_KERNEL(Relu); +// DEFINE_ACTIVATION_KERNEL(Relu); -DEFINE_ACTIVATION_GRAD_KERNEL(ReluGrad); +// DEFINE_ACTIVATION_GRAD_KERNEL(ReluGrad); } // namespace operators } // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index bd964c5d07..bed35d7822 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -56,7 +56,7 @@ USE_OP(sum); USE_OP(reshape); USE_OP(sigmoid); USE_OP(exp); -USE_OP(relu); +// USE_OP(relu); namespace paddle { namespace framework { From 0957fa7b3c8b8929aa3a8fd94e33a75af3c314dc Mon Sep 17 00:00:00 2001 From: qijun Date: Thu, 14 Sep 2017 07:33:07 +0800 Subject: [PATCH 06/13] fix relu functor and revert some codes --- paddle/framework/operator.cc | 4 +- paddle/framework/operator.h | 25 ++++-- paddle/operators/activation_op.cc | 79 ++++-------------- paddle/operators/activation_op.cu | 22 +++-- paddle/operators/activation_op.h | 82 ++++-------------- paddle/operators/math/activation_functor.h | 96 ---------------------- paddle/platform/device_context.cc | 7 +- paddle/platform/device_context.h | 19 +---- paddle/platform/device_context_test.cc | 2 +- paddle/pybind/pybind.cc | 2 +- 10 files changed, 78 insertions(+), 260 deletions(-) delete mode 100644 paddle/operators/math/activation_functor.h diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 25c545d3f9..e1e122091f 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -22,14 +22,14 @@ namespace framework { template <> Eigen::DefaultDevice& ExecutionContext::GetEigenDevice< platform::CPUPlace, Eigen::DefaultDevice>() const { - return *device_context_->get_eigen_device(); + return *device_context_->get_eigen_device(); } #ifndef PADDLE_ONLY_CPU template <> Eigen::GpuDevice& ExecutionContext::GetEigenDevice() const { - return *device_context_->get_eigen_device(); + return *device_context_->get_eigen_device(); } #endif diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 0970797e02..4600b06009 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -139,9 +139,9 @@ class OperatorBase { // Macro for define a clone method. // If you are writing an kernel operator, `Clone` will be defined when you // register it. i.e. `Clone` method is not needed to define by yourself. -#define DEFINE_OP_CLONE_METHOD(cls) \ - std::unique_ptr<::paddle::framework::OperatorBase> Clone() const final { \ - return std::unique_ptr<::paddle::framework::OperatorBase>(new cls(*this)); \ +#define DEFINE_OP_CLONE_METHOD(cls) \ + std::unique_ptr Clone() const final { \ + return std::unique_ptr(new cls(*this)); \ } // Macro for define a default constructor for Operator. @@ -331,6 +331,21 @@ class InferShapeContext { const Scope& scope_; }; +template +struct EigenDeviceConverter; + +template <> +struct EigenDeviceConverter { + using EigenDeviceType = Eigen::DefaultDevice; +}; + +#ifndef PADDLE_ONLY_CPU +template <> +struct EigenDeviceConverter { + using EigenDeviceType = Eigen::GpuDevice; +}; +#endif + class ExecutionContext : public InferShapeContext { public: ExecutionContext(const OperatorBase& op, const Scope& scope, @@ -338,8 +353,8 @@ class ExecutionContext : public InferShapeContext { : InferShapeContext(op, scope), device_context_(device_context) {} template ::EigenDeviceType> + typename DeviceType = + typename EigenDeviceConverter::EigenDeviceType> DeviceType& GetEigenDevice() const; platform::Place GetPlace() const { return device_context_->GetPlace(); } diff --git a/paddle/operators/activation_op.cc b/paddle/operators/activation_op.cc index e713b5a211..ffa5c26da3 100644 --- a/paddle/operators/activation_op.cc +++ b/paddle/operators/activation_op.cc @@ -14,26 +14,6 @@ #include "paddle/operators/activation_op.h" -// #define FILL_ACTIVATION_OP \ -// public: \ -// using framework::OperatorWithKernel::OperatorWithKernel; \ -// \ -// protected: \ -// void InferShape(const framework::InferShapeContext &ctx) const override { \ -// ctx.Output("Y")->Resize( \ -// ctx.Input("X")->dims()); \ -// } - -// #define FILL_ACTIVATION_GRAD_OP \ -// public: \ -// using framework::OperatorWithKernel::OperatorWithKernel; \ -// \ -// protected: \ -// void InferShape(const framework::InferShapeContext &ctx) const override { \ -// ctx.Output(framework::GradVarName("X")) \ -// ->Resize(ctx.Input("Y")->dims()); \ -// } - namespace paddle { namespace operators { @@ -59,10 +39,6 @@ class ActivationOpGrad : public framework::OperatorWithKernel { } }; -// class SigmoidOp : public framework::OperatorWithKernel { -// FILL_ACTIVATION_OP -// }; - class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker { public: SigmoidOpMaker(framework::OpProto *proto, @@ -74,14 +50,6 @@ class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker { } }; -// class SigmoidOpGrad : public framework::OperatorWithKernel { -// FILL_ACTIVATION_GRAD_OP -// }; - -// class ExpOp : public framework::OperatorWithKernel { -// FILL_ACTIVATION_OP -// }; - class ExpOpMaker : public framework::OpProtoAndCheckerMaker { public: ExpOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) @@ -92,14 +60,6 @@ class ExpOpMaker : public framework::OpProtoAndCheckerMaker { } }; -// class ExpOpGrad : public framework::OperatorWithKernel { -// FILL_ACTIVATION_GRAD_OP -// }; - -// class ReluOp : public framework::OperatorWithKernel { -// FILL_ACTIVATION_OP -// }; - class ReluOpMaker : public framework::OpProtoAndCheckerMaker { public: ReluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) @@ -110,36 +70,33 @@ class ReluOpMaker : public framework::OpProtoAndCheckerMaker { } }; -// class ReluOpGrad : public framework::OperatorWithKernel { -// FILL_ACTIVATION_GRAD_OP -// }; - } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OP(sigmoid, ops::ActivationOp, ops::SigmoidOpMaker, sigmoid_grad, ops::ActivationOpGrad); +REGISTER_OP_CPU_KERNEL(sigmoid, + ops::ActivationKernel); REGISTER_OP_CPU_KERNEL( - sigmoid, - ops::ActivationKernel); -REGISTER_OP_CPU_KERNEL(sigmoid_grad, - ops::ActivationGradKernel); + sigmoid_grad, ops::ActivationGradKernel); REGISTER_OP(exp, ops::ActivationOp, ops::ExpOpMaker, exp_grad, ops::ActivationOpGrad); REGISTER_OP_CPU_KERNEL( - exp, ops::ActivationKernel); + exp, + ops::ActivationKernel); +REGISTER_OP_CPU_KERNEL(exp_grad, + ops::ActivationGradKernel); + +REGISTER_OP(relu, ops::ActivationOp, ops::ReluOpMaker, relu_grad, + ops::ActivationOpGrad); +REGISTER_OP_CPU_KERNEL(relu, + ops::ActivationKernel>); REGISTER_OP_CPU_KERNEL( - exp_grad, - ops::ActivationGradKernel); - -// REGISTER_OP(relu, ops::ActivationOp, ops::ReluOpMaker, relu_grad, -// ops::ActivationOpGrad); -// REGISTER_OP_CPU_KERNEL(relu, -// ops::ReluKernel); -// REGISTER_OP_CPU_KERNEL(relu_grad, -// ops::ReluGradKernel); + relu_grad, ops::ActivationGradKernel>); diff --git a/paddle/operators/activation_op.cu b/paddle/operators/activation_op.cu index 55d9f52124..3b2c147f46 100644 --- a/paddle/operators/activation_op.cu +++ b/paddle/operators/activation_op.cu @@ -18,15 +18,21 @@ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL(sigmoid, - ops::SigmoidKernel); + ops::ActivationKernel); REGISTER_OP_GPU_KERNEL( - sigmoid_grad, ops::SigmoidGradKernel); + sigmoid_grad, ops::ActivationGradKernel); -REGISTER_OP_GPU_KERNEL(exp, ops::ExpKernel); +REGISTER_OP_GPU_KERNEL( + exp, + ops::ActivationKernel); REGISTER_OP_GPU_KERNEL(exp_grad, - ops::ExpGradKernel); - + ops::ActivationGradKernel); REGISTER_OP_GPU_KERNEL(relu, - ops::ReluKernel); -REGISTER_OP_GPU_KERNEL(relu_grad, - ops::ReluGradKernel); + ops::ActivationKernel>); +REGISTER_OP_GPU_KERNEL( + relu_grad, ops::ActivationGradKernel>); diff --git a/paddle/operators/activation_op.h b/paddle/operators/activation_op.h index 7d5c5bb26f..0b7e171e72 100644 --- a/paddle/operators/activation_op.h +++ b/paddle/operators/activation_op.h @@ -15,42 +15,6 @@ #pragma once #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" -// #include "paddle/operators/math/activation_functor.h" - -// #define ACTIVATION_KERNEL_NAME(ACTIVATION_NAME) ACTIVATION_NAME##Kernel - -// #define DEFINE_ACTIVATION_KERNEL(ACTIVATION_NAME) \ -// template \ -// class ACTIVATION_KERNEL_NAME(ACTIVATION_NAME) : public framework::OpKernel { \ -// public: \ -// void Compute(const framework::ExecutionContext& context) const override { \ -// auto* X = context.Input("X"); \ -// auto* Y = context.Output("Y"); \ -// Y->mutable_data(context.GetPlace()); \ -// math::ACTIVATION_NAME functor; \ -// auto* device_context = context.device_context(); \ -// functor(*device_context, *X, Y); \ -// } \ -// }; - -// #define DEFINE_ACTIVATION_GRAD_KERNEL(ACTIVATION_GRAD_NAME) \ -// template \ -// class ACTIVATION_KERNEL_NAME(ACTIVATION_GRAD_NAME) \ -// : public framework::OpKernel { \ -// public: \ -// void Compute(const framework::ExecutionContext& context) const override { \ -// auto* X = context.Input("X"); \ -// auto* Y = context.Input("Y"); \ -// auto* dY = \ -// context.Input(framework::GradVarName("Y")); \ -// auto* dX = \ -// context.Output(framework::GradVarName("X")); \ -// dX->mutable_data(context.GetPlace()); \ -// math::ACTIVATION_GRAD_NAME functor; \ -// auto* device_context = context.device_context(); \ -// functor(*device_context, *X, *Y, *dY, dX); \ -// } \ -// }; namespace paddle { namespace operators { @@ -91,59 +55,49 @@ class ActivationGradKernel : public framework::OpKernel { } }; -struct Sigmoid { +struct SigmoidFunctor { template void operator()(Device d, X x, Y y) { y.device(d) = 1. / (1. + (-x).exp()); } }; -struct SigmoidGrad { +struct SigmoidGradFunctor { template void operator()(Device d, X x, Y y, dY dy, dX dx) { dx.device(d) = dy * y * (1. - y); } }; -struct Exp { +struct ExpFunctor { template void operator()(Device d, X x, Y y) { y.device(d) = x.exp(); } }; -struct ExpGrad { +struct ExpGradFunctor { template void operator()(Device d, X x, Y y, dY dy, dX dx) { dx.device(d) = y; } }; -// template -// struct Relu { -// void operator()(Device d, X x, Y y) { -// y.device(d) = x.cwiseMax(static_cast(0)); -// } -// }; - -// template -// struct ReluGrad { -// void operator()(Device d, X x, Y y, dY dy, dX dx) { -// dx.device(d) = dy * (x > static_cast(0)).template cast(); -// } -// }; - -// DEFINE_ACTIVATION_KERNEL(Sigmoid); - -// DEFINE_ACTIVATION_GRAD_KERNEL(SigmoidGrad); - -// DEFINE_ACTIVATION_KERNEL(Exp); - -// DEFINE_ACTIVATION_GRAD_KERNEL(ExpGrad); - -// DEFINE_ACTIVATION_KERNEL(Relu); +template +struct ReluFunctor { + template + void operator()(Device d, X x, Y y) { + y.device(d) = x.cwiseMax(static_cast(0)); + } +}; -// DEFINE_ACTIVATION_GRAD_KERNEL(ReluGrad); +template +struct ReluGradFunctor { + template + void operator()(Device d, X x, Y y, dY dy, dX dx) { + dx.device(d) = dy * (x > static_cast(0)).template cast(); + } +}; } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/activation_functor.h b/paddle/operators/math/activation_functor.h deleted file mode 100644 index 1e9bdd142e..0000000000 --- a/paddle/operators/math/activation_functor.h +++ /dev/null @@ -1,96 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#pragma once -#include "paddle/framework/eigen.h" -#include "paddle/framework/tensor.h" - -namespace paddle { -namespace operators { -namespace math { - -template -struct Sigmoid { - void operator()(const platform::DeviceContext& device_context, - const framework::Tensor& X, framework::Tensor* Y) { - auto x = framework::EigenVector::Flatten(X); - auto y = framework::EigenVector::Flatten(*Y); - auto* place = device_context.template get_eigen_device(); - y.device(*place) = 1. / (1. + (-x).exp()); - } -}; - -template -struct SigmoidGrad { - void operator()(const platform::DeviceContext& device_context, - const framework::Tensor& X, const framework::Tensor& Y, - const framework::Tensor& dY, framework::Tensor* dX) { - auto dx = framework::EigenVector::Flatten(*dX); - auto y = framework::EigenVector::Flatten(Y); - auto dy = framework::EigenVector::Flatten(dY); - auto* place = device_context.template get_eigen_device(); - dx.device(*place) = dy * y * (1. - y); - } -}; - -template -struct Exp { - void operator()(const platform::DeviceContext& device_context, - const framework::Tensor& input, framework::Tensor* output) { - auto x = framework::EigenVector::Flatten(input); - auto y = framework::EigenVector::Flatten(*output); - auto* place = device_context.template get_eigen_device(); - y.device(*place) = x.exp(); - } -}; - -template -struct ExpGrad { - void operator()(const platform::DeviceContext& device_context, - const framework::Tensor& X, const framework::Tensor& Y, - const framework::Tensor& dY, framework::Tensor* dX) { - auto dx = framework::EigenVector::Flatten(*dX); - auto y = framework::EigenVector::Flatten(Y); - auto* place = device_context.template get_eigen_device(); - dx.device(*place) = y; - } -}; - -template -struct Relu { - void operator()(const platform::DeviceContext& device_context, - const framework::Tensor& input, framework::Tensor* output) { - auto x = framework::EigenVector::Flatten(input); - auto y = framework::EigenVector::Flatten(*output); - auto* place = device_context.template get_eigen_device(); - y.device(*place) = x.cwiseMax(static_cast(0)); - } -}; - -template -struct ReluGrad { - void operator()(const platform::DeviceContext& device_context, - const framework::Tensor& X, const framework::Tensor& Y, - const framework::Tensor& dY, framework::Tensor* dX) { - auto dx = framework::EigenVector::Flatten(*dX); - auto dy = framework::EigenVector::Flatten(dY); - auto x = framework::EigenVector::Flatten(X); - auto* place = device_context.template get_eigen_device(); - dx.device(*place) = dy * (x > static_cast(0)).template cast(); - } -}; - -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index cf5c3eec81..ad212c5b2c 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -16,8 +16,8 @@ namespace paddle { namespace platform { template <> -Eigen::DefaultDevice* -DeviceContext::get_eigen_device() const { +Eigen::DefaultDevice* DeviceContext::get_eigen_device() + const { return reinterpret_cast(this)->eigen_device(); } @@ -91,8 +91,7 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface { }; template <> -Eigen::GpuDevice* DeviceContext::get_eigen_device() - const { +Eigen::GpuDevice* DeviceContext::get_eigen_device() const { return reinterpret_cast(this)->eigen_device(); } diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index a46ba4c703..11528e1194 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -27,29 +27,12 @@ limitations under the License. */ namespace paddle { namespace platform { -template -struct EigenDeviceConverter; - -template <> -struct EigenDeviceConverter { - using EigenDeviceType = Eigen::DefaultDevice; -}; - -#ifndef PADDLE_ONLY_CPU -template <> -struct EigenDeviceConverter { - using EigenDeviceType = Eigen::GpuDevice; -}; -#endif - class DeviceContext { public: virtual ~DeviceContext() {} virtual Place GetPlace() const = 0; - template ::EigenDeviceType> + template DeviceType* get_eigen_device() const; }; diff --git a/paddle/platform/device_context_test.cc b/paddle/platform/device_context_test.cc index d71e0aae58..5883a55272 100644 --- a/paddle/platform/device_context_test.cc +++ b/paddle/platform/device_context_test.cc @@ -24,7 +24,7 @@ TEST(Device, Init) { for (int i = 0; i < count; i++) { DeviceContext* device_context = new CUDADeviceContext(GPUPlace(i)); Eigen::GpuDevice* gpu_device = - device_context->template get_eigen_device(); + device_context->template get_eigen_device(); ASSERT_NE(nullptr, gpu_device); delete device_context; } diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index bed35d7822..bd964c5d07 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -56,7 +56,7 @@ USE_OP(sum); USE_OP(reshape); USE_OP(sigmoid); USE_OP(exp); -// USE_OP(relu); +USE_OP(relu); namespace paddle { namespace framework { From 3110bf9a9aaa8cbc3f52bd171b12340b299481db Mon Sep 17 00:00:00 2001 From: qijun Date: Thu, 14 Sep 2017 07:37:31 +0800 Subject: [PATCH 07/13] merge activation operator python tests --- .../v2/framework/tests/test_activation_op.py | 50 +++++++++++++++++++ .../paddle/v2/framework/tests/test_exp_op.py | 22 -------- .../paddle/v2/framework/tests/test_relu_op.py | 20 -------- .../v2/framework/tests/test_sigmoid_op.py | 22 -------- 4 files changed, 50 insertions(+), 64 deletions(-) create mode 100644 python/paddle/v2/framework/tests/test_activation_op.py delete mode 100644 python/paddle/v2/framework/tests/test_exp_op.py delete mode 100644 python/paddle/v2/framework/tests/test_relu_op.py delete mode 100644 python/paddle/v2/framework/tests/test_sigmoid_op.py diff --git a/python/paddle/v2/framework/tests/test_activation_op.py b/python/paddle/v2/framework/tests/test_activation_op.py new file mode 100644 index 0000000000..23ff584396 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_activation_op.py @@ -0,0 +1,50 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestExp(OpTest): + def setUp(self): + self.op_type = "exp" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") + } + self.outputs = {'Y': np.exp(self.inputs['X'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.007) + + +class TestRelu(OpTest): + def setUp(self): + self.op_type = "relu" + self.inputs = {'X': np.random.uniform(-1, 1, [4, 4]).astype("float32")} + self.outputs = {'Y': np.maximum(self.inputs['X'], 0)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.007) + + +class TestSigmoid(OpTest): + def setUp(self): + self.op_type = "sigmoid" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") + } + self.outputs = {'Y': 1 / (1 + np.exp(-self.inputs['X']))} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.007) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_exp_op.py b/python/paddle/v2/framework/tests/test_exp_op.py deleted file mode 100644 index 0ec41e56a0..0000000000 --- a/python/paddle/v2/framework/tests/test_exp_op.py +++ /dev/null @@ -1,22 +0,0 @@ -import unittest -import numpy as np -from op_test import OpTest - - -class TestExp(OpTest): - def setUp(self): - self.op_type = "exp" - self.inputs = { - 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") - } - self.outputs = {'Y': np.exp(self.inputs['X'])} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(['X'], 'Y', max_relative_error=0.007) - - -if __name__ == "__main__": - unittest.main() diff --git a/python/paddle/v2/framework/tests/test_relu_op.py b/python/paddle/v2/framework/tests/test_relu_op.py deleted file mode 100644 index c9af0c2ba7..0000000000 --- a/python/paddle/v2/framework/tests/test_relu_op.py +++ /dev/null @@ -1,20 +0,0 @@ -import unittest -import numpy as np -from op_test import OpTest - - -class TestRelu(OpTest): - def setUp(self): - self.op_type = "relu" - self.inputs = {'X': np.random.uniform(-1, 1, [4, 4]).astype("float32")} - self.outputs = {'Y': np.maximum(self.inputs['X'], 0)} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(['X'], 'Y', max_relative_error=0.007) - - -if __name__ == "__main__": - unittest.main() diff --git a/python/paddle/v2/framework/tests/test_sigmoid_op.py b/python/paddle/v2/framework/tests/test_sigmoid_op.py deleted file mode 100644 index cf05e934d5..0000000000 --- a/python/paddle/v2/framework/tests/test_sigmoid_op.py +++ /dev/null @@ -1,22 +0,0 @@ -import unittest -import numpy as np -from op_test import OpTest - - -class TestSigmoid(OpTest): - def setUp(self): - self.op_type = "sigmoid" - self.inputs = { - 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") - } - self.outputs = {'Y': 1 / (1 + np.exp(-self.inputs['X']))} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(['X'], 'Y', max_relative_error=0.007) - - -if __name__ == "__main__": - unittest.main() From e515f18dd857d2f9f986955cd76208a965eb5c5c Mon Sep 17 00:00:00 2001 From: qijun Date: Thu, 14 Sep 2017 10:26:41 +0800 Subject: [PATCH 08/13] add tanh and sqrt activation operators --- paddle/operators/activation_op.h | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/paddle/operators/activation_op.h b/paddle/operators/activation_op.h index 0b7e171e72..4421c10957 100644 --- a/paddle/operators/activation_op.h +++ b/paddle/operators/activation_op.h @@ -99,5 +99,36 @@ struct ReluGradFunctor { } }; +struct TanhFunctor { + template + void operator()(Device d, X x, Y y) { + y.device(d) = x.tanh(); + } +}; + +template +struct TanhGradFunctor { + template + void operator()(Device d, X x, Y y, dY dy, dX dx) { + dx.device(d) = dy * (T(1) - y * y); + } +}; + +struct SqrtFunctor { + template + void operator()(Device d, X x, Y y) { + y.device(d) = x.sqrt(); + } +}; + +template +struct SqrtGradFunctor { + template + void operator()(Device d, X x, Y y, dY dy, dX dx) { + const T y_conj = Eigen::numext::conj(y); + dx.device(d) = static_cast(0.5) * dy / y_conj; + } +}; + } // namespace operators } // namespace paddle From dadace3178ab1f038bec7d8fcdfb849e8fc6963f Mon Sep 17 00:00:00 2001 From: qijun Date: Thu, 14 Sep 2017 14:02:29 +0800 Subject: [PATCH 09/13] add more activation functors --- paddle/operators/activation_op.h | 62 +++++++++++++++++++++++++++++++- 1 file changed, 61 insertions(+), 1 deletion(-) diff --git a/paddle/operators/activation_op.h b/paddle/operators/activation_op.h index 4421c10957..9bf340f2ed 100644 --- a/paddle/operators/activation_op.h +++ b/paddle/operators/activation_op.h @@ -55,6 +55,8 @@ class ActivationGradKernel : public framework::OpKernel { } }; +// sigmoid = 1 / (1 + exp(-x) +template struct SigmoidFunctor { template void operator()(Device d, X x, Y y) { @@ -69,6 +71,7 @@ struct SigmoidGradFunctor { } }; +// exp(x) = e^x struct ExpFunctor { template void operator()(Device d, X x, Y y) { @@ -79,10 +82,11 @@ struct ExpFunctor { struct ExpGradFunctor { template void operator()(Device d, X x, Y y, dY dy, dX dx) { - dx.device(d) = y; + dx.device(d) = dy * y; } }; +// relu(x) = max(x, 0) template struct ReluFunctor { template @@ -99,6 +103,7 @@ struct ReluGradFunctor { } }; +// tanh = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) struct TanhFunctor { template void operator()(Device d, X x, Y y) { @@ -114,6 +119,7 @@ struct TanhGradFunctor { } }; +// sqrt(x) = x^(1/2) struct SqrtFunctor { template void operator()(Device d, X x, Y y) { @@ -130,5 +136,59 @@ struct SqrtGradFunctor { } }; +// abs(x) = |x| +struct AbsFunctor { + template + void operator()(Device d, X x, Y y) { + y.device(d) = x.abs(); + } +}; + +// reciprocal(x) = 1 / x +template +struct ReciprocalFunctor { + template + void operator()(Device d, X x, Y y) { + y.device(d) = 1. / x; + } +}; + +struct ReciprocalGradFunctor { + template + void operator()(Device d, X x, Y y, dY dy, dX dx) { + dx.device(d) = dy * (-1.0) * y * y; + } +}; + +// log(x) = natural logarithm of x +struct LogFunctor { + template + void operator()(Device d, X x, Y y) { + y.device(d) = x.log(); + } +}; + +struct LogGradFunctor { + template + void operator()(Device d, X x, Y y, dY dy, dX dx) { + dx.device(d) = dy * (1. / x); + } +}; + +// square(x) = x^2 +struct SquareFunctor { + template + void operator()(Device d, X x, Y y) { + y.device(d) = x.square(); + } +} + +struct SquareGradFunctor { + template + void operator()(Device d, X x, Y y, dY dy, dX dx) { + dx.device(d) = dy * 2 * x; + } +}; + } // namespace operators } // namespace paddle From 5824d850012e0c802e90f2ad7d23f4b8e3fc00d2 Mon Sep 17 00:00:00 2001 From: qijun Date: Thu, 14 Sep 2017 18:19:13 +0800 Subject: [PATCH 10/13] add activation operators and python unittests --- paddle/operators/activation_op.cc | 214 +++++++++++++++++- paddle/operators/activation_op.cu | 82 +++++++ paddle/operators/activation_op.h | 181 ++++++++++++++- paddle/pybind/pybind.cc | 2 - python/paddle/v2/framework/tests/op_test.py | 2 +- .../v2/framework/tests/test_activation_op.py | 165 +++++++++++++- 6 files changed, 626 insertions(+), 20 deletions(-) diff --git a/paddle/operators/activation_op.cc b/paddle/operators/activation_op.cc index ffa5c26da3..8ada158ff3 100644 --- a/paddle/operators/activation_op.cc +++ b/paddle/operators/activation_op.cc @@ -46,7 +46,7 @@ class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Sigmoid operator"); AddOutput("Y", "Output of Sigmoid operator"); - AddComment("Sigmoid activation operator"); + AddComment("Sigmoid activation operator, sigmoid = 1 / (1 + exp(-x))"); } }; @@ -56,7 +56,7 @@ class ExpOpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Exp operator"); AddOutput("Y", "Output of Exp operator"); - AddComment("Exp activation operator"); + AddComment("Exp activation operator, exp(x) = e^x"); } }; @@ -66,7 +66,129 @@ class ReluOpMaker : public framework::OpProtoAndCheckerMaker { : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Relu operator"); AddOutput("Y", "Output of Relu operator"); - AddComment("Relu activation operator"); + AddComment("Relu activation operator, relu(x) = max(x, 0)"); + } +}; + +class TanhOpMaker : public framework::OpProtoAndCheckerMaker { + public: + TanhOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of Tanh operator"); + AddOutput("Y", "Output of Tanh operator"); + AddComment( + "Tanh activation operator, tanh = (exp(x) - exp(-x)) / (exp(x) + " + "exp(-x))"); + } +}; + +class SqrtOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SqrtOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of Sqrt operator"); + AddOutput("Y", "Output of Sqrt operator"); + AddComment("Sqrt activation operator, sqrt(x) = x^(1/2)"); + } +}; + +class AbsOpMaker : public framework::OpProtoAndCheckerMaker { + public: + AbsOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of Abs operator"); + AddOutput("Y", "Output of Abs operator"); + AddComment("Abs activation operator, abs(x) = |x|"); + } +}; + +class ReciprocalOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ReciprocalOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of Reciprocal operator"); + AddOutput("Y", "Output of Reciprocal operator"); + AddComment("Reciprocal activation operator, reciprocal(x) = 1 / x"); + } +}; + +class LogOpMaker : public framework::OpProtoAndCheckerMaker { + public: + LogOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of Log operator"); + AddOutput("Y", "Output of Log operator"); + AddComment("Log activation operator, log(x) = natural logarithm of x"); + } +}; + +class SquareOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SquareOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of Square operator"); + AddOutput("Y", "Output of Square operator"); + AddComment("Square activation operator, square(x) = x^2"); + } +}; + +template +class BReluOpMaker : public framework::OpProtoAndCheckerMaker { + public: + BReluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of BRelu operator"); + AddOutput("Y", "Output of BRelu operator"); + AddComment("BRelu activation operator, brelu = max(min(x, t_min), t_max)"); + AddAttr("t_min", "The min marginal value of BRelu") + .SetDefault(static_cast(0)); + AddAttr("t_max", "The max marginal value of BRelu") + .SetDefault(static_cast(24)); + } +}; + +template +class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SoftReluOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of SoftRelu operator"); + AddOutput("Y", "Output of SoftRelu operator"); + AddComment( + "SoftRelu activation operator, soft_relu = log(1 + exp(max(min(x, " + "threshold), threshold)))"); + AddAttr("threshold", "The threshold value of SoftRelu") + .SetDefault(static_cast(40)); + } +}; + +template +class PowOpMaker : public framework::OpProtoAndCheckerMaker { + public: + PowOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of Pow operator"); + AddOutput("Y", "Output of Pow operator"); + AddComment("Pow activation operator, pow(x, factor) = x^factor"); + AddAttr("factor", "The exponential factor of Pow") + .SetDefault(static_cast(1)); + } +}; + +template +class STanhOpMaker : public framework::OpProtoAndCheckerMaker { + public: + STanhOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of STanh operator"); + AddOutput("Y", "Output of STanh operator"); + AddComment("STanh activation operator, stanh = b * tanh(a * x)"); + AddAttr("scale_a", "The scale parameter of a for the input") + .SetDefault(static_cast(2 / 3)); + AddAttr("scale_b", "The scale parameter of b for the input") + .SetDefault(static_cast(1.7159)); } }; @@ -78,10 +200,10 @@ REGISTER_OP(sigmoid, ops::ActivationOp, ops::SigmoidOpMaker, sigmoid_grad, ops::ActivationOpGrad); REGISTER_OP_CPU_KERNEL(sigmoid, ops::ActivationKernel); + ops::SigmoidFunctor>); REGISTER_OP_CPU_KERNEL( sigmoid_grad, ops::ActivationGradKernel); + ops::SigmoidGradFunctor>); REGISTER_OP(exp, ops::ActivationOp, ops::ExpOpMaker, exp_grad, ops::ActivationOpGrad); @@ -100,3 +222,85 @@ REGISTER_OP_CPU_KERNEL(relu, REGISTER_OP_CPU_KERNEL( relu_grad, ops::ActivationGradKernel>); + +REGISTER_OP(tanh, ops::ActivationOp, ops::TanhOpMaker, tanh_grad, + ops::ActivationOpGrad); +REGISTER_OP_CPU_KERNEL( + tanh, + ops::ActivationKernel); +REGISTER_OP_CPU_KERNEL( + tanh_grad, ops::ActivationGradKernel>); + +REGISTER_OP(sqrt, ops::ActivationOp, ops::SqrtOpMaker, sqrt_grad, + ops::ActivationOpGrad); +REGISTER_OP_CPU_KERNEL( + sqrt, + ops::ActivationKernel); +REGISTER_OP_CPU_KERNEL( + sqrt_grad, ops::ActivationGradKernel>); + +REGISTER_OP(abs, ops::ActivationOp, ops::AbsOpMaker, abs_grad, + ops::ActivationOpGrad); +REGISTER_OP_CPU_KERNEL( + abs, + ops::ActivationKernel); +REGISTER_OP_CPU_KERNEL(abs_grad, + ops::ActivationGradKernel); + +REGISTER_OP(reciprocal, ops::ActivationOp, ops::ReciprocalOpMaker, + reciprocal_grad, ops::ActivationOpGrad); +REGISTER_OP_CPU_KERNEL(reciprocal, + ops::ActivationKernel>); +REGISTER_OP_CPU_KERNEL( + reciprocal_grad, + ops::ActivationGradKernel>); + +REGISTER_OP(log, ops::ActivationOp, ops::LogOpMaker, log_grad, + ops::ActivationOpGrad); +REGISTER_OP_CPU_KERNEL( + log, + ops::ActivationKernel); +REGISTER_OP_CPU_KERNEL( + log_grad, ops::ActivationGradKernel>); + +REGISTER_OP(square, ops::ActivationOp, ops::SquareOpMaker, square_grad, + ops::ActivationOpGrad); +REGISTER_OP_CPU_KERNEL(square, + ops::ActivationKernel); +REGISTER_OP_CPU_KERNEL( + square_grad, ops::ActivationGradKernel>); + +REGISTER_OP(brelu, ops::ActivationOp, ops::BReluOpMaker, brelu_grad, + ops::ActivationOpGrad); +REGISTER_OP_CPU_KERNEL(brelu, + ops::BReluKernel); +REGISTER_OP_CPU_KERNEL(brelu_grad, + ops::BReluGradKernel); + +REGISTER_OP(soft_relu, ops::ActivationOp, ops::SoftReluOpMaker, + soft_relu_grad, ops::ActivationOpGrad); +REGISTER_OP_CPU_KERNEL(soft_relu, + ops::SoftReluKernel); +REGISTER_OP_CPU_KERNEL( + soft_relu_grad, ops::SoftReluGradKernel); + +REGISTER_OP(pow, ops::ActivationOp, ops::PowOpMaker, pow_grad, + ops::ActivationOpGrad); +REGISTER_OP_CPU_KERNEL(pow, ops::PowKernel); +REGISTER_OP_CPU_KERNEL(pow_grad, + ops::PowGradKernel); + +REGISTER_OP(stanh, ops::ActivationOp, ops::STanhOpMaker, stanh_grad, + ops::ActivationOpGrad); +REGISTER_OP_CPU_KERNEL(stanh, + ops::STanhKernel); +REGISTER_OP_CPU_KERNEL(stanh_grad, + ops::STanhGradKernel); diff --git a/paddle/operators/activation_op.cu b/paddle/operators/activation_op.cu index 3b2c147f46..112b33d225 100644 --- a/paddle/operators/activation_op.cu +++ b/paddle/operators/activation_op.cu @@ -36,3 +36,85 @@ REGISTER_OP_GPU_KERNEL(relu, REGISTER_OP_GPU_KERNEL( relu_grad, ops::ActivationGradKernel>); + +REGISTER_OP(tanh, ops::ActivationOp, ops::TanhOpMaker, tanh_grad, + ops::ActivationOpGrad); +REGISTER_OP_GPU_KERNEL(tanh, + ops::ActivationKernel>); +REGISTER_OP_GPU_KERNEL( + tanh_grad, ops::ActivationGradKernel>); + +REGISTER_OP(sqrt, ops::ActivationOp, ops::SqrtOpMaker, sqrt_grad, + ops::ActivationOpGrad); +REGISTER_OP_GPU_KERNEL(sqrt, + ops::ActivationKernel>); +REGISTER_OP_GPU_KERNEL( + sqrt_grad, ops::ActivationGradKernel>); + +REGISTER_OP(abs, ops::ActivationOp, ops::AbsOpMaker, abs_grad, + ops::ActivationOpGrad); +REGISTER_OP_GPU_KERNEL(abs, + ops::ActivationKernel>); +REGISTER_OP_GPU_KERNEL( + abs_grad, ops::ActivationGradKernel>); + +REGISTER_OP(reciprocal, ops::ActivationOp, ops::ReciprocalOpMaker, + reciprocal_grad, ops::ActivationOpGrad); +REGISTER_OP_GPU_KERNEL(reciprocal, + ops::ActivationKernel>); +REGISTER_OP_GPU_KERNEL( + reciprocal_grad, + ops::ActivationGradKernel>); + +REGISTER_OP(log, ops::ActivationOp, ops::LogOpMaker, log_grad, + ops::ActivationOpGrad); +REGISTER_OP_GPU_KERNEL(log, + ops::ActivationKernel>); +REGISTER_OP_GPU_KERNEL( + log_grad, ops::ActivationGradKernel>); + +REGISTER_OP(square, ops::ActivationOp, ops::SquareOpMaker, square_grad, + ops::ActivationOpGrad); +REGISTER_OP_GPU_KERNEL(square, + ops::ActivationKernel>); +REGISTER_OP_GPU_KERNEL( + square_grad, ops::ActivationGradKernel>); + +REGISTER_OP(brelu, ops::ActivationOp, ops::BReluOpMaker, brelu_grad, + ops::ActivationOpGrad); +REGISTER_OP_GPU_KERNEL(brelu, + ops::BReluKernel); +REGISTER_OP_GPU_KERNEL(brelu_grad, + ops::BReluGradKernel); + +REGISTER_OP(soft_relu, ops::ActivationOp, ops::SoftReluOpMaker, + soft_relu_grad, ops::ActivationOpGrad); +REGISTER_OP_GPU_KERNEL(soft_relu, + ops::SoftReluKernel); +REGISTER_OP_GPU_KERNEL( + soft_relu_grad, ops::SoftReluGradKernel); + +REGISTER_OP(pow, ops::ActivationOp, ops::PowOpMaker, pow_grad, + ops::ActivationOpGrad); +REGISTER_OP_GPU_KERNEL(pow, ops::PowKernel); +REGISTER_OP_GPU_KERNEL(pow_grad, + ops::PowGradKernel); + +REGISTER_OP(stanh, ops::ActivationOp, ops::STanhOpMaker, stanh_grad, + ops::ActivationOpGrad); +REGISTER_OP_GPU_KERNEL(stanh, + ops::STanhKernel); +REGISTER_OP_GPU_KERNEL(stanh_grad, + ops::STanhGradKernel); \ No newline at end of file diff --git a/paddle/operators/activation_op.h b/paddle/operators/activation_op.h index 9bf340f2ed..15f8afb4ba 100644 --- a/paddle/operators/activation_op.h +++ b/paddle/operators/activation_op.h @@ -55,19 +55,20 @@ class ActivationGradKernel : public framework::OpKernel { } }; -// sigmoid = 1 / (1 + exp(-x) +// sigmoid(x) = 1 / (1 + exp(-x)) template struct SigmoidFunctor { template void operator()(Device d, X x, Y y) { - y.device(d) = 1. / (1. + (-x).exp()); + y.device(d) = static_cast(1) / (static_cast(1) + (-x).exp()); } }; +template struct SigmoidGradFunctor { template void operator()(Device d, X x, Y y, dY dy, dX dx) { - dx.device(d) = dy * y * (1. - y); + dx.device(d) = dy * y * (static_cast(1) - y); } }; @@ -103,7 +104,7 @@ struct ReluGradFunctor { } }; -// tanh = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) +// tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) struct TanhFunctor { template void operator()(Device d, X x, Y y) { @@ -115,7 +116,7 @@ template struct TanhGradFunctor { template void operator()(Device d, X x, Y y, dY dy, dX dx) { - dx.device(d) = dy * (T(1) - y * y); + dx.device(d) = dy * (static_cast(1) - y * y); } }; @@ -131,7 +132,7 @@ template struct SqrtGradFunctor { template void operator()(Device d, X x, Y y, dY dy, dX dx) { - const T y_conj = Eigen::numext::conj(y); + const Y y_conj = Eigen::numext::conj(y); dx.device(d) = static_cast(0.5) * dy / y_conj; } }; @@ -144,19 +145,27 @@ struct AbsFunctor { } }; +struct AbsGradFunctor { + template + void operator()(Device d, X x, Y y, dY dy, dX dx) { + dx.device(d) = dy * x.sign(); + } +}; + // reciprocal(x) = 1 / x template struct ReciprocalFunctor { template void operator()(Device d, X x, Y y) { - y.device(d) = 1. / x; + y.device(d) = static_cast(1) / x; } }; +template struct ReciprocalGradFunctor { template void operator()(Device d, X x, Y y, dY dy, dX dx) { - dx.device(d) = dy * (-1.0) * y * y; + dx.device(d) = dy * static_cast(-1) * y * y; } }; @@ -168,10 +177,11 @@ struct LogFunctor { } }; +template struct LogGradFunctor { template void operator()(Device d, X x, Y y, dY dy, dX dx) { - dx.device(d) = dy * (1. / x); + dx.device(d) = dy * (static_cast(1) / x); } }; @@ -181,12 +191,161 @@ struct SquareFunctor { void operator()(Device d, X x, Y y) { y.device(d) = x.square(); } -} +}; +template struct SquareGradFunctor { template void operator()(Device d, X x, Y y, dY dy, dX dx) { - dx.device(d) = dy * 2 * x; + dx.device(d) = dy * static_cast(2) * x; + } +}; + +template +class BReluKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* X = context.Input("X"); + auto* Y = context.Output("Y"); + auto t_min = static_cast(context.Attr("t_min")); + auto t_max = static_cast(context.Attr("t_max")); + Y->mutable_data(context.GetPlace()); + + auto x = framework::EigenVector::Flatten(*X); + auto y = framework::EigenVector::Flatten(*Y); + auto place = context.GetEigenDevice(); + y.device(place) = x.cwiseMax(t_min).cwiseMin(t_max); + } +}; + +template +class BReluGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* X = context.Input("X"); + auto* dY = context.Input(framework::GradVarName("Y")); + auto* dX = context.Output(framework::GradVarName("X")); + auto t_min = static_cast(context.Attr("t_min")); + auto t_max = static_cast(context.Attr("t_max")); + dX->mutable_data(context.GetPlace()); + + auto dy = framework::EigenVector::Flatten(*dY); + auto x = framework::EigenVector::Flatten(*X); + auto dx = framework::EigenVector::Flatten(*dX); + auto place = context.GetEigenDevice(); + + dx.device(place) = dy * ((x > t_min) * (x < t_max)).template cast(); + } +}; + +template +class SoftReluKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* X = context.Input("X"); + auto* Y = context.Output("Y"); + auto threshold = static_cast(context.Attr("threshold")); + Y->mutable_data(context.GetPlace()); + + auto x = framework::EigenVector::Flatten(*X); + auto y = framework::EigenVector::Flatten(*Y); + auto place = context.GetEigenDevice(); + auto temp = x.cwiseMax(-threshold).cwiseMin(threshold).eval(); + y.device(place) = (static_cast(1) + temp.exp()).log(); + } +}; + +template +class SoftReluGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* X = context.Input("X"); + auto* Y = context.Input("Y"); + auto* dY = context.Input(framework::GradVarName("Y")); + auto* dX = context.Output(framework::GradVarName("X")); + auto threshold = static_cast(context.Attr("threshold")); + dX->mutable_data(context.GetPlace()); + + auto x = framework::EigenVector::Flatten(*X); + auto y = framework::EigenVector::Flatten(*Y); + auto dy = framework::EigenVector::Flatten(*dY); + auto dx = framework::EigenVector::Flatten(*dX); + auto place = context.GetEigenDevice(); + auto temp = ((x > -threshold) * (x < threshold)).template cast().eval(); + dx.device(place) = dy * (static_cast(1) - (-y).exp()) * temp; + } +}; + +template +class PowKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* X = context.Input("X"); + auto* Y = context.Output("Y"); + auto factor = static_cast(context.Attr("factor")); + Y->mutable_data(context.GetPlace()); + + auto x = framework::EigenVector::Flatten(*X); + auto y = framework::EigenVector::Flatten(*Y); + auto place = context.GetEigenDevice(); + y.device(place) = x.pow(factor); + } +}; + +template +class PowGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* X = context.Input("X"); + auto* dY = context.Input(framework::GradVarName("Y")); + auto* dX = context.Output(framework::GradVarName("X")); + auto factor = static_cast(context.Attr("factor")); + dX->mutable_data(context.GetPlace()); + + auto dy = framework::EigenVector::Flatten(*dY); + auto x = framework::EigenVector::Flatten(*X); + auto dx = framework::EigenVector::Flatten(*dX); + auto place = context.GetEigenDevice(); + + dx.device(place) = dy * factor * x.pow(factor - static_cast(1)); + } +}; + +template +class STanhKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* X = context.Input("X"); + auto* Y = context.Output("Y"); + auto scale_a = static_cast(context.Attr("scale_a")); + auto scale_b = static_cast(context.Attr("scale_b")); + Y->mutable_data(context.GetPlace()); + + auto x = framework::EigenVector::Flatten(*X); + auto y = framework::EigenVector::Flatten(*Y); + auto place = context.GetEigenDevice(); + y.device(place) = scale_b * (scale_a * x).tanh(); + } +}; + +template +class STanhGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* X = context.Input("X"); + auto* dY = context.Input(framework::GradVarName("Y")); + auto* dX = context.Output(framework::GradVarName("X")); + auto scale_a = static_cast(context.Attr("scale_a")); + auto scale_b = static_cast(context.Attr("scale_b")); + dX->mutable_data(context.GetPlace()); + + auto dy = framework::EigenVector::Flatten(*dY); + auto x = framework::EigenVector::Flatten(*X); + auto dx = framework::EigenVector::Flatten(*dX); + auto place = context.GetEigenDevice(); + + auto temp = (scale_a * x).tanh() * (scale_a * x).tanh(); + dx.device(place) = dy * scale_a * scale_b * (static_cast(1) - temp); } }; diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index bd964c5d07..28195b1b0a 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -55,8 +55,6 @@ USE_OP(squared_l2_distance); USE_OP(sum); USE_OP(reshape); USE_OP(sigmoid); -USE_OP(exp); -USE_OP(relu); namespace paddle { namespace framework { diff --git a/python/paddle/v2/framework/tests/op_test.py b/python/paddle/v2/framework/tests/op_test.py index 4fec4c9109..899d3ae991 100644 --- a/python/paddle/v2/framework/tests/op_test.py +++ b/python/paddle/v2/framework/tests/op_test.py @@ -196,7 +196,7 @@ class OpTest(unittest.TestCase): self.assertTrue( np.allclose( actual, expect, atol=1e-05), - "output name: " + out_name + "has diff") + "output name: " + out_name + " has diff") def check_output(self): places = [core.CPUPlace()] diff --git a/python/paddle/v2/framework/tests/test_activation_op.py b/python/paddle/v2/framework/tests/test_activation_op.py index 23ff584396..7cd39dfe91 100644 --- a/python/paddle/v2/framework/tests/test_activation_op.py +++ b/python/paddle/v2/framework/tests/test_activation_op.py @@ -21,7 +21,9 @@ class TestExp(OpTest): class TestRelu(OpTest): def setUp(self): self.op_type = "relu" - self.inputs = {'X': np.random.uniform(-1, 1, [4, 4]).astype("float32")} + x = np.random.uniform(-1, 1, [11, 17]).astype("float32") + x = np.sign(x) * np.exp(np.abs(x)) + self.inputs = {'X': x} self.outputs = {'Y': np.maximum(self.inputs['X'], 0)} def test_check_output(self): @@ -42,6 +44,167 @@ class TestSigmoid(OpTest): def test_check_output(self): self.check_output() + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.008) + + +class TestTanh(OpTest): + def setUp(self): + self.op_type = "tanh" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") + } + self.outputs = {'Y': np.tanh(self.inputs['X'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.007) + + +class TestSqrt(OpTest): + def setUp(self): + self.op_type = "sqrt" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") + } + self.outputs = {'Y': np.sqrt(self.inputs['X'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.007) + + +class TestAbs(OpTest): + def setUp(self): + self.op_type = "abs" + x = np.random.uniform(-1, 1, [11, 17]).astype("float32") + x = np.sign(x) * np.exp(np.abs(x)) + self.inputs = {'X': x} + self.outputs = {'Y': np.abs(self.inputs['X'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.007) + + +class TestReciprocal(OpTest): + def setUp(self): + self.op_type = "reciprocal" + self.inputs = {'X': np.random.uniform(1, 2, [11, 17]).astype("float32")} + self.outputs = {'Y': np.reciprocal(self.inputs['X'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.01) + + +class TestLog(OpTest): + def setUp(self): + self.op_type = "log" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") + } + self.outputs = {'Y': np.log(self.inputs['X'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.007) + + +class TestSquare(OpTest): + def setUp(self): + self.op_type = "square" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") + } + self.outputs = {'Y': np.square(self.inputs['X'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.007) + + +class TestBRelu(OpTest): + def setUp(self): + self.op_type = "brelu" + x = np.random.uniform(-1, 1, [4, 4]).astype("float32") + x = 2 * np.sign(x) * np.exp(np.abs(x)) + self.inputs = {'X': x} + t_min = 0 + t_max = 4 + self.attrs = {'t_min': t_min, 't_max': t_max} + t = np.copy(x) + t[t < t_min] = t_min + t[t > t_max] = t_max + self.outputs = {'Y': t} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.02) + + +class TestSoftRelu(OpTest): + def setUp(self): + self.op_type = "soft_relu" + x = np.random.uniform(-1, 1, [4, 4]).astype("float32") + x = 2 * np.sign(x) * np.exp(np.abs(x)) + self.inputs = {'X': x} + threshold = 4 + self.attrs = {'threshold': threshold} + t = np.copy(x) + t[t < -threshold] = -threshold + t[t > threshold] = threshold + self.outputs = {'Y': np.log((np.exp(t) + 1))} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.02) + + +class TestPow(OpTest): + def setUp(self): + self.op_type = "pow" + self.inputs = {'X': np.random.uniform(1, 2, [11, 17]).astype("float32")} + self.attrs = {'factor': 3} + self.outputs = {'Y': np.power(self.inputs['X'], 3)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.02) + + +class TestSTanh(OpTest): + def setUp(self): + self.op_type = "stanh" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") + } + scale_a = 2.0 / 3.0 + scale_b = 1.7159 + self.attrs = {'scale_a': scale_a, 'scale_b': scale_b} + self.outputs = {'Y': scale_b * np.tanh(self.inputs['X'] * scale_a)} + + def test_check_output(self): + self.check_output() + def test_check_grad(self): self.check_grad(['X'], 'Y', max_relative_error=0.007) From 41271f03cb609a9a772c3ff720a011ff3b1a1b93 Mon Sep 17 00:00:00 2001 From: qijun Date: Thu, 14 Sep 2017 19:36:52 +0800 Subject: [PATCH 11/13] fix gpu build error --- paddle/operators/activation_op.cu | 56 ++++++------------- .../paddle/trainer_config_helpers/networks.py | 4 +- 2 files changed, 20 insertions(+), 40 deletions(-) diff --git a/paddle/operators/activation_op.cu b/paddle/operators/activation_op.cu index 112b33d225..feed1302b2 100644 --- a/paddle/operators/activation_op.cu +++ b/paddle/operators/activation_op.cu @@ -19,10 +19,10 @@ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL(sigmoid, ops::ActivationKernel); + ops::SigmoidFunctor>); REGISTER_OP_GPU_KERNEL( sigmoid_grad, ops::ActivationGradKernel); + ops::SigmoidGradFunctor>); REGISTER_OP_GPU_KERNEL( exp, @@ -37,35 +37,27 @@ REGISTER_OP_GPU_KERNEL( relu_grad, ops::ActivationGradKernel>); -REGISTER_OP(tanh, ops::ActivationOp, ops::TanhOpMaker, tanh_grad, - ops::ActivationOpGrad); -REGISTER_OP_GPU_KERNEL(tanh, - ops::ActivationKernel>); +REGISTER_OP_GPU_KERNEL( + tanh, + ops::ActivationKernel); REGISTER_OP_GPU_KERNEL( tanh_grad, ops::ActivationGradKernel>); -REGISTER_OP(sqrt, ops::ActivationOp, ops::SqrtOpMaker, sqrt_grad, - ops::ActivationOpGrad); -REGISTER_OP_GPU_KERNEL(sqrt, - ops::ActivationKernel>); +REGISTER_OP_GPU_KERNEL( + sqrt, + ops::ActivationKernel); REGISTER_OP_GPU_KERNEL( sqrt_grad, ops::ActivationGradKernel>); -REGISTER_OP(abs, ops::ActivationOp, ops::AbsOpMaker, abs_grad, - ops::ActivationOpGrad); -REGISTER_OP_GPU_KERNEL(abs, - ops::ActivationKernel>); REGISTER_OP_GPU_KERNEL( - abs_grad, ops::ActivationGradKernel>); + abs, + ops::ActivationKernel); +REGISTER_OP_GPU_KERNEL(abs_grad, + ops::ActivationGradKernel); -REGISTER_OP(reciprocal, ops::ActivationOp, ops::ReciprocalOpMaker, - reciprocal_grad, ops::ActivationOpGrad); REGISTER_OP_GPU_KERNEL(reciprocal, ops::ActivationKernel>); @@ -74,47 +66,35 @@ REGISTER_OP_GPU_KERNEL( ops::ActivationGradKernel>); -REGISTER_OP(log, ops::ActivationOp, ops::LogOpMaker, log_grad, - ops::ActivationOpGrad); -REGISTER_OP_GPU_KERNEL(log, - ops::ActivationKernel>); +REGISTER_OP_GPU_KERNEL( + log, + ops::ActivationKernel); REGISTER_OP_GPU_KERNEL( log_grad, ops::ActivationGradKernel>); -REGISTER_OP(square, ops::ActivationOp, ops::SquareOpMaker, square_grad, - ops::ActivationOpGrad); REGISTER_OP_GPU_KERNEL(square, ops::ActivationKernel>); + ops::SquareFunctor>); REGISTER_OP_GPU_KERNEL( square_grad, ops::ActivationGradKernel>); -REGISTER_OP(brelu, ops::ActivationOp, ops::BReluOpMaker, brelu_grad, - ops::ActivationOpGrad); REGISTER_OP_GPU_KERNEL(brelu, ops::BReluKernel); REGISTER_OP_GPU_KERNEL(brelu_grad, ops::BReluGradKernel); -REGISTER_OP(soft_relu, ops::ActivationOp, ops::SoftReluOpMaker, - soft_relu_grad, ops::ActivationOpGrad); REGISTER_OP_GPU_KERNEL(soft_relu, ops::SoftReluKernel); REGISTER_OP_GPU_KERNEL( soft_relu_grad, ops::SoftReluGradKernel); -REGISTER_OP(pow, ops::ActivationOp, ops::PowOpMaker, pow_grad, - ops::ActivationOpGrad); REGISTER_OP_GPU_KERNEL(pow, ops::PowKernel); REGISTER_OP_GPU_KERNEL(pow_grad, ops::PowGradKernel); -REGISTER_OP(stanh, ops::ActivationOp, ops::STanhOpMaker, stanh_grad, - ops::ActivationOpGrad); REGISTER_OP_GPU_KERNEL(stanh, ops::STanhKernel); REGISTER_OP_GPU_KERNEL(stanh_grad, - ops::STanhGradKernel); \ No newline at end of file + ops::STanhGradKernel); diff --git a/python/paddle/trainer_config_helpers/networks.py b/python/paddle/trainer_config_helpers/networks.py index 34be203ee2..28a71cf788 100644 --- a/python/paddle/trainer_config_helpers/networks.py +++ b/python/paddle/trainer_config_helpers/networks.py @@ -1406,7 +1406,7 @@ def inputs(layers, *args): if len(args) != 0: layers.extend(args) - Inputs(*[l.name for l in layers]) + Inputs(* [l.name for l in layers]) def outputs(layers, *args): @@ -1456,7 +1456,7 @@ def outputs(layers, *args): assert len(layers) > 0 if HasInputsSet(): # input already set - Outputs(*[l.name for l in layers]) + Outputs(* [l.name for l in layers]) return # just return outputs. if len(layers) != 1: From 34ecfcad4a182f8d5c5feae03f290242adcbc313 Mon Sep 17 00:00:00 2001 From: qijun Date: Fri, 15 Sep 2017 03:20:36 +0000 Subject: [PATCH 12/13] fix code style --- python/paddle/trainer_config_helpers/networks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/trainer_config_helpers/networks.py b/python/paddle/trainer_config_helpers/networks.py index 28a71cf788..34be203ee2 100644 --- a/python/paddle/trainer_config_helpers/networks.py +++ b/python/paddle/trainer_config_helpers/networks.py @@ -1406,7 +1406,7 @@ def inputs(layers, *args): if len(args) != 0: layers.extend(args) - Inputs(* [l.name for l in layers]) + Inputs(*[l.name for l in layers]) def outputs(layers, *args): @@ -1456,7 +1456,7 @@ def outputs(layers, *args): assert len(layers) > 0 if HasInputsSet(): # input already set - Outputs(* [l.name for l in layers]) + Outputs(*[l.name for l in layers]) return # just return outputs. if len(layers) != 1: From 48f5f6bdd071736df63d7bdcf6a3740c8ae06240 Mon Sep 17 00:00:00 2001 From: qijun Date: Fri, 15 Sep 2017 11:23:19 +0800 Subject: [PATCH 13/13] refine some operators' python unittests --- .../v2/framework/tests/test_activation_op.py | 124 ++++++++++-------- 1 file changed, 67 insertions(+), 57 deletions(-) diff --git a/python/paddle/v2/framework/tests/test_activation_op.py b/python/paddle/v2/framework/tests/test_activation_op.py index 7cd39dfe91..003f6d50b6 100644 --- a/python/paddle/v2/framework/tests/test_activation_op.py +++ b/python/paddle/v2/framework/tests/test_activation_op.py @@ -18,21 +18,6 @@ class TestExp(OpTest): self.check_grad(['X'], 'Y', max_relative_error=0.007) -class TestRelu(OpTest): - def setUp(self): - self.op_type = "relu" - x = np.random.uniform(-1, 1, [11, 17]).astype("float32") - x = np.sign(x) * np.exp(np.abs(x)) - self.inputs = {'X': x} - self.outputs = {'Y': np.maximum(self.inputs['X'], 0)} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(['X'], 'Y', max_relative_error=0.007) - - class TestSigmoid(OpTest): def setUp(self): self.op_type = "sigmoid" @@ -81,8 +66,12 @@ class TestSqrt(OpTest): class TestAbs(OpTest): def setUp(self): self.op_type = "abs" - x = np.random.uniform(-1, 1, [11, 17]).astype("float32") - x = np.sign(x) * np.exp(np.abs(x)) + x = np.random.uniform(-1, 1, [4, 4]).astype("float32") + # Because we set delta = 0.005 in caculating numeric gradient, + # if x is too small, such as 0.002, x_neg will be -0.003 + # x_pos will be 0.007, so the numeric gradient is unaccurate. + # we should avoid this + x[np.abs(x) < 0.005] = 0.02 self.inputs = {'X': x} self.outputs = {'Y': np.abs(self.inputs['X'])} @@ -93,41 +82,14 @@ class TestAbs(OpTest): self.check_grad(['X'], 'Y', max_relative_error=0.007) -class TestReciprocal(OpTest): - def setUp(self): - self.op_type = "reciprocal" - self.inputs = {'X': np.random.uniform(1, 2, [11, 17]).astype("float32")} - self.outputs = {'Y': np.reciprocal(self.inputs['X'])} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(['X'], 'Y', max_relative_error=0.01) - - -class TestLog(OpTest): - def setUp(self): - self.op_type = "log" - self.inputs = { - 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") - } - self.outputs = {'Y': np.log(self.inputs['X'])} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(['X'], 'Y', max_relative_error=0.007) - - -class TestSquare(OpTest): +class TestRelu(OpTest): def setUp(self): - self.op_type = "square" - self.inputs = { - 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") - } - self.outputs = {'Y': np.square(self.inputs['X'])} + self.op_type = "relu" + x = np.random.uniform(-1, 1, [11, 17]).astype("float32") + # The same reason with TestAbs + x[np.abs(x) < 0.005] = 0.02 + self.inputs = {'X': x} + self.outputs = {'Y': np.maximum(self.inputs['X'], 0)} def test_check_output(self): self.check_output() @@ -140,10 +102,13 @@ class TestBRelu(OpTest): def setUp(self): self.op_type = "brelu" x = np.random.uniform(-1, 1, [4, 4]).astype("float32") - x = 2 * np.sign(x) * np.exp(np.abs(x)) - self.inputs = {'X': x} - t_min = 0 + t_min = 1 t_max = 4 + # The same with TestAbs + x[np.abs(x - t_min) < 0.005] = t_min + 0.02 + x[np.abs(x - t_max) < 0.005] = t_min + 0.02 + + self.inputs = {'X': x} self.attrs = {'t_min': t_min, 't_max': t_max} t = np.copy(x) t[t < t_min] = t_min @@ -160,10 +125,12 @@ class TestBRelu(OpTest): class TestSoftRelu(OpTest): def setUp(self): self.op_type = "soft_relu" - x = np.random.uniform(-1, 1, [4, 4]).astype("float32") - x = 2 * np.sign(x) * np.exp(np.abs(x)) + x = np.random.uniform(-3, 3, [4, 4]).astype("float32") + threshold = 2 + # The same reason with TestAbs + x[np.abs(x - threshold) < 0.005] = threshold + 0.02 + x[np.abs(x + threshold) < 0.005] = -threshold + 0.02 self.inputs = {'X': x} - threshold = 4 self.attrs = {'threshold': threshold} t = np.copy(x) t[t < -threshold] = -threshold @@ -177,6 +144,49 @@ class TestSoftRelu(OpTest): self.check_grad(['X'], 'Y', max_relative_error=0.02) +class TestReciprocal(OpTest): + def setUp(self): + self.op_type = "reciprocal" + self.inputs = {'X': np.random.uniform(1, 2, [11, 17]).astype("float32")} + self.outputs = {'Y': np.reciprocal(self.inputs['X'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.01) + + +class TestLog(OpTest): + def setUp(self): + self.op_type = "log" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") + } + self.outputs = {'Y': np.log(self.inputs['X'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.007) + + +class TestSquare(OpTest): + def setUp(self): + self.op_type = "square" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") + } + self.outputs = {'Y': np.square(self.inputs['X'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.007) + + class TestPow(OpTest): def setUp(self): self.op_type = "pow"