From 1be85d011df85829149d94be43dd2e155accba4b Mon Sep 17 00:00:00 2001
From: tensor-tang <tangjian03@baidu.com>
Date: Tue, 13 Nov 2018 08:09:58 +0000
Subject: [PATCH] add mkl vsqr and vpow

---
 paddle/fluid/operators/math/blas.h      | 16 +++++++++
 paddle/fluid/operators/math/blas_impl.h | 48 +++++++++++++++++++++++++
 paddle/fluid/platform/dynload/mklml.h   |  4 +++
 3 files changed, 68 insertions(+)

diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h
index da185d93c0..5d0d562030 100644
--- a/paddle/fluid/operators/math/blas.h
+++ b/paddle/fluid/operators/math/blas.h
@@ -152,6 +152,12 @@ class Blas {
   template <typename T>
   void VEXP(int n, const T* x, T* y) const;
 
+  template <typename T>
+  void VSQR(int n, const T* x, T* y) const;
+
+  template <typename T>
+  void VPOW(int n, const T* x, T alpha, T* y) const;
+
   template <typename T>
   void GEMV(bool trans_a, int M, int N, T alpha, const T* A, const T* B, T beta,
             T* C) const;
@@ -238,6 +244,16 @@ class BlasT : private Blas<DeviceContext> {
     Base()->template VEXP<T>(args...);
   }
 
+  template <typename... ARGS>
+  void VSQR(ARGS... args) const {
+    Base()->template VSQR<T>(args...);
+  }
+
+  template <typename... ARGS>
+  void VPOW(ARGS... args) const {
+    Base()->template VPOW<T>(args...);
+  }
+
   template <typename... ARGS>
   void GEMV(ARGS... args) const {
     Base()->template GEMV<T>(args...);
diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h
index e1df78d11e..59454669be 100644
--- a/paddle/fluid/operators/math/blas_impl.h
+++ b/paddle/fluid/operators/math/blas_impl.h
@@ -12,6 +12,7 @@
 // See the License for the specific language governing permissions and
 // limitations under the License.
 #pragma once
+#include <cmath>
 #include <limits>
 #include <vector>
 #include "paddle/fluid/operators/math/math_function.h"
@@ -102,6 +103,16 @@ struct CBlas<float> {
   static void VEXP(ARGS... args) {
     platform::dynload::vsExp(args...);
   }
+
+  template <typename... ARGS>
+  static void VSQR(ARGS... args) {
+    platform::dynload::vsSqr(args...);
+  }
+
+  template <typename... ARGS>
+  static void VPOW(ARGS... args) {
+    platform::dynload::vsPowx(args...);
+  }
 };
 
 template <>
@@ -182,6 +193,16 @@ struct CBlas<double> {
   static void VEXP(ARGS... args) {
     platform::dynload::vdExp(args...);
   }
+
+  template <typename... ARGS>
+  static void VSQR(ARGS... args) {
+    platform::dynload::vdSqr(args...);
+  }
+
+  template <typename... ARGS>
+  static void VPOW(ARGS... args) {
+    platform::dynload::vdPowx(args...);
+  }
 };
 
 #else
@@ -241,6 +262,8 @@ struct CBlas<platform::float16> {
   }
   static void VMUL(...) { PADDLE_THROW("float16 VMUL not supported on CPU"); }
   static void VEXP(...) { PADDLE_THROW("float16 VEXP not supported on CPU"); }
+  static void VSQR(...) { PADDLE_THROW("float16 VSQR not supported on CPU"); }
+  static void VPOW(...) { PADDLE_THROW("float16 VPOW not supported on CPU"); }
   static void DOT(...) { PADDLE_THROW("float16 DOT not supported on CPU"); };
   static void SCAL(...) { PADDLE_THROW("float16 SCAL not supported on CPU"); };
 #ifdef PADDLE_WITH_MKLML
@@ -398,6 +421,31 @@ void Blas<platform::CPUDeviceContext>::VEXP(int n, const T *x, T *y) const {
 #endif
 }
 
+template <>
+template <typename T>
+void Blas<platform::CPUDeviceContext>::VSQR(int n, const T *x, T *y) const {
+#ifdef PADDLE_WITH_MKLML
+  CBlas<T>::VSQR(n, x, y);
+#else
+  for (int i = 0; i < n; ++i) {
+    y[i] = std::sqrt(x[i]);
+  }
+#endif
+}
+
+template <>
+template <typename T>
+void Blas<platform::CPUDeviceContext>::VPOW(int n, const T *x, T a,
+                                            T *y) const {
+#ifdef PADDLE_WITH_MKLML
+  CBlas<T>::VPOW(n, x, a, y);
+#else
+  for (int i = 0; i < n; ++i) {
+    y[i] = std::pow(x[i], a);
+  }
+#endif
+}
+
 template <>
 template <typename T>
 T Blas<platform::CPUDeviceContext>::DOT(int n, const T *x, const T *y) const {
diff --git a/paddle/fluid/platform/dynload/mklml.h b/paddle/fluid/platform/dynload/mklml.h
index aa20553cef..9273e9b1e7 100644
--- a/paddle/fluid/platform/dynload/mklml.h
+++ b/paddle/fluid/platform/dynload/mklml.h
@@ -76,6 +76,10 @@ extern void* mklml_dso_handle;
   __macro(vdMul);                   \
   __macro(vsExp);                   \
   __macro(vdExp);                   \
+  __macro(vsSqr);                   \
+  __macro(vdSqr);                   \
+  __macro(vsPowx);                  \
+  __macro(vdPowx);                  \
   __macro(MKL_Set_Num_Threads)
 
 MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP);