From 22dac40c3aab587fce717a07d46e1ba61712694c Mon Sep 17 00:00:00 2001 From: qijun Date: Thu, 3 Aug 2017 18:52:51 +0800 Subject: [PATCH 01/19] add gemm for both cpu and gpu --- paddle/framework/operator.h | 4 + paddle/operators/CMakeLists.txt | 4 +- paddle/operators/math/CMakeLists.txt | 5 + paddle/operators/math/math_function.cc | 121 +++++++++++++++ paddle/operators/math/math_function.cu | 146 ++++++++++++++++++ paddle/operators/math/math_function.h | 78 ++++++++++ paddle/operators/mean_op.h | 2 +- paddle/operators/mul_op.cc | 1 + paddle/operators/mul_op.cu | 2 + paddle/operators/mul_op.h | 32 ++-- .../paddle/v2/framework/tests/op_test_util.py | 2 +- 11 files changed, 385 insertions(+), 12 deletions(-) create mode 100644 paddle/operators/math/CMakeLists.txt create mode 100644 paddle/operators/math/math_function.cc create mode 100644 paddle/operators/math/math_function.cu create mode 100644 paddle/operators/math/math_function.h diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 5543510348..6a9057e5db 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -257,6 +257,10 @@ class ExecutionContext : public OperatorContext { platform::Place GetPlace() const { return device_context_.GetPlace(); } + const platform::DeviceContext& device_context() const { + return device_context_; + }; + const platform::DeviceContext& device_context_; }; diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 6465deeec9..6be90d9124 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -41,13 +41,15 @@ function(op_library TARGET) endif() endfunction() +add_subdirectory(math) + op_library(add_op SRCS add_op.cc add_op.cu) cc_test(add_op_test SRCS add_op_test.cc DEPS add_op) op_library(mean_op SRCS mean_op.cc mean_op.cu) cc_test(mean_op_test SRCS mean_op_test.cc DEPS mean_op) -op_library(mul_op SRCS mul_op.cc mul_op.cu) +op_library(mul_op SRCS mul_op.cc mul_op.cu DEPS math_function) op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc) op_library(sigmoid_op SRCS sigmoid_op.cc sigmoid_op.cu) diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt new file mode 100644 index 0000000000..586347668e --- /dev/null +++ b/paddle/operators/math/CMakeLists.txt @@ -0,0 +1,5 @@ +if (WITH_GPU) + nv_library(math_function SRCS math_function.cc math_function.cu DEPS cblas device_context) +else() + cc_library(math_function SRCS math_function.cc DEPS cblas device_context) +endif() diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc new file mode 100644 index 0000000000..0532e8f034 --- /dev/null +++ b/paddle/operators/math/math_function.cc @@ -0,0 +1,121 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { +namespace math { + +template <> +void gemm(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, + const int M, + const int N, + const int K, + const float alpha, + const float* A, + const int lda, + const float* B, + const int ldb, + const float beta, + float* C, + const int ldc, + const platform::DeviceContext* context) { + cblas_sgemm(CblasRowMajor, + transA, + transB, + M, + N, + K, + alpha, + A, + lda, + B, + ldb, + beta, + C, + ldc); +} + +template <> +void gemm(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, + const int M, + const int N, + const int K, + const double alpha, + const double* A, + const int lda, + const double* B, + const int ldb, + const double beta, + double* C, + const int ldc, + const platform::DeviceContext* context) { + cblas_dgemm(CblasRowMajor, + transA, + transB, + M, + N, + K, + alpha, + A, + lda, + B, + ldb, + beta, + C, + ldc); +} + +template <> +void axpy(const int n, + const float alpha, + const float* x, + float* y, + const platform::DeviceContext* context) { + cblas_saxpy(n, alpha, x, 1, y, 1); +} + +template <> +void axpy(const int n, + const double alpha, + const double* x, + double* y, + const platform::DeviceContext* context) { + cblas_daxpy(n, alpha, x, 1, y, 1); +} + +template <> +float dotProduct( + const int n, + const float* x, + const float* y, + const platform::DeviceContext* context) { + return cblas_sdot(n, x, 1, y, 1); +} + +template <> +double dotProduct( + const int n, + const double* x, + const double* y, + const platform::DeviceContext* context) { + return cblas_ddot(n, x, 1, y, 1); +} + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu new file mode 100644 index 0000000000..46301df8f9 --- /dev/null +++ b/paddle/operators/math/math_function.cu @@ -0,0 +1,146 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/math_function.h" + + +namespace paddle { +namespace operators { +namespace math { + +template <> +void gemm(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, + const int M, + const int N, + const int K, + const float alpha, + const float* A, + const int lda, + const float* B, + const int ldb, + const float beta, + float* C, + const int ldc, + const platform::DeviceContext* context) { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + cublasOperation_t cuTransA = + (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + + PADDLE_ENFORCE(platform::dynload::cublasSgemm( + reinterpret_cast(context)-> + cublas_handle(), + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + ldb, + A, + lda, + &beta, + C, + ldc)); +} + +template <> +void gemm(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, + const int M, + const int N, + const int K, + const double alpha, + const double* A, + const int lda, + const double* B, + const int ldb, + const double beta, + double* C, + const int ldc, + const platform::DeviceContext* context) { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + cublasOperation_t cuTransA = + (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + PADDLE_ENFORCE(platform::dynload::cublasDgemm( + reinterpret_cast(context)-> + cublas_handle(), + cuTransB, + cuTransA, + N, + M, + K, + &alpha, + B, + ldb, + A, + lda, + &beta, + C, + ldc)); +} + + +template <> +void axpy(const int n, + const float alpha, + const float* x, + float* y, + const platform::DeviceContext* context) { + CUBLAS_ENFORCE(platform::dynload::cublasSaxpy( + reinterpret_cast(context)-> + cublas_handle(), N, &alpha, X, 1, Y, 1)); +} + +template <> +void axpy(const int n, + const double alpha, + const double* x, + double* y, + const platform::DeviceContext* context) { + CUBLAS_ENFORCE(platform::dynload::cublasDaxpy( + reinterpret_cast(context)-> + cublas_handle(), N, &alpha, X, 1, Y, 1)); +} + +template <> +float dotProduct(const int n, + const float* x, + const float* y, + const platform::DeviceContext* context) { + CUBLAS_ENFORCE(platform::dynload::cublasSdot( + reinterpret_cast(context)-> + cublas_handle(), n, a, 1, b, 1, &result)); +} + +template <> +double dotProduct(const int n, + const double* x, + const double* y, + const platform::DeviceContext* context) { + CUBLAS_ENFORCE(platform::dynload::cublasDdot( + reinterpret_cast(context)-> + cublas_handle(), n, a, 1, b, 1, &result)); +} + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h new file mode 100644 index 0000000000..c5b7fe8793 --- /dev/null +++ b/paddle/operators/math/math_function.h @@ -0,0 +1,78 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#ifdef PADDLE_USE_MKLML +#include +#include +#include +#endif + +#ifdef PADDLE_USE_MKL +#include +#include +#endif + +#ifdef PADDLE_USE_ATLAS +extern "C" { +#include +#include +} +#endif + +#ifdef PADDLE_USE_OPENBLAS +#include +#include +#endif + +#include +#include "paddle/platform/device_context.h" + +namespace paddle { +namespace operators { +namespace math { + +template +void gemm(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, + const int M, + const int N, + const int K, + const T alpha, + const T* A, + const int lda, + const T* B, + const int ldb, + const T beta, + T* C, + const int ldc, + const platform::DeviceContext* context); + +template +void axpy(const int n, + const T alpha, + const T* x, + T* y, + const platform::DeviceContext* context); + +template +T dotProduct(const int n, + const T* x, + const T* y, + const platform::DeviceContext* context); + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/mean_op.h b/paddle/operators/mean_op.h index a89cb422f9..e712dee6a7 100644 --- a/paddle/operators/mean_op.h +++ b/paddle/operators/mean_op.h @@ -47,7 +47,7 @@ public: T ig_size = (T)framework::product(IG->dims()); - EigenVector::Flatten(*IG).device(*(context.GetEigenDevice())) = + EigenVector::Flatten(*IG).device(context.GetEigenDevice()) = EigenScalar::From(*OG) / ig_size; } }; diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index d127f3a302..eaf1d3266c 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -13,6 +13,7 @@ limitations under the License. */ #include "paddle/operators/mul_op.h" +#include "paddle/operators/math/math_function.h" namespace paddle { namespace operators { diff --git a/paddle/operators/mul_op.cu b/paddle/operators/mul_op.cu index dc92367016..ba04605503 100644 --- a/paddle/operators/mul_op.cu +++ b/paddle/operators/mul_op.cu @@ -15,4 +15,6 @@ #define EIGEN_USE_GPU #include "paddle/operators/mul_op.h" + + REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel); \ No newline at end of file diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index c7b78ad390..e1759d00c5 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -14,6 +14,7 @@ #pragma once +#include "paddle/operators/math/math_function.h" #include "paddle/operators/type_alias.h" namespace paddle { @@ -23,22 +24,35 @@ template class MulKernel : public OpKernel { public: void Compute(const ExecutionContext& context) const override { - Eigen::array, 1> dim_pair = { - {Eigen::IndexPair(1, 0)}}; - auto input0 = context.Input("X"); auto input1 = context.Input("Y"); auto output = context.Output(0); output->mutable_data(context.GetPlace()); - auto X = EigenMatrix::From(*input0); - auto Y = EigenMatrix::From(*input1); - auto Z = EigenMatrix::From(*output); - auto place = context.GetEigenDevice(); - - Z.device(place) = X.contract(Y, dim_pair); + auto out_dim = output->dims(); + auto in0_dim = input0->dims(); + + int M = out_dim[0]; + int N = out_dim[1]; + int K = in0_dim[1]; + + paddle::operators::math::template gemm(CblasNoTrans, + CblasNoTrans, + M, + N, + K, + 1, + input0->data(), + K, + input1->data(), + N, + 0, + output->data(), + N, + &context.device_context()); } }; + } // namespace operators } // namespace paddle diff --git a/python/paddle/v2/framework/tests/op_test_util.py b/python/paddle/v2/framework/tests/op_test_util.py index 98fae1b975..35d285e2e6 100644 --- a/python/paddle/v2/framework/tests/op_test_util.py +++ b/python/paddle/v2/framework/tests/op_test_util.py @@ -61,7 +61,7 @@ class OpTestMeta(type): for out_name in func.all_output_args: actual = numpy.array(scope.find_var(out_name).get_tensor()) expect = getattr(self, out_name) - # TODO(qijun) The default decimal is 7, but numpy.dot and eigen.mul + # TODO(qijun) The default decimal is 7, but numpy.dot and blas.gemm # has some diff, and could not pass unittest. So I set decimal 3 here. # And I will check this in future. numpy.testing.assert_almost_equal(actual, expect, decimal=3) From f190a795382b4bf3926455ce52beda7157e4ec2e Mon Sep 17 00:00:00 2001 From: qijun Date: Thu, 3 Aug 2017 12:29:04 +0000 Subject: [PATCH 02/19] fix gpu build error --- paddle/operators/math/math_function.cc | 40 +----------- paddle/operators/math/math_function.cu | 84 +++++++------------------- paddle/operators/math/math_function.h | 15 +---- paddle/operators/mul_op.h | 29 ++++----- 4 files changed, 39 insertions(+), 129 deletions(-) diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index 0532e8f034..c678b37616 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -32,7 +32,7 @@ void gemm(const CBLAS_TRANSPOSE transA, const float beta, float* C, const int ldc, - const platform::DeviceContext* context) { + platform::DeviceContext* context) { cblas_sgemm(CblasRowMajor, transA, transB, @@ -63,7 +63,7 @@ void gemm(const CBLAS_TRANSPOSE transA, const double beta, double* C, const int ldc, - const platform::DeviceContext* context) { + platform::DeviceContext* context) { cblas_dgemm(CblasRowMajor, transA, transB, @@ -80,42 +80,6 @@ void gemm(const CBLAS_TRANSPOSE transA, ldc); } -template <> -void axpy(const int n, - const float alpha, - const float* x, - float* y, - const platform::DeviceContext* context) { - cblas_saxpy(n, alpha, x, 1, y, 1); -} - -template <> -void axpy(const int n, - const double alpha, - const double* x, - double* y, - const platform::DeviceContext* context) { - cblas_daxpy(n, alpha, x, 1, y, 1); -} - -template <> -float dotProduct( - const int n, - const float* x, - const float* y, - const platform::DeviceContext* context) { - return cblas_sdot(n, x, 1, y, 1); -} - -template <> -double dotProduct( - const int n, - const double* x, - const double* y, - const platform::DeviceContext* context) { - return cblas_ddot(n, x, 1, y, 1); -} - } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 46301df8f9..190312e59d 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -20,29 +20,29 @@ namespace operators { namespace math { template <> -void gemm(const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE transB, - const int M, - const int N, - const int K, - const float alpha, - const float* A, - const int lda, - const float* B, - const int ldb, - const float beta, - float* C, - const int ldc, - const platform::DeviceContext* context) { +void gemm(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, + const int M, + const int N, + const int K, + const float alpha, + const float* A, + const int lda, + const float* B, + const int ldb, + const float beta, + float* C, + const int ldc, + platform::DeviceContext* context) { // Note that cublas follows fortran order, so the order is different from // the cblas convention. cublasOperation_t cuTransA = - (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; cublasOperation_t cuTransB = - (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; PADDLE_ENFORCE(platform::dynload::cublasSgemm( - reinterpret_cast(context)-> + reinterpret_cast(context)-> cublas_handle(), cuTransB, cuTransA, @@ -73,15 +73,15 @@ void gemm(const CBLAS_TRANSPOSE transA, const double beta, double* C, const int ldc, - const platform::DeviceContext* context) { + platform::DeviceContext* context) { // Note that cublas follows fortran order, so the order is different from // the cblas convention. cublasOperation_t cuTransA = - (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; cublasOperation_t cuTransB = - (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; PADDLE_ENFORCE(platform::dynload::cublasDgemm( - reinterpret_cast(context)-> + reinterpret_cast(context)-> cublas_handle(), cuTransB, cuTransA, @@ -99,48 +99,6 @@ void gemm(const CBLAS_TRANSPOSE transA, } -template <> -void axpy(const int n, - const float alpha, - const float* x, - float* y, - const platform::DeviceContext* context) { - CUBLAS_ENFORCE(platform::dynload::cublasSaxpy( - reinterpret_cast(context)-> - cublas_handle(), N, &alpha, X, 1, Y, 1)); -} - -template <> -void axpy(const int n, - const double alpha, - const double* x, - double* y, - const platform::DeviceContext* context) { - CUBLAS_ENFORCE(platform::dynload::cublasDaxpy( - reinterpret_cast(context)-> - cublas_handle(), N, &alpha, X, 1, Y, 1)); -} - -template <> -float dotProduct(const int n, - const float* x, - const float* y, - const platform::DeviceContext* context) { - CUBLAS_ENFORCE(platform::dynload::cublasSdot( - reinterpret_cast(context)-> - cublas_handle(), n, a, 1, b, 1, &result)); -} - -template <> -double dotProduct(const int n, - const double* x, - const double* y, - const platform::DeviceContext* context) { - CUBLAS_ENFORCE(platform::dynload::cublasDdot( - reinterpret_cast(context)-> - cublas_handle(), n, a, 1, b, 1, &result)); -} - } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index c5b7fe8793..f1f87ac5f2 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -58,20 +58,7 @@ void gemm(const CBLAS_TRANSPOSE transA, const T beta, T* C, const int ldc, - const platform::DeviceContext* context); - -template -void axpy(const int n, - const T alpha, - const T* x, - T* y, - const platform::DeviceContext* context); - -template -T dotProduct(const int n, - const T* x, - const T* y, - const platform::DeviceContext* context); + platform::DeviceContext* context); } // namespace math } // namespace operators diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index e1759d00c5..0bffe79a1e 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -37,20 +37,21 @@ public: int N = out_dim[1]; int K = in0_dim[1]; - paddle::operators::math::template gemm(CblasNoTrans, - CblasNoTrans, - M, - N, - K, - 1, - input0->data(), - K, - input1->data(), - N, - 0, - output->data(), - N, - &context.device_context()); + paddle::operators::math::template gemm( + CblasNoTrans, + CblasNoTrans, + M, + N, + K, + 1, + input0->data(), + K, + input1->data(), + N, + 0, + output->data(), + N, + &const_cast(context.device_context())); } }; From 97d8175a5e19dbd60ea55cb21640cd7187d60974 Mon Sep 17 00:00:00 2001 From: qijun Date: Mon, 7 Aug 2017 11:45:00 +0800 Subject: [PATCH 03/19] add global matmul function for Tensor --- paddle/framework/tensor.h | 2 + paddle/operators/math/math_function.cc | 93 ++++++++++++++++++++++++++ paddle/operators/math/math_function.cu | 73 ++++++++++++++++++++ paddle/operators/math/math_function.h | 12 ++++ paddle/operators/mul_op.h | 31 +++------ 5 files changed, 189 insertions(+), 22 deletions(-) diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 4c3b14b83d..2aac8a128a 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -103,6 +103,8 @@ class Tensor { template inline Tensor Slice(const int& begin_idx, const int& end_idx) const; + platform::Place place() const { return holder_->place(); } + private: template inline void check_memory_size() const; diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index c678b37616..1bfbc75573 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -80,6 +80,99 @@ void gemm(const CBLAS_TRANSPOSE transA, ldc); } +template <> +void matmul(const framework::Tensor& in1, + bool in1_T, + const framework::Tensor& in2, + bool in2_T, + float alpha, + framework::Tensor* out, + float beta, + platform::DeviceContext* context) { + auto in1_dim = in1.dims(); + auto in2_dim = in2.dims(); + auto out_dim = out->dims(); + PADDLE_ENFORCE( + in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2, + "The input and output of matmul be matrix"); + PADDLE_ENFORCE( + in1_dim[1] == in2_dim[0], + "First matrix's width must be equal with second matrix's height."); + + PADDLE_ENFORCE(platform::is_cpu_place(in1.place()) && + platform::is_cpu_place(in2.place()) && + platform::is_cpu_place(out->place()), + "Matrix must all be in CPUPlace"); + + int M = out_dim[0]; + int N = out_dim[1]; + int K = in1_dim[1]; + + CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; + + gemm(in1_Trans, + in2_Trans, + M, + N, + K, + alpha, + in1.data(), + K, + in2.data(), + N, + beta, + out->data(), + N, + context); +} + +template <> +void matmul(const framework::Tensor& in1, + bool in1_T, + const framework::Tensor& in2, + bool in2_T, + float alpha, + framework::Tensor* out, + float beta, + platform::DeviceContext* context) { + auto in1_dim = in1.dims(); + auto in2_dim = in2.dims(); + auto out_dim = out->dims(); + PADDLE_ENFORCE( + in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2, + "The input and output of matmul be matrix"); + PADDLE_ENFORCE( + in1_dim[1] == in2_dim[0], + "First matrix's width must be equal with second matrix's height."); + + PADDLE_ENFORCE(platform::is_cpu_place(in1.place()) && + platform::is_cpu_place(in2.place()) && + platform::is_cpu_place(out->place()), + "Matrix must all be in CPUPlace"); + + int M = out_dim[0]; + int N = out_dim[1]; + int K = in1_dim[1]; + CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; + + gemm(in1_Trans, + in2_Trans, + M, + N, + K, + alpha, + in1.data(), + K, + in2.data(), + N, + beta, + out->data(), + N, + context); +} + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 190312e59d..e1ac856082 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -98,6 +98,79 @@ void gemm(const CBLAS_TRANSPOSE transA, ldc)); } +template <> +void matmul(const framework::Tensor& in1, bool in1_T, const framework::Tensor& in2, bool in2_T, float alpha, +framework::Tensor* out, float beta, platform::DeviceContext* context) { + auto in1_dim = in1.dims(); + auto in2_dim = in2.dims(); + auto out_dim = out->dims(); + PADDLE_ENFORCE(in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2, + "The input and output of matmul be matrix"); + PADDLE_ENFORCE( + in1_dim[1] == in2_dim[0], + "First matrix's width must be equal with second matrix's height."); + + PADDLE_ENFORCE(platform::is_gpu_place(in1.place()) && platform::is_gpu_place(in2.place())&& platform::is_gpu_place(out->place()), "Matrix must all be in GPUPlace"); + + int M = out_dim[0]; + int N = out_dim[1]; + int K = in1_dim[1]; + + CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; + + gemm(in1_Trans, + in2_Trans, + M, + N, + K, + alpha, + in1.data(), + K, + in2.data(), + N, + beta, + out->data(), + N, + context); + +} + + +template <> +void matmul(const framework::Tensor& in1, bool in1_T, const framework::Tensor& in2, bool in2_T, float alpha, +framework::Tensor* out, float beta, platform::DeviceContext* context) { + auto in1_dim = in1.dims(); + auto in2_dim = in2.dims(); + auto out_dim = out->dims(); + PADDLE_ENFORCE(in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2, + "The input and output of matmul be matrix"); + PADDLE_ENFORCE( + in1_dim[1] == in2_dim[0], + "First matrix's width must be equal with second matrix's height."); + + PADDLE_ENFORCE(platform::is_gpu_place(in1.place()) && platform::is_gpu_place(in2.place())&& platform::is_gpu_place(out->place()), "Matrix must all be in GPUPlace"); + + CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; + + gemm(in1_Trans, + in2_Trans, + M, + N, + K, + alpha, + in1.data(), + K, + in2.data(), + N, + beta, + out->data(), + N, + context); + +} + } // namespace math } // namespace operators diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index f1f87ac5f2..f068f4a15e 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -38,6 +38,7 @@ extern "C" { #endif #include +#include "paddle/framework/tensor.h" #include "paddle/platform/device_context.h" namespace paddle { @@ -60,6 +61,17 @@ void gemm(const CBLAS_TRANSPOSE transA, const int ldc, platform::DeviceContext* context); +// matrix multiply with continous memory +template +void matmul(const framework::Tensor& in1, + bool in1_T, + const framework::Tensor& in2, + bool in2_T, + float alpha, + framework::Tensor* out, + float beta, + platform::DeviceContext* context); + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index 0bffe79a1e..d5d8e220ab 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -24,33 +24,20 @@ template class MulKernel : public OpKernel { public: void Compute(const ExecutionContext& context) const override { - auto input0 = context.Input("X"); - auto input1 = context.Input("Y"); - auto output = context.Output(0); + auto* input0 = context.Input("X"); + auto* input1 = context.Input("Y"); + auto* output = context.Output(0); output->mutable_data(context.GetPlace()); - auto out_dim = output->dims(); - auto in0_dim = input0->dims(); - - int M = out_dim[0]; - int N = out_dim[1]; - int K = in0_dim[1]; - - paddle::operators::math::template gemm( - CblasNoTrans, - CblasNoTrans, - M, - N, - K, + paddle::operators::math::template matmul( + *input0, + false, + *input1, + false, 1, - input0->data(), - K, - input1->data(), - N, + output, 0, - output->data(), - N, &const_cast(context.device_context())); } }; From 5703eb50fa32b1ae141aaf58d4a46f8b06e24478 Mon Sep 17 00:00:00 2001 From: qijun Date: Mon, 7 Aug 2017 05:04:22 +0000 Subject: [PATCH 04/19] add .clang-format file --- paddle/operators/math/.clang-format | 5 + paddle/operators/math/math_function.cu | 165 +++++++++---------------- 2 files changed, 61 insertions(+), 109 deletions(-) create mode 100644 paddle/operators/math/.clang-format diff --git a/paddle/operators/math/.clang-format b/paddle/operators/math/.clang-format new file mode 100644 index 0000000000..47b8a85206 --- /dev/null +++ b/paddle/operators/math/.clang-format @@ -0,0 +1,5 @@ +--- +Language: Cpp +BasedOnStyle: Google +Standard: Cpp11 +... diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index e1ac856082..3e2aeea1da 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -14,66 +14,34 @@ limitations under the License. */ #include "paddle/operators/math/math_function.h" - namespace paddle { namespace operators { namespace math { template <> -void gemm(const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE transB, - const int M, - const int N, - const int K, - const float alpha, - const float* A, - const int lda, - const float* B, - const int ldb, - const float beta, - float* C, - const int ldc, - platform::DeviceContext* context) { +void gemm( + const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, + const int N, const int K, const float alpha, const float* A, const int lda, + const float* B, const int ldb, const float beta, float* C, const int ldc, + platform::DeviceContext* context) { // Note that cublas follows fortran order, so the order is different from // the cblas convention. cublasOperation_t cuTransA = (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; cublasOperation_t cuTransB = (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - + PADDLE_ENFORCE(platform::dynload::cublasSgemm( - reinterpret_cast(context)-> - cublas_handle(), - cuTransB, - cuTransA, - N, - M, - K, - &alpha, - B, - ldb, - A, - lda, - &beta, - C, - ldc)); + reinterpret_cast(context)->cublas_handle(), + cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc)); } template <> -void gemm(const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE transB, - const int M, - const int N, - const int K, - const double alpha, - const double* A, - const int lda, - const double* B, - const int ldb, - const double beta, - double* C, - const int ldc, - platform::DeviceContext* context) { +void gemm( + const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, + const int N, const int K, const double alpha, const double* A, + const int lda, const double* B, const int ldb, const double beta, double* C, + const int ldc, platform::DeviceContext* context) { // Note that cublas follows fortran order, so the order is different from // the cblas convention. cublasOperation_t cuTransA = @@ -81,97 +49,76 @@ void gemm(const CBLAS_TRANSPOSE transA, cublasOperation_t cuTransB = (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; PADDLE_ENFORCE(platform::dynload::cublasDgemm( - reinterpret_cast(context)-> - cublas_handle(), - cuTransB, - cuTransA, - N, - M, - K, - &alpha, - B, - ldb, - A, - lda, - &beta, - C, - ldc)); + reinterpret_cast(context)->cublas_handle(), + cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc)); } template <> -void matmul(const framework::Tensor& in1, bool in1_T, const framework::Tensor& in2, bool in2_T, float alpha, -framework::Tensor* out, float beta, platform::DeviceContext* context) { +void matmul(const framework::Tensor& in1, bool in1_T, + const framework::Tensor& in2, bool in2_T, + float alpha, framework::Tensor* out, + float beta, + platform::DeviceContext* context) { auto in1_dim = in1.dims(); auto in2_dim = in2.dims(); auto out_dim = out->dims(); - PADDLE_ENFORCE(in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2, - "The input and output of matmul be matrix"); PADDLE_ENFORCE( - in1_dim[1] == in2_dim[0], - "First matrix's width must be equal with second matrix's height."); + in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2, + "The input and output of matmul be matrix"); + PADDLE_ENFORCE( + in1_dim[1] == in2_dim[0], + "First matrix's width must be equal with second matrix's height."); - PADDLE_ENFORCE(platform::is_gpu_place(in1.place()) && platform::is_gpu_place(in2.place())&& platform::is_gpu_place(out->place()), "Matrix must all be in GPUPlace"); + PADDLE_ENFORCE(platform::is_gpu_place(in1.place()) && + platform::is_gpu_place(in2.place()) && + platform::is_gpu_place(out->place()), + "Matrix must all be in GPUPlace"); - int M = out_dim[0]; + int M = out_dim[0]; int N = out_dim[1]; int K = in1_dim[1]; - CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; - gemm(in1_Trans, - in2_Trans, - M, - N, - K, - alpha, - in1.data(), - K, - in2.data(), - N, - beta, - out->data(), - N, - context); - + gemm(in1_Trans, in2_Trans, M, N, K, alpha, + in1.data(), K, in2.data(), N, + beta, out->data(), N, context); } - template <> -void matmul(const framework::Tensor& in1, bool in1_T, const framework::Tensor& in2, bool in2_T, float alpha, -framework::Tensor* out, float beta, platform::DeviceContext* context) { +void matmul(const framework::Tensor& in1, + bool in1_T, + const framework::Tensor& in2, + bool in2_T, float alpha, + framework::Tensor* out, float beta, + platform::DeviceContext* context) { auto in1_dim = in1.dims(); auto in2_dim = in2.dims(); auto out_dim = out->dims(); - PADDLE_ENFORCE(in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2, - "The input and output of matmul be matrix"); PADDLE_ENFORCE( - in1_dim[1] == in2_dim[0], - "First matrix's width must be equal with second matrix's height."); + in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2, + "The input and output of matmul be matrix"); + PADDLE_ENFORCE( + in1_dim[1] == in2_dim[0], + "First matrix's width must be equal with second matrix's height."); - PADDLE_ENFORCE(platform::is_gpu_place(in1.place()) && platform::is_gpu_place(in2.place())&& platform::is_gpu_place(out->place()), "Matrix must all be in GPUPlace"); + PADDLE_ENFORCE(platform::is_gpu_place(in1.place()) && + platform::is_gpu_place(in2.place()) && + platform::is_gpu_place(out->place()), + "Matrix must all be in GPUPlace"); - CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; + int M = out_dim[0]; + int N = out_dim[1]; + int K = in1_dim[1]; + CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; - gemm(in1_Trans, - in2_Trans, - M, - N, - K, - alpha, - in1.data(), - K, - in2.data(), - N, - beta, - out->data(), - N, - context); - + gemm(in1_Trans, in2_Trans, M, N, K, alpha, + in1.data(), K, in2.data(), N, + beta, out->data(), N, context); } - } // namespace math } // namespace operators } // namespace paddle From 081593591642c4c21e0a7daaa6e6bc3999abc856 Mon Sep 17 00:00:00 2001 From: qijun Date: Mon, 7 Aug 2017 05:45:02 +0000 Subject: [PATCH 05/19] fix typo error --- paddle/operators/math/math_function.cc | 121 ++++++------------------- 1 file changed, 26 insertions(+), 95 deletions(-) diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index 1bfbc75573..5833fc90a7 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -19,74 +19,29 @@ namespace operators { namespace math { template <> -void gemm(const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE transB, - const int M, - const int N, - const int K, - const float alpha, - const float* A, - const int lda, - const float* B, - const int ldb, - const float beta, - float* C, - const int ldc, - platform::DeviceContext* context) { - cblas_sgemm(CblasRowMajor, - transA, - transB, - M, - N, - K, - alpha, - A, - lda, - B, - ldb, - beta, - C, - ldc); +void gemm( + const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, + const int N, const int K, const float alpha, const float* A, const int lda, + const float* B, const int ldb, const float beta, float* C, const int ldc, + platform::DeviceContext* context) { + cblas_sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, + beta, C, ldc); } template <> -void gemm(const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE transB, - const int M, - const int N, - const int K, - const double alpha, - const double* A, - const int lda, - const double* B, - const int ldb, - const double beta, - double* C, - const int ldc, - platform::DeviceContext* context) { - cblas_dgemm(CblasRowMajor, - transA, - transB, - M, - N, - K, - alpha, - A, - lda, - B, - ldb, - beta, - C, - ldc); +void gemm( + const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, + const int N, const int K, const double alpha, const double* A, + const int lda, const double* B, const int ldb, const double beta, double* C, + const int ldc, platform::DeviceContext* context) { + cblas_dgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, + beta, C, ldc); } template <> -void matmul(const framework::Tensor& in1, - bool in1_T, - const framework::Tensor& in2, - bool in2_T, - float alpha, - framework::Tensor* out, +void matmul(const framework::Tensor& in1, bool in1_T, + const framework::Tensor& in2, bool in2_T, + float alpha, framework::Tensor* out, float beta, platform::DeviceContext* context) { auto in1_dim = in1.dims(); @@ -111,30 +66,17 @@ void matmul(const framework::Tensor& in1, CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; - gemm(in1_Trans, - in2_Trans, - M, - N, - K, - alpha, - in1.data(), - K, - in2.data(), - N, - beta, - out->data(), - N, - context); + gemm(in1_Trans, in2_Trans, M, N, K, alpha, + in1.data(), K, in2.data(), N, + beta, out->data(), N, context); } template <> -void matmul(const framework::Tensor& in1, +void matmul(const framework::Tensor& in1, bool in1_T, const framework::Tensor& in2, - bool in2_T, - float alpha, - framework::Tensor* out, - float beta, + bool in2_T, float alpha, + framework::Tensor* out, float beta, platform::DeviceContext* context) { auto in1_dim = in1.dims(); auto in2_dim = in2.dims(); @@ -157,20 +99,9 @@ void matmul(const framework::Tensor& in1, CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; - gemm(in1_Trans, - in2_Trans, - M, - N, - K, - alpha, - in1.data(), - K, - in2.data(), - N, - beta, - out->data(), - N, - context); + gemm(in1_Trans, in2_Trans, M, N, K, alpha, + in1.data(), K, in2.data(), N, + beta, out->data(), N, context); } } // namespace math From 6b12c697ff3e2a86e555fafa53ab5b1017e982ce Mon Sep 17 00:00:00 2001 From: qijun Date: Mon, 7 Aug 2017 14:35:50 +0800 Subject: [PATCH 06/19] handle mkl --- paddle/operators/math/CMakeLists.txt | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 586347668e..d34bc92594 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -1,5 +1,13 @@ if (WITH_GPU) - nv_library(math_function SRCS math_function.cc math_function.cu DEPS cblas device_context) + if (WITH_MKLML) + nv_library(math_function SRCS math_function.cc math_function.cu DEPS mklml device_context) + else() + nv_library(math_function SRCS math_function.cc math_function.cu DEPS cblas device_context) + endif() else() - cc_library(math_function SRCS math_function.cc DEPS cblas device_context) + if (WITH_MKLML) + cc_library(math_function SRCS math_function.cc DEPS mklml device_context) + else() + cc_library(math_function SRCS math_function.cc DEPS cblas device_context) + endif() endif() From cabcf7bcfd4a4a02aface02da11b278e10124117 Mon Sep 17 00:00:00 2001 From: qijun Date: Mon, 7 Aug 2017 18:17:34 +0800 Subject: [PATCH 07/19] format code --- paddle/framework/operator.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 5a9b7dd914..7242b6418d 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -268,7 +268,7 @@ class ExecutionContext : public OperatorContext { const platform::DeviceContext* device_context() const { return device_context_; - }; + } const platform::DeviceContext* device_context_; }; From 7307b439e1b92f7afebdadfec884bdbfc6f024b9 Mon Sep 17 00:00:00 2001 From: qijun Date: Wed, 9 Aug 2017 13:03:35 +0000 Subject: [PATCH 08/19] fix gpu build error --- CMakeLists.txt | 4 ++-- paddle/operators/math/math_function.cu | 6 ++++-- paddle/operators/math/math_function.h | 16 +++++++++++++++- paddle/operators/mul_op.cu | 1 + paddle/operators/mul_op.h | 3 --- 5 files changed, 22 insertions(+), 8 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b174831109..c7d743e193 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,8 +36,8 @@ include(simd) ################################ Configurations ####################################### option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND}) option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND}) -option(WITH_MKLDNN "Compile PaddlePaddle with mkl-dnn support." ${AVX_FOUND}) -option(WITH_MKLML "Compile PaddlePaddle with mklml package." ${AVX_FOUND}) +option(WITH_MKLDNN "Compile PaddlePaddle with mkl-dnn support." OFF) +option(WITH_MKLML "Compile PaddlePaddle with mklml package." OFF) option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON) option(WITH_TESTING "Compile PaddlePaddle with unit testing" ON) option(WITH_SWIG_PY "Compile PaddlePaddle with inference api" ON) diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 3e2aeea1da..2cc3c24fb3 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/operators/math/math_function.h" - namespace paddle { namespace operators { namespace math { @@ -26,6 +25,8 @@ void gemm( platform::DeviceContext* context) { // Note that cublas follows fortran order, so the order is different from // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; cublasOperation_t cuTransA = (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; cublasOperation_t cuTransB = @@ -44,6 +45,8 @@ void gemm( const int ldc, platform::DeviceContext* context) { // Note that cublas follows fortran order, so the order is different from // the cblas convention. + lda = (transA == CblasNoTrans) ? K : M; + ldb = (transB == CblasNoTrans) ? N : K; cublasOperation_t cuTransA = (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; cublasOperation_t cuTransB = @@ -118,7 +121,6 @@ void matmul(const framework::Tensor& in1, in1.data(), K, in2.data(), N, beta, out->data(), N, context); } - } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index f068f4a15e..1ecca60403 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -37,6 +37,20 @@ extern "C" { #include #endif +#ifndef LAPACK_FOUND +extern "C" { +#include +int LAPACKE_sgetrf( + int matrix_layout, int m, int n, float* a, int lda, int* ipiv); +int LAPACKE_dgetrf( + int matrix_layout, int m, int n, double* a, int lda, int* ipiv); +int LAPACKE_sgetri( + int matrix_layout, int n, float* a, int lda, const int* ipiv); +int LAPACKE_dgetri( + int matrix_layout, int n, double* a, int lda, const int* ipiv); +} +#endif + #include #include "paddle/framework/tensor.h" #include "paddle/platform/device_context.h" @@ -61,7 +75,7 @@ void gemm(const CBLAS_TRANSPOSE transA, const int ldc, platform::DeviceContext* context); -// matrix multiply with continous memory +// matrix multiply with continuous memory template void matmul(const framework::Tensor& in1, bool in1_T, diff --git a/paddle/operators/mul_op.cu b/paddle/operators/mul_op.cu index 7435b74bd8..346a7e505d 100644 --- a/paddle/operators/mul_op.cu +++ b/paddle/operators/mul_op.cu @@ -15,4 +15,5 @@ #define EIGEN_USE_GPU #include "paddle/operators/mul_op.h" +namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel); diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index 2087e98901..98c54f1dfb 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -31,9 +31,6 @@ template class MulKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - Eigen::array, 1> dim_pair = { - {Eigen::IndexPair(1, 0)}}; - auto input0 = context.Input("X"); auto input1 = context.Input("Y"); auto output = context.Output(0); From 8de4e3bdd6b24f55a1a6c9acb97233d7a18b021c Mon Sep 17 00:00:00 2001 From: qijun Date: Thu, 10 Aug 2017 12:24:49 +0800 Subject: [PATCH 09/19] disable gpu implementation temporarily --- paddle/operators/math/math_function.cu | 6 ++++++ paddle/operators/math/math_function.h | 29 +++++++------------------- paddle/operators/mul_op.cu | 3 ++- 3 files changed, 16 insertions(+), 22 deletions(-) diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 3e2aeea1da..b7d2c48a5f 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -26,6 +26,7 @@ void gemm( platform::DeviceContext* context) { // Note that cublas follows fortran order, so the order is different from // the cblas convention. + /* cublasOperation_t cuTransA = (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; cublasOperation_t cuTransB = @@ -34,6 +35,8 @@ void gemm( PADDLE_ENFORCE(platform::dynload::cublasSgemm( reinterpret_cast(context)->cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc)); + */ + PADDLE_THROW("not implemented now"); } template <> @@ -44,6 +47,7 @@ void gemm( const int ldc, platform::DeviceContext* context) { // Note that cublas follows fortran order, so the order is different from // the cblas convention. + /* cublasOperation_t cuTransA = (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; cublasOperation_t cuTransB = @@ -51,6 +55,8 @@ void gemm( PADDLE_ENFORCE(platform::dynload::cublasDgemm( reinterpret_cast(context)->cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc)); + */ + PADDLE_THROW("not implemented now"); } template <> diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index f068f4a15e..7a214e3a5a 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -40,36 +40,23 @@ extern "C" { #include #include "paddle/framework/tensor.h" #include "paddle/platform/device_context.h" +#include "paddle/platform/enforce.h" namespace paddle { namespace operators { namespace math { template -void gemm(const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE transB, - const int M, - const int N, - const int K, - const T alpha, - const T* A, - const int lda, - const T* B, - const int ldb, - const T beta, - T* C, - const int ldc, - platform::DeviceContext* context); +void gemm(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, + const int M, const int N, const int K, const T alpha, const T* A, + const int lda, const T* B, const int ldb, const T beta, T* C, + const int ldc, platform::DeviceContext* context); // matrix multiply with continous memory template -void matmul(const framework::Tensor& in1, - bool in1_T, - const framework::Tensor& in2, - bool in2_T, - float alpha, - framework::Tensor* out, - float beta, +void matmul(const framework::Tensor& in1, bool in1_T, + const framework::Tensor& in2, bool in2_T, float alpha, + framework::Tensor* out, float beta, platform::DeviceContext* context); } // namespace math diff --git a/paddle/operators/mul_op.cu b/paddle/operators/mul_op.cu index 7435b74bd8..aac5a6936e 100644 --- a/paddle/operators/mul_op.cu +++ b/paddle/operators/mul_op.cu @@ -15,4 +15,5 @@ #define EIGEN_USE_GPU #include "paddle/operators/mul_op.h" -REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel); +// REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel); From de967fcefe4dc778769d61f50c8ba00661c64c8c Mon Sep 17 00:00:00 2001 From: qijun Date: Thu, 10 Aug 2017 14:25:26 +0800 Subject: [PATCH 10/19] set gemm support continuous memory now --- paddle/operators/math/math_function.cc | 37 ++++++++++++++++---------- paddle/operators/math/math_function.cu | 29 ++++++++++---------- paddle/operators/math/math_function.h | 4 +-- paddle/operators/mul_op.cu | 3 +-- 4 files changed, 40 insertions(+), 33 deletions(-) diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index 5833fc90a7..7827c213fe 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -19,21 +19,30 @@ namespace operators { namespace math { template <> -void gemm( - const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, - const int N, const int K, const float alpha, const float* A, const int lda, - const float* B, const int ldb, const float beta, float* C, const int ldc, - platform::DeviceContext* context) { +void gemm(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, const int M, + const int N, const int K, + const float alpha, const float* A, + const float* B, const float beta, float* C, + platform::DeviceContext* context) { + int lda = K; + int ldb = N; + int ldc = N; cblas_sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); } template <> -void gemm( - const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, - const int N, const int K, const double alpha, const double* A, - const int lda, const double* B, const int ldb, const double beta, double* C, - const int ldc, platform::DeviceContext* context) { +void gemm(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, const int M, + const int N, const int K, + const double alpha, const double* A, + const double* B, const double beta, + double* C, + platform::DeviceContext* context) { + int lda = K; + int ldb = N; + int ldc = N; cblas_dgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); } @@ -67,8 +76,8 @@ void matmul(const framework::Tensor& in1, bool in1_T, CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; gemm(in1_Trans, in2_Trans, M, N, K, alpha, - in1.data(), K, in2.data(), N, - beta, out->data(), N, context); + in1.data(), in2.data(), beta, + out->data(), context); } template <> @@ -100,8 +109,8 @@ void matmul(const framework::Tensor& in1, CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; gemm(in1_Trans, in2_Trans, M, N, K, alpha, - in1.data(), K, in2.data(), N, - beta, out->data(), N, context); + in1.data(), in2.data(), beta, + out->data(), context); } } // namespace math diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index eb07bc8996..12ddd2146f 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -18,14 +18,16 @@ namespace operators { namespace math { template <> -void gemm( - const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, - const int N, const int K, const float alpha, const float* A, const int lda, - const float* B, const int ldb, const float beta, float* C, const int ldc, - platform::DeviceContext* context) { +void gemm(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, const int M, + const int N, const int K, + const float alpha, const float* A, + const float* B, const float beta, float* C, + platform::DeviceContext* context) { // Note that cublas follows fortran order, so the order is different from // the cblas convention. - /* + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; cublasOperation_t cuTransA = (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; cublasOperation_t cuTransB = @@ -34,8 +36,6 @@ void gemm( PADDLE_ENFORCE(platform::dynload::cublasSgemm( reinterpret_cast(context)->cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc)); - */ - PADDLE_THROW("not implemented now"); } template <> @@ -46,7 +46,8 @@ void gemm( const int ldc, platform::DeviceContext* context) { // Note that cublas follows fortran order, so the order is different from // the cblas convention. - /* + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; cublasOperation_t cuTransA = (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; cublasOperation_t cuTransB = @@ -54,8 +55,6 @@ void gemm( PADDLE_ENFORCE(platform::dynload::cublasDgemm( reinterpret_cast(context)->cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc)); - */ - PADDLE_THROW("not implemented now"); } template <> @@ -87,8 +86,8 @@ void matmul(const framework::Tensor& in1, bool in1_T, CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; gemm(in1_Trans, in2_Trans, M, N, K, alpha, - in1.data(), K, in2.data(), N, - beta, out->data(), N, context); + in1.data(), in2.data(), beta, + out->data(), context); } template <> @@ -120,8 +119,8 @@ void matmul(const framework::Tensor& in1, CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; gemm(in1_Trans, in2_Trans, M, N, K, alpha, - in1.data(), K, in2.data(), N, - beta, out->data(), N, context); + in1.data(), in2.data(), beta, + out->data(), context); } } // namespace math } // namespace operators diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index 0f8e7169f7..12d1706afb 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -60,11 +60,11 @@ namespace paddle { namespace operators { namespace math { +// support continuous memory now template void gemm(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const T alpha, const T* A, - const int lda, const T* B, const int ldb, const T beta, T* C, - const int ldc, platform::DeviceContext* context); + const T* B, const T beta, T* C, platform::DeviceContext* context); // matrix multiply with continuous memory template diff --git a/paddle/operators/mul_op.cu b/paddle/operators/mul_op.cu index 35975865c9..346a7e505d 100644 --- a/paddle/operators/mul_op.cu +++ b/paddle/operators/mul_op.cu @@ -16,5 +16,4 @@ #include "paddle/operators/mul_op.h" namespace ops = paddle::operators; -// REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel); +REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel); From 8b7d48bc0ef4ee029f8cea087500624cf4dc01c1 Mon Sep 17 00:00:00 2001 From: qijun Date: Thu, 10 Aug 2017 06:47:56 +0000 Subject: [PATCH 11/19] fix gpu build error --- paddle/operators/math/math_function.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 12ddd2146f..d36e6e6a2c 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -35,15 +35,15 @@ void gemm(const CBLAS_TRANSPOSE transA, PADDLE_ENFORCE(platform::dynload::cublasSgemm( reinterpret_cast(context)->cublas_handle(), - cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc)); + cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N)); } template <> void gemm( const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const double alpha, const double* A, - const int lda, const double* B, const int ldb, const double beta, double* C, - const int ldc, platform::DeviceContext* context) { + const double* B, const double beta, double* C, + platform::DeviceContext* context) { // Note that cublas follows fortran order, so the order is different from // the cblas convention. int lda = (transA == CblasNoTrans) ? K : M; @@ -54,7 +54,7 @@ void gemm( (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; PADDLE_ENFORCE(platform::dynload::cublasDgemm( reinterpret_cast(context)->cublas_handle(), - cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc)); + cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N)); } template <> From 52b52ba80cc1ddd47ed6c4e1a89d747f13fec283 Mon Sep 17 00:00:00 2001 From: qijun Date: Thu, 10 Aug 2017 14:50:02 +0800 Subject: [PATCH 12/19] fix gpu build error --- paddle/operators/math/math_function.cu | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 12ddd2146f..50fc9939b1 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -35,15 +35,17 @@ void gemm(const CBLAS_TRANSPOSE transA, PADDLE_ENFORCE(platform::dynload::cublasSgemm( reinterpret_cast(context)->cublas_handle(), - cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc)); + cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N)); } template <> -void gemm( - const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, - const int N, const int K, const double alpha, const double* A, - const int lda, const double* B, const int ldb, const double beta, double* C, - const int ldc, platform::DeviceContext* context) { +void gemm(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, const int M, + const int N, const int K, + const double alpha, const double* A, + const double* B, const double beta, + double* C, + platform::DeviceContext* context) { // Note that cublas follows fortran order, so the order is different from // the cblas convention. int lda = (transA == CblasNoTrans) ? K : M; @@ -54,7 +56,7 @@ void gemm( (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; PADDLE_ENFORCE(platform::dynload::cublasDgemm( reinterpret_cast(context)->cublas_handle(), - cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc)); + cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N)); } template <> From c5a7471e93ec94ed20a03b2fc40d174b23dcb691 Mon Sep 17 00:00:00 2001 From: qijun Date: Thu, 10 Aug 2017 15:41:48 +0800 Subject: [PATCH 13/19] add math_function_test --- paddle/operators/math/CMakeLists.txt | 3 ++ paddle/operators/math/math_function_test.cc | 34 +++++++++++++++++++++ 2 files changed, 37 insertions(+) create mode 100644 paddle/operators/math/math_function_test.cc diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index d34bc92594..bae11905b7 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -11,3 +11,6 @@ else() cc_library(math_function SRCS math_function.cc DEPS cblas device_context) endif() endif() + + +nv_test(math_function_test SRCS math_function_test.cc DEPS math_function) diff --git a/paddle/operators/math/math_function_test.cc b/paddle/operators/math/math_function_test.cc new file mode 100644 index 0000000000..f7b453a20c --- /dev/null +++ b/paddle/operators/math/math_function_test.cc @@ -0,0 +1,34 @@ +#include "paddle/operators/math/math_function.h" +#include "gtest/gtest.h" + +#ifndef PADDLE_ONLY_CPU +TEST(math_function, GPU) { + paddle::framework::Tensor input1; + paddle::framework::Tensor input1_gpu; + paddle::framework::Tensor input2_gpu; + paddle::framework::Tensor out_gpu; + paddle::framework::Tensor out; + + auto* cpu_place = new paddle::platform::CPUPlace(); + float* input1_ptr = input1.mutable_data({2, 2}, *cpu_place); + float arr[4] = {0, 1, 2, 3}; + + auto* gpu_place = new paddle::platform::GPUPlace(0); + paddle::platform::DeviceContext* context = new CUDADeviceContext(gpu_place); + + input1_gpu.CopyFrom(input1, *gpu_place); + input2_gpu.CopyFrom(input1, *gpu_place); + out_gpu.CopyFrom(input1, *gpu_place); + + matmul(input1_gpu, false, input2_gpu, + false, 1, &out_gpu, 0, context); + + out.CopyFrom(out_gpu, *cpu_place); + + float* out_ptr = out.data(); + EXPECT_EQ(out_ptr[0], 2); + EXPECT_EQ(out_ptr[1], 3); + EXPECT_EQ(out_ptr[2], 6); + EXPECT_EQ(out_ptr[3], 11); +} +#endif \ No newline at end of file From 5f1081d83d2d699ad8519d55174cf9e2f1861a3c Mon Sep 17 00:00:00 2001 From: qijun Date: Thu, 10 Aug 2017 08:54:05 +0000 Subject: [PATCH 14/19] fix bug in dynload --- paddle/operators/math/CMakeLists.txt | 2 +- paddle/operators/math/math_function_test.cc | 11 +++++++---- paddle/platform/dynload/cublas.h | 12 ++++++------ 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index bae11905b7..b1d0bc8f87 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -13,4 +13,4 @@ else() endif() -nv_test(math_function_test SRCS math_function_test.cc DEPS math_function) +nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/math_function_test.cc b/paddle/operators/math/math_function_test.cc index f7b453a20c..d0f0acab91 100644 --- a/paddle/operators/math/math_function_test.cc +++ b/paddle/operators/math/math_function_test.cc @@ -12,16 +12,19 @@ TEST(math_function, GPU) { auto* cpu_place = new paddle::platform::CPUPlace(); float* input1_ptr = input1.mutable_data({2, 2}, *cpu_place); float arr[4] = {0, 1, 2, 3}; + memcpy(input1_ptr, arr, 4 * sizeof(int)); auto* gpu_place = new paddle::platform::GPUPlace(0); - paddle::platform::DeviceContext* context = new CUDADeviceContext(gpu_place); + paddle::platform::DeviceContext* context = + new paddle::platform::CUDADeviceContext(*gpu_place); input1_gpu.CopyFrom(input1, *gpu_place); input2_gpu.CopyFrom(input1, *gpu_place); out_gpu.CopyFrom(input1, *gpu_place); - matmul(input1_gpu, false, input2_gpu, - false, 1, &out_gpu, 0, context); + paddle::operators::math::matmul( + input1_gpu, false, input2_gpu, + false, 1, &out_gpu, 0, context); out.CopyFrom(out_gpu, *cpu_place); @@ -31,4 +34,4 @@ TEST(math_function, GPU) { EXPECT_EQ(out_ptr[2], 6); EXPECT_EQ(out_ptr[3], 11); } -#endif \ No newline at end of file +#endif diff --git a/paddle/platform/dynload/cublas.h b/paddle/platform/dynload/cublas.h index c44b7240a8..617866d17c 100644 --- a/paddle/platform/dynload/cublas.h +++ b/paddle/platform/dynload/cublas.h @@ -62,12 +62,12 @@ extern void *cublas_dso_handle; DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(__name) #define CUBLAS_BLAS_ROUTINE_EACH(__macro) \ - __macro(cublasSgemv); \ - __macro(cublasDgemv); \ - __macro(cublasSgemm); \ - __macro(cublasDgemm); \ - __macro(cublasSgeam); \ - __macro(cublasDgeam); \ + __macro(cublasSgemv_v2); \ + __macro(cublasDgemv_v2); \ + __macro(cublasSgemm_v2); \ + __macro(cublasDgemm_v2); \ + __macro(cublasSgeam_v2); \ + __macro(cublasDgeam_v2); \ __macro(cublasCreate_v2); \ __macro(cublasDestroy_v2); \ __macro(cublasSetStream_v2); \ From 688c43b10458400440c9a434ccf6d61530e356b9 Mon Sep 17 00:00:00 2001 From: qijun Date: Thu, 10 Aug 2017 09:27:02 +0000 Subject: [PATCH 15/19] format code --- paddle/operators/math/math_function_test.cc | 5 ++--- paddle/platform/dynload/cublas.h | 12 ++++++------ 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/paddle/operators/math/math_function_test.cc b/paddle/operators/math/math_function_test.cc index d0f0acab91..a7a6881a5c 100644 --- a/paddle/operators/math/math_function_test.cc +++ b/paddle/operators/math/math_function_test.cc @@ -16,15 +16,14 @@ TEST(math_function, GPU) { auto* gpu_place = new paddle::platform::GPUPlace(0); paddle::platform::DeviceContext* context = - new paddle::platform::CUDADeviceContext(*gpu_place); + new paddle::platform::CUDADeviceContext(*gpu_place); input1_gpu.CopyFrom(input1, *gpu_place); input2_gpu.CopyFrom(input1, *gpu_place); out_gpu.CopyFrom(input1, *gpu_place); paddle::operators::math::matmul( - input1_gpu, false, input2_gpu, - false, 1, &out_gpu, 0, context); + input1_gpu, false, input2_gpu, false, 1, &out_gpu, 0, context); out.CopyFrom(out_gpu, *cpu_place); diff --git a/paddle/platform/dynload/cublas.h b/paddle/platform/dynload/cublas.h index 617866d17c..6b00b2aa48 100644 --- a/paddle/platform/dynload/cublas.h +++ b/paddle/platform/dynload/cublas.h @@ -62,12 +62,12 @@ extern void *cublas_dso_handle; DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(__name) #define CUBLAS_BLAS_ROUTINE_EACH(__macro) \ - __macro(cublasSgemv_v2); \ - __macro(cublasDgemv_v2); \ - __macro(cublasSgemm_v2); \ - __macro(cublasDgemm_v2); \ - __macro(cublasSgeam_v2); \ - __macro(cublasDgeam_v2); \ + __macro(cublasSgemv_v2); \ + __macro(cublasDgemv_v2); \ + __macro(cublasSgemm_v2); \ + __macro(cublasDgemm_v2); \ + __macro(cublasSgeam_v2); \ + __macro(cublasDgeam_v2); \ __macro(cublasCreate_v2); \ __macro(cublasDestroy_v2); \ __macro(cublasSetStream_v2); \ From c2631ebf6f7a7a0d4c1c2f149b3d8a37d492d52a Mon Sep 17 00:00:00 2001 From: qijun Date: Fri, 11 Aug 2017 13:06:01 +0800 Subject: [PATCH 16/19] add unittest --- paddle/operators/math/math_function.cc | 25 ++++++--- paddle/operators/math/math_function.cu | 24 ++++++--- paddle/operators/math/math_function_test.cc | 59 +++++++++++++++++---- 3 files changed, 86 insertions(+), 22 deletions(-) diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index 1038221143..fa4c298fe4 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -59,9 +59,16 @@ void matmul(const framework::Tensor& in1, bool in1_T, PADDLE_ENFORCE( in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2, "The input and output of matmul be matrix"); - PADDLE_ENFORCE( - in1_dim[1] == in2_dim[0], - "First matrix's width must be equal with second matrix's height."); + + if (!in1_T && !in2_T) { + PADDLE_ENFORCE(in1_dim[1] == in2_dim[0]); + } else if (in1_T && !in2_T) { + PADDLE_ENFORCE(in1_dim[0] == in2_dim[0]); + } else if (!in1_T && in2_T) { + PADDLE_ENFORCE(in1_dim[1] == in2_dim[0]); + } else { + PADDLE_ENFORCE(in1_dim[0] == in2_dim[1]); + } PADDLE_ENFORCE(platform::is_cpu_place(in1.place()) && platform::is_cpu_place(in2.place()) && @@ -93,9 +100,15 @@ void matmul(const framework::Tensor& in1, PADDLE_ENFORCE( in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2, "The input and output of matmul be matrix"); - PADDLE_ENFORCE( - in1_dim[1] == in2_dim[0], - "First matrix's width must be equal with second matrix's height."); + if (!in1_T && !in2_T) { + PADDLE_ENFORCE(in1_dim[1] == in2_dim[0]); + } else if (in1_T && !in2_T) { + PADDLE_ENFORCE(in1_dim[0] == in2_dim[0]); + } else if (!in1_T && in2_T) { + PADDLE_ENFORCE(in1_dim[1] == in2_dim[0]); + } else { + PADDLE_ENFORCE(in1_dim[0] == in2_dim[1]); + } PADDLE_ENFORCE(platform::is_cpu_place(in1.place()) && platform::is_cpu_place(in2.place()) && diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index f4d238e8ab..d2c8aec548 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -71,9 +71,15 @@ void matmul(const framework::Tensor& in1, bool in1_T, PADDLE_ENFORCE( in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2, "The input and output of matmul be matrix"); - PADDLE_ENFORCE( - in1_dim[1] == in2_dim[0], - "First matrix's width must be equal with second matrix's height."); + if (!in1_T && !in2_T) { + PADDLE_ENFORCE(in1_dim[1] == in2_dim[0]); + } else if (in1_T && !in2_T) { + PADDLE_ENFORCE(in1_dim[0] == in2_dim[0]); + } else if (!in1_T && in2_T) { + PADDLE_ENFORCE(in1_dim[1] == in2_dim[0]); + } else { + PADDLE_ENFORCE(in1_dim[0] == in2_dim[1]); + } PADDLE_ENFORCE(platform::is_gpu_place(in1.place()) && platform::is_gpu_place(in2.place()) && @@ -105,9 +111,15 @@ void matmul(const framework::Tensor& in1, PADDLE_ENFORCE( in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2, "The input and output of matmul be matrix"); - PADDLE_ENFORCE( - in1_dim[1] == in2_dim[0], - "First matrix's width must be equal with second matrix's height."); + if (!in1_T && !in2_T) { + PADDLE_ENFORCE(in1_dim[1] == in2_dim[0]); + } else if (in1_T && !in2_T) { + PADDLE_ENFORCE(in1_dim[0] == in2_dim[0]); + } else if (!in1_T && in2_T) { + PADDLE_ENFORCE(in1_dim[1] == in2_dim[0]); + } else { + PADDLE_ENFORCE(in1_dim[0] == in2_dim[1]); + } PADDLE_ENFORCE(platform::is_gpu_place(in1.place()) && platform::is_gpu_place(in2.place()) && diff --git a/paddle/operators/math/math_function_test.cc b/paddle/operators/math/math_function_test.cc index a7a6881a5c..4de0eab6ce 100644 --- a/paddle/operators/math/math_function_test.cc +++ b/paddle/operators/math/math_function_test.cc @@ -2,7 +2,7 @@ #include "gtest/gtest.h" #ifndef PADDLE_ONLY_CPU -TEST(math_function, GPU) { +TEST(math_function, N_T) { paddle::framework::Tensor input1; paddle::framework::Tensor input1_gpu; paddle::framework::Tensor input2_gpu; @@ -10,9 +10,9 @@ TEST(math_function, GPU) { paddle::framework::Tensor out; auto* cpu_place = new paddle::platform::CPUPlace(); - float* input1_ptr = input1.mutable_data({2, 2}, *cpu_place); - float arr[4] = {0, 1, 2, 3}; - memcpy(input1_ptr, arr, 4 * sizeof(int)); + float* input1_ptr = input1.mutable_data({2, 3}, *cpu_place); + float arr[6] = {0, 1, 2, 3, 4, 5}; + memcpy(input1_ptr, arr, 6 * sizeof(float)); auto* gpu_place = new paddle::platform::GPUPlace(0); paddle::platform::DeviceContext* context = @@ -20,17 +20,56 @@ TEST(math_function, GPU) { input1_gpu.CopyFrom(input1, *gpu_place); input2_gpu.CopyFrom(input1, *gpu_place); - out_gpu.CopyFrom(input1, *gpu_place); + + out_gpu.mutable_data({2, 2}, *gpu_place); + + paddle::operators::math::matmul( + input1_gpu, false, input2_gpu, true, 1, &out_gpu, 0, context); + + out.CopyFrom(out_gpu, *cpu_place); + + float* out_ptr = out.data(); + EXPECT_EQ(out_ptr[0], 5); + EXPECT_EQ(out_ptr[1], 14); + EXPECT_EQ(out_ptr[2], 14); + EXPECT_EQ(out_ptr[3], 50); +} + +TEST(math_function, T_N) { + paddle::framework::Tensor input1; + paddle::framework::Tensor input1_gpu; + paddle::framework::Tensor input2_gpu; + paddle::framework::Tensor out_gpu; + paddle::framework::Tensor out; + + auto* cpu_place = new paddle::platform::CPUPlace(); + float* input1_ptr = input1.mutable_data({2, 3}, *cpu_place); + float arr[6] = {0, 1, 2, 3, 4, 5}; + memcpy(input1_ptr, arr, 6 * sizeof(float)); + + auto* gpu_place = new paddle::platform::GPUPlace(0); + paddle::platform::DeviceContext* context = + new paddle::platform::CUDADeviceContext(*gpu_place); + + input1_gpu.CopyFrom(input1, *gpu_place); + input2_gpu.CopyFrom(input1, *gpu_place); + + out_gpu.mutable_data({3, 3}, *gpu_place); paddle::operators::math::matmul( - input1_gpu, false, input2_gpu, false, 1, &out_gpu, 0, context); + input1_gpu, true, input2_gpu, false, 1, &out_gpu, 0, context); out.CopyFrom(out_gpu, *cpu_place); float* out_ptr = out.data(); - EXPECT_EQ(out_ptr[0], 2); - EXPECT_EQ(out_ptr[1], 3); - EXPECT_EQ(out_ptr[2], 6); - EXPECT_EQ(out_ptr[3], 11); + EXPECT_EQ(out_ptr[0], 9); + EXPECT_EQ(out_ptr[1], 12); + EXPECT_EQ(out_ptr[2], 15); + EXPECT_EQ(out_ptr[3], 12); + EXPECT_EQ(out_ptr[4], 17); + EXPECT_EQ(out_ptr[5], 22); + EXPECT_EQ(out_ptr[6], 15); + EXPECT_EQ(out_ptr[7], 22); + EXPECT_EQ(out_ptr[8], 29); } #endif From 37aa4b98ff85f16ce70ee6349d4e4e1acd340906 Mon Sep 17 00:00:00 2001 From: qijun Date: Fri, 11 Aug 2017 05:26:13 +0000 Subject: [PATCH 17/19] refine unittest --- paddle/operators/math/math_function.cc | 24 ++---------------------- paddle/operators/math/math_function.cu | 23 ++--------------------- 2 files changed, 4 insertions(+), 43 deletions(-) diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index fa4c298fe4..e5eefedde0 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -60,16 +60,6 @@ void matmul(const framework::Tensor& in1, bool in1_T, in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2, "The input and output of matmul be matrix"); - if (!in1_T && !in2_T) { - PADDLE_ENFORCE(in1_dim[1] == in2_dim[0]); - } else if (in1_T && !in2_T) { - PADDLE_ENFORCE(in1_dim[0] == in2_dim[0]); - } else if (!in1_T && in2_T) { - PADDLE_ENFORCE(in1_dim[1] == in2_dim[0]); - } else { - PADDLE_ENFORCE(in1_dim[0] == in2_dim[1]); - } - PADDLE_ENFORCE(platform::is_cpu_place(in1.place()) && platform::is_cpu_place(in2.place()) && platform::is_cpu_place(out->place()), @@ -77,7 +67,7 @@ void matmul(const framework::Tensor& in1, bool in1_T, int M = out_dim[0]; int N = out_dim[1]; - int K = in1_dim[1]; + int K = (in1_T == false) ? in1_dim[1] : in1_dim[0]; CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE in2_Trans = (in2_T == false) ? CblasNoTrans : CblasTrans; @@ -100,16 +90,6 @@ void matmul(const framework::Tensor& in1, PADDLE_ENFORCE( in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2, "The input and output of matmul be matrix"); - if (!in1_T && !in2_T) { - PADDLE_ENFORCE(in1_dim[1] == in2_dim[0]); - } else if (in1_T && !in2_T) { - PADDLE_ENFORCE(in1_dim[0] == in2_dim[0]); - } else if (!in1_T && in2_T) { - PADDLE_ENFORCE(in1_dim[1] == in2_dim[0]); - } else { - PADDLE_ENFORCE(in1_dim[0] == in2_dim[1]); - } - PADDLE_ENFORCE(platform::is_cpu_place(in1.place()) && platform::is_cpu_place(in2.place()) && platform::is_cpu_place(out->place()), @@ -117,7 +97,7 @@ void matmul(const framework::Tensor& in1, int M = out_dim[0]; int N = out_dim[1]; - int K = in1_dim[1]; + int K = (in1_T == false) ? in1_dim[1] : in1_dim[0]; CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE in2_Trans = (in2_T == false) ? CblasNoTrans : CblasTrans; diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index d2c8aec548..ff02c6ad7e 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -71,15 +71,6 @@ void matmul(const framework::Tensor& in1, bool in1_T, PADDLE_ENFORCE( in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2, "The input and output of matmul be matrix"); - if (!in1_T && !in2_T) { - PADDLE_ENFORCE(in1_dim[1] == in2_dim[0]); - } else if (in1_T && !in2_T) { - PADDLE_ENFORCE(in1_dim[0] == in2_dim[0]); - } else if (!in1_T && in2_T) { - PADDLE_ENFORCE(in1_dim[1] == in2_dim[0]); - } else { - PADDLE_ENFORCE(in1_dim[0] == in2_dim[1]); - } PADDLE_ENFORCE(platform::is_gpu_place(in1.place()) && platform::is_gpu_place(in2.place()) && @@ -88,7 +79,7 @@ void matmul(const framework::Tensor& in1, bool in1_T, int M = out_dim[0]; int N = out_dim[1]; - int K = in1_dim[1]; + int K = (in1_T == false) ? in1_dim[1] : in1_dim[0]; CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE in2_Trans = (in2_T == false) ? CblasNoTrans : CblasTrans; @@ -111,16 +102,6 @@ void matmul(const framework::Tensor& in1, PADDLE_ENFORCE( in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2, "The input and output of matmul be matrix"); - if (!in1_T && !in2_T) { - PADDLE_ENFORCE(in1_dim[1] == in2_dim[0]); - } else if (in1_T && !in2_T) { - PADDLE_ENFORCE(in1_dim[0] == in2_dim[0]); - } else if (!in1_T && in2_T) { - PADDLE_ENFORCE(in1_dim[1] == in2_dim[0]); - } else { - PADDLE_ENFORCE(in1_dim[0] == in2_dim[1]); - } - PADDLE_ENFORCE(platform::is_gpu_place(in1.place()) && platform::is_gpu_place(in2.place()) && platform::is_gpu_place(out->place()), @@ -128,7 +109,7 @@ void matmul(const framework::Tensor& in1, int M = out_dim[0]; int N = out_dim[1]; - int K = in1_dim[1]; + int K = (in1_T == false) ? in1_dim[1] : in1_dim[0]; CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE in2_Trans = (in2_T == false) ? CblasNoTrans : CblasTrans; From 2ec8dab4c78eceb81122783b54c9366473c3f62d Mon Sep 17 00:00:00 2001 From: qijun Date: Mon, 14 Aug 2017 14:59:41 +0800 Subject: [PATCH 18/19] follow comments --- paddle/operators/math/.clang-format | 5 - paddle/operators/math/CMakeLists.txt | 21 ++-- paddle/operators/math/math_function.cc | 127 +++++++++++++++--------- paddle/operators/math/math_function.cu | 129 ++++++++++++++++--------- paddle/operators/math/math_function.h | 51 ++-------- 5 files changed, 187 insertions(+), 146 deletions(-) delete mode 100644 paddle/operators/math/.clang-format diff --git a/paddle/operators/math/.clang-format b/paddle/operators/math/.clang-format deleted file mode 100644 index 47b8a85206..0000000000 --- a/paddle/operators/math/.clang-format +++ /dev/null @@ -1,5 +0,0 @@ ---- -Language: Cpp -BasedOnStyle: Google -Standard: Cpp11 -... diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index b1d0bc8f87..84fffe6843 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -1,16 +1,13 @@ -if (WITH_GPU) - if (WITH_MKLML) - nv_library(math_function SRCS math_function.cc math_function.cu DEPS mklml device_context) - else() - nv_library(math_function SRCS math_function.cc math_function.cu DEPS cblas device_context) - endif() +if(WITH_MKLML) + set(BLAS_LIB mklml) else() - if (WITH_MKLML) - cc_library(math_function SRCS math_function.cc DEPS mklml device_context) - else() - cc_library(math_function SRCS math_function.cc DEPS cblas device_context) - endif() -endif() + set(BLAS_LIB cblas) +endif() +if(WITH_GPU) + nv_library(math_function SRCS math_function.cc math_function.cu DEPS ${BLAS_LIB} device_context) +else() + cc_library(math_function SRCS math_function.cc math_function.cu DEPS ${BLAS_LIB} device_context) +endif() nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index e5eefedde0..03a63d063f 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -12,6 +12,44 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#ifdef PADDLE_USE_MKLML +#include +#include +#include +#endif + +#ifdef PADDLE_USE_MKL +#include +#include +#endif + +#ifdef PADDLE_USE_ATLAS +extern "C" { +#include +#include +} +#endif + +#ifdef PADDLE_USE_OPENBLAS +#include +#include +#endif + +#ifndef LAPACK_FOUND +extern "C" { +#include +int LAPACKE_sgetrf(int matrix_layout, int m, int n, float* a, int lda, + int* ipiv); +int LAPACKE_dgetrf(int matrix_layout, int m, int n, double* a, int lda, + int* ipiv); +int LAPACKE_sgetri(int matrix_layout, int n, float* a, int lda, + const int* ipiv); +int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, + const int* ipiv); +} +#endif + +#include #include "paddle/operators/math/math_function.h" namespace paddle { @@ -48,62 +86,65 @@ void gemm(const CBLAS_TRANSPOSE transA, } template <> -void matmul(const framework::Tensor& in1, bool in1_T, - const framework::Tensor& in2, bool in2_T, - float alpha, framework::Tensor* out, +void matmul(const framework::Tensor& matrix_a, + bool trans_a, + const framework::Tensor& matrix_b, + bool trans_b, float alpha, + framework::Tensor* matrix_out, float beta, platform::DeviceContext* context) { - auto in1_dim = in1.dims(); - auto in2_dim = in2.dims(); - auto out_dim = out->dims(); - PADDLE_ENFORCE( - in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2, - "The input and output of matmul be matrix"); - - PADDLE_ENFORCE(platform::is_cpu_place(in1.place()) && - platform::is_cpu_place(in2.place()) && - platform::is_cpu_place(out->place()), + auto dim_a = matrix_a.dims(); + auto dim_b = matrix_b.dims(); + auto dim_out = matrix_out->dims(); + PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, + "The input and output of matmul be matrix"); + + PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) && + platform::is_cpu_place(matrix_b.place()) && + platform::is_cpu_place(matrix_out->place()), "Matrix must all be in CPUPlace"); - int M = out_dim[0]; - int N = out_dim[1]; - int K = (in1_T == false) ? in1_dim[1] : in1_dim[0]; + int M = dim_out[0]; + int N = dim_out[1]; + int K = (trans_a == false) ? dim_a[1] : dim_a[0]; - CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; - CBLAS_TRANSPOSE in2_Trans = (in2_T == false) ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; - gemm(in1_Trans, in2_Trans, M, N, K, alpha, - in1.data(), in2.data(), beta, - out->data(), context); + gemm( + transA, transB, M, N, K, alpha, matrix_a.data(), + matrix_b.data(), beta, matrix_out->data(), context); } template <> -void matmul(const framework::Tensor& in1, - bool in1_T, - const framework::Tensor& in2, - bool in2_T, float alpha, - framework::Tensor* out, float beta, +void matmul(const framework::Tensor& matrix_a, + bool trans_a, + const framework::Tensor& matrix_b, + bool trans_b, double alpha, + framework::Tensor* matrix_out, + double beta, platform::DeviceContext* context) { - auto in1_dim = in1.dims(); - auto in2_dim = in2.dims(); - auto out_dim = out->dims(); - PADDLE_ENFORCE( - in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2, - "The input and output of matmul be matrix"); - PADDLE_ENFORCE(platform::is_cpu_place(in1.place()) && - platform::is_cpu_place(in2.place()) && - platform::is_cpu_place(out->place()), + auto dim_a = matrix_a.dims(); + auto dim_b = matrix_b.dims(); + auto dim_out = matrix_out->dims(); + PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, + "The input and output of matmul be matrix"); + + PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) && + platform::is_cpu_place(matrix_b.place()) && + platform::is_cpu_place(matrix_out->place()), "Matrix must all be in CPUPlace"); - int M = out_dim[0]; - int N = out_dim[1]; - int K = (in1_T == false) ? in1_dim[1] : in1_dim[0]; - CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; - CBLAS_TRANSPOSE in2_Trans = (in2_T == false) ? CblasNoTrans : CblasTrans; + int M = dim_out[0]; + int N = dim_out[1]; + int K = (trans_a == false) ? dim_a[1] : dim_a[0]; + + CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; - gemm(in1_Trans, in2_Trans, M, N, K, alpha, - in1.data(), in2.data(), beta, - out->data(), context); + gemm( + transA, transB, M, N, K, alpha, matrix_a.data(), + matrix_b.data(), beta, matrix_out->data(), context); } } // namespace math diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index ff02c6ad7e..c1ec2d93ed 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -12,7 +12,46 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#ifdef PADDLE_USE_MKLML +#include +#include +#include +#endif + +#ifdef PADDLE_USE_MKL +#include +#include +#endif + +#ifdef PADDLE_USE_ATLAS +extern "C" { +#include +#include +} +#endif + +#ifdef PADDLE_USE_OPENBLAS +#include +#include +#endif + +#ifndef LAPACK_FOUND +extern "C" { +#include +int LAPACKE_sgetrf(int matrix_layout, int m, int n, float* a, int lda, + int* ipiv); +int LAPACKE_dgetrf(int matrix_layout, int m, int n, double* a, int lda, + int* ipiv); +int LAPACKE_sgetri(int matrix_layout, int n, float* a, int lda, + const int* ipiv); +int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, + const int* ipiv); +} +#endif + +#include #include "paddle/operators/math/math_function.h" + namespace paddle { namespace operators { namespace math { @@ -60,63 +99,67 @@ void gemm(const CBLAS_TRANSPOSE transA, } template <> -void matmul(const framework::Tensor& in1, bool in1_T, - const framework::Tensor& in2, bool in2_T, - float alpha, framework::Tensor* out, +void matmul(const framework::Tensor& matrix_a, + bool trans_a, + const framework::Tensor& matrix_b, + bool trans_b, float alpha, + framework::Tensor* matrix_out, float beta, platform::DeviceContext* context) { - auto in1_dim = in1.dims(); - auto in2_dim = in2.dims(); - auto out_dim = out->dims(); - PADDLE_ENFORCE( - in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2, - "The input and output of matmul be matrix"); - - PADDLE_ENFORCE(platform::is_gpu_place(in1.place()) && - platform::is_gpu_place(in2.place()) && - platform::is_gpu_place(out->place()), + auto dim_a = matrix_a.dims(); + auto dim_b = matrix_b.dims(); + auto dim_out = matrix_out->dims(); + PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, + "The input and output of matmul be matrix"); + + PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) && + platform::is_gpu_place(matrix_b.place()) && + platform::is_gpu_place(matrix_out->place()), "Matrix must all be in GPUPlace"); - int M = out_dim[0]; - int N = out_dim[1]; - int K = (in1_T == false) ? in1_dim[1] : in1_dim[0]; + int M = dim_out[0]; + int N = dim_out[1]; + int K = (trans_a == false) ? dim_a[1] : dim_a[0]; - CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; - CBLAS_TRANSPOSE in2_Trans = (in2_T == false) ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; - gemm(in1_Trans, in2_Trans, M, N, K, alpha, - in1.data(), in2.data(), beta, - out->data(), context); + gemm( + transA, transB, M, N, K, alpha, matrix_a.data(), + matrix_b.data(), beta, matrix_out->data(), context); } template <> -void matmul(const framework::Tensor& in1, - bool in1_T, - const framework::Tensor& in2, - bool in2_T, float alpha, - framework::Tensor* out, float beta, +void matmul(const framework::Tensor& matrix_a, + bool trans_a, + const framework::Tensor& matrix_b, + bool trans_b, double alpha, + framework::Tensor* matrix_out, + double beta, platform::DeviceContext* context) { - auto in1_dim = in1.dims(); - auto in2_dim = in2.dims(); - auto out_dim = out->dims(); - PADDLE_ENFORCE( - in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2, - "The input and output of matmul be matrix"); - PADDLE_ENFORCE(platform::is_gpu_place(in1.place()) && - platform::is_gpu_place(in2.place()) && - platform::is_gpu_place(out->place()), + auto dim_a = matrix_a.dims(); + auto dim_b = matrix_b.dims(); + auto dim_out = matrix_out->dims(); + PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, + "The input and output of matmul be matrix"); + + PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) && + platform::is_gpu_place(matrix_b.place()) && + platform::is_gpu_place(matrix_out->place()), "Matrix must all be in GPUPlace"); - int M = out_dim[0]; - int N = out_dim[1]; - int K = (in1_T == false) ? in1_dim[1] : in1_dim[0]; - CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; - CBLAS_TRANSPOSE in2_Trans = (in2_T == false) ? CblasNoTrans : CblasTrans; + int M = dim_out[0]; + int N = dim_out[1]; + int K = (trans_a == false) ? dim_a[1] : dim_a[0]; + + CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; - gemm(in1_Trans, in2_Trans, M, N, K, alpha, - in1.data(), in2.data(), beta, - out->data(), context); + gemm( + transA, transB, M, N, K, alpha, matrix_a.data(), + matrix_b.data(), beta, matrix_out->data(), context); } + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index 12d1706afb..c20e6a3b39 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -14,44 +14,6 @@ limitations under the License. */ #pragma once -#ifdef PADDLE_USE_MKLML -#include -#include -#include -#endif - -#ifdef PADDLE_USE_MKL -#include -#include -#endif - -#ifdef PADDLE_USE_ATLAS -extern "C" { -#include -#include -} -#endif - -#ifdef PADDLE_USE_OPENBLAS -#include -#include -#endif - -#ifndef LAPACK_FOUND -extern "C" { -#include -int LAPACKE_sgetrf(int matrix_layout, int m, int n, float* a, int lda, - int* ipiv); -int LAPACKE_dgetrf(int matrix_layout, int m, int n, double* a, int lda, - int* ipiv); -int LAPACKE_sgetri(int matrix_layout, int n, float* a, int lda, - const int* ipiv); -int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, - const int* ipiv); -} -#endif - -#include #include "paddle/framework/tensor.h" #include "paddle/platform/device_context.h" #include "paddle/platform/enforce.h" @@ -60,17 +22,20 @@ namespace paddle { namespace operators { namespace math { -// support continuous memory now -template +// Support continuous memory now +// If transA = N, and transB = N +// Then matrixA: M * K, matrixB: K * N matrixC : M * N +// For more detailed info, please refer to +// http://www.netlib.org/lapack/explore-html/d4/de2/sgemm_8f.html void gemm(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const T alpha, const T* A, const T* B, const T beta, T* C, platform::DeviceContext* context); // matrix multiply with continuous memory template -void matmul(const framework::Tensor& in1, bool in1_T, - const framework::Tensor& in2, bool in2_T, float alpha, - framework::Tensor* out, float beta, +void matmul(const framework::Tensor& matrix_a, bool trans_a, + const framework::Tensor& matrix_b, bool trans_b, float alpha, + framework::Tensor* matrix_out, float beta, platform::DeviceContext* context); } // namespace math From 960a52555064d0496c8b76ce726c604d3fba66d4 Mon Sep 17 00:00:00 2001 From: qijun Date: Mon, 14 Aug 2017 07:20:16 +0000 Subject: [PATCH 19/19] fix gpu build error --- paddle/operators/math/CMakeLists.txt | 2 +- paddle/operators/math/math_function.cc | 38 ----------------------- paddle/operators/math/math_function.cu | 38 ----------------------- paddle/operators/math/math_function.h | 43 ++++++++++++++++++++++++-- 4 files changed, 42 insertions(+), 79 deletions(-) diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 84fffe6843..abcaf940ab 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -7,7 +7,7 @@ endif() if(WITH_GPU) nv_library(math_function SRCS math_function.cc math_function.cu DEPS ${BLAS_LIB} device_context) else() - cc_library(math_function SRCS math_function.cc math_function.cu DEPS ${BLAS_LIB} device_context) + cc_library(math_function SRCS math_function.cc DEPS ${BLAS_LIB} device_context) endif() nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index 03a63d063f..affdd1ac2c 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -12,44 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_USE_MKLML -#include -#include -#include -#endif - -#ifdef PADDLE_USE_MKL -#include -#include -#endif - -#ifdef PADDLE_USE_ATLAS -extern "C" { -#include -#include -} -#endif - -#ifdef PADDLE_USE_OPENBLAS -#include -#include -#endif - -#ifndef LAPACK_FOUND -extern "C" { -#include -int LAPACKE_sgetrf(int matrix_layout, int m, int n, float* a, int lda, - int* ipiv); -int LAPACKE_dgetrf(int matrix_layout, int m, int n, double* a, int lda, - int* ipiv); -int LAPACKE_sgetri(int matrix_layout, int n, float* a, int lda, - const int* ipiv); -int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, - const int* ipiv); -} -#endif - -#include #include "paddle/operators/math/math_function.h" namespace paddle { diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index c1ec2d93ed..da40b27c94 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -12,44 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_USE_MKLML -#include -#include -#include -#endif - -#ifdef PADDLE_USE_MKL -#include -#include -#endif - -#ifdef PADDLE_USE_ATLAS -extern "C" { -#include -#include -} -#endif - -#ifdef PADDLE_USE_OPENBLAS -#include -#include -#endif - -#ifndef LAPACK_FOUND -extern "C" { -#include -int LAPACKE_sgetrf(int matrix_layout, int m, int n, float* a, int lda, - int* ipiv); -int LAPACKE_dgetrf(int matrix_layout, int m, int n, double* a, int lda, - int* ipiv); -int LAPACKE_sgetri(int matrix_layout, int n, float* a, int lda, - const int* ipiv); -int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, - const int* ipiv); -} -#endif - -#include #include "paddle/operators/math/math_function.h" namespace paddle { diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index c20e6a3b39..155589fadb 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -13,6 +13,44 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#ifdef PADDLE_USE_MKLML +#include +#include +#include +#endif + +#ifdef PADDLE_USE_MKL +#include +#include +#endif + +#ifdef PADDLE_USE_ATLAS +extern "C" { +#include +#include +} +#endif + +#ifdef PADDLE_USE_OPENBLAS +#include +#include +#endif + +#ifndef LAPACK_FOUND +extern "C" { +#include +int LAPACKE_sgetrf(int matrix_layout, int m, int n, float* a, int lda, + int* ipiv); +int LAPACKE_dgetrf(int matrix_layout, int m, int n, double* a, int lda, + int* ipiv); +int LAPACKE_sgetri(int matrix_layout, int n, float* a, int lda, + const int* ipiv); +int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, + const int* ipiv); +} +#endif + +#include #include "paddle/framework/tensor.h" #include "paddle/platform/device_context.h" @@ -27,6 +65,7 @@ namespace math { // Then matrixA: M * K, matrixB: K * N matrixC : M * N // For more detailed info, please refer to // http://www.netlib.org/lapack/explore-html/d4/de2/sgemm_8f.html +template void gemm(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const T alpha, const T* A, const T* B, const T beta, T* C, platform::DeviceContext* context); @@ -34,8 +73,8 @@ void gemm(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, // matrix multiply with continuous memory template void matmul(const framework::Tensor& matrix_a, bool trans_a, - const framework::Tensor& matrix_b, bool trans_b, float alpha, - framework::Tensor* matrix_out, float beta, + const framework::Tensor& matrix_b, bool trans_b, T alpha, + framework::Tensor* matrix_out, T beta, platform::DeviceContext* context); } // namespace math