From 7a6fcc7d30ab8dd8f452a9974e16798dbbe05dfe Mon Sep 17 00:00:00 2001 From: qijun Date: Thu, 28 Sep 2017 17:39:50 -0700 Subject: [PATCH 1/5] move EigenDeviceConverter to device_context.h --- paddle/framework/operator.cc | 4 ++-- paddle/framework/operator.h | 19 ++----------------- paddle/platform/device_context.cc | 15 ++++++++------- paddle/platform/device_context.h | 19 +++++++++++++++++-- 4 files changed, 29 insertions(+), 28 deletions(-) diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index d7beff5bc1..8b5560ffa1 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -22,14 +22,14 @@ namespace framework { template <> Eigen::DefaultDevice& ExecutionContext::GetEigenDevice< platform::CPUPlace, Eigen::DefaultDevice>() const { - return *device_context_.get_eigen_device(); + return *device_context_.GetEigenDevice(); } #ifndef PADDLE_ONLY_CPU template <> Eigen::GpuDevice& ExecutionContext::GetEigenDevice() const { - return *device_context_.get_eigen_device(); + return *device_context_.GetEigenDevice(); } #endif diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index ba697a43e9..310d68d7c1 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -296,21 +296,6 @@ template <> std::vector InferShapeContext::MultiOutput( const std::string& name) const; -template -struct EigenDeviceConverter; - -template <> -struct EigenDeviceConverter { - using EigenDeviceType = Eigen::DefaultDevice; -}; - -#ifndef PADDLE_ONLY_CPU -template <> -struct EigenDeviceConverter { - using EigenDeviceType = Eigen::GpuDevice; -}; -#endif - class ExecutionContext : public InferShapeContext { public: ExecutionContext(const OperatorBase& op, const Scope& scope, @@ -318,8 +303,8 @@ class ExecutionContext : public InferShapeContext { : InferShapeContext(op, scope), device_context_(device_context) {} template ::EigenDeviceType> + typename DeviceType = typename platform::EigenDeviceConverter< + PlaceType>::EigenDeviceType> DeviceType& GetEigenDevice() const; platform::Place GetPlace() const { return device_context_.GetPlace(); } diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index 93b472b41c..36af1ac677 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -16,8 +16,8 @@ namespace paddle { namespace platform { template <> -Eigen::DefaultDevice* DeviceContext::get_eigen_device() - const { +Eigen::DefaultDevice* DeviceContext::GetEigenDevice< + platform::CPUPlace, Eigen::DefaultDevice>() const { return reinterpret_cast(this)->eigen_device(); } @@ -37,6 +37,12 @@ Place CPUDeviceContext::GetPlace() const { return CPUPlace(); } #ifndef PADDLE_ONLY_CPU +template <> +Eigen::GpuDevice* +DeviceContext::GetEigenDevice() const { + return reinterpret_cast(this)->eigen_device(); +} + class EigenCudaStreamDevice : public Eigen::StreamInterface { public: EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) { @@ -90,11 +96,6 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface { mutable unsigned int* semaphore_; }; -template <> -Eigen::GpuDevice* DeviceContext::get_eigen_device() const { - return reinterpret_cast(this)->eigen_device(); -} - CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) { SetDeviceId(place_.device); PADDLE_ENFORCE(cudaStreamCreate(&stream_)); diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index f6a39a8e26..d805d2ab08 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -27,13 +27,23 @@ limitations under the License. */ namespace paddle { namespace platform { +template +struct EigenDeviceConverter; + +template <> +struct EigenDeviceConverter { + using EigenDeviceType = Eigen::DefaultDevice; +}; + class DeviceContext { public: virtual ~DeviceContext() {} virtual Place GetPlace() const = 0; - template - DeviceType* get_eigen_device() const; + template ::EigenDeviceType> + DeviceType* GetEigenDevice() const; virtual void Wait() const {} }; @@ -52,6 +62,11 @@ class CPUDeviceContext : public DeviceContext { }; #ifndef PADDLE_ONLY_CPU +template <> +struct EigenDeviceConverter { + using EigenDeviceType = Eigen::GpuDevice; +}; + class EigenCudaStreamDevice; class CUDADeviceContext : public DeviceContext { From c634a8480addf2e3cbbd271853f4c8aa4b10832b Mon Sep 17 00:00:00 2001 From: qijun Date: Thu, 28 Sep 2017 17:53:50 -0700 Subject: [PATCH 2/5] add SetConstant method in math_function.h --- paddle/operators/math/CMakeLists.txt | 3 ++- paddle/operators/math/math_function.h | 8 ++++++++ paddle/operators/math/math_function_test.cc | 21 +++++++++++++++++++++ 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 91ae3d49f1..6bea9817f1 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -1,6 +1,7 @@ if(WITH_GPU) nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu DEPS cblas device_context operator) + nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) nv_library(softmax_function SRCS softmax.cc softmax.cu DEPS operator) nv_library(cross_entropy_function SRCS cross_entropy.cc cross_entropy.cu @@ -8,9 +9,9 @@ if(WITH_GPU) else() cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator) + cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) cc_library(softmax_function SRCS softmax.cc DEPS operator) cc_library(cross_entropy_function SRCS cross_entropy.cc DEPS operator) endif() -nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index 43306fca73..473eff4d19 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -52,6 +52,7 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, #include +#include "paddle/framework/eigen.h" #include "paddle/framework/tensor.h" #include "paddle/platform/device_context.h" #include "paddle/platform/enforce.h" @@ -84,6 +85,13 @@ void matmul(const platform::DeviceContext& context, const framework::Tensor& matrix_b, bool trans_b, T alpha, framework::Tensor* matrix_out, T beta); +template +void SetConstant(const platform::DeviceContext& context, + framework::Tensor* tensor, T num) { + auto t = framework::EigenVector::Flatten(*tensor); + t.device(*context.GetEigenDevice()) = t.constant(static_cast(num)); +} + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function_test.cc b/paddle/operators/math/math_function_test.cc index f272f7e513..22468a0c4a 100644 --- a/paddle/operators/math/math_function_test.cc +++ b/paddle/operators/math/math_function_test.cc @@ -243,3 +243,24 @@ TEST(math_function, gemm_trans_clbas) { EXPECT_EQ(input3_ptr[6], 86); EXPECT_EQ(input3_ptr[7], 99); } + +TEST(math_function, zero) { + paddle::framework::Tensor tensor; + auto* cpu_place = new paddle::platform::CPUPlace(); + float* t = tensor.mutable_data({2, 2}, *cpu_place); + paddle::platform::CPUDeviceContext context(*cpu_place); + paddle::operators::math::SetConstant( + context, &tensor, 0); + EXPECT_EQ(t[0], 0); + EXPECT_EQ(t[1], 0); + EXPECT_EQ(t[2], 0); + EXPECT_EQ(t[3], 0); + + paddle::operators::math::SetConstant( + context, &tensor, 1); + + EXPECT_EQ(t[0], 1); + EXPECT_EQ(t[1], 1); + EXPECT_EQ(t[2], 1); + EXPECT_EQ(t[3], 1); +} From 79def5e6347228773a9d77966108653fb7a16c60 Mon Sep 17 00:00:00 2001 From: qijun Date: Thu, 28 Sep 2017 18:03:36 -0700 Subject: [PATCH 3/5] refine CrossEntropyFunctor --- paddle/operators/cross_entropy_op.cu | 14 +------------ paddle/operators/cross_entropy_op.h | 6 +++--- paddle/operators/math/cross_entropy.cc | 6 +++--- paddle/operators/math/cross_entropy.cu | 20 +++++++++---------- paddle/operators/math/cross_entropy.h | 4 +--- .../operators/softmax_with_cross_entropy_op.h | 3 ++- 6 files changed, 19 insertions(+), 34 deletions(-) diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index 76d63f77ad..04ae66de91 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -18,14 +18,6 @@ namespace paddle { namespace operators { namespace { -// TODO(qingqing): make zero setting a common function. -template -__global__ void Zero(T* X, const int N) { - for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; - i += blockDim.x * gridDim.x) { - X[i] = 0.0; - } -} template __global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X, @@ -99,11 +91,7 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel { .stream()>>>(dx_data, dy_data, x_data, label_data, batch_size, class_num); } else { - Zero<<( - ctx.device_context()) - .stream()>>>(dx_data, batch_size * class_num); - + math::SetConstant(ctx.device_context(), dx, 0); auto* label_data = label->data(); grid = (batch_size + block - 1) / block; CrossEntropyGradientKernel<<< diff --git a/paddle/operators/cross_entropy_op.h b/paddle/operators/cross_entropy_op.h index fa81d3b431..d2d321aa7e 100644 --- a/paddle/operators/cross_entropy_op.h +++ b/paddle/operators/cross_entropy_op.h @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" #include "paddle/operators/math/cross_entropy.h" +#include "paddle/operators/math/math_function.h" namespace paddle { namespace operators { @@ -37,7 +38,7 @@ class CrossEntropyOpKernel : public framework::OpKernel { y->mutable_data(ctx.GetPlace()); math::CrossEntropyFunctor()( - ctx, y, x, labels, ctx.Attr("softLabel")); + ctx.device_context(), y, x, labels, ctx.Attr("softLabel")); } }; @@ -69,8 +70,7 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel { const T* x_data = x->data(); const int* label_data = label->data(); - // TODO(qingqing): make zero setting a common function. - memset(dx_data, 0, sizeof(T) * batch_size * class_num); + math::SetConstant(ctx.device_context(), dx, 0); for (int i = 0; i < batch_size; ++i) { PADDLE_ASSERT(label_data[i] >= 0 || label_data[i] < class_num); diff --git a/paddle/operators/math/cross_entropy.cc b/paddle/operators/math/cross_entropy.cc index a5a426bc7b..150a65f275 100644 --- a/paddle/operators/math/cross_entropy.cc +++ b/paddle/operators/math/cross_entropy.cc @@ -26,8 +26,8 @@ using EigenMatrix = framework::EigenMatrix; template class CrossEntropyFunctor { public: - void operator()(const framework::ExecutionContext& ctx, - framework::Tensor* out, const framework::Tensor* prob, + void operator()(const platform::DeviceContext& ctx, framework::Tensor* out, + const framework::Tensor* prob, const framework::Tensor* labels, const bool softLabel) { const int batch_size = prob->dims()[0]; if (softLabel) { @@ -35,7 +35,7 @@ class CrossEntropyFunctor { auto lbl = EigenMatrix::From(*labels); auto loss = EigenMatrix::From(*out); - loss.device(ctx.GetEigenDevice()) = + loss.device(*ctx.GetEigenDevice()) = -((lbl * in.log().unaryExpr(math::TolerableValue())) .sum(Eigen::DSizes(1)) .reshape(Eigen::DSizes(batch_size, 1))); diff --git a/paddle/operators/math/cross_entropy.cu b/paddle/operators/math/cross_entropy.cu index d14a75a30c..2c589521c1 100644 --- a/paddle/operators/math/cross_entropy.cu +++ b/paddle/operators/math/cross_entropy.cu @@ -74,8 +74,8 @@ using Tensor = framework::Tensor; template class CrossEntropyFunctor { public: - void operator()(const framework::ExecutionContext& ctx, - framework::Tensor* out, const framework::Tensor* prob, + void operator()(const framework::DeviceContext& ctx, framework::Tensor* out, + const framework::Tensor* prob, const framework::Tensor* labels, bool softLabel) { const T* prob_data = prob->data(); T* loss_data = out->mutable_data(ctx.GetPlace()); @@ -87,20 +87,18 @@ class CrossEntropyFunctor { const T* label_data = labels->data(); int block = class_num > 512 ? 512 : pow(2, int(std::log2(class_num))); - SoftCrossEntropyKernel< - T><<( - ctx.device_context()) - .stream()>>>(loss_data, prob_data, label_data, class_num); + SoftCrossEntropyKernel<<< + batch_size, block, block * sizeof(T), + reinterpret_cast(ctx).stream()>>>( + loss_data, prob_data, label_data, class_num); } else { const int* label_data = labels->data(); int block = 512; int grid = (batch_size + block - 1) / block; CrossEntropyKernel<<< - grid, block, 0, reinterpret_cast( - ctx.device_context()) - .stream()>>>(loss_data, prob_data, label_data, - batch_size, class_num); + grid, block, 0, + reinterpret_cast(ctx).stream()>>>( + loss_data, prob_data, label_data, batch_size, class_num); } } }; diff --git a/paddle/operators/math/cross_entropy.h b/paddle/operators/math/cross_entropy.h index 18e637cf91..0ab6827ffa 100644 --- a/paddle/operators/math/cross_entropy.h +++ b/paddle/operators/math/cross_entropy.h @@ -37,9 +37,7 @@ struct TolerableValue { template class CrossEntropyFunctor { public: - // (TODO caoying) it is much better to use DeviceContext as the first - // parameter. - void operator()(const framework::ExecutionContext& context, + void operator()(const platform::DeviceContext& context, framework::Tensor* out, const framework::Tensor* prob, const framework::Tensor* labels, const bool softLabel); }; diff --git a/paddle/operators/softmax_with_cross_entropy_op.h b/paddle/operators/softmax_with_cross_entropy_op.h index a8b18504e1..7dcb6ad9b4 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.h +++ b/paddle/operators/softmax_with_cross_entropy_op.h @@ -42,7 +42,8 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { math::SoftmaxFunctor()(context, logits, softmax); math::CrossEntropyFunctor()( - context, loss, softmax, labels, context.Attr("softLabel")); + context.device_context(), loss, softmax, labels, + context.Attr("softLabel")); } }; From 84ff7e97842890e70f1baf6bf41ef54513d1a4a3 Mon Sep 17 00:00:00 2001 From: qijun Date: Thu, 28 Sep 2017 20:05:15 -0700 Subject: [PATCH 4/5] refine SoftmaxFunctor --- paddle/operators/math/softmax.h | 6 +++--- paddle/operators/softmax_op.h | 2 +- paddle/operators/softmax_with_cross_entropy_op.cu | 3 ++- paddle/operators/softmax_with_cross_entropy_op.h | 3 ++- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/paddle/operators/math/softmax.h b/paddle/operators/math/softmax.h index 3d2f0d0aec..225323f05a 100644 --- a/paddle/operators/math/softmax.h +++ b/paddle/operators/math/softmax.h @@ -36,7 +36,7 @@ struct ValueClip { template class SoftmaxFunctor { public: - void operator()(const framework::ExecutionContext& context, + void operator()(const platform::DeviceContext& context, const framework::Tensor* X, framework::Tensor* Y) { auto logits = EigenMatrix::From(*X); auto softmax = EigenMatrix::From(*Y); @@ -58,8 +58,8 @@ class SoftmaxFunctor { .broadcast(one_by_class)) .unaryExpr(ValueClip()); - softmax.device(context.GetEigenDevice()) = shifted_logits.exp(); - softmax.device(context.GetEigenDevice()) = + softmax.device(*context.GetEigenDevice()) = shifted_logits.exp(); + softmax.device(*context.GetEigenDevice()) = (softmax * softmax.sum(along_class) .inverse() diff --git a/paddle/operators/softmax_op.h b/paddle/operators/softmax_op.h index 9996536454..8fdda8b1df 100644 --- a/paddle/operators/softmax_op.h +++ b/paddle/operators/softmax_op.h @@ -35,7 +35,7 @@ class SoftmaxKernel : public framework::OpKernel { // allocate memory on device. Y->mutable_data(context.GetPlace()); - math::SoftmaxFunctor()(context, X, Y); + math::SoftmaxFunctor()(context.device_context(), X, Y); } }; diff --git a/paddle/operators/softmax_with_cross_entropy_op.cu b/paddle/operators/softmax_with_cross_entropy_op.cu index c3086e729e..b5a7cda734 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/operators/softmax_with_cross_entropy_op.cu @@ -66,7 +66,8 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { softmax->mutable_data(context.GetPlace()); loss->mutable_data(context.GetPlace()); - math::SoftmaxFunctor()(context, logits, softmax); + math::SoftmaxFunctor()(context.device_context(), + logits, softmax); math::CrossEntropyFunctor()( context, loss, softmax, labels, context.Attr("softLabel")); } diff --git a/paddle/operators/softmax_with_cross_entropy_op.h b/paddle/operators/softmax_with_cross_entropy_op.h index 7dcb6ad9b4..cffd422f18 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.h +++ b/paddle/operators/softmax_with_cross_entropy_op.h @@ -40,7 +40,8 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { softmax->mutable_data(context.GetPlace()); loss->mutable_data(context.GetPlace()); - math::SoftmaxFunctor()(context, logits, softmax); + math::SoftmaxFunctor()(context.device_context(), + logits, softmax); math::CrossEntropyFunctor()( context.device_context(), loss, softmax, labels, context.Attr("softLabel")); From b611a479fcf687367c9a6808242f6a348854c645 Mon Sep 17 00:00:00 2001 From: qijun Date: Thu, 28 Sep 2017 20:48:55 -0700 Subject: [PATCH 5/5] fix gpu build error --- paddle/operators/cross_entropy_op.cu | 2 +- paddle/operators/math/cross_entropy.cu | 2 +- paddle/operators/softmax_with_cross_entropy_op.cu | 3 ++- paddle/platform/device_context_test.cc | 2 +- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index 04ae66de91..5e2024e0ea 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -56,7 +56,7 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel { y->mutable_data(ctx.GetPlace()); math::CrossEntropyFunctor()( - ctx, y, x, label, ctx.Attr("softLabel")); + ctx.device_context(), y, x, label, ctx.Attr("softLabel")); } }; diff --git a/paddle/operators/math/cross_entropy.cu b/paddle/operators/math/cross_entropy.cu index 2c589521c1..367190e6b0 100644 --- a/paddle/operators/math/cross_entropy.cu +++ b/paddle/operators/math/cross_entropy.cu @@ -74,7 +74,7 @@ using Tensor = framework::Tensor; template class CrossEntropyFunctor { public: - void operator()(const framework::DeviceContext& ctx, framework::Tensor* out, + void operator()(const platform::DeviceContext& ctx, framework::Tensor* out, const framework::Tensor* prob, const framework::Tensor* labels, bool softLabel) { const T* prob_data = prob->data(); diff --git a/paddle/operators/softmax_with_cross_entropy_op.cu b/paddle/operators/softmax_with_cross_entropy_op.cu index b5a7cda734..2bc53ecf87 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/operators/softmax_with_cross_entropy_op.cu @@ -69,7 +69,8 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { math::SoftmaxFunctor()(context.device_context(), logits, softmax); math::CrossEntropyFunctor()( - context, loss, softmax, labels, context.Attr("softLabel")); + context.device_context(), loss, softmax, labels, + context.Attr("softLabel")); } }; diff --git a/paddle/platform/device_context_test.cc b/paddle/platform/device_context_test.cc index 5883a55272..f4b00c57de 100644 --- a/paddle/platform/device_context_test.cc +++ b/paddle/platform/device_context_test.cc @@ -24,7 +24,7 @@ TEST(Device, Init) { for (int i = 0; i < count; i++) { DeviceContext* device_context = new CUDADeviceContext(GPUPlace(i)); Eigen::GpuDevice* gpu_device = - device_context->template get_eigen_device(); + device_context->template GetEigenDevice(); ASSERT_NE(nullptr, gpu_device); delete device_context; }