Merge pull request #9231 from kexinzhao/elementwise_add_fp16

Add float16 support to Elementwise Add op
helinwang-patch-1
Kexin Zhao 7 years ago committed by GitHub
commit c1e9b1e37e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -14,19 +14,20 @@ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/fluid/operators/elementwise_add_op.h" #include "paddle/fluid/operators/elementwise_add_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
elementwise_add, elementwise_add, ops::ElementwiseAddKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseAddKernel<paddle::platform::CUDADeviceContext, float>, ops::ElementwiseAddKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseAddKernel<paddle::platform::CUDADeviceContext, double>, ops::ElementwiseAddKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseAddKernel<paddle::platform::CUDADeviceContext, int>, ops::ElementwiseAddKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseAddKernel<paddle::platform::CUDADeviceContext, int64_t>); ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
elementwise_add_grad, elementwise_add_grad,
ops::ElementwiseAddGradKernel<paddle::platform::CUDADeviceContext, float>, ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseAddGradKernel<paddle::platform::CUDADeviceContext, double>, ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseAddGradKernel<paddle::platform::CUDADeviceContext, int>, ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseAddGradKernel<paddle::platform::CUDADeviceContext, ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int64_t>);
int64_t>);

@ -600,7 +600,7 @@ HOSTDEVICE inline bool operator>=(const float16& a, const float16& b) {
// Arithmetic operators for float16 on ARMv8.2-A CPU // Arithmetic operators for float16 on ARMv8.2-A CPU
#elif defined(PADDLE_WITH_NATIVE_FP16) #elif defined(PADDLE_WITH_NATIVE_FP16)
HOST inline float16 operator+(const float16& a, const float16& b) { inline float16 operator+(const float16& a, const float16& b) {
float16 res; float16 res;
asm volatile( asm volatile(
"ld1 {v0.h}[0], [%[a_ptr]]\n" "ld1 {v0.h}[0], [%[a_ptr]]\n"
@ -616,7 +616,7 @@ HOST inline float16 operator+(const float16& a, const float16& b) {
return res; return res;
} }
HOST inline float16 operator-(const float16& a, const float16& b) { inline float16 operator-(const float16& a, const float16& b) {
float16 res; float16 res;
asm volatile( asm volatile(
"ld1 {v0.h}[0], [%[a_ptr]]\n" "ld1 {v0.h}[0], [%[a_ptr]]\n"
@ -632,7 +632,7 @@ HOST inline float16 operator-(const float16& a, const float16& b) {
return res; return res;
} }
HOST inline float16 operator*(const float16& a, const float16& b) { inline float16 operator*(const float16& a, const float16& b) {
float16 res; float16 res;
asm volatile( asm volatile(
"ld1 {v0.h}[0], [%[a_ptr]]\n" "ld1 {v0.h}[0], [%[a_ptr]]\n"
@ -648,7 +648,7 @@ HOST inline float16 operator*(const float16& a, const float16& b) {
return res; return res;
} }
HOST inline float16 operator/(const float16& a, const float16& b) { inline float16 operator/(const float16& a, const float16& b) {
float16 res; float16 res;
asm volatile( asm volatile(
"ld1 {v0.h}[0], [%[a_ptr]]\n" "ld1 {v0.h}[0], [%[a_ptr]]\n"
@ -664,7 +664,7 @@ HOST inline float16 operator/(const float16& a, const float16& b) {
return res; return res;
} }
HOST inline float16 operator-(const float16& a) { inline float16 operator-(const float16& a) {
float16 res; float16 res;
asm volatile( asm volatile(
"ld1 {v0.h}[0], [%[a_ptr]]\n" "ld1 {v0.h}[0], [%[a_ptr]]\n"
@ -679,27 +679,27 @@ HOST inline float16 operator-(const float16& a) {
return res; return res;
} }
HOST inline float16& operator+=(float16& a, const float16& b) { inline float16& operator+=(float16& a, const float16& b) {
a = a + b; a = a + b;
return a; return a;
} }
HOST inline float16& operator-=(float16& a, const float16& b) { inline float16& operator-=(float16& a, const float16& b) {
a = a - b; a = a - b;
return a; return a;
} }
HOST inline float16& operator*=(float16& a, const float16& b) { inline float16& operator*=(float16& a, const float16& b) {
a = a * b; a = a * b;
return a; return a;
} }
HOST inline float16& operator/=(float16& a, const float16& b) { inline float16& operator/=(float16& a, const float16& b) {
a = a / b; a = a / b;
return a; return a;
} }
HOST inline bool operator==(const float16& a, const float16& b) { inline bool operator==(const float16& a, const float16& b) {
uint16_t res; uint16_t res;
asm volatile( asm volatile(
"ld1 {v0.h}[0], [%[a_ptr]]\n" "ld1 {v0.h}[0], [%[a_ptr]]\n"
@ -715,11 +715,9 @@ HOST inline bool operator==(const float16& a, const float16& b) {
return (res & 0xffff) != 0; return (res & 0xffff) != 0;
} }
HOST inline bool operator!=(const float16& a, const float16& b) { inline bool operator!=(const float16& a, const float16& b) { return !(a == b); }
return !(a == b);
}
HOST inline bool operator<(const float16& a, const float16& b) { inline bool operator<(const float16& a, const float16& b) {
uint16_t res; uint16_t res;
asm volatile( asm volatile(
"ld1 {v1.h}[0], [%[a_ptr]]\n" "ld1 {v1.h}[0], [%[a_ptr]]\n"
@ -735,7 +733,7 @@ HOST inline bool operator<(const float16& a, const float16& b) {
return (res & 0xffff) != 0; return (res & 0xffff) != 0;
} }
HOST inline bool operator<=(const float16& a, const float16& b) { inline bool operator<=(const float16& a, const float16& b) {
uint16_t res; uint16_t res;
asm volatile( asm volatile(
"ld1 {v1.h}[0], [%[a_ptr]]\n" "ld1 {v1.h}[0], [%[a_ptr]]\n"
@ -751,7 +749,7 @@ HOST inline bool operator<=(const float16& a, const float16& b) {
return (res & 0xffff) != 0; return (res & 0xffff) != 0;
} }
HOST inline bool operator>(const float16& a, const float16& b) { inline bool operator>(const float16& a, const float16& b) {
uint16_t res; uint16_t res;
asm volatile( asm volatile(
"ld1 {v0.h}[0], [%[a_ptr]]\n" "ld1 {v0.h}[0], [%[a_ptr]]\n"
@ -767,7 +765,7 @@ HOST inline bool operator>(const float16& a, const float16& b) {
return (res & 0xffff) != 0; return (res & 0xffff) != 0;
} }
HOST inline bool operator>=(const float16& a, const float16& b) { inline bool operator>=(const float16& a, const float16& b) {
uint16_t res; uint16_t res;
asm volatile( asm volatile(
"ld1 {v0.h}[0], [%[a_ptr]]\n" "ld1 {v0.h}[0], [%[a_ptr]]\n"
@ -785,69 +783,69 @@ HOST inline bool operator>=(const float16& a, const float16& b) {
// Arithmetic operators for float16, software emulated on other CPU // Arithmetic operators for float16, software emulated on other CPU
#else #else
HOST inline float16 operator+(const float16& a, const float16& b) { inline float16 operator+(const float16& a, const float16& b) {
return float16(float(a) + float(b)); return float16(float(a) + float(b));
} }
HOST inline float16 operator-(const float16& a, const float16& b) { inline float16 operator-(const float16& a, const float16& b) {
return float16(float(a) - float(b)); return float16(float(a) - float(b));
} }
HOST inline float16 operator*(const float16& a, const float16& b) { inline float16 operator*(const float16& a, const float16& b) {
return float16(float(a) * float(b)); return float16(float(a) * float(b));
} }
HOST inline float16 operator/(const float16& a, const float16& b) { inline float16 operator/(const float16& a, const float16& b) {
return float16(float(a) / float(b)); return float16(float(a) / float(b));
} }
HOST inline float16 operator-(const float16& a) { inline float16 operator-(const float16& a) {
float16 res; float16 res;
res.x = a.x ^ 0x8000; res.x = a.x ^ 0x8000;
return res; return res;
} }
HOST inline float16& operator+=(float16& a, const float16& b) { inline float16& operator+=(float16& a, const float16& b) {
a = float16(float(a) + float(b)); a = float16(float(a) + float(b));
return a; return a;
} }
HOST inline float16& operator-=(float16& a, const float16& b) { inline float16& operator-=(float16& a, const float16& b) {
a = float16(float(a) - float(b)); a = float16(float(a) - float(b));
return a; return a;
} }
HOST inline float16& operator*=(float16& a, const float16& b) { inline float16& operator*=(float16& a, const float16& b) {
a = float16(float(a) * float(b)); a = float16(float(a) * float(b));
return a; return a;
} }
HOST inline float16& operator/=(float16& a, const float16& b) { inline float16& operator/=(float16& a, const float16& b) {
a = float16(float(a) / float(b)); a = float16(float(a) / float(b));
return a; return a;
} }
HOST inline bool operator==(const float16& a, const float16& b) { inline bool operator==(const float16& a, const float16& b) {
return float(a) == float(b); return float(a) == float(b);
} }
HOST inline bool operator!=(const float16& a, const float16& b) { inline bool operator!=(const float16& a, const float16& b) {
return float(a) != float(b); return float(a) != float(b);
} }
HOST inline bool operator<(const float16& a, const float16& b) { inline bool operator<(const float16& a, const float16& b) {
return float(a) < float(b); return float(a) < float(b);
} }
HOST inline bool operator<=(const float16& a, const float16& b) { inline bool operator<=(const float16& a, const float16& b) {
return float(a) <= float(b); return float(a) <= float(b);
} }
HOST inline bool operator>(const float16& a, const float16& b) { inline bool operator>(const float16& a, const float16& b) {
return float(a) > float(b); return float(a) > float(b);
} }
HOST inline bool operator>=(const float16& a, const float16& b) { inline bool operator>=(const float16& a, const float16& b) {
return float(a) >= float(b); return float(a) >= float(b);
} }
#endif #endif

Loading…
Cancel
Save