follow comments

revert-3824-remove_grad_op_type
qijun 8 years ago
parent 37aa4b98ff
commit 2ec8dab4c7

@ -1,5 +0,0 @@
---
Language: Cpp
BasedOnStyle: Google
Standard: Cpp11
...

@ -1,16 +1,13 @@
if (WITH_GPU) if(WITH_MKLML)
if (WITH_MKLML) set(BLAS_LIB mklml)
nv_library(math_function SRCS math_function.cc math_function.cu DEPS mklml device_context)
else()
nv_library(math_function SRCS math_function.cc math_function.cu DEPS cblas device_context)
endif()
else() else()
if (WITH_MKLML) set(BLAS_LIB cblas)
cc_library(math_function SRCS math_function.cc DEPS mklml device_context)
else()
cc_library(math_function SRCS math_function.cc DEPS cblas device_context)
endif()
endif() endif()
if(WITH_GPU)
nv_library(math_function SRCS math_function.cc math_function.cu DEPS ${BLAS_LIB} device_context)
else()
cc_library(math_function SRCS math_function.cc math_function.cu DEPS ${BLAS_LIB} device_context)
endif()
nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)

@ -12,6 +12,44 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_USE_MKLML
#include <mkl_cblas.h>
#include <mkl_lapacke.h>
#include <mkl_vml_functions.h>
#endif
#ifdef PADDLE_USE_MKL
#include <mkl.h>
#include <mkl_lapacke.h>
#endif
#ifdef PADDLE_USE_ATLAS
extern "C" {
#include <cblas.h>
#include <clapack.h>
}
#endif
#ifdef PADDLE_USE_OPENBLAS
#include <cblas.h>
#include <lapacke.h>
#endif
#ifndef LAPACK_FOUND
extern "C" {
#include <cblas.h>
int LAPACKE_sgetrf(int matrix_layout, int m, int n, float* a, int lda,
int* ipiv);
int LAPACKE_dgetrf(int matrix_layout, int m, int n, double* a, int lda,
int* ipiv);
int LAPACKE_sgetri(int matrix_layout, int n, float* a, int lda,
const int* ipiv);
int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda,
const int* ipiv);
}
#endif
#include <cmath>
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
namespace paddle { namespace paddle {
@ -48,62 +86,65 @@ void gemm<platform::CPUPlace, double>(const CBLAS_TRANSPOSE transA,
} }
template <> template <>
void matmul<platform::CPUPlace, float>(const framework::Tensor& in1, bool in1_T, void matmul<platform::CPUPlace, float>(const framework::Tensor& matrix_a,
const framework::Tensor& in2, bool in2_T, bool trans_a,
float alpha, framework::Tensor* out, const framework::Tensor& matrix_b,
bool trans_b, float alpha,
framework::Tensor* matrix_out,
float beta, float beta,
platform::DeviceContext* context) { platform::DeviceContext* context) {
auto in1_dim = in1.dims(); auto dim_a = matrix_a.dims();
auto in2_dim = in2.dims(); auto dim_b = matrix_b.dims();
auto out_dim = out->dims(); auto dim_out = matrix_out->dims();
PADDLE_ENFORCE( PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2, "The input and output of matmul be matrix");
"The input and output of matmul be matrix");
PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) &&
PADDLE_ENFORCE(platform::is_cpu_place(in1.place()) && platform::is_cpu_place(matrix_b.place()) &&
platform::is_cpu_place(in2.place()) && platform::is_cpu_place(matrix_out->place()),
platform::is_cpu_place(out->place()),
"Matrix must all be in CPUPlace"); "Matrix must all be in CPUPlace");
int M = out_dim[0]; int M = dim_out[0];
int N = out_dim[1]; int N = dim_out[1];
int K = (in1_T == false) ? in1_dim[1] : in1_dim[0]; int K = (trans_a == false) ? dim_a[1] : dim_a[0];
CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE in2_Trans = (in2_T == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
gemm<platform::CPUPlace, float>(in1_Trans, in2_Trans, M, N, K, alpha, gemm<platform::CPUPlace, float>(
in1.data<float>(), in2.data<float>(), beta, transA, transB, M, N, K, alpha, matrix_a.data<float>(),
out->data<float>(), context); matrix_b.data<float>(), beta, matrix_out->data<float>(), context);
} }
template <> template <>
void matmul<platform::CPUPlace, double>(const framework::Tensor& in1, void matmul<platform::CPUPlace, double>(const framework::Tensor& matrix_a,
bool in1_T, bool trans_a,
const framework::Tensor& in2, const framework::Tensor& matrix_b,
bool in2_T, float alpha, bool trans_b, double alpha,
framework::Tensor* out, float beta, framework::Tensor* matrix_out,
double beta,
platform::DeviceContext* context) { platform::DeviceContext* context) {
auto in1_dim = in1.dims(); auto dim_a = matrix_a.dims();
auto in2_dim = in2.dims(); auto dim_b = matrix_b.dims();
auto out_dim = out->dims(); auto dim_out = matrix_out->dims();
PADDLE_ENFORCE( PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2, "The input and output of matmul be matrix");
"The input and output of matmul be matrix");
PADDLE_ENFORCE(platform::is_cpu_place(in1.place()) && PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) &&
platform::is_cpu_place(in2.place()) && platform::is_cpu_place(matrix_b.place()) &&
platform::is_cpu_place(out->place()), platform::is_cpu_place(matrix_out->place()),
"Matrix must all be in CPUPlace"); "Matrix must all be in CPUPlace");
int M = out_dim[0]; int M = dim_out[0];
int N = out_dim[1]; int N = dim_out[1];
int K = (in1_T == false) ? in1_dim[1] : in1_dim[0]; int K = (trans_a == false) ? dim_a[1] : dim_a[0];
CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE in2_Trans = (in2_T == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
gemm<platform::CPUPlace, double>(in1_Trans, in2_Trans, M, N, K, alpha, gemm<platform::CPUPlace, double>(
in1.data<double>(), in2.data<double>(), beta, transA, transB, M, N, K, alpha, matrix_a.data<double>(),
out->data<double>(), context); matrix_b.data<double>(), beta, matrix_out->data<double>(), context);
} }
} // namespace math } // namespace math

