From 1a968b4f64567d1281dd278a6b412cd823663e43 Mon Sep 17 00:00:00 2001 From: qijun Date: Mon, 10 Jul 2017 20:39:48 +0800 Subject: [PATCH 01/10] init --- paddle/framework/ddim.h | 10 ++++ paddle/framework/tensor.h | 27 ++++++++-- paddle/framework/tensor_types.h | 91 +++++++++++++++++++++++++++++++++ 3 files changed, 124 insertions(+), 4 deletions(-) create mode 100644 paddle/framework/tensor_types.h diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h index 223c4180be..053a09d63a 100644 --- a/paddle/framework/ddim.h +++ b/paddle/framework/ddim.h @@ -6,6 +6,7 @@ #include #include "paddle/framework/dim.h" +#include "unsupported/Eigen/CXX11/Tensor" namespace paddle { namespace framework { @@ -91,6 +92,15 @@ int arity(const DDim& ddim); std::ostream& operator<<(std::ostream&, const DDim&); +template +Eigen::DSizes ToEigenDSizes(DDim dims) const { + Eigen::DSizes dsizes; + for (int d = 0; d < paddle::framework::arity(dims); d++) { + dsizes[d] = dims[d]; + } + return dsizes; +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index ce5d98b04e..81af430611 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -18,8 +18,10 @@ limitations under the License. */ #include #include "paddle/framework/ddim.h" #include "paddle/framework/enforce.h" +#include "paddle/framework/tensor_types.h" #include "paddle/memory/memory.h" #include "paddle/platform/place.h" +#include "unsupported/Eigen/CXX11/Tensor" namespace paddle { namespace framework { @@ -33,6 +35,13 @@ class Tensor { return static_cast(holder_->Ptr()); } + template + T* data() const { + PADDLE_ENFORCE(holder_ != nullptr, + "Tensor::data must be called after Tensor::mutable_data."); + return static_cast(holder_->Ptr()); + } + template ::value>::type* = nullptr> T* mutable_data(DDim dims, paddle::platform::Place place) { @@ -41,14 +50,23 @@ class Tensor { place) /* some versions of boost::variant don't have operator!= */ || holder_->Size() < product(dims) * sizeof(T)) { holder_.reset(new PlaceholderImpl(place, product(dims) * sizeof(T))); + dims_ = dims; } return static_cast(holder_->Ptr()); } - template ::value>::type* = nullptr> - T* mutable_data(DDim dims) { - return mutable_data(dims, paddle::platform::get_place()); + DDim dim() const { return dims_; } + + template + typename TTypes::ConstantTensor Tensor::tensor() { + return typename TTypes::Tensor( + data(), paddle::framework::ToEigenDSizes(dims_)); + } + + template + typename TTypes::Tensor Tensor::tensor() { + return typename TTypes::Tensor( + data(), paddle::framework::ToEigenDSizes(dims_)); } private: @@ -92,6 +110,7 @@ class Tensor { }; std::shared_ptr holder_; // holds the memory block if allocated. + DDim dims_; }; } // namespace framework diff --git a/paddle/framework/tensor_types.h b/paddle/framework/tensor_types.h new file mode 100644 index 0000000000..b68697108c --- /dev/null +++ b/paddle/framework/tensor_types.h @@ -0,0 +1,91 @@ +#pragma once + +#include "unsupported/Eigen/CXX11/Tensor" + +namespace paddle { +namespace framework { + +// Helper to define Tensor types given that the scalar is of type T. +template +struct TTypes { + // Rank- tensor of scalar type T. + typedef Eigen::TensorMap, + Eigen::Aligned> + Tensor; + typedef Eigen::TensorMap< + Eigen::Tensor, Eigen::Aligned> + ConstTensor; + + // Unaligned Rank- tensor of scalar type T. + typedef Eigen::TensorMap> + UnalignedTensor; + typedef Eigen::TensorMap< + Eigen::Tensor> + UnalignedConstTensor; + + typedef Eigen::TensorMap, + Eigen::Aligned> + Tensor32Bit; + + // Scalar tensor (implemented as a rank-0 tensor) of scalar type T. + typedef Eigen::TensorMap< + Eigen::TensorFixedSize, Eigen::RowMajor, IndexType>, + Eigen::Aligned> + Scalar; + typedef Eigen::TensorMap, + Eigen::RowMajor, IndexType>, + Eigen::Aligned> + ConstScalar; + + // Unaligned Scalar tensor of scalar type T. + typedef Eigen::TensorMap< + Eigen::TensorFixedSize, Eigen::RowMajor, IndexType>> + UnalignedScalar; + typedef Eigen::TensorMap, + Eigen::RowMajor, IndexType>> + UnalignedConstScalar; + + // Rank-1 tensor (vector) of scalar type T. + typedef Eigen::TensorMap, + Eigen::Aligned> + Flat; + typedef Eigen::TensorMap< + Eigen::Tensor, Eigen::Aligned> + ConstFlat; + typedef Eigen::TensorMap, + Eigen::Aligned> + Vec; + typedef Eigen::TensorMap< + Eigen::Tensor, Eigen::Aligned> + ConstVec; + + // Unaligned Rank-1 tensor (vector) of scalar type T. + typedef Eigen::TensorMap> + UnalignedFlat; + typedef Eigen::TensorMap< + Eigen::Tensor> + UnalignedConstFlat; + typedef Eigen::TensorMap> + UnalignedVec; + typedef Eigen::TensorMap< + Eigen::Tensor> + UnalignedConstVec; + + // Rank-2 tensor (matrix) of scalar type T. + typedef Eigen::TensorMap, + Eigen::Aligned> + Matrix; + typedef Eigen::TensorMap< + Eigen::Tensor, Eigen::Aligned> + ConstMatrix; + + // Unaligned Rank-2 tensor (matrix) of scalar type T. + typedef Eigen::TensorMap> + UnalignedMatrix; + typedef Eigen::TensorMap< + Eigen::Tensor> + UnalignedConstMatrix; +}; + +} // namespace framework +} // namespace paddle From d6f7c3535d0907af4e2d955451e9a872d6b857a3 Mon Sep 17 00:00:00 2001 From: qijun Date: Tue, 11 Jul 2017 12:52:07 +0800 Subject: [PATCH 02/10] move unaligned tensor types --- paddle/framework/tensor_types.h | 38 --------------------------------- 1 file changed, 38 deletions(-) diff --git a/paddle/framework/tensor_types.h b/paddle/framework/tensor_types.h index b68697108c..26de25b7c2 100644 --- a/paddle/framework/tensor_types.h +++ b/paddle/framework/tensor_types.h @@ -16,17 +16,6 @@ struct TTypes { Eigen::Tensor, Eigen::Aligned> ConstTensor; - // Unaligned Rank- tensor of scalar type T. - typedef Eigen::TensorMap> - UnalignedTensor; - typedef Eigen::TensorMap< - Eigen::Tensor> - UnalignedConstTensor; - - typedef Eigen::TensorMap, - Eigen::Aligned> - Tensor32Bit; - // Scalar tensor (implemented as a rank-0 tensor) of scalar type T. typedef Eigen::TensorMap< Eigen::TensorFixedSize, Eigen::RowMajor, IndexType>, @@ -37,14 +26,6 @@ struct TTypes { Eigen::Aligned> ConstScalar; - // Unaligned Scalar tensor of scalar type T. - typedef Eigen::TensorMap< - Eigen::TensorFixedSize, Eigen::RowMajor, IndexType>> - UnalignedScalar; - typedef Eigen::TensorMap, - Eigen::RowMajor, IndexType>> - UnalignedConstScalar; - // Rank-1 tensor (vector) of scalar type T. typedef Eigen::TensorMap, Eigen::Aligned> @@ -59,18 +40,6 @@ struct TTypes { Eigen::Tensor, Eigen::Aligned> ConstVec; - // Unaligned Rank-1 tensor (vector) of scalar type T. - typedef Eigen::TensorMap> - UnalignedFlat; - typedef Eigen::TensorMap< - Eigen::Tensor> - UnalignedConstFlat; - typedef Eigen::TensorMap> - UnalignedVec; - typedef Eigen::TensorMap< - Eigen::Tensor> - UnalignedConstVec; - // Rank-2 tensor (matrix) of scalar type T. typedef Eigen::TensorMap, Eigen::Aligned> @@ -78,13 +47,6 @@ struct TTypes { typedef Eigen::TensorMap< Eigen::Tensor, Eigen::Aligned> ConstMatrix; - - // Unaligned Rank-2 tensor (matrix) of scalar type T. - typedef Eigen::TensorMap> - UnalignedMatrix; - typedef Eigen::TensorMap< - Eigen::Tensor> - UnalignedConstMatrix; }; } // namespace framework From 958511160bc42fee48c9ad775dfb08e5198bf3e9 Mon Sep 17 00:00:00 2001 From: qijun Date: Tue, 11 Jul 2017 13:40:44 +0800 Subject: [PATCH 03/10] add simple add_op_functor --- paddle/framework/ddim.cc | 12 ++++++++ paddle/framework/ddim.h | 8 +----- paddle/framework/tensor.h | 47 +++++++++++++++++++++++++++++-- paddle/framework/tensor_types.h | 14 +++++++++ paddle/operators/add_op_functor.h | 35 +++++++++++++++++++++++ 5 files changed, 107 insertions(+), 9 deletions(-) create mode 100644 paddle/operators/add_op_functor.h diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc index 3f949a6595..9431645cf5 100644 --- a/paddle/framework/ddim.cc +++ b/paddle/framework/ddim.cc @@ -1,4 +1,5 @@ #include "paddle/framework/ddim.h" +#include "paddle/framework/enforce.h" namespace paddle { namespace framework { @@ -220,5 +221,16 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) { return os; } +template +Eigen::DSizes ToEigenDSizes(DDim dims) const { + int rank = paddle::framework::arity(dims); + PADDLE_ENFORCE(rank == NDIMS, "DDim and NDIMS must be same") + Eigen::DSizes dsizes; + for (int d = 0; d < paddle::framework::arity(dims); d++) { + dsizes[d] = dims[d]; + } + return dsizes; +} + } // namespace framework } // namespace paddle diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h index 053a09d63a..a83a367196 100644 --- a/paddle/framework/ddim.h +++ b/paddle/framework/ddim.h @@ -93,13 +93,7 @@ int arity(const DDim& ddim); std::ostream& operator<<(std::ostream&, const DDim&); template -Eigen::DSizes ToEigenDSizes(DDim dims) const { - Eigen::DSizes dsizes; - for (int d = 0; d < paddle::framework::arity(dims); d++) { - dsizes[d] = dims[d]; - } - return dsizes; -} +Eigen::DSizes ToEigenDSizes(DDim dims) const; } // namespace framework } // namespace paddle diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 81af430611..0fa74e7ab1 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -57,18 +57,61 @@ class Tensor { DDim dim() const { return dims_; } + size_t NumElements() const { return product(dims_); } + + template + typename TTypes::Tensor Tensor::shaped(DDim new_dims) { + Eigen::array dims = + paddle::framework::ToEigenDSizes(new_dims); + return typename TTypes::Tensor(data(), dims); + } + template - typename TTypes::ConstantTensor Tensor::tensor() { + typename TTypes::Tensor Tensor::tensor() { return typename TTypes::Tensor( data(), paddle::framework::ToEigenDSizes(dims_)); } + // flat to rank = 1 + template + typename TTypes::Flat flat() { + return shaped({NumElements()}); + } + + // to TensorType Vec + template + typename TTypes::Vec vec() { + return tensor(); + } + + // to TensorType Matrix + template + typename TTypes::Matrix matrix() { + return tensor(); + } + + // const versions of all the methods above. template - typename TTypes::Tensor Tensor::tensor() { + typename TTypes::ConstantTensor Tensor::tensor() const { return typename TTypes::Tensor( data(), paddle::framework::ToEigenDSizes(dims_)); } + template + typename TTypes::ConstFlat flat() const { + return shaped({NumElements()}); + } + + template + typename TTypes::ConstVec vec() const { + return tensor(); + } + + template + typename TTypes::ConstMatrix matrix() const { + return tensor(); + } + private: // Placeholder hides type T, so it doesn't appear as a template // parameter of Variable. diff --git a/paddle/framework/tensor_types.h b/paddle/framework/tensor_types.h index 26de25b7c2..4bf27a377e 100644 --- a/paddle/framework/tensor_types.h +++ b/paddle/framework/tensor_types.h @@ -1,3 +1,17 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + #pragma once #include "unsupported/Eigen/CXX11/Tensor" diff --git a/paddle/operators/add_op_functor.h b/paddle/operators/add_op_functor.h new file mode 100644 index 0000000000..904f24b030 --- /dev/null +++ b/paddle/operators/add_op_functor.h @@ -0,0 +1,35 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/tensor_types.h" +#include "unsupported/Eigen/CXX11/Tensor" + +namespace paddle { +namespace operators { +namespace functor { + +template +struct Add { + void Operator()(const Device& d, + typename TTypes::ConstTensor input1, + typename TTypes::ConstTensor input2, + typename TTypes::Tensor output) { + output.device(d) = input1 + input2; + } +}; +} // namespace functor +} // namespace operators +} // namespace paddle From d607f0b70308c61e5399773a475b8e8c640e63c1 Mon Sep 17 00:00:00 2001 From: qijun Date: Tue, 11 Jul 2017 14:15:45 +0800 Subject: [PATCH 04/10] use cached rank --- paddle/framework/ddim.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc index 9431645cf5..3fd3e538e8 100644 --- a/paddle/framework/ddim.cc +++ b/paddle/framework/ddim.cc @@ -226,7 +226,7 @@ Eigen::DSizes ToEigenDSizes(DDim dims) const { int rank = paddle::framework::arity(dims); PADDLE_ENFORCE(rank == NDIMS, "DDim and NDIMS must be same") Eigen::DSizes dsizes; - for (int d = 0; d < paddle::framework::arity(dims); d++) { + for (int d = 0; d < rank; d++) { dsizes[d] = dims[d]; } return dsizes; From a5eb1d8fabe15b1e59e09fb0ed18de76de8843be Mon Sep 17 00:00:00 2001 From: qijun Date: Tue, 11 Jul 2017 16:21:19 +0800 Subject: [PATCH 05/10] fix build error --- paddle/framework/CMakeLists.txt | 2 +- paddle/framework/ddim.cc | 6 +++--- paddle/framework/ddim.h | 2 +- paddle/framework/tensor.h | 15 ++++----------- 4 files changed, 9 insertions(+), 16 deletions(-) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 4409c6feae..2d26a62d0f 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -1,5 +1,5 @@ # ddim lib -cc_library(ddim SRCS ddim.cc) +cc_library(ddim SRCS ddim.cc DEPS eigen3) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim) cc_test(tensor_test SRCS tensor_test.cc DEPS ddim) diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc index 3fd3e538e8..fe8f79abd4 100644 --- a/paddle/framework/ddim.cc +++ b/paddle/framework/ddim.cc @@ -222,9 +222,9 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) { } template -Eigen::DSizes ToEigenDSizes(DDim dims) const { - int rank = paddle::framework::arity(dims); - PADDLE_ENFORCE(rank == NDIMS, "DDim and NDIMS must be same") +Eigen::DSizes ToEigenDSizes(const DDim& dims) { + int rank = arity(dims); + PADDLE_ENFORCE(rank == NDIMS, "DDim and NDIMS must be same"); Eigen::DSizes dsizes; for (int d = 0; d < rank; d++) { dsizes[d] = dims[d]; diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h index a83a367196..18395c3636 100644 --- a/paddle/framework/ddim.h +++ b/paddle/framework/ddim.h @@ -93,7 +93,7 @@ int arity(const DDim& ddim); std::ostream& operator<<(std::ostream&, const DDim&); template -Eigen::DSizes ToEigenDSizes(DDim dims) const; +Eigen::DSizes ToEigenDSizes(const DDim& dims); } // namespace framework } // namespace paddle diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 0fa74e7ab1..21818937e8 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -28,13 +28,6 @@ namespace framework { class Tensor { public: - template - const T* data() const { - PADDLE_ENFORCE(holder_ != nullptr, - "Tensor::data must be called after Tensor::mutable_data."); - return static_cast(holder_->Ptr()); - } - template T* data() const { PADDLE_ENFORCE(holder_ != nullptr, @@ -60,14 +53,14 @@ class Tensor { size_t NumElements() const { return product(dims_); } template - typename TTypes::Tensor Tensor::shaped(DDim new_dims) { + typename TTypes::Tensor shaped(DDim new_dims) { Eigen::array dims = - paddle::framework::ToEigenDSizes(new_dims); + paddle::framework::ToEigenDSizes(new_dims); return typename TTypes::Tensor(data(), dims); } template - typename TTypes::Tensor Tensor::tensor() { + typename TTypes::Tensor tensor() { return typename TTypes::Tensor( data(), paddle::framework::ToEigenDSizes(dims_)); } @@ -92,7 +85,7 @@ class Tensor { // const versions of all the methods above. template - typename TTypes::ConstantTensor Tensor::tensor() const { + typename TTypes::ConstantTensor tensor() const { return typename TTypes::Tensor( data(), paddle::framework::ToEigenDSizes(dims_)); } From bac1426d47727a9ea101dd42135a0800c2c5e023 Mon Sep 17 00:00:00 2001 From: qijun Date: Fri, 14 Jul 2017 16:57:03 +0800 Subject: [PATCH 06/10] add_op kernel implementation --- paddle/framework/operator.cc | 12 +++++++ paddle/framework/operator.h | 67 +++++++++++++++++++++++------------- paddle/framework/tensor.h | 16 ++++++++- paddle/operators/add_op.cc | 11 +++--- paddle/operators/add_op.cu | 8 +++-- paddle/operators/add_op.h | 21 +++++++---- 6 files changed, 97 insertions(+), 38 deletions(-) diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 8f7adff8b3..25d120c9a9 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -17,6 +17,18 @@ limitations under the License. */ namespace paddle { namespace framework { +template <> +DeviceType* KernelContext::get_eigen_device() { + return device_context_.get_eigen_device(); +} + +#ifndef PADDLE_ONLY_CPU +template <> +DeviceType* KernelContext::get_eigen_device() { + return device_context_.get_eigen_device(); +} +#endif + std::string OperatorBase::DebugString() const { std::stringstream ss; ss << "=================\n"; diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index d3c55e0ceb..48cfeeb731 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -29,6 +29,21 @@ limitations under the License. */ namespace paddle { namespace framework { +template +struct EigenDeviceConverter; + +template <> +struct EigenDeviceConverter { + using EigenDeviceType = Eigen::DefaultDevice; +}; + +#ifndef PADDLE_ONLY_CPU +template <> +struct EigenDeviceConverter { + using EigenDeviceType = Eigen::GpuDevice; +}; +#endif + class OperatorBase; /** @@ -72,33 +87,39 @@ class OperatorBase { AttributeMap attrs_; }; -class OpKernel { +/** + * KernelContext is the only parameter of Kernel Run function. + * Run will get input/output variables, state such as momentum and + * device resource such as CUDA stream, cublas handle, etc. from + * KernelContext. User should construct it before run the Operator. + */ +class KernelContext { public: - /** - * KernelContext is the only parameter of Kernel Run function. - * Run will get input/output variables, state such as momentum and - * device resource such as CUDA stream, cublas handle, etc. from - * KernelContext. User should construct it before run the Operator. - */ - class KernelContext { - public: - KernelContext(const OperatorBase* op, const std::shared_ptr& scope, - const platform::DeviceContext& device_context) - : op_(*op), scope_(scope), device_context_(device_context) {} - - const Variable* Input(int index) const { - return scope_->GetVariable(op_.inputs_[index]); - } + KernelContext(const OperatorBase* op, const std::shared_ptr& scope, + const platform::DeviceContext& device_context) + : op_(*op), scope_(scope), device_context_(device_context) {} - Variable* Output(int index) const { - return scope_->GetVariable(op_.outputs_[index]); - } + const Variable* Input(int index) const { + return scope_->GetVariable(op_.inputs_[index]); + } - const OperatorBase& op_; - const std::shared_ptr& scope_; - const platform::DeviceContext& device_context_; - }; + Variable* Output(int index) const { + return scope_->GetVariable(op_.outputs_[index]); + } + + platform::DeviceContext& device_context() const { return device_context_; } + template ::EigenDeviceType> + DeviceType* get_eigen_device(); + + const OperatorBase& op_; + const std::shared_ptr& scope_; + const platform::DeviceContext& device_context_; +}; + +class OpKernel { + public: virtual void Compute(const KernelContext& context) const = 0; virtual ~OpKernel() {} diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index e14b75d0e0..01244f617c 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -35,7 +35,7 @@ class Tensor { template - T* data() const { + const T* data() const { PADDLE_ENFORCE( holder_ != nullptr, "Tenosr has not been initialized. Call Tensor::mutable_data first."); @@ -58,6 +58,20 @@ class Tensor { offset_); } + template ::value>::type* = nullptr> + T* mutable_data(paddle::platform::Place place) { + if (holder_ == nullptr || + !(holder_->Place() == + place) /* some versions of boost::variant don't have operator!= */ + || holder_->Size() < product(dims_) * sizeof(T) + offset_) { + holder_.reset(new PlaceholderImpl(place, product(dims_) * sizeof(T))); + offset_ = 0; + } + return reinterpret_cast(reinterpret_cast(holder_->Ptr()) + + offset_); + } + size_t NumElements() const { return product(dims_); } template diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc index 2766f0bf25..ef39e426fd 100644 --- a/paddle/operators/add_op.cc +++ b/paddle/operators/add_op.cc @@ -1,6 +1,6 @@ -#include -#include -#include +#include "paddle/operators/add_op.h" +#include "paddle/framework/op_registry.h" +#include "paddle/framework/tensor.h" namespace paddle { namespace operators { @@ -36,9 +36,10 @@ The equation is: Out = X + Y )DOC"); } }; -} // namespace op +} // namespace operators } // namespace paddle REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker); REGISTER_OP_CPU_KERNEL( - add_two, ::paddle::operators::AddKernel<::paddle::platform::CPUPlace>); \ No newline at end of file + add_two, + ::paddle::operators::AddKernel<::paddle::platform::CPUPlace, float>); \ No newline at end of file diff --git a/paddle/operators/add_op.cu b/paddle/operators/add_op.cu index 5979345fff..f4a4fb16a6 100644 --- a/paddle/operators/add_op.cu +++ b/paddle/operators/add_op.cu @@ -1,5 +1,7 @@ -#include -#include +#define EIGEN_USE_GPU + +#include "paddle/operators/add_op.h" +#include "paddle/framework/op_registry.h" REGISTER_OP_GPU_KERNEL(add_two, - paddle::operators::AddKernel); \ No newline at end of file + paddle::operators::AddKernel); \ No newline at end of file diff --git a/paddle/operators/add_op.h b/paddle/operators/add_op.h index 17d459dbc8..27a477a3ac 100644 --- a/paddle/operators/add_op.h +++ b/paddle/operators/add_op.h @@ -1,17 +1,26 @@ #pragma once -#include -#include +#include "glog/logging.h" +#include "paddle/framework/operator.h" +//#include "paddle/operators/add_op_functor.h" namespace paddle { namespace operators { -template +// Place can be CPUPlace or GPUPlace +template class AddKernel : public framework::OpKernel { public: - void Compute(const KernelContext &context) const override { - LOG(INFO) << "Add kernel in " << typeid(Place).name(); + void Compute(const KernelContext& context) const override { + auto* input0 = context.Input(0); + auto* input1 = context.Input(1); + + auto* output = context.Output(0); + output->mutable_data(Place()); + + output->flat().device(*(context.get_eigen_device())) = + input0->flat() + input1->flat(); } }; -} // namespace op +} // namespace operators } // namespace paddle From d649dbf442bd7ba4ce63a2a4e479a27c8d40ca8d Mon Sep 17 00:00:00 2001 From: qijun Date: Mon, 17 Jul 2017 09:40:06 +0800 Subject: [PATCH 07/10] implement add_op kernel --- paddle/framework/operator.cc | 8 +++-- paddle/framework/operator.h | 59 +++++++++++++++---------------- paddle/framework/tensor.h | 6 ++-- paddle/operators/add_op.cc | 6 ++-- paddle/operators/add_op.cu | 5 ++- paddle/operators/add_op.h | 13 ++++--- paddle/platform/device_context.cc | 9 ++--- paddle/platform/device_context.h | 13 +++---- 8 files changed, 58 insertions(+), 61 deletions(-) diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 25d120c9a9..3c6376c150 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -18,13 +18,15 @@ namespace paddle { namespace framework { template <> -DeviceType* KernelContext::get_eigen_device() { - return device_context_.get_eigen_device(); +Eigen::DefaultDevice* OpKernel::KernelContext::get_eigen_device< + platform::CPUPlace, Eigen::DefaultDevice>() const { + return device_context_.get_eigen_device(); } #ifndef PADDLE_ONLY_CPU template <> -DeviceType* KernelContext::get_eigen_device() { +DeviceType* OpKernel::KernelContext::get_eigen_device() + const { return device_context_.get_eigen_device(); } #endif diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 48cfeeb731..558d4a0b67 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -33,13 +33,13 @@ template struct EigenDeviceConverter; template <> -struct EigenDeviceConverter { +struct EigenDeviceConverter { using EigenDeviceType = Eigen::DefaultDevice; }; #ifndef PADDLE_ONLY_CPU template <> -struct EigenDeviceConverter { +struct EigenDeviceConverter { using EigenDeviceType = Eigen::GpuDevice; }; #endif @@ -87,39 +87,38 @@ class OperatorBase { AttributeMap attrs_; }; -/** - * KernelContext is the only parameter of Kernel Run function. - * Run will get input/output variables, state such as momentum and - * device resource such as CUDA stream, cublas handle, etc. from - * KernelContext. User should construct it before run the Operator. - */ -class KernelContext { +class OpKernel { public: - KernelContext(const OperatorBase* op, const std::shared_ptr& scope, - const platform::DeviceContext& device_context) - : op_(*op), scope_(scope), device_context_(device_context) {} - - const Variable* Input(int index) const { - return scope_->GetVariable(op_.inputs_[index]); - } - - Variable* Output(int index) const { - return scope_->GetVariable(op_.outputs_[index]); - } + /** + * KernelContext is the only parameter of Kernel Run function. + * Run will get input/output variables, state such as momentum and + * device resource such as CUDA stream, cublas handle, etc. from + * KernelContext. User should construct it before run the Operator. + */ + class KernelContext { + public: + KernelContext(const OperatorBase* op, const std::shared_ptr& scope, + const platform::DeviceContext& device_context) + : op_(*op), scope_(scope), device_context_(device_context) {} + + const Variable* Input(int index) const { + return scope_->GetVariable(op_.inputs_[index]); + } - platform::DeviceContext& device_context() const { return device_context_; } + Variable* Output(int index) const { + return scope_->GetVariable(op_.outputs_[index]); + } - template ::EigenDeviceType> - DeviceType* get_eigen_device(); + template ::EigenDeviceType> + DeviceType* get_eigen_device() const; - const OperatorBase& op_; - const std::shared_ptr& scope_; - const platform::DeviceContext& device_context_; -}; + const OperatorBase& op_; + const std::shared_ptr& scope_; + const platform::DeviceContext& device_context_; + }; -class OpKernel { - public: virtual void Compute(const KernelContext& context) const = 0; virtual ~OpKernel() {} diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 01244f617c..784d52cc42 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -35,7 +35,7 @@ class Tensor { template - const T* data() const { + T* data() const { PADDLE_ENFORCE( holder_ != nullptr, "Tenosr has not been initialized. Call Tensor::mutable_data first."); @@ -90,7 +90,7 @@ class Tensor { // flat to rank = 1 template typename TTypes::Flat flat() { - return shaped({NumElements()}); + return shaped(make_ddim({static_cast(NumElements())})); } // to TensorType Vec @@ -114,7 +114,7 @@ class Tensor { template typename TTypes::ConstFlat flat() const { - return shaped({NumElements()}); + return shaped(make_ddim({static_cast(NumElements())})); } template diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc index ef39e426fd..7dc6414af2 100644 --- a/paddle/operators/add_op.cc +++ b/paddle/operators/add_op.cc @@ -40,6 +40,6 @@ The equation is: Out = X + Y } // namespace paddle REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker); -REGISTER_OP_CPU_KERNEL( - add_two, - ::paddle::operators::AddKernel<::paddle::platform::CPUPlace, float>); \ No newline at end of file +typedef paddle::operators::AddKernel<::paddle::platform::CPUPlace, float> + AddKernel_CPU_float; +REGISTER_OP_CPU_KERNEL(add_two, AddKernel_CPU_float); \ No newline at end of file diff --git a/paddle/operators/add_op.cu b/paddle/operators/add_op.cu index f4a4fb16a6..0edf142ee4 100644 --- a/paddle/operators/add_op.cu +++ b/paddle/operators/add_op.cu @@ -1,7 +1,6 @@ -#define EIGEN_USE_GPU - #include "paddle/operators/add_op.h" #include "paddle/framework/op_registry.h" +typedef paddle::operators::AddKernel<::paddle::platform::GPUPlace, float> AddKernel_GPU_float; REGISTER_OP_GPU_KERNEL(add_two, - paddle::operators::AddKernel); \ No newline at end of file + AddKernel_GPU_float); \ No newline at end of file diff --git a/paddle/operators/add_op.h b/paddle/operators/add_op.h index 27a477a3ac..568cb19742 100644 --- a/paddle/operators/add_op.h +++ b/paddle/operators/add_op.h @@ -6,19 +6,18 @@ namespace paddle { namespace operators { -// Place can be CPUPlace or GPUPlace -template +template class AddKernel : public framework::OpKernel { public: void Compute(const KernelContext& context) const override { - auto* input0 = context.Input(0); - auto* input1 = context.Input(1); + auto input0 = context.Input(0)->Get(); + auto input1 = context.Input(1)->Get(); + auto* output = context.Output(0)->GetMutable(); - auto* output = context.Output(0); - output->mutable_data(Place()); + output->mutable_data(Place()); output->flat().device(*(context.get_eigen_device())) = - input0->flat() + input1->flat(); + input0.flat() + input1.flat(); } }; diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index 960ef0a595..9c1d94e9e7 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -15,14 +15,15 @@ namespace paddle { namespace platform { template <> -Eigen::DefaultDevice* DeviceContext::get_eigen_device() { - return reinterpret_cast(this)->eigen_device(); +Eigen::DefaultDevice* DeviceContext::get_eigen_device() + const { + return reinterpret_cast(this)->eigen_device(); } #ifndef PADDLE_ONLY_CPU template <> -Eigen::GpuDevice* DeviceContext::get_eigen_device() { - return reinterpret_cast(this)->eigen_device(); +Eigen::GpuDevice* DeviceContext::get_eigen_device() const { + return reinterpret_cast(this)->eigen_device(); } #endif diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 7de07d06be..2ec7b05599 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -32,17 +32,14 @@ class DeviceContext { virtual Place GetPlace() const = 0; template - DeviceType* get_eigen_device(); + DeviceType* get_eigen_device() const; }; class CPUDeviceContext : public DeviceContext { public: - Eigen::DefaultDevice* eigen_device() { - if (!eigen_device_) { - eigen_device_.reset(new Eigen::DefaultDevice()); - } - return eigen_device_.get(); - } + CPUDeviceContext() { eigen_device_.reset(new Eigen::DefaultDevice()); } + + Eigen::DefaultDevice* eigen_device() const { return eigen_device_.get(); } Place GetPlace() const override { Place retv = CPUPlace(); @@ -91,7 +88,7 @@ class CUDADeviceContext : public DeviceContext { cudaStream_t stream() { return stream_; } - Eigen::GpuDevice* eigen_device() { return eigen_device_.get(); } + Eigen::GpuDevice* eigen_device() const { return eigen_device_.get(); } cublasHandle_t cublas_handle() { if (!blas_handle_) { From 65dbeb6a24a0362fb696e9f67b3effc1691d4d9e Mon Sep 17 00:00:00 2001 From: qijun Date: Mon, 17 Jul 2017 03:01:33 +0000 Subject: [PATCH 08/10] fix gpu build error --- paddle/framework/operator.cc | 6 +++--- paddle/function/RowConvOpGpu.cu | 21 +++++++++++++-------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index aa859591f0..946bde5734 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -25,9 +25,9 @@ Eigen::DefaultDevice* OpKernel::KernelContext::get_eigen_device< #ifndef PADDLE_ONLY_CPU template <> -DeviceType* OpKernel::KernelContext::get_eigen_device() - const { - return device_context_.get_eigen_device(); +Eigen::GpuDevice* OpKernel::KernelContext::get_eigen_device< + platform::GPUPlace, Eigen::GpuDevice>() const { + return device_context_.get_eigen_device(); } #endif diff --git a/paddle/function/RowConvOpGpu.cu b/paddle/function/RowConvOpGpu.cu index c0b947e224..d9dcc7d59d 100644 --- a/paddle/function/RowConvOpGpu.cu +++ b/paddle/function/RowConvOpGpu.cu @@ -32,7 +32,7 @@ __global__ void KeRowConv(real* y, const real* x, const real* w, for (int i = tidy; i < context; i += blky) { sw[i][tidx] = gidx + tidx < width ? w[i*width + gidx + tidx] : 0.0; } - + __syncthreads(); for (int i = 0; i < numSeq; ++i) { @@ -144,12 +144,15 @@ __global__ void KeRowConvBwWeight(real* dw, const real* x, const real* dy, int yoff = start + j; // transpose - sh_x[tidx][tidy] = (xoff < width && yoff < end) ? x[yoff * width + xoff] : 0.0; - sh_dy[tidx][tidy + context - 1] = (xoff < width && yoff < end) ? dy[yoff * width + xoff] : 0.0; + sh_x[tidx][tidy] = (xoff < width && yoff < end) ? + x[yoff * width + xoff] : 0.0; + sh_dy[tidx][tidy + context - 1] = (xoff < width && yoff < end) ? + dy[yoff * width + xoff] : 0.0; __syncthreads(); if (tidy < (context - 1)) { yoff = yoff - context + 1; - sh_dy[tidx][tidy] = (xoff < width && yoff >= start) ? dy[yoff * width + xoff] : 0.0; + sh_dy[tidx][tidy] = (xoff < width && yoff >= start) ? + dy[yoff * width + xoff] : 0.0; } __syncthreads(); @@ -199,11 +202,13 @@ __global__ void KeRowConvBwWeight2(real* dw, const real* x, const real* dy, int yoff = start + j; // transpose - sh_x[tidx][tidy] = (xoff < width && yoff < end) ? x[yoff * width + xoff] : 0.0; + sh_x[tidx][tidy] = (xoff < width && yoff < end) ? + x[yoff * width + xoff] : 0.0; __syncthreads(); for (int t = 0; t < context; t++) { - sh_dy[tidx][tidy] = (xoff < width && (yoff - t) >= start && yoff - t < end) ? dy[(yoff - t) * width + xoff] : 0.0; + sh_dy[tidx][tidy] = (xoff < width && (yoff - t) >= start && + yoff - t < end) ? dy[(yoff - t) * width + xoff] : 0.0; __syncthreads(); real val = sh_x[tidy][tidx] * sh_dy[tidy][tidx]; @@ -239,7 +244,7 @@ __global__ void KeRowConvBwData(real* dx, const real* w, const real* dy, for (int i = tidy; i < context; i += blky) { sw[i][tidx] = gidx + tidx < width ? w[i*width + gidx + tidx] : 0.0; } - + __syncthreads(); for (int i = 0; i < numSeq; ++i) { @@ -312,7 +317,7 @@ void RowConvGrad(const GpuMatrix& outG, dim3 dimBlock(32, 32); dim3 dimGrid(DIVUP(width, dimBlock.x), 1); real* dw = filterG.getData(); - if (contextLength <= 32) { + if (contextLength <= 32) { KeRowConvBwWeight<32, 32, 32> <<>> (dw, x, dy, starts, height, width, numSeq, contextLength); From 5017b154689bd8cb595c1d37a54cb2fd072488bc Mon Sep 17 00:00:00 2001 From: qijun Date: Mon, 17 Jul 2017 15:37:42 +0800 Subject: [PATCH 09/10] refactor tensor mutable_data --- paddle/framework/operator.h | 14 +++++++------- paddle/framework/tensor.h | 22 ++++++++++------------ paddle/platform/device_context.h | 4 ++-- 3 files changed, 19 insertions(+), 21 deletions(-) diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index c48d990eb2..e6cae9c32b 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -14,17 +14,17 @@ limitations under the License. */ #pragma once -#include -#include -#include -#include -#include -#include -#include #include #include #include #include +#include "paddle/framework/attr_checker.h" +#include "paddle/framework/op_desc.pb.h" +#include "paddle/framework/scope.h" +#include "paddle/framework/tensor.h" +#include "paddle/platform/device_context.h" +#include "paddle/platform/place.h" +#include "paddle/utils/Error.h" namespace paddle { namespace framework { diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 30e00d0e0f..7ba4b29e7c 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -62,21 +62,19 @@ class Tensor { !(holder_->place() == place) /* some versions of boost::variant don't have operator!= */ || holder_->size() < numel_ * sizeof(T) + offset_) { + if (platform::is_cpu_place(place)) { + holder_.reset(new PlaceholderImpl( + boost::get(place), numel_ * sizeof(T))); + } #ifdef __CUDACC__ - switch (place.which()) { - case 0: - holder_.reset(new PlaceholderImpl( - boost::get(place), numel_ * sizeof(T))); - break; - - case 1: - holder_.reset(new PlaceholderImpl( - boost::get(place), numel_ * sizeof(T))); - break; + else if (platform::is_gpu_place(place)) { + holder_.reset(new PlaceholderImpl( + boost::get(place), numel_ * sizeof(T))); } #else - holder_.reset(new PlaceholderImpl( - boost::get(place), numel_ * sizeof(T))); + else if (platform::is_gpu_place(place)) { + PADDLE_ENFORCE(true, "GPU not support!"); + } #endif offset_ = 0; } diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 5f8ad15951..f226a75c20 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -20,9 +20,9 @@ limitations under the License. */ #include "paddle/platform/gpu_info.h" #define EIGEN_USE_GPU #endif -#include #include -#include +#include "paddle/platform/place.h" +#include "unsupported/Eigen/CXX11/Tensor" namespace paddle { namespace platform { From 2a03e3808d48257a71366f5802aeec052914e1cc Mon Sep 17 00:00:00 2001 From: qijun Date: Mon, 17 Jul 2017 16:45:42 +0800 Subject: [PATCH 10/10] set correct place for output tensor --- paddle/framework/operator.cc | 4 ++-- paddle/framework/operator.h | 4 +++- paddle/operators/add_op.h | 4 ++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 946bde5734..1a7e332227 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -18,14 +18,14 @@ namespace paddle { namespace framework { template <> -Eigen::DefaultDevice* OpKernel::KernelContext::get_eigen_device< +Eigen::DefaultDevice* OpKernel::KernelContext::GetEigenDevice< platform::CPUPlace, Eigen::DefaultDevice>() const { return device_context_.get_eigen_device(); } #ifndef PADDLE_ONLY_CPU template <> -Eigen::GpuDevice* OpKernel::KernelContext::get_eigen_device< +Eigen::GpuDevice* OpKernel::KernelContext::GetEigenDevice< platform::GPUPlace, Eigen::GpuDevice>() const { return device_context_.get_eigen_device(); } diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index e6cae9c32b..b8c5098e49 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -109,7 +109,9 @@ class OpKernel { template ::EigenDeviceType> - DeviceType* get_eigen_device() const; + DeviceType* GetEigenDevice() const; + + platform::Place GetPlace() const { return device_context_.GetPlace(); } const OperatorBase& op_; const ScopePtr& scope_; diff --git a/paddle/operators/add_op.h b/paddle/operators/add_op.h index e8c718669a..e9a793d23b 100644 --- a/paddle/operators/add_op.h +++ b/paddle/operators/add_op.h @@ -27,9 +27,9 @@ public: auto input1 = context.Input(1)->Get(); auto* output = context.Output(0)->GetMutable(); - output->mutable_data(Place()); + output->mutable_data(context.GetPlace()); - output->flat().device(*(context.get_eigen_device())) = + output->flat().device(*(context.GetEigenDevice())) = input0.flat() + input1.flat(); } };