diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_grad_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_grad_impl.cu index f8377fd721..e887b98eca 100644 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_grad_impl.cu +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_grad_impl.cu @@ -18,10 +18,21 @@ #include #include #include "kernel/gpu/cuda_impl/layer_norm_grad_impl.cuh" +#include "kernel/gpu/cuda_impl/layer_norm_impl.cuh" constexpr int NUM_PER_THREAD_REDUCE = 4; constexpr int WARP_SIZE = 32; +template +inline __device__ T my_pow(T a, double b) { + return pow(a, static_cast(b)); +} + +template <> +inline __device__ half my_pow(half a, double b) { + return __float2half(pow(__half2float(a), static_cast(b))); +} + template inline __device__ void GammaAndBetaThreadReduce(const int& col, const int& row_dim, const int& col_dim, const T& epsilon, const T* dy, const T* x, const T* mean, const T* var, @@ -35,7 +46,7 @@ inline __device__ void GammaAndBetaThreadReduce(const int& col, const int& row_d } int pos = row * col_dim + col; - dg[0] += dy[pos] * pow(var[row] + epsilon, -0.5) * (x[pos] - mean[row]); + dg[0] += dy[pos] * my_pow(var[row] + epsilon, -0.5) * (x[pos] - mean[row]); db[0] += dy[pos]; } } @@ -58,26 +69,26 @@ inline __device__ void GammaAndBetaBlockReduce(const int& col, const int& row_di // load data to share memory // thread(0, 32, 64, 96, ...) keep the data - extern __shared__ T share_mem[]; + DynamicSharedMem share_mem; if (threadIdx.x % WARP_SIZE == 0) { int offset = threadIdx.x / WARP_SIZE * 2; - share_mem[offset] = dg[0]; - share_mem[offset + 1] = db[0]; + share_mem.addr()[offset] = dg[0]; + share_mem.addr()[offset + 1] = db[0]; } __syncthreads(); for (int stride = blockDim.x / WARP_SIZE / 2; stride > 0; stride >>= 1) { if (threadIdx.x < stride) { int offset = (threadIdx.x + stride) * 2; - share_mem[threadIdx.x * 2] += share_mem[offset]; - share_mem[threadIdx.x * 2 + 1] += share_mem[offset + 1]; + share_mem.addr()[threadIdx.x * 2] += share_mem.addr()[offset]; + share_mem.addr()[threadIdx.x * 2 + 1] += share_mem.addr()[offset + 1]; } } __syncthreads(); if (threadIdx.x == 0) { - dg_addr[col] = share_mem[0]; - db_addr[col] = share_mem[1]; + dg_addr[col] = share_mem.addr()[0]; + db_addr[col] = share_mem.addr()[1]; } } @@ -114,13 +125,37 @@ inline __device__ void InputThreadReduce(const int& row, const int& col_dim, con T v1 = dy[pos] * gamma[gamma_offset]; T v2 = x[pos] - mean[row]; - sum1[0] += -0.5 * v1 * v2 * pow(var[row] + epsilon, -1.5); + sum1[0] += -0.5 * v1 * v2 * my_pow(var[row] + epsilon, -1.5); sum2[0] += v1; sum3[0] += -2.0 * v2; } } } +template <> +inline __device__ void InputThreadReduce(const int& row, const int& col_dim, const int& param_dim, const half& epsilon, + half* sum1, half* sum2, half* sum3, const half* dy, const half* x, + const half* mean, const half* var, const half* gamma) { + int loop_num = (col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE; + for (int i = threadIdx.x; i < loop_num; i += blockDim.x) { + for (int j = 0; j < NUM_PER_THREAD_REDUCE; j++) { + int col = NUM_PER_THREAD_REDUCE * i + j; + if (col >= col_dim) { + return; + } + + int pos = row * col_dim + col; + int gamma_offset = pos % param_dim; + half v1 = dy[pos] * gamma[gamma_offset]; + half v2 = x[pos] - mean[row]; + + sum1[0] += __float2half(-0.5) * v1 * v2 * my_pow(var[row] + epsilon, -1.5); + sum2[0] += v1; + sum3[0] += __float2half(-2.0) * v2; + } + } +} + template inline __device__ void InputWarpReduce(T* sum1, T* sum2, T* sum3) { for (int delta = (WARP_SIZE >> 1); delta > 0; delta >>= 1) { @@ -166,12 +201,28 @@ inline __device__ void InputProp(const int& row, const int& col_dim, const int& int gamma_offset = pos % param_dim; T v1 = dy[pos] * gamma[gamma_offset]; T v2 = x[pos] - mean[row]; - T v3 = pow(var[row] + epsilon, -0.5); + T v3 = my_pow(var[row] + epsilon, -0.5); dx[pos] = v1 * v3 + share_mem[0] * (2.0 / col_dim) * v2 + (-1.0 * v3 * share_mem[1] + (1.0 / col_dim) * share_mem[0] * share_mem[2]) * (1.0 / col_dim); } } +template <> +inline __device__ void InputProp(const int& row, const int& col_dim, const int& param_dim, const half& epsilon, + const half* dy, const half* x, const half* mean, const half* var, const half* gamma, + half* dx, const half* share_mem) { + for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { + int pos = (row * col_dim + col); + int gamma_offset = pos % param_dim; + half v1 = dy[pos] * gamma[gamma_offset]; + half v2 = x[pos] - mean[row]; + half v3 = my_pow(var[row] + epsilon, -0.5); + dx[pos] = v1 * v3 + share_mem[0] * __float2half(2.0 / col_dim) * v2 + + (__float2half(-1.0) * v3 * share_mem[1] + __float2half(1.0 / col_dim) * share_mem[0] * share_mem[2])\ + * __float2half(1.0 / col_dim); + } +} + template __global__ void InputPropKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, const T* dy, const T* x, const T* mean, const T* var, const T* gamma, T* dx) { @@ -179,27 +230,30 @@ __global__ void InputPropKernel(const int row_dim, const int col_dim, const int T sum1 = 0; T sum2 = 0; T sum3 = 0; - extern __shared__ T share_mem[]; + DynamicSharedMem share_mem; InputThreadReduce(row, col_dim, param_dim, epsilon, &sum1, &sum2, &sum3, dy, x, mean, var, gamma); InputWarpReduce(&sum1, &sum2, &sum3); - InputBlockReduce(col_dim, &sum1, &sum2, &sum3, share_mem); - InputProp(row, col_dim, param_dim, epsilon, dy, x, mean, var, gamma, dx, share_mem); + InputBlockReduce(col_dim, &sum1, &sum2, &sum3, share_mem.addr()); + InputProp(row, col_dim, param_dim, epsilon, dy, x, mean, var, gamma, dx, share_mem.addr()); } } template void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const T& epsilon, const T* dy, const T* x, const T* mean, const T* var, const T* gamma, T* dx, T* dg, T* db, cudaStream_t stream) { - int share_mem = + int share_mem_size = ((col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 3 * sizeof(T); - InputPropKernel<<>>(row_dim, col_dim, param_dim, epsilon, dy, x, mean, var, gamma, - dx); + InputPropKernel<<>>(row_dim, col_dim, param_dim, epsilon, dy, x, mean, var, + gamma, dx); - share_mem = + share_mem_size = ((row_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 2 * sizeof(T); - GammaAndBetaPropKernel<<>>(row_dim, col_dim, epsilon, dy, x, mean, var, dg, db); + GammaAndBetaPropKernel<<>>(row_dim, col_dim, epsilon, dy, x, mean, var, dg, db); } template void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const float& epsilon, const float* dy, const float* x, const float* mean, const float* var, const float* gamma, float* dx, float* dg, float* db, cudaStream_t stream); +template void LayerNormGrad(const int& row_dim, const int& col_dim, const int& param_dim, const half& epsilon, + const half* dy, const half* x, const half* mean, const half* var, const half* gamma, + half* dx, half* dg, half* db, cudaStream_t stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cu index cef74dc8ba..cfb60f0ba6 100644 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cu +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cu @@ -35,7 +35,8 @@ inline __device__ void MeanAndVarAccumulation(T *mean, T *var, T *num, const T & template inline __device__ void MeanAndVarMerge(T *m1, T *v1, T *n1, const T &m2, const T &v2, const T &n2) { - if (n2 == 0) { + T zero = 0; + if (n2 == zero) { return; } @@ -112,6 +113,17 @@ inline __device__ void LayerNorm(const int &row, const int &col_dim, const int & } } +template <> +inline __device__ void LayerNorm(const int &row, const int &col_dim, const int ¶m_dim, const half *x, + const half *share_mem, const half *gamma, const half *beta, const half epsilon, + half *y) { + for (int col = threadIdx.x; col < col_dim; col += blockDim.x) { + int pos = row * col_dim + col; + int i = pos % param_dim; + y[pos] = (x[pos] - share_mem[0]) / hsqrt(share_mem[1] + epsilon) * gamma[i] + beta[i]; + } +} + template __global__ void LayerNormKernel(const int row_dim, const int col_dim, const int param_dim, const T epsilon, const T *x, const T *gamma, const T *beta, T *y, T *mean_addr, T *var_addr) { @@ -120,14 +132,14 @@ __global__ void LayerNormKernel(const int row_dim, const int col_dim, const int T var = 0; T num = 0; const T *block_addr = x + row * col_dim; - extern __shared__ T share_mem[]; + DynamicSharedMem share_mem; ThreadReduce(col_dim, block_addr, &mean, &var, &num); WarpReduce(&mean, &var, &num); - BlockReduce(col_dim, &mean, &var, &num, mean_addr, var_addr, share_mem); + BlockReduce(col_dim, &mean, &var, &num, mean_addr, var_addr, share_mem.addr()); __syncthreads(); - LayerNorm(row, col_dim, param_dim, x, share_mem, gamma, beta, epsilon, y); + LayerNorm(row, col_dim, param_dim, x, share_mem.addr(), gamma, beta, epsilon, y); } } @@ -137,12 +149,15 @@ void LayerNorm(const int &row_dim, const int &col_dim, const int ¶m_dim, con const dim3 block(row_dim); const dim3 thread(256); // keep the mean/var/num after warp reduce - int share_mem = + int share_mem_size = ((col_dim + NUM_PER_THREAD_REDUCE - 1) / NUM_PER_THREAD_REDUCE + WARP_SIZE - 1) / WARP_SIZE * 3 * sizeof(T); - LayerNormKernel<<>>(row_dim, col_dim, param_dim, epsilon, x, gamma, beta, y, mean, - var); + LayerNormKernel<<>>(row_dim, col_dim, param_dim, epsilon, x, gamma, beta, y, + mean, var); } template void LayerNorm(const int &row_dim, const int &col_dim, const int ¶m_dim, const float &epsilon, const float *x, const float *gamma, const float *beta, float *y, float *mean, float *var, cudaStream_t stream); +template void LayerNorm(const int &row_dim, const int &col_dim, const int ¶m_dim, const half &epsilon, + const half *x, const half *gamma, const half *beta, half *y, half *mean, half *var, + cudaStream_t stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cuh index 4832b08746..c06a698384 100644 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cuh +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/layer_norm_impl.cuh @@ -19,6 +19,23 @@ #include "device/gpu/cuda_common.h" +template +struct DynamicSharedMem; +template<> +struct DynamicSharedMem { + __device__ float *addr() { + extern __shared__ float addr_float[]; + return addr_float; + } +}; +template<> +struct DynamicSharedMem { + __device__ half *addr() { + extern __shared__ half addr_half[]; + return addr_half; + } +}; + template void LayerNorm(const int& outer, const int& inner, const int& param_dim, const T& epsilon, const T* x, const T* gamma, const T* beta, T* y, T* mean, T* var, cudaStream_t stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/momentum_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/momentum_impl.cu index ae24a8dec9..5a1c9eb687 100755 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/momentum_impl.cu +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/momentum_impl.cu @@ -15,25 +15,38 @@ */ #include "momentum_impl.cuh" -template -__global__ void MomentumUpdateVariableKernel(const size_t size, T *variable, T *accumulation, const T *learning_rate, - const T *gradient, const T *momentum) { +template +__global__ void MomentumUpdateVariableKernel(const size_t size, T *variable, T *accumulation, const S *learning_rate, + const T *gradient, const S *momentum) { for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { accumulation[i] = momentum[0] * accumulation[i] + gradient[i]; variable[i] -= learning_rate[0] * accumulation[i]; } return; } -template -void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const T *learning_rate, const T *gradient, - const T *momentum, cudaStream_t cuda_stream) { +template <> +__global__ void MomentumUpdateVariableKernel(const size_t size, half *variable, half *accumulation, + const float *learning_rate, const half *gradient, + const float *momentum) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { + accumulation[i] = __float2half(momentum[0]) * accumulation[i] + gradient[i]; + variable[i] -= __float2half(learning_rate[0]) * accumulation[i]; + } + return; +} +template +void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const T *gradient, + const S *momentum, cudaStream_t cuda_stream) { MomentumUpdateVariableKernel<<>>(size, variable, accumulation, learning_rate, gradient, momentum); return; } -template void MomentumUpdateVariable(const size_t size, float *variable, float *accumulation, - const float *learning_rate, const float *gradient, const float *momentum, - cudaStream_t cuda_stream); -template void MomentumUpdateVariable(const size_t size, half *variable, half *accumulation, - const half *learning_rate, const half *gradient, const half *momentum, - cudaStream_t cuda_stream); +template void MomentumUpdateVariable(const size_t size, float *variable, float *accumulation, + const float *learning_rate, const float *gradient, + const float *momentum, cudaStream_t cuda_stream); +template void MomentumUpdateVariable(const size_t size, half *variable, half *accumulation, + const half *learning_rate, const half *gradient, + const half *momentum, cudaStream_t cuda_stream); +template void MomentumUpdateVariable(const size_t size, half *variable, half *accumulation, + const float *learning_rate, const half *gradient, + const float *momentum, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/momentum_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/momentum_impl.cuh index 2993e04ff3..5405f5ef1d 100755 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/momentum_impl.cuh +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/momentum_impl.cuh @@ -18,8 +18,8 @@ #define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_ #include "device/gpu/cuda_common.h" -template -void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const T *learning_rate, const T *gradient, - const T *momentum, cudaStream_t cuda_stream); +template +void MomentumUpdateVariable(const size_t size, T *variable, T *accumulation, const S *learning_rate, const T *gradient, + const S *momentum, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_MOMENTUMIMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/layer_norm_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/layer_norm_gpu_kernel.cc index e67b745ab3..19e4dc17a6 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/layer_norm_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/nn/layer_norm_gpu_kernel.cc @@ -27,5 +27,14 @@ MS_REG_GPU_KERNEL_ONE(LayerNorm, .AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32), LayerNormGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(LayerNorm, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + LayerNormGpuKernel, half) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/layer_norm_grad_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/layer_norm_grad_gpu_kernel.cc index e268161349..7991d42499 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/layer_norm_grad_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/nn/layer_norm_grad_gpu_kernel.cc @@ -29,5 +29,16 @@ MS_REG_GPU_KERNEL_ONE(LayerNormGrad, .AddOutputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32), LayerNormGradGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(LayerNormGrad, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + LayerNormGradGpuKernel, half) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/momentum_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/momentum_gpu_kernel.cc index 4a77f7342b..e8b2b17706 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/momentum_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/nn/momentum_gpu_kernel.cc @@ -18,7 +18,7 @@ namespace mindspore { namespace kernel { -MS_REG_GPU_KERNEL_ONE(ApplyMomentum, +MS_REG_GPU_KERNEL_TWO(ApplyMomentum, KernelAttr() .AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32) @@ -26,8 +26,8 @@ MS_REG_GPU_KERNEL_ONE(ApplyMomentum, .AddInputAttr(kNumberTypeFloat32) .AddInputAttr(kNumberTypeFloat32) .AddOutputAttr(kNumberTypeFloat32), - MomentumGpuKernel, float) -MS_REG_GPU_KERNEL_ONE(ApplyMomentum, + MomentumGpuKernel, float, float) +MS_REG_GPU_KERNEL_TWO(ApplyMomentum, KernelAttr() .AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeFloat16) @@ -35,6 +35,15 @@ MS_REG_GPU_KERNEL_ONE(ApplyMomentum, .AddInputAttr(kNumberTypeFloat16) .AddInputAttr(kNumberTypeFloat16) .AddOutputAttr(kNumberTypeFloat16), - MomentumGpuKernel, half) + MomentumGpuKernel, half, half) +MS_REG_GPU_KERNEL_TWO(ApplyMomentum, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat16), + MomentumGpuKernel, half, float) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/momentum_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/momentum_gpu_kernel.h index 8452c177db..5abfb9e97b 100644 --- a/mindspore/ccsrc/kernel/gpu/nn/momentum_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/nn/momentum_gpu_kernel.h @@ -23,7 +23,7 @@ #include "kernel/gpu/cuda_impl/momentum_impl.cuh" namespace mindspore { namespace kernel { -template +template class MomentumGpuKernel : public GpuKernel { public: MomentumGpuKernel() @@ -37,9 +37,9 @@ class MomentumGpuKernel : public GpuKernel { void *stream_ptr) override { T *variable = GetDeviceAddress(inputs, 0); T *accumulation = GetDeviceAddress(inputs, 1); - T *learning_rate = GetDeviceAddress(inputs, 2); + S *learning_rate = GetDeviceAddress(inputs, 2); T *gradient = GetDeviceAddress(inputs, 3); - T *momentum = GetDeviceAddress(inputs, 4); + S *momentum = GetDeviceAddress(inputs, 4); MomentumUpdateVariable(inputs[0]->size / sizeof(T), variable, accumulation, learning_rate, gradient, momentum, reinterpret_cast(stream_ptr)); return true; @@ -53,9 +53,9 @@ class MomentumGpuKernel : public GpuKernel { variable_size_ = sizeof(T); accumulation_size_ = sizeof(T); - learning_rate_size_ = sizeof(T); + learning_rate_size_ = sizeof(S); gradient_size_ = sizeof(T); - momentum_size_ = sizeof(T); + momentum_size_ = sizeof(S); auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); for (size_t i = 0; i < variable_shape.size(); i++) {