@ -12,7 +12,46 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_USE_MKLML
#include <mkl_cblas.h>
#include <mkl_lapacke.h>
#include <mkl_vml_functions.h>
#endif
#ifdef PADDLE_USE_MKL
#include <mkl.h>
#include <mkl_lapacke.h>
#endif
#ifdef PADDLE_USE_ATLAS
extern "C" {
#include <cblas.h>
#include <clapack.h>
}
#endif
#ifdef PADDLE_USE_OPENBLAS
#include <cblas.h>
#include <lapacke.h>
#endif
#ifndef LAPACK_FOUND
extern "C" {
#include <cblas.h>
int LAPACKE_sgetrf(int matrix_layout, int m, int n, float* a, int lda,
int* ipiv);
int LAPACKE_dgetrf(int matrix_layout, int m, int n, double* a, int lda,
int* ipiv);
int LAPACKE_sgetri(int matrix_layout, int n, float* a, int lda,
const int* ipiv);
int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda,
const int* ipiv);
}
#endif
#include <cmath>
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
@ -60,63 +99,67 @@ void gemm<platform::GPUPlace, double>(const CBLAS_TRANSPOSE transA,
} }
template <> template <>
void matmul<platform::GPUPlace, float>(const framework::Tensor& in1, bool in1_T, void matmul<platform::GPUPlace, float>(const framework::Tensor& matrix_a,
const framework::Tensor& in2, bool in2_T, bool trans_a,
float alpha, framework::Tensor* out, const framework::Tensor& matrix_b,
bool trans_b, float alpha,
framework::Tensor* matrix_out,
float beta, float beta,
platform::DeviceContext* context) { platform::DeviceContext* context) {
auto in1_dim = in1.dims(); auto dim_a = matrix_a.dims();
auto in2_dim = in2.dims(); auto dim_b = matrix_b.dims();
auto out_dim = out->dims(); auto dim_out = matrix_out->dims();
PADDLE_ENFORCE( PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2, "The input and output of matmul be matrix");
"The input and output of matmul be matrix");
PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) &&
PADDLE_ENFORCE(platform::is_gpu_place(in1.place()) && platform::is_gpu_place(matrix_b.place()) &&
platform::is_gpu_place(in2.place()) && platform::is_gpu_place(matrix_out->place()),
platform::is_gpu_place(out->place()),
"Matrix must all be in GPUPlace"); "Matrix must all be in GPUPlace");
int M = out_dim[0]; int M = dim_out[0];
int N = out_dim[1]; int N = dim_out[1];
int K = (in1_T == false) ? in1_dim[1] : in1_dim[0]; int K = (trans_a == false) ? dim_a[1] : dim_a[0];
CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE in2_Trans = (in2_T == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
gemm<platform::GPUPlace, float>(in1_Trans, in2_Trans, M, N, K, alpha, gemm<platform::GPUPlace, float>(
in1.data<float>(), in2.data<float>(), beta, transA, transB, M, N, K, alpha, matrix_a.data<float>(),
out->data<float>(), context); matrix_b.data<float>(), beta, matrix_out->data<float>(), context);
} }
template <> template <>
void matmul<platform::GPUPlace, double>(const framework::Tensor& in1, void matmul<platform::GPUPlace, double>(const framework::Tensor& matrix_a,
bool in1_T, bool trans_a,
const framework::Tensor& in2, const framework::Tensor& matrix_b,
bool in2_T, float alpha, bool trans_b, double alpha,
framework::Tensor* out, float beta, framework::Tensor* matrix_out,
double beta,
platform::DeviceContext* context) { platform::DeviceContext* context) {
auto in1_dim = in1.dims(); auto dim_a = matrix_a.dims();
auto in2_dim = in2.dims(); auto dim_b = matrix_b.dims();
auto out_dim = out->dims(); auto dim_out = matrix_out->dims();
PADDLE_ENFORCE( PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2,
in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2, "The input and output of matmul be matrix");
"The input and output of matmul be matrix");
PADDLE_ENFORCE(platform::is_gpu_place(in1.place()) && PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) &&
platform::is_gpu_place(in2.place()) && platform::is_gpu_place(matrix_b.place()) &&
platform::is_gpu_place(out->place()), platform::is_gpu_place(matrix_out->place()),
"Matrix must all be in GPUPlace"); "Matrix must all be in GPUPlace");
int M = out_dim[0]; int M = dim_out[0];
int N = out_dim[1]; int N = dim_out[1];
int K = (in1_T == false) ? in1_dim[1] : in1_dim[0]; int K = (trans_a == false) ? dim_a[1] : dim_a[0];
CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE in2_Trans = (in2_T == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans;
gemm<platform::GPUPlace, double>(in1_Trans, in2_Trans, M, N, K, alpha, gemm<platform::GPUPlace, double>(
in1.data<double>(), in2.data<double>(), beta, transA, transB, M, N, K, alpha, matrix_a.data<double>(),
out->data<double>(), context); matrix_b.data<double>(), beta, matrix_out->data<double>(), context);
} }
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle

