Add supports to ACosGrad and AsinGrad

pull/7841/head
zhouyuanshen 4 years ago
parent 78f795971b
commit 8481fd59d8

@ -15,6 +15,7 @@
*/
#include "unary_op_grad_impl.cuh"
template <typename T>
__global__ void SqrtGradKernel(const T *input, const T *dout, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
@ -36,7 +37,44 @@ __global__ void RsqrtGradKernel(const T *input, const T *dout, T *output, const
}
return;
}
template <typename T>
__global__ void AsinGradKernel(const T *input, const T *dout, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
T one = 1;
T sqt = sqrtf(one - input[i] * input[i]);
output[i] = dout[i] / sqt;
}
return;
}
template <>
__global__ void AsinGradKernel(const half *input, const half *dout, half *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
half one = 1;
half sqt = hsqrt(one - input[i] * input[i]);
output[i] = dout[i] / sqt;
}
return;
}
template <typename T>
__global__ void ACosGradKernel(const T *input, const T *dout, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
T neg_one = -1;
T one = 1;
T sqt = sqrtf(one - input[i] * input[i]);
output[i] = neg_one * dout[i] / sqt;
}
return;
}
template <>
__global__ void ACosGradKernel(const half *input, const half *dout, half *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
half neg_one = -1;
half one = 1;
half sqt = hsqrt(one - input[i] * input[i]);
output[i] = neg_one * dout[i] / sqt;
}
return;
}
template <typename T>
void SqrtGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) {
SqrtGradKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, dout, output, count);
@ -48,11 +86,31 @@ void RsqrtGrad(const T *input, const T *dout, T *output, const size_t count, cud
return;
}
template <typename T>
void AsinGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) {
AsinGradKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, dout, output, count);
return;
}
template <typename T>
void ACosGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream) {
ACosGradKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>(input, dout, output, count);
return;
}
template void SqrtGrad<float>(const float *input, const float *dout, float *output, const size_t count,
cudaStream_t cuda_stream);
template void RsqrtGrad<float>(const float *input, const float *dout, float *output, const size_t count,
cudaStream_t cuda_stream);
template void AsinGrad<float>(const float *input, const float *dout, float *output, const size_t count,
cudaStream_t cuda_stream);
template void ACosGrad<float>(const float *input, const float *dout, float *output, const size_t count,
cudaStream_t cuda_stream);
template void SqrtGrad<half>(const half *input, const half *dout, half *output, const size_t count,
cudaStream_t cuda_stream);
template void RsqrtGrad<half>(const half *input, const half *dout, half *output, const size_t count,
cudaStream_t cuda_stream);
template void AsinGrad<half>(const half *input, const half *dout, half *output, const size_t count,
cudaStream_t cuda_stream);
template void ACosGrad<half>(const half *input, const half *dout, half *output, const size_t count,
cudaStream_t cuda_stream);

@ -22,5 +22,9 @@ template <typename T>
void SqrtGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void RsqrtGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void AsinGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream);
template <typename T>
void ACosGrad(const T *input, const T *dout, T *output, const size_t count, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOP_GRAD_IMPL_H_

@ -34,5 +34,21 @@ MS_REG_GPU_KERNEL_ONE(
RsqrtGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryGradOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
AsinGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryGradOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
AsinGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryGradOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
ACosGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
UnaryGradOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
ACosGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
UnaryGradOpGpuKernel, half)
} // namespace kernel
} // namespace mindspore

@ -27,9 +27,17 @@
namespace mindspore {
namespace kernel {
enum UnaryGradOptype { UNARY_OP_SQRT_GRAD = 0, UNARY_OP_RSQRT_GRAD, UNARY_OP_GRAD_INVALID_TYPE = 255 };
enum UnaryGradOptype {
UNARY_OP_SQRT_GRAD = 0,
UNARY_OP_RSQRT_GRAD = 1,
UNARY_OP_ASIN_GRAD = 2,
UNARY_OP_ACOS_GRAD = 3,
UNARY_OP_GRAD_INVALID_TYPE = 255
};
static const std::map<std::string, UnaryGradOptype> kUnaryGradOpTypeMap = {{"SqrtGrad", UNARY_OP_SQRT_GRAD},
{"RsqrtGrad", UNARY_OP_RSQRT_GRAD}};
{"RsqrtGrad", UNARY_OP_RSQRT_GRAD},
{"AsinGrad", UNARY_OP_ASIN_GRAD},
{"ACosGrad", UNARY_OP_ACOS_GRAD}};
template <typename T>
class UnaryGradOpGpuKernel : public GpuKernel {
public:
@ -59,6 +67,16 @@ class UnaryGradOpGpuKernel : public GpuKernel {
reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
case UNARY_OP_ASIN_GRAD: {
AsinGrad(input_x_addr, input_dx_addr, output_y_addr, inputs[0]->size / sizeof(T),
reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
case UNARY_OP_ACOS_GRAD: {
ACosGrad(input_x_addr, input_dx_addr, output_y_addr, inputs[0]->size / sizeof(T),
reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
case UNARY_OP_RSQRT_GRAD: {
RsqrtGrad(input_x_addr, input_dx_addr, output_y_addr, inputs[0]->size / sizeof(T),
reinterpret_cast<cudaStream_t>(stream_ptr));

@ -0,0 +1,46 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
import mindspore.ops.operations._grad_ops as P
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_acosgrad_fp32():
error = np.ones(4) * 1.0e-7
x_np = np.array([0, -0.25, 0.5, 0.3]).astype(np.float32)
dout_np = np.array([1, 1, 1, 1]).astype(np.float32)
output_ms = P.ACosGrad()(Tensor(x_np), Tensor(dout_np))
expect = np.array([-1, -1.0327955, -1.1547005, -1.0482849])
diff = output_ms.asnumpy() - expect
assert np.all(diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_acosgrad_fp16():
error = np.ones(4) * 1.0e-3
x_np = np.array([0, -0.25, 0.5, 0.3]).astype(np.float16)
dout_np = np.array([1, 1, 1, 1]).astype(np.float16)
output_ms = P.ACosGrad()(Tensor(x_np), Tensor(dout_np))
expect = np.array([-1, -1.033, -1.154, -1.048])
diff = output_ms.asnumpy() - expect
assert np.all(diff < error)

@ -0,0 +1,46 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
from mindspore import Tensor
import mindspore.ops.operations._grad_ops as P
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_asingrad_fp32():
error = np.ones(4) * 1.0e-7
x_np = np.array([0, -0.25, 0.5, 0.3]).astype(np.float32)
dout_np = np.array([1, 1, 1, 1]).astype(np.float32)
output_ms = P.AsinGrad()(Tensor(x_np), Tensor(dout_np))
expect = np.array([1, 1.0327955, 1.1547005, 1.0482849])
diff = output_ms.asnumpy() - expect
assert np.all(diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_asingrad_fp16():
error = np.ones(4) * 1.0e-3
x_np = np.array([0, -0.25, 0.5, 0.3]).astype(np.float16)
dout_np = np.array([1, 1, 1, 1]).astype(np.float16)
output_ms = P.AsinGrad()(Tensor(x_np), Tensor(dout_np))
expect = np.array([1, 1.033, 1.154, 1.048])
diff = output_ms.asnumpy() - expect
assert np.all(diff < error)
Loading…
Cancel
Save