parent
54481c30c8
commit
cc80c76687
@ -0,0 +1,63 @@
|
||||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""hsigmoid"""
|
||||
import _akg.topi as topi
|
||||
import _akg.tvm as tvm
|
||||
from _akg.topi import tag
|
||||
|
||||
|
||||
@tvm.tag_scope(tag=tag.ELEMWISE)
|
||||
def topi_nn_hsigmoid(x):
|
||||
"""
|
||||
topi hsigmoid
|
||||
Args:
|
||||
x:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
return tvm.compute(x.shape, lambda *i: tvm.if_then_else(x(*i) <= -3, 0,
|
||||
tvm.if_then_else(x(*i) >= 3, 1,
|
||||
(x(*i) + 3) / 6)))
|
||||
|
||||
|
||||
def Hsigmoid(x):
|
||||
"""
|
||||
Hsigmoid
|
||||
Args:
|
||||
x:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
return topi_nn_hsigmoid(x)
|
||||
|
||||
|
||||
def gpu_schedule_Hsigmoid(outs):
|
||||
"""
|
||||
gpu schedule Hsigmoid
|
||||
Args:
|
||||
outs:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
device = 'cuda'
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
raise SystemError("Skip because %s is not enabled" % device)
|
||||
with tvm.target.create(device):
|
||||
sch = topi.cuda.schedule_elemwise(outs)
|
||||
return sch
|
@ -0,0 +1,51 @@
|
||||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Hsigmoid grad"""
|
||||
import _akg.topi as topi
|
||||
import _akg.tvm as tvm
|
||||
|
||||
|
||||
def HsigmoidGrad(y_grad, x):
|
||||
"""
|
||||
HsigmoidGrad
|
||||
Args:
|
||||
y_grad:
|
||||
x:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
return tvm.compute(x.shape, lambda *i: tvm.if_then_else(x(*i) <= -3, 0,
|
||||
tvm.if_then_else(x(*i) >= 3, 0,
|
||||
y_grad(*i) / 6)))
|
||||
|
||||
|
||||
def gpu_schedule_HsigmoidGrad(outs):
|
||||
"""
|
||||
gpu schedule ReLU6Grad
|
||||
Args:
|
||||
outs:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
device = 'cuda'
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
raise SystemError("Skip because %s is not enabled" % device)
|
||||
|
||||
with tvm.target.create(device):
|
||||
sch = topi.cuda.schedule_elemwise(outs)
|
||||
return sch
|
@ -0,0 +1,63 @@
|
||||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""hswish"""
|
||||
import _akg.topi as topi
|
||||
import _akg.tvm as tvm
|
||||
from _akg.topi import tag
|
||||
|
||||
|
||||
@tvm.tag_scope(tag=tag.ELEMWISE)
|
||||
def topi_nn_hswish(x):
|
||||
"""
|
||||
topi hswish
|
||||
Args:
|
||||
x:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
return tvm.compute(x.shape, lambda *i: tvm.if_then_else(x(*i) <= -3, 0,
|
||||
tvm.if_then_else(x(*i) >= 3, x(*i),
|
||||
x(*i) * (x(*i) + 3) / 6)))
|
||||
|
||||
|
||||
def Hswish(x):
|
||||
"""
|
||||
Hswish
|
||||
Args:
|
||||
x:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
return topi_nn_hswish(x)
|
||||
|
||||
|
||||
def gpu_schedule_Hswish(outs):
|
||||
"""
|
||||
gpu schedule Hswish
|
||||
Args:
|
||||
outs:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
device = 'cuda'
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
raise SystemError("Skip because %s is not enabled" % device)
|
||||
with tvm.target.create(device):
|
||||
sch = topi.cuda.schedule_elemwise(outs)
|
||||
return sch
|
@ -0,0 +1,53 @@
|
||||
# Copyright 2019 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""HswishGrad"""
|
||||
import _akg.topi as topi
|
||||
import _akg.tvm as tvm
|
||||
|
||||
|
||||
def HswishGrad(y_grad, x):
|
||||
"""
|
||||
HswishGrad
|
||||
Args:
|
||||
y_grad:
|
||||
x:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
shape = x.shape
|
||||
|
||||
res0 = tvm.compute(shape, lambda *i: tvm.if_then_else(x(*i) <= -3, 0, y_grad(*i) * (2 * x(*i) + 3) / 6))
|
||||
res6 = tvm.compute(shape, lambda *i: tvm.if_then_else(x(*i) >= 3, y_grad(*i), res0(*i)))
|
||||
return res6
|
||||
|
||||
|
||||
def gpu_schedule_HswishGrad(outs):
|
||||
"""
|
||||
gpu schedule HswishGrad
|
||||
Args:
|
||||
outs:
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
device = 'cuda'
|
||||
ctx = tvm.context(device, 0)
|
||||
if not ctx.exist:
|
||||
raise SystemError("Skip because %s is not enabled" % device)
|
||||
|
||||
with tvm.target.create(device):
|
||||
sch = topi.cuda.schedule_elemwise(outs)
|
||||
return sch
|
@ -0,0 +1,169 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <stdint.h>
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/fill.h>
|
||||
#include <thrust/reduce.h>
|
||||
#include <thrust/system/cuda/execution_policy.h>
|
||||
#include "batchnorm_fold2_impl.cuh"
|
||||
#include "batchnorm_fold_impl.cuh"
|
||||
#include "include/cuda_runtime.h"
|
||||
|
||||
|
||||
template <typename T>
|
||||
__global__ void BatchNormFold2Kernel(const T *x, const T *beta, const T *gamma, const T *batch_std, const T *batch_mean,
|
||||
const T *running_std, const T *running_mean, const int *global_step, T *y,
|
||||
int freeze_bn, size_t N, size_t C, size_t H, size_t W) {
|
||||
int c = 0;
|
||||
size_t num_count = N * C * H * W;
|
||||
if (*global_step < freeze_bn) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) {
|
||||
c = i / (H * W) % C;
|
||||
y[i] = x[i] * running_std[c] / batch_std[c] + beta[c] - gamma[c] * batch_mean[c] / batch_std[c];
|
||||
}
|
||||
} else {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) {
|
||||
c = i / (H * W) % C;
|
||||
y[i] = x[i] + beta[c] - gamma[c] * running_mean[c] / running_std[c];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void BatchNormFold2GradReduce1(const T *dout, T *tmp, const T *x, T *tmp2, size_t N, size_t C, size_t HW) {
|
||||
int n = 0;
|
||||
int c = 0;
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < N * C; i += blockDim.x * gridDim.x) {
|
||||
n = i / C;
|
||||
c = i % C;
|
||||
tmp[c * N + n] = thrust::reduce(thrust::seq, dout + i * HW, dout + (i + 1) * HW, 0.f, thrust::plus<T>());
|
||||
tmp2[c * N + n] = thrust::reduce(thrust::seq, x + i * HW, x + (i + 1) * HW, 0.f, thrust::plus<T>());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void BatchNormFold2GradReduce2(const T *tmp, T *d_beta, const T *tmp2, T *reduce_x, size_t N, size_t C) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < C; i += blockDim.x * gridDim.x) {
|
||||
d_beta[i] = thrust::reduce(thrust::seq, tmp + i * N, tmp + (i + 1) * N, 0.f, thrust::plus<T>());
|
||||
reduce_x[i] = thrust::reduce(thrust::seq, tmp2 + i * N, tmp2 + (i + 1) * N, 0.f, thrust::plus<T>());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void BatchNormFold2GradNotFreeze(const T *d_beta, const T *reduce_x, const T *batch_mean, const T *batch_std,
|
||||
const T *running_mean, const T *running_std, const T *gamma, T *d_gamma,
|
||||
T *d_batch_mean, T *d_batch_std, size_t C) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < C; i += blockDim.x * gridDim.x) {
|
||||
d_gamma[i] = -d_beta[i] * batch_mean[i] / batch_std[i];
|
||||
d_batch_mean[i] = -d_beta[i] * gamma[i] / batch_std[i];
|
||||
d_batch_std[i] =
|
||||
(d_beta[i] * gamma[i] * batch_mean[i] - reduce_x[i] * running_std[i]) / batch_std[i] / batch_std[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void BatchNormFold2GradFreeze(const T *d_beta, const T *running_mean, const T *running_std, T *d_gamma,
|
||||
size_t C) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < C; i += blockDim.x * gridDim.x) {
|
||||
d_gamma[i] = -d_beta[i] * running_mean[i] / running_std[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void BatchNormFold2GradMul(const T *dout, const T *x, T *tmp_x, size_t NCHW) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < NCHW; i += blockDim.x * gridDim.x) {
|
||||
tmp_x[i] = dout[i] * x[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void DxMul(size_t N, size_t C, size_t HW, const T *batch_std, const T *running_std, T *d_x) {
|
||||
int c = 0;
|
||||
size_t num_count = N * C * HW;
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) {
|
||||
c = (i / HW) % C;
|
||||
d_x[i] = d_x[i] * running_std[c] / batch_std[c];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void BatchNormFold2Forward(const T *x, const T *beta, const T *gamma, const T *batch_std, const T *batch_mean,
|
||||
const T *running_std, const T *running_mean, const int *global_step, T *y, int freeze_bn,
|
||||
size_t N, size_t C, size_t H, size_t W, cudaStream_t cuda_stream) {
|
||||
auto num_count = N * C * H * W;
|
||||
BatchNormFold2Kernel<<<GET_BLOCKS(num_count), GET_THREADS, 0, cuda_stream>>>(
|
||||
x, beta, gamma, batch_std, batch_mean, running_std, running_mean, global_step, y, freeze_bn, N, C, H, W);
|
||||
}
|
||||
|
||||
template void BatchNormFold2Forward<float>(const float *x, const float *beta, const float *gamma,
|
||||
const float *batch_std, const float *batch_mean, const float *running_std,
|
||||
const float *running_mean, const int *global_step, float *y, int freeze_bn,
|
||||
size_t N, size_t C, size_t H, size_t W, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void BatchNormFold2GradReduce(const T *dout, const T *x, T *d_beta, T *tmp, T *reduce_x, T *tmp2, T *tmp_x, size_t N,
|
||||
size_t C, size_t H, size_t W, cudaStream_t cuda_stream) {
|
||||
auto hw = H * W;
|
||||
auto num_count = N * C * H * W;
|
||||
BatchNormFold2GradMul<<<GET_BLOCKS(num_count), GET_THREADS, 0, cuda_stream>>>(dout, x, tmp_x, num_count);
|
||||
BatchNormFold2GradReduce1<<<GET_BLOCKS(N * C), GET_THREADS, 0, cuda_stream>>>(dout, tmp, tmp_x, tmp2, N, C, hw);
|
||||
BatchNormFold2GradReduce2<<<GET_BLOCKS(C), GET_THREADS, 0, cuda_stream>>>(tmp, d_beta, tmp2, reduce_x, N, C);
|
||||
}
|
||||
|
||||
template void BatchNormFold2GradReduce<float>(const float *dout, const float *x, float *d_beta, float *tmp,
|
||||
float *reduce_x, float *tmp2, float *tmp_x, size_t N, size_t C, size_t H,
|
||||
size_t W, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void CalBatchNormFold2GradNotFreeze(const T *d_beta, const T *reduce_x, const T *batch_mean, const T *batch_std,
|
||||
const T *running_mean, const T *running_std, const T *gamma, T *d_gamma,
|
||||
T *d_batch_mean, T *d_batch_std, size_t C, cudaStream_t cuda_stream) {
|
||||
BatchNormFold2GradNotFreeze<<<GET_BLOCKS(C), GET_THREADS, 0, cuda_stream>>>(
|
||||
d_beta, reduce_x, batch_mean, batch_std, running_mean, running_std, gamma, d_gamma, d_batch_mean, d_batch_std, C);
|
||||
}
|
||||
|
||||
template void CalBatchNormFold2GradNotFreeze<float>(const float *d_beta, const float *reduce_x, const float *batch_mean,
|
||||
const float *batch_std, const float *running_mean,
|
||||
const float *running_std, const float *gamma, float *d_gamma,
|
||||
float *d_batch_mean, float *d_batch_std, size_t C,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void CalBatchNormFold2GradFreeze(const T *d_beta, const T *reduce_x, const T *batch_mean, const T *batch_std,
|
||||
const T *running_mean, const T *running_std, const T *gamma, T *d_gamma,
|
||||
T *d_batch_mean, T *d_batch_std, size_t C, cudaStream_t cuda_stream) {
|
||||
BatchNormFold2GradFreeze<<<GET_BLOCKS(C), GET_THREADS, 0, cuda_stream>>>(d_beta, running_mean, running_std, d_gamma,
|
||||
C);
|
||||
ThrustFillWith(d_batch_mean, C, (T)0.f, cuda_stream);
|
||||
ThrustFillWith(d_batch_std, C, (T)0.f, cuda_stream);
|
||||
}
|
||||
|
||||
template void CalBatchNormFold2GradFreeze<float>(const float *d_beta, const float *reduce_x, const float *batch_mean,
|
||||
const float *batch_std, const float *running_mean,
|
||||
const float *running_std, const float *gamma, float *d_gamma,
|
||||
float *d_batch_mean, float *d_batch_std, size_t C,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void CalBatchNormFold2GradNotFreezeDxMul(const T *batch_std, const T *running_std, T *d_x, size_t N, size_t C, size_t H,
|
||||
size_t W, cudaStream_t cuda_stream) {
|
||||
DxMul<<<GET_BLOCKS(N * C * H * W), GET_THREADS, 0, cuda_stream>>>(N, C, H * W, batch_std, running_std, d_x);
|
||||
}
|
||||
|
||||
template void CalBatchNormFold2GradNotFreezeDxMul<float>(const float *batch_std, const float *running_std, float *d_x,
|
||||
size_t N, size_t C, size_t H, size_t W,
|
||||
cudaStream_t cuda_stream);
|
@ -0,0 +1,40 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORMFOLD2_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORMFOLD2_H_
|
||||
|
||||
#include "device/gpu/cuda_common.h"
|
||||
template <typename T>
|
||||
void BatchNormFold2Forward(const T *x, const T *beta, const T *gamma, const T *batch_std, const T *batch_mean,
|
||||
const T *running_std, const T *running_mean, const int *global_step, T *y, int freeze_bn,
|
||||
size_t N, size_t C, size_t H, size_t W, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
void CalBatchNormFold2GradNotFreeze(const T *d_beta, const T *reduce_x, const T *batch_mean, const T *batch_std,
|
||||
const T *running_mean, const T *running_std, const T *gamma, T *d_gamma,
|
||||
T *d_batch_mean, T *d_batch_std, size_t C, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
void CalBatchNormFold2GradFreeze(const T *d_beta, const T *reduce_x, const T *batch_mean, const T *batch_std,
|
||||
const T *running_mean, const T *running_std, const T *gamma, T *d_gamma,
|
||||
T *d_batch_mean, T *d_batch_std, size_t C, cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
void BatchNormFold2GradReduce(const T *dout, const T *x, T *d_beta, T *tmp, T *reduce_x, T *tmp2, T *tmp_x, size_t N,
|
||||
size_t C, size_t H, size_t W, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void CalBatchNormFold2GradNotFreezeDxMul(const T *batch_std, const T *running_std, T *d_x, size_t N, size_t C, size_t H,
|
||||
size_t W, cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORMFOLD2_H_
|
@ -0,0 +1,88 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <thrust/device_ptr.h>
|
||||
#include <thrust/fill.h>
|
||||
#include <thrust/system/cuda/execution_policy.h>
|
||||
#include "batchnorm_fold_impl.cuh"
|
||||
#include "device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
__global__ void UpdateRunningStd(int channel_size, const double epsilon, T* running_std) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channel_size; i += blockDim.x * gridDim.x) {
|
||||
running_std[i] = sqrtf(running_std[i] + epsilon);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void UpdateBatchStd(int channel_size, T* batch_std) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channel_size; i += blockDim.x * gridDim.x) {
|
||||
batch_std[i] = 1 / batch_std[i];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void CalDx(const T* d_batch_mean, const T* d_batch_std, const T* x, const T* batch_mean, const T* batch_std,
|
||||
int batch_size, int channel_size, int height, int width, T* dx) {
|
||||
int n = batch_size * channel_size * height * width;
|
||||
int normal_size = batch_size * height * width;
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += blockDim.x * gridDim.x) {
|
||||
int channel_index = i / (height * width) % channel_size;
|
||||
dx[i] = d_batch_mean[channel_index] / normal_size +
|
||||
d_batch_std[channel_index] * (x[i] - batch_mean[channel_index]) / batch_std[channel_index] / normal_size;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalUpdateRunningStd(int channel_size, double epsilon, T* running_std, cudaStream_t cuda_stream) {
|
||||
UpdateRunningStd<<<GET_BLOCKS(channel_size), GET_THREADS, 0, cuda_stream>>>(channel_size, epsilon, running_std);
|
||||
return;
|
||||
}
|
||||
|
||||
template void CalUpdateRunningStd<float>(int channel_size, double epsilon, float* running_std,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void CalUpdateBatchStd(int channel_size, T* batch_std, cudaStream_t cuda_stream) {
|
||||
UpdateBatchStd<<<GET_BLOCKS(channel_size), GET_THREADS, 0, cuda_stream>>>(channel_size, batch_std);
|
||||
return;
|
||||
}
|
||||
|
||||
template void CalUpdateBatchStd<float>(int channel_size, float* batch_std, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void CalBatchNormFoldGrad(const T* d_batch_mean, const T* d_batch_std, const T* x, const T* batch_mean,
|
||||
const T* batch_std, int batch_size, int channel_size, int height, int width, T* dx,
|
||||
cudaStream_t cuda_stream) {
|
||||
CalDx<<<GET_BLOCKS(batch_size * channel_size * height * width), GET_THREADS, 0, cuda_stream>>>(
|
||||
d_batch_mean, d_batch_std, x, batch_mean, batch_std, batch_size, channel_size, height, width, dx);
|
||||
}
|
||||
|
||||
template void CalBatchNormFoldGrad<float>(const float* d_batch_mean, const float* d_batch_std, const float* x,
|
||||
const float* batch_mean, const float* batch_std, int batch_size,
|
||||
int channel_size, int height, int width, float* dx, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void ThrustFillWith(T* array, int size, T tofill, cudaStream_t cuda_stream) {
|
||||
thrust::device_ptr<T> dev_ptr(array);
|
||||
thrust::fill(thrust::cuda::par.on(cuda_stream), dev_ptr, dev_ptr + size, tofill);
|
||||
}
|
||||
|
||||
template void ThrustFillWith<float>(float* array, int size, float tofill, cudaStream_t cuda_stream);
|
||||
|
@ -0,0 +1,32 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORM_FOLD_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHNORM_FOLD_H_
|
||||
|
||||
template <typename T>
|
||||
void CalUpdateRunningStd(int channel_size, double epsilon, T* running_std, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void CalUpdateBatchStd(int channel_size, T* batch_std, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void CalBatchNormFoldGrad(const T* d_batch_mean, const T* d_batch_std, const T* x, const T* batch_mean,
|
||||
const T* batch_std, int batch_size, int channel_size, int height, int width, T* dx,
|
||||
cudaStream_t cuda_stream);
|
||||
template <typename T>
|
||||
void ThrustFillWith(T* array, int size, T tofill, cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_BATCHNORM_FOLD_H_
|
@ -0,0 +1,66 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <thrust/reduce.h>
|
||||
#include "correction_mul_impl.cuh"
|
||||
#include "device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
__global__ void CorrectionMul(const T* weight, const T* gamma, const T* running_std, const int batchsize, const int chw,
|
||||
T* output) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < batchsize * chw; i += blockDim.x * gridDim.x) {
|
||||
int n = i / chw;
|
||||
output[i] = weight[i] * gamma[n] / running_std[n];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void Mul(int N, const T* a, const T* b, T* c) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) {
|
||||
c[i] = a[i] * b[i];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void Reduce(int N, int CHW, const T* tmp, const T* running_std, T* d_gamma) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N; i += blockDim.x * gridDim.x) {
|
||||
d_gamma[i] = thrust::reduce(thrust::seq, tmp + i * CHW, tmp + (i + 1) * CHW, 0.f, thrust::plus<T>());
|
||||
d_gamma[i] = d_gamma[i] / running_std[i];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalCorrectionMul(const T* weight, const T* gamma, const T* running_std, int N, int C, int H, int W, T* output,
|
||||
cudaStream_t cuda_stream) {
|
||||
CorrectionMul<<<GET_BLOCKS(N * C * H * W), GET_THREADS, 0, cuda_stream>>>(weight, gamma, running_std, N, C * H * W,
|
||||
output);
|
||||
}
|
||||
|
||||
template void CalCorrectionMul<float>(const float* weight, const float* gamma, const float* running_std, int N, int C,
|
||||
int H, int W, float* output, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void CalCorrectionMulGrad(const T* d_out, const T* weight, const T* running_std, int N, int C, int H, int W, T* d_gamma,
|
||||
T* tmp, cudaStream_t cuda_stream) {
|
||||
Mul<<<GET_BLOCKS(N * C * H * W), GET_THREADS, 0, cuda_stream>>>(N * C * H * W, d_out, weight, tmp);
|
||||
Reduce<<<GET_BLOCKS(N), GET_THREADS, 0, cuda_stream>>>(N, C * H * W, tmp, running_std, d_gamma);
|
||||
}
|
||||
|
||||
template void CalCorrectionMulGrad<float>(const float* d_out, const float* weight, const float* running_std, int N,
|
||||
int C, int H, int W, float* d_gamma, float* tmp, cudaStream_t cuda_stream);
|
@ -0,0 +1,27 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CORRECTIONMUL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CORRECTIONMUL_H_
|
||||
|
||||
template <typename T>
|
||||
void CalCorrectionMul(const T* weight, const T* gamma, const T* running_std, int batch_size, int channel_size,
|
||||
int height, int width, T* output, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void CalCorrectionMulGrad(const T* d_out, const T* weight, const T* running_std, int batch_size, int channel_size,
|
||||
int height, int width, T* d_gamma, T* tmp, cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_CORRECTIONMUL_H_
|
@ -0,0 +1,47 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <stdint.h>
|
||||
#include "cross_entropy_cuda_impl.cuh"
|
||||
#include "include/cuda_runtime.h"
|
||||
|
||||
__global__ void CalCrossEntropyWithGradKernel(const float *softmax_logits, const float *log_softmax_logits,
|
||||
const float *labels, const int batch_size, const int num_classes,
|
||||
float *loss, float *dx) {
|
||||
extern __shared__ float loss_shared[];
|
||||
const float mean_scale = 1.0f / static_cast<float>(batch_size);
|
||||
|
||||
loss_shared[threadIdx.x] = 0;
|
||||
for (int i = threadIdx.x * num_classes; i < (threadIdx.x + 1) * num_classes; ++i) {
|
||||
loss_shared[threadIdx.x] -= log_softmax_logits[i] * labels[i];
|
||||
dx[i] = (softmax_logits[i] - labels[i]) * mean_scale;
|
||||
}
|
||||
__syncthreads();
|
||||
if (threadIdx.x == 0) {
|
||||
*loss = 0;
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
*loss += loss_shared[i];
|
||||
}
|
||||
*loss *= mean_scale;
|
||||
}
|
||||
}
|
||||
|
||||
void CalCrossEntropyWithGrad(const float *softmax_logits, const float *log_softmax_logits, const float *labels,
|
||||
const int batch_size, const int num_classes, float *loss, float *dx,
|
||||
cudaStream_t cuda_stream) {
|
||||
CalCrossEntropyWithGradKernel<<<1, batch_size, batch_size * sizeof(float), cuda_stream>>>(
|
||||
softmax_logits, log_softmax_logits, labels, batch_size, num_classes, loss, dx);
|
||||
}
|
@ -0,0 +1,26 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPYCUDAIMPL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPYCUDAIMPL_H_
|
||||
|
||||
#include "device/gpu/cuda_common.h"
|
||||
|
||||
void CalCrossEntropyWithGrad(const float *softmax_logits, const float *log_softmax_logits, const float *labels,
|
||||
const int batch_size, const int num_classes, float *loss, float *dx,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CROSSENTROPYCUDAIMPL_H_
|
@ -0,0 +1,47 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <stdint.h>
|
||||
#include "dropout_impl.cuh"
|
||||
#include "include/cuda_runtime.h"
|
||||
|
||||
__global__ void DropoutForwardKernel(const float *input, float *mask, float *output, size_t num_count,
|
||||
float drop_prob) {
|
||||
float scale = 1.f / (1.f - drop_prob);
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) {
|
||||
mask[i] = mask[i] > drop_prob;
|
||||
output[i] = scale * input[i] * mask[i];
|
||||
}
|
||||
}
|
||||
|
||||
void DropoutForward(const float *input, float *mask, float *output, size_t num_count, float drop_prob,
|
||||
cudaStream_t cuda_stream) {
|
||||
DropoutForwardKernel<<<GET_BLOCKS(num_count), GET_THREADS, 0, cuda_stream>>>(input, mask, output, num_count,
|
||||
drop_prob);
|
||||
}
|
||||
|
||||
__global__ void DropoutBackwardKernel(const float *dy, const float *mask, float *dx, size_t num_count,
|
||||
float drop_prob) {
|
||||
float scale = 1.f / (1.f - drop_prob);
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_count; i += blockDim.x * gridDim.x) {
|
||||
dx[i] = scale * dy[i] * mask[i];
|
||||
}
|
||||
}
|
||||
|
||||
void DropoutBackward(const float *dy, const float *mask, float *dx, size_t num_count, float drop_prob,
|
||||
cudaStream_t cuda_stream) {
|
||||
DropoutBackwardKernel<<<GET_BLOCKS(num_count), GET_THREADS, 0, cuda_stream>>>(dy, mask, dx, num_count, drop_prob);
|
||||
}
|
@ -0,0 +1,26 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_
|
||||
|
||||
#include "device/gpu/cuda_common.h"
|
||||
void DropoutForward(const float *input, float *mask, float *output, size_t num_count, float drop_prob,
|
||||
cudaStream_t cuda_stream);
|
||||
void DropoutBackward(const float *dy, const float *mask, float *dx, size_t num_count, float drop_prob,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_DROPOUT_H_
|
@ -0,0 +1,133 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <thrust/extrema.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/pair.h>
|
||||
#include "device/gpu/cuda_common.h"
|
||||
#include "fake_quant_impl.cuh"
|
||||
|
||||
__global__ void FakeQuantize(const float* input, float* output, const int size, const float* nudge_min,
|
||||
const float* nudge_max, const float* scale, bool symmetric) {
|
||||
float input_x = 0.f;
|
||||
int nudge_input = 0;
|
||||
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) {
|
||||
input_x = input[i];
|
||||
// clamp input x
|
||||
if (input_x < nudge_min[0]) {
|
||||
input_x = nudge_min[0];
|
||||
}
|
||||
if (input_x > nudge_max[0]) {
|
||||
input_x = nudge_max[0];
|
||||
}
|
||||
// clamp shift
|
||||
nudge_input = floor((input_x - nudge_min[0]) / scale[0] + 0.5f);
|
||||
|
||||
// quantize
|
||||
output[i] = nudge_input * scale[0] + nudge_min[0];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
__global__ void FakeQuantizeGrad(const float* input, const float* gradient, float* output, const int size,
|
||||
const float* nudge_min, const float* nudge_max) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) {
|
||||
if (input[i] < nudge_min[0] || input[i] > nudge_max[0]) {
|
||||
output[i] = 0;
|
||||
} else {
|
||||
output[i] = gradient[i];
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
__global__ void NudgeMinMax(const float* input_min, const float* input_max, const float quant_min,
|
||||
const float quant_max, float* nudge_min, float* nudge_max, float* scale) {
|
||||
float zp_from_min = 0.f;
|
||||
if ((quant_max - quant_min) == 0 || (*input_max - *input_min) == 0) {
|
||||
*scale = 0.f;
|
||||
zp_from_min = 0.f;
|
||||
} else {
|
||||
*scale = (*input_max - *input_min) / (quant_max - quant_min);
|
||||
zp_from_min = quant_min - *input_min / *scale;
|
||||
}
|
||||
|
||||
float nudge_zp = 0.f;
|
||||
if (zp_from_min <= quant_min) {
|
||||
nudge_zp = quant_min;
|
||||
} else if (zp_from_min >= quant_max) {
|
||||
nudge_zp = quant_max;
|
||||
} else {
|
||||
nudge_zp = round(zp_from_min);
|
||||
}
|
||||
|
||||
*nudge_min = (quant_min - nudge_zp) * (*scale);
|
||||
*nudge_max = (quant_max - nudge_zp) * (*scale);
|
||||
return;
|
||||
}
|
||||
|
||||
__global__ void UpdateInputMinMaxWithEMA(float* input_min, float* input_max, const float min, const float max,
|
||||
const float decay) {
|
||||
*input_min = decay * (min) + (1 - decay) * (*input_min);
|
||||
*input_min = *input_min > 0 ? 0 : *input_min;
|
||||
*input_max = decay * (max) + (1 - decay) * (*input_max);
|
||||
*input_max = *input_max < 0 ? 0 : *input_max;
|
||||
return;
|
||||
}
|
||||
|
||||
__global__ void UpdateInputMinMax(float* input_min, float* input_max, const float min, const float max) {
|
||||
*input_min = min;
|
||||
*input_max = max;
|
||||
}
|
||||
|
||||
void CalFakeQuantize(const float* input, float* output, const int size, const float* nudge_min, const float* nudge_max,
|
||||
const float* scale, bool symmetric, cudaStream_t cuda_stream) {
|
||||
FakeQuantize<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(input, output, size, nudge_min, nudge_max, scale,
|
||||
symmetric);
|
||||
return;
|
||||
}
|
||||
|
||||
void CalFakeQuantizeGrad(const float* input, const float* gradient, float* output, const int size,
|
||||
const float* nudge_min, const float* nudge_max, cudaStream_t cuda_stream) {
|
||||
FakeQuantizeGrad<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(input, gradient, output, size, nudge_min,
|
||||
nudge_max);
|
||||
return;
|
||||
}
|
||||
|
||||
void CalNudge(const float* input_min, const float* input_max, const float quant_min, const float quant_max,
|
||||
float* nudge_min, float* nudge_max, float* scale, cudaStream_t cuda_stream) {
|
||||
NudgeMinMax<<<1, 1>>>(input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale);
|
||||
return;
|
||||
}
|
||||
|
||||
void CalMinMax(float* input, float* input_min, float* input_max, const int size, const float ema_decay, const bool ema,
|
||||
cudaStream_t cuda_stream) {
|
||||
float minel = 0.f;
|
||||
float maxel = 0.f;
|
||||
thrust::pair<thrust::device_ptr<float>, thrust::device_ptr<float>> tuple;
|
||||
tuple = thrust::minmax_element(thrust::device_pointer_cast(input), thrust::device_pointer_cast(input) + size);
|
||||
minel = tuple.first[0];
|
||||
maxel = tuple.second[0];
|
||||
|
||||
if (ema) {
|
||||
UpdateInputMinMaxWithEMA<<<1, 1>>>(input_min, input_max, minel, maxel, ema_decay);
|
||||
} else {
|
||||
UpdateInputMinMax<<<1, 1>>>(input_min, input_max, minel, maxel);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
@ -0,0 +1,32 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_
|
||||
|
||||
void CalFakeQuantize(const float* input, float* output, const int size, const float* nudge_min, const float* nudge_max,
|
||||
const float* scale, bool symmetric, cudaStream_t cuda_stream);
|
||||
|
||||
void CalFakeQuantizeGrad(const float* input, const float* gradient, float* output, const int size,
|
||||
const float* nudge_min, const float* nudge_max, cudaStream_t cuda_stream);
|
||||
|
||||
void CalNudge(const float* input_min, const float* input_max, const float quant_min, const float quant_max,
|
||||
float* nudge_min, float* nudge_max, float* scale, cudaStream_t cuda_stream);
|
||||
|
||||
void CalMinMax(float* input, float* input_min, float* input_max, const int size, const float ema_decay, const bool ema,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_
|
@ -0,0 +1,174 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <thrust/extrema.h>
|
||||
#include <thrust/device_vector.h>
|
||||
#include <thrust/execution_policy.h>
|
||||
#include <thrust/reduce.h>
|
||||
#include <thrust/pair.h>
|
||||
#include "fake_quant_per_channel_impl.cuh"
|
||||
#include "device/gpu/cuda_common.h"
|
||||
|
||||
/**
|
||||
* Find the nudge min, max and scale value as output.
|
||||
* @param input_min array
|
||||
* @param input_max array
|
||||
* @param quant_min 1 << bit -1
|
||||
* @param quant_max 0
|
||||
* @param nudge_min array
|
||||
* @param nudge_max array
|
||||
* @param scale array
|
||||
* @param channel_num
|
||||
* @return
|
||||
*/
|
||||
__global__ void NudgeMinMaxPerChannel(const float* input_min, const float* input_max, const float quant_min,
|
||||
const float quant_max, float* nudge_min, float* nudge_max, float* scale,
|
||||
int channel_num) {
|
||||
float zp_from_min = 0.f;
|
||||
float nudge_zp = 0.f;
|
||||
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channel_num; i += blockDim.x * gridDim.x) {
|
||||
if ((quant_max - quant_min) == 0 || (input_max[i] - input_min[i]) == 0) {
|
||||
scale[i] = 0.f;
|
||||
zp_from_min = 0.f;
|
||||
} else {
|
||||
scale[i] = (input_max[i] - input_min[i]) / (quant_max - quant_min);
|
||||
zp_from_min = quant_min - input_min[i] / scale[i];
|
||||
}
|
||||
|
||||
if (zp_from_min <= quant_min) {
|
||||
nudge_zp = quant_min;
|
||||
} else if (zp_from_min >= quant_max) {
|
||||
nudge_zp = quant_max;
|
||||
} else {
|
||||
nudge_zp = round(zp_from_min);
|
||||
}
|
||||
|
||||
nudge_min[i] = (quant_min - nudge_zp) * (scale[i]);
|
||||
nudge_max[i] = (quant_max - nudge_zp) * (scale[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void CalNudgePerChannel(const float* input_min, const float* input_max, const float quant_min, const float quant_max,
|
||||
float* nudge_min, float* nudge_max, float* scale, const int channel_num,
|
||||
cudaStream_t cuda_stream) {
|
||||
NudgeMinMaxPerChannel<<<GET_BLOCKS(channel_num), GET_THREADS, 0, cuda_stream>>>(
|
||||
input_min, input_max, quant_min, quant_max, nudge_min, nudge_max, scale, channel_num);
|
||||
}
|
||||
|
||||
/**
|
||||
* Calulate fake quant output accroding by nudge min, nudge max, nudge scale.
|
||||
* @param input - array
|
||||
* @param output - array
|
||||
* @param total_size - int, purpose for cal the per chanel number in filters
|
||||
* @param channel_size - int, purpose for cal the per channel number in filters
|
||||
* @param nudge_min - array
|
||||
* @param nudge_max - array
|
||||
* @param scale - array
|
||||
* @return
|
||||
*/
|
||||
__global__ void FakeQuantizePerChannel(const float* input, float* output, const int total_size, const int channel_size,
|
||||
const float* nudge_min, const float* nudge_max, const float* scale,
|
||||
bool symmetric) {
|
||||
float input_x = 0.f;
|
||||
int nudge_input = 0;
|
||||
int channel_idx = 0;
|
||||
int per_channel_num = total_size / channel_size;
|
||||
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < total_size; i += blockDim.x * gridDim.x) {
|
||||
input_x = input[i];
|
||||
channel_idx = floor(static_cast<double>(i) / static_cast<double>(per_channel_num));
|
||||
// clamp input x
|
||||
if (input_x < nudge_min[channel_idx]) {
|
||||
input_x = nudge_min[channel_idx];
|
||||
}
|
||||
if (input_x > nudge_max[channel_idx]) {
|
||||
input_x = nudge_max[channel_idx];
|
||||
}
|
||||
// clamp shift
|
||||
nudge_input = floor((input_x - nudge_min[channel_idx]) / scale[channel_idx] + 0.5f);
|
||||
|
||||
// quantize
|
||||
output[i] = nudge_input * scale[channel_idx] + nudge_min[channel_idx];
|
||||
}
|
||||
}
|
||||
|
||||
void CalFakeQuantizePerChannel(const float* input, float* output, const int total_size, const int channel_size,
|
||||
const float* nudge_min, const float* nudge_max, const float* scale, bool symmetric,
|
||||
cudaStream_t cuda_stream) {
|
||||
FakeQuantizePerChannel<<<GET_BLOCKS(total_size), GET_THREADS, 0, cuda_stream>>>(
|
||||
input, output, total_size, channel_size, nudge_min, nudge_max, scale, symmetric);
|
||||
}
|
||||
|
||||
/**
|
||||
* UpdateInputMinMaxPerChannel or UpdateInputMinMaxPerChannel With EMA.
|
||||
* @param input_min
|
||||
* @param input_max
|
||||
* @param min
|
||||
* @param max
|
||||
* @return
|
||||
*/
|
||||
__global__ void UpdateInputMinMaxPerChannel(float* input_min, float* input_max, float* input, int channels,
|
||||
int per_channel_nums, bool ema, float ema_decay) {
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < channels; i += blockDim.x * gridDim.x) {
|
||||
thrust::pair<float*, float*> sum =
|
||||
thrust::minmax_element(thrust::device, input + i * per_channel_nums, input + per_channel_nums * (i + 1));
|
||||
if (ema) {
|
||||
input_min[i] = ema_decay * sum.first[0] + (1 - ema_decay) * input_min[i];
|
||||
input_max[i] = ema_decay * sum.second[0] + (1 - ema_decay) * input_max[i];
|
||||
} else {
|
||||
input_min[i] = sum.first[0];
|
||||
input_max[i] = sum.second[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void UpdateInputMinMaxPerChannelWithEMA(float* input_min, float* input_max, float min, float max,
|
||||
const float decay) {
|
||||
*input_min = decay * (min) + (1 - decay) * (*input_min);
|
||||
*input_max = decay * (max) + (1 - decay) * (*input_max);
|
||||
}
|
||||
|
||||
void CalMinMaxPerChannel(float* input, float* input_min, float* input_max, const int total_size, const int channel_size,
|
||||
const float ema_decay, const bool ema, cudaStream_t cuda_stream) {
|
||||
int per_channel_num = total_size / channel_size;
|
||||
UpdateInputMinMaxPerChannel<<<GET_BLOCKS(channel_size), GET_THREADS, 0, cuda_stream>>>(
|
||||
input_min, input_max, input, channel_size, per_channel_num, ema, ema_decay);
|
||||
}
|
||||
|
||||
__global__ void FakeQuantizePerChannelGrad(const float* input, const float* gradient, float* output,
|
||||
const int total_size, const int channel_size, const float* nudge_min,
|
||||
const float* nudge_max) {
|
||||
int channel_idx = 0;
|
||||
int per_channel_num = total_size / channel_size;
|
||||
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < total_size; i += blockDim.x * gridDim.x) {
|
||||
channel_idx = floor(static_cast<double>(i) / static_cast<double>(per_channel_num));
|
||||
if (input[i] < nudge_min[channel_idx] || input[i] > nudge_max[channel_idx]) {
|
||||
output[i] = 0;
|
||||
} else {
|
||||
output[i] = gradient[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CalFakeQuantizePerChannelGrad(const float* input, const float* gradient, float* output, const int total_num,
|
||||
const int channel_num, const float* nudge_min, const float* nudge_max,
|
||||
cudaStream_t cuda_stream) {
|
||||
FakeQuantizePerChannelGrad<<<GET_BLOCKS(channel_num), GET_THREADS, 0, cuda_stream>>>(
|
||||
input, gradient, output, total_num, channel_num, nudge_min, nudge_max);
|
||||
}
|
||||
|
@ -0,0 +1,35 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_
|
||||
|
||||
void CalNudgePerChannel(const float* input_min, const float* input_max, const float quant_min, const float quant_max,
|
||||
float* nudge_min, float* nudge_max, float* scale, const int channel_num,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
void CalFakeQuantizePerChannel(const float* input, float* output, const int total_num, const int channel_num,
|
||||
const float* nudge_min, const float* nudge_max, const float* scale, bool symmetric,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
void CalMinMaxPerChannel(float* input, float* input_min, float* input_max, const int total_num, const int channel_num,
|
||||
const float ema_decay, const bool ema, cudaStream_t cuda_stream);
|
||||
|
||||
void CalFakeQuantizePerChannelGrad(const float* input, const float* gradient, float* output, const int total_num,
|
||||
const int channel_num, const float* nudge_min, const float* nudge_max,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKEQUANTIZE_H_
|
@ -0,0 +1,77 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <stdint.h>
|
||||
#include "sparse_cross_entropy_cuda_impl.cuh"
|
||||
#include "include/cuda_runtime.h"
|
||||
|
||||
template <typename T>
|
||||
__global__ void CalCrossEntropyKernel(const float *logits, T *labels, const int batch_size, const int class_num,
|
||||
float *loss) {
|
||||
float total_loss = 0.0;
|
||||
float epsilon = 1e-6;
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
float logit = logits[i * class_num + labels[i]];
|
||||
if (logit <= 0) {
|
||||
logit += epsilon;
|
||||
}
|
||||
float single_loss = -logf(logit);
|
||||
total_loss += single_loss;
|
||||
}
|
||||
|
||||
total_loss /= batch_size;
|
||||
loss[0] = total_loss;
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void CalCrossEntropyGradKernel(const float *logits, T *labels, const int batch_size, const int class_num,
|
||||
float *grad) {
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
for (int j = blockIdx.x * blockDim.x + threadIdx.x; j < class_num; j += blockDim.x * gridDim.x) {
|
||||
if (labels[i] == j) {
|
||||
grad[i * class_num + j] = (logits[i * class_num + j] - 1) / batch_size;
|
||||
} else {
|
||||
grad[i * class_num + j] = logits[i * class_num + j] / batch_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalCrossEntropy(const float *logits, T *labels, const int batch_size, const int class_num, float *loss,
|
||||
cudaStream_t cuda_stream) {
|
||||
CalCrossEntropyKernel<<<1, 1, 0, cuda_stream>>>(logits, labels, batch_size, class_num, loss);
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalCrossEntropyGrad(const float *logits, T *labels, const int batch_size, const int class_num, float *grad,
|
||||
cudaStream_t cuda_stream) {
|
||||
CalCrossEntropyGradKernel<<<GET_BLOCKS(class_num), GET_THREADS, 0, cuda_stream>>>(logits, labels, batch_size,
|
||||
class_num, grad);
|
||||
return;
|
||||
}
|
||||
|
||||
template void CalCrossEntropy<int>(const float *logits, int *labels, const int batch_size, const int class_num,
|
||||
float *loss, cudaStream_t cuda_stream);
|
||||
template void CalCrossEntropy<uint64_t>(const float *logits, uint64_t *labels, const int batch_size,
|
||||
const int class_num, float *loss, cudaStream_t cuda_stream);
|
||||
template void CalCrossEntropyGrad<int>(const float *logits, int *labels, const int batch_size, const int class_num,
|
||||
float *grad, cudaStream_t cuda_stream);
|
||||
template void CalCrossEntropyGrad<uint64_t>(const float *logits, uint64_t *labels, const int batch_size,
|
||||
const int class_num, float *grad, cudaStream_t cuda_stream);
|
@ -0,0 +1,30 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPARSECROSSENTROPYCUDAIMPL_H_
|
||||
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPARSECROSSENTROPYCUDAIMPL_H_
|
||||
|
||||
#include "device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
void CalCrossEntropy(const float *logits, T *labels, const int batch_size, const int class_num, float *loss,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
void CalCrossEntropyGrad(const float *logits, T *labels, const int batch_size, const int class_num, float *grad,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPARSECROSSENTROPYCUDAIMPL_H_
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue