fix gpu build error

revert-3824-remove_grad_op_type
qijun 8 years ago
parent 22dac40c3a
commit f190a79538

@ -32,7 +32,7 @@ void gemm<platform::CPUPlace, float>(const CBLAS_TRANSPOSE transA,
const float beta, const float beta,
float* C, float* C,
const int ldc, const int ldc,
const platform::DeviceContext* context) { platform::DeviceContext* context) {
cblas_sgemm(CblasRowMajor, cblas_sgemm(CblasRowMajor,
transA, transA,
transB, transB,
@ -63,7 +63,7 @@ void gemm<platform::CPUPlace, double>(const CBLAS_TRANSPOSE transA,
const double beta, const double beta,
double* C, double* C,
const int ldc, const int ldc,
const platform::DeviceContext* context) { platform::DeviceContext* context) {
cblas_dgemm(CblasRowMajor, cblas_dgemm(CblasRowMajor,
transA, transA,
transB, transB,
@ -80,42 +80,6 @@ void gemm<platform::CPUPlace, double>(const CBLAS_TRANSPOSE transA,
ldc); ldc);
} }
template <>
void axpy<platform::CPUPlace, float>(const int n,
const float alpha,
const float* x,
float* y,
const platform::DeviceContext* context) {
cblas_saxpy(n, alpha, x, 1, y, 1);
}
template <>
void axpy<platform::CPUPlace, double>(const int n,
const double alpha,
const double* x,
double* y,
const platform::DeviceContext* context) {
cblas_daxpy(n, alpha, x, 1, y, 1);
}
template <>
float dotProduct<platform::CPUPlace, float>(
const int n,
const float* x,
const float* y,
const platform::DeviceContext* context) {
return cblas_sdot(n, x, 1, y, 1);
}
template <>
double dotProduct<platform::CPUPlace, double>(
const int n,
const double* x,
const double* y,
const platform::DeviceContext* context) {
return cblas_ddot(n, x, 1, y, 1);
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle

@ -20,7 +20,7 @@ namespace operators {
namespace math { namespace math {
template <> template <>
void gemm<platform::GPUPlace float>(const CBLAS_TRANSPOSE transA, void gemm<platform::GPUPlace, float>(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const CBLAS_TRANSPOSE transB,
const int M, const int M,
const int N, const int N,
@ -33,16 +33,16 @@ void gemm<platform::GPUPlace float>(const CBLAS_TRANSPOSE transA,
const float beta, const float beta,
float* C, float* C,
const int ldc, const int ldc,
const platform::DeviceContext* context) { platform::DeviceContext* context) {
// Note that cublas follows fortran order, so the order is different from // Note that cublas follows fortran order, so the order is different from
// the cblas convention. // the cblas convention.
cublasOperation_t cuTransA = cublasOperation_t cuTransA =
(TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB = cublasOperation_t cuTransB =
(TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE(platform::dynload::cublasSgemm( PADDLE_ENFORCE(platform::dynload::cublasSgemm(
reinterpret_cast<const platform::CUDADeviceContext*>(context)-> reinterpret_cast<platform::CUDADeviceContext*>(context)->
cublas_handle(), cublas_handle(),
cuTransB, cuTransB,
cuTransA, cuTransA,
@ -73,15 +73,15 @@ void gemm<platform::GPUPlace, double>(const CBLAS_TRANSPOSE transA,
const double beta, const double beta,
double* C, double* C,
const int ldc, const int ldc,
const platform::DeviceContext* context) { platform::DeviceContext* context) {
// Note that cublas follows fortran order, so the order is different from // Note that cublas follows fortran order, so the order is different from
// the cblas convention. // the cblas convention.
cublasOperation_t cuTransA = cublasOperation_t cuTransA =
(TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB = cublasOperation_t cuTransB =
(TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE(platform::dynload::cublasDgemm( PADDLE_ENFORCE(platform::dynload::cublasDgemm(
reinterpret_cast<const platform::CUDADeviceContext*>(context)-> reinterpret_cast<platform::CUDADeviceContext*>(context)->
cublas_handle(), cublas_handle(),
cuTransB, cuTransB,
cuTransA, cuTransA,
@ -99,48 +99,6 @@ void gemm<platform::GPUPlace, double>(const CBLAS_TRANSPOSE transA,
} }
template <>
void axpy<platform::GPUPlace, float>(const int n,
const float alpha,
const float* x,
float* y,
const platform::DeviceContext* context) {
CUBLAS_ENFORCE(platform::dynload::cublasSaxpy(
reinterpret_cast<const platform::CUDADeviceContext*>(context)->
cublas_handle(), N, &alpha, X, 1, Y, 1));
}
template <>
void axpy<platform::GPUPlace, double>(const int n,
const double alpha,
const double* x,
double* y,
const platform::DeviceContext* context) {
CUBLAS_ENFORCE(platform::dynload::cublasDaxpy(
reinterpret_cast<const platform::CUDADeviceContext*>(context)->
cublas_handle(), N, &alpha, X, 1, Y, 1));
}
template <>
float dotProduct<platform::GPUPlace, float>(const int n,
const float* x,
const float* y,
const platform::DeviceContext* context) {
CUBLAS_ENFORCE(platform::dynload::cublasSdot(
reinterpret_cast<const platform::CUDADeviceContext*>(context)->
cublas_handle(), n, a, 1, b, 1, &result));
}
template <>
double dotProduct<platform::GPUPlace, double>(const int n,
const double* x,
const double* y,
const platform::DeviceContext* context) {
CUBLAS_ENFORCE(platform::dynload::cublasDdot(
reinterpret_cast<const platform::CUDADeviceContext*>(context)->
cublas_handle(), n, a, 1, b, 1, &result));
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle

@ -58,20 +58,7 @@ void gemm(const CBLAS_TRANSPOSE transA,
const T beta, const T beta,
T* C, T* C,
const int ldc, const int ldc,
const platform::DeviceContext* context); platform::DeviceContext* context);
template <typename Place, typename T>
void axpy(const int n,
const T alpha,
const T* x,
T* y,
const platform::DeviceContext* context);
template <typename Place, typename T>
T dotProduct(const int n,
const T* x,
const T* y,
const platform::DeviceContext* context);
} // namespace math } // namespace math
} // namespace operators } // namespace operators

@ -37,7 +37,8 @@ public:
int N = out_dim[1]; int N = out_dim[1];
int K = in0_dim[1]; int K = in0_dim[1];
paddle::operators::math::template gemm<Place, T>(CblasNoTrans, paddle::operators::math::template gemm<Place, T>(
CblasNoTrans,
CblasNoTrans, CblasNoTrans,
M, M,
N, N,
@ -50,7 +51,7 @@ public:
0, 0,
output->data<T>(), output->data<T>(),
N, N,
&context.device_context()); &const_cast<platform::DeviceContext&>(context.device_context()));
} }
}; };

Loading…
Cancel
Save