@ -14,44 +14,6 @@ limitations under the License. */
#pragma once #pragma once
#ifdef PADDLE_USE_MKLML
#include <mkl_cblas.h>
#include <mkl_lapacke.h>
#include <mkl_vml_functions.h>
#endif
#ifdef PADDLE_USE_MKL
#include <mkl.h>
#include <mkl_lapacke.h>
#endif
#ifdef PADDLE_USE_ATLAS
extern "C" {
#include <cblas.h>
#include <clapack.h>
}
#endif
#ifdef PADDLE_USE_OPENBLAS
#include <cblas.h>
#include <lapacke.h>
#endif
#ifndef LAPACK_FOUND
extern "C" {
#include <cblas.h>
int LAPACKE_sgetrf(int matrix_layout, int m, int n, float* a, int lda,
int* ipiv);
int LAPACKE_dgetrf(int matrix_layout, int m, int n, double* a, int lda,
int* ipiv);
int LAPACKE_sgetri(int matrix_layout, int n, float* a, int lda,
const int* ipiv);
int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda,
const int* ipiv);
}
#endif
#include <cmath>
#include "paddle/framework/tensor.h" #include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h" #include "paddle/platform/enforce.h"
@ -60,17 +22,20 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
// support continuous memory now // Support continuous memory now
template <typename Place, typename T> // If transA = N, and transB = N
// Then matrixA: M * K, matrixB: K * N matrixC : M * N
// For more detailed info, please refer to
// http://www.netlib.org/lapack/explore-html/d4/de2/sgemm_8f.html
void gemm(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, void gemm(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB,
const int M, const int N, const int K, const T alpha, const T* A, const int M, const int N, const int K, const T alpha, const T* A,
const T* B, const T beta, T* C, platform::DeviceContext* context); const T* B, const T beta, T* C, platform::DeviceContext* context);
// matrix multiply with continuous memory // matrix multiply with continuous memory
template <typename Place, typename T> template <typename Place, typename T>
void matmul(const framework::Tensor& in1, bool in1_T, void matmul(const framework::Tensor& matrix_a, bool trans_a,
const framework::Tensor& in2, bool in2_T, float alpha, const framework::Tensor& matrix_b, bool trans_b, float alpha,
framework::Tensor* out, float beta, framework::Tensor* matrix_out, float beta,
platform::DeviceContext* context); platform::DeviceContext* context);
} // namespace math } // namespace math

Loading…
Cancel
Save