Merge pull request #10934 from tensor-tang/mklml_funcs

speedup vInvSqrt vLogqp vTanh with mklml
release/0.13.0
Tao Luo 7 years ago committed by GitHub
commit 25aa45394b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "MathFunctions.h" #include "paddle/math/MathFunctions.h"
#include "hl_matrix_apply.cuh" #include "hl_matrix_apply.cuh"
#include "hl_matrix_ops.cuh" #include "hl_matrix_ops.cuh"
#include "paddle/utils/DynamicLoader.h" #include "paddle/utils/DynamicLoader.h"
@ -240,6 +240,36 @@ template <>
void vAdd<double>(const int n, const double* a, const double* b, double* r) { void vAdd<double>(const int n, const double* a, const double* b, double* r) {
vdAdd(n, a, b, r); vdAdd(n, a, b, r);
} }
template <>
void vTanh<float>(const int n, const float* a, float* r) {
vsTanh(n, a, r);
}
template <>
void vTanh<double>(const int n, const double* a, double* r) {
vdTanh(n, a, r);
}
template <>
void vInvSqrt<float>(const int n, const float* a, float* r) {
vsInvSqrt(n, a, r);
}
template <>
void vInvSqrt<double>(const int n, const double* a, double* r) {
vdInvSqrt(n, a, r);
}
template <>
void vLog1p<float>(const int n, const float* a, float* r) {
vsLog1p(n, a, r);
}
template <>
void vLog1p<double>(const int n, const double* a, double* r) {
vdLog1p(n, a, r);
}
#else #else
DEFINE_MATRIX_BINARY_OP(vExp, b = std::exp(a)); DEFINE_MATRIX_BINARY_OP(vExp, b = std::exp(a));
@ -277,17 +307,6 @@ void vAdd(const int n, const T* a, const T* b, T* r) {
n); n);
} }
template void vExp(const int n, const float* a, float* r);
template void vExp(const int n, const double* a, double* r);
template void vLog(const int n, const float* a, float* r);
template void vLog(const int n, const double* a, double* r);
template void vPow(const int n, const float* a, const float b, float* r);
template void vPow(const int n, const double* a, const double b, double* r);
template void vAdd(const int n, const float* a, const float* b, float* r);
template void vAdd(const int n, const double* a, const double* b, double* r);
#endif
DEFINE_MATRIX_BINARY_OP(vInvSqrt, b = 1.0f / std::sqrt(a)); DEFINE_MATRIX_BINARY_OP(vInvSqrt, b = 1.0f / std::sqrt(a));
template <class T> template <class T>
void vInvSqrt(const int n, const T* a, T* r) { void vInvSqrt(const int n, const T* a, T* r) {
@ -311,11 +330,19 @@ void vTanh(const int n, const T* a, T* r) {
binary::vTanh<T>(), const_cast<T*>(a), r, 1, n, n, n); binary::vTanh<T>(), const_cast<T*>(a), r, 1, n, n, n);
} }
template void vExp(const int n, const float* a, float* r);
template void vExp(const int n, const double* a, double* r);
template void vLog(const int n, const float* a, float* r);
template void vLog(const int n, const double* a, double* r);
template void vPow(const int n, const float* a, const float b, float* r);
template void vPow(const int n, const double* a, const double b, double* r);
template void vAdd(const int n, const float* a, const float* b, float* r);
template void vAdd(const int n, const double* a, const double* b, double* r);
template void vInvSqrt(const int n, const double* a, double* r); template void vInvSqrt(const int n, const double* a, double* r);
template void vInvSqrt(const int n, const float* a, float* r); template void vInvSqrt(const int n, const float* a, float* r);
template void vLog1p(const int n, const float* a, float* r); template void vLog1p(const int n, const float* a, float* r);
template void vLog1p(const int n, const double* a, double* r); template void vLog1p(const int n, const double* a, double* r);
template void vTanh(const int n, const float* a, float* r); template void vTanh(const int n, const float* a, float* r);
template void vTanh(const int n, const double* a, double* r); template void vTanh(const int n, const double* a, double* r);
#endif
} // namespace paddle } // namespace paddle

Loading…
Cancel
Save