diff --git a/mindspore/ops/_grad/grad_math_ops.py b/mindspore/ops/_grad/grad_math_ops.py index d89de1924c..f66ec3df21 100755 --- a/mindspore/ops/_grad/grad_math_ops.py +++ b/mindspore/ops/_grad/grad_math_ops.py @@ -420,6 +420,19 @@ def get_bprop_sqrt(self): return bprop +@bprop_getters.register(G.SqrtGrad) +def get_bprop_sqrt_grad(self): + """Grad definition for `SqrtGrad` operation.""" + + def bprop(y, grad, out, dout): + gy = dout / y + dy = -gy * out + dgrad = 0.5 * gy + return dy, dgrad + + return bprop + + @bprop_getters.register(P.Rsqrt) def get_bprop_rsqrt(self): """Grad definition for `Rsqrt` operation.""" @@ -962,6 +975,19 @@ def get_bprop_asinh(self): return bprop +@bprop_getters.register(G.AsinhGrad) +def get_bprop_asinh_grad(self): + """Grad definition for `AsinhGrad` operation.""" + input_grad = G.AsinhGrad() + tanh = P.Tanh() + + def bprop(y, grad, out, dout): + dy = dout * out * -1.0 * tanh(y) + dgrad = input_grad(y, dout) + return dy, dgrad + return bprop + + @bprop_getters.register(P.Sinh) def get_bprop_sinh(self): """Grad definition for `Sinh` operation.""" @@ -1026,6 +1052,20 @@ def get_bprop_acosh(self): return bprop +@bprop_getters.register(G.AcoshGrad) +def get_bprop_acosh_grad(self): + """Grad definition for `AcoshGrad` operation.""" + input_grad = G.AcoshGrad() + tanh = P.Tanh() + + def bprop(y, grad, out, dout): + dy = dout * out * -1.0 / tanh(y) + dgrad = input_grad(y, dout) + return dy, dgrad + + return bprop + + @bprop_getters.register(P.Cosh) def get_bprop_cosh(self): """Grad definition for `Cosh` operation.""" @@ -1150,6 +1190,18 @@ def get_bprop_atan(self): return bprop +@bprop_getters.register(G.AtanGrad) +def get_bprop_atan_grad(self): + """Grad definition for `AtanGrad` operation.""" + input_grad = G.AtanGrad() + + def bprop(x, grad, out, dout): + dgrad = input_grad(x, dout) + dx = out * dgrad * -2.0 * x + return dx, dgrad + return bprop + + @bprop_getters.register(P.Tan) def get_bprop_tan(self): """Grad definition for `Tan` operation.""" diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index 2ed2e3ad58..e419fdd4f1 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -422,6 +422,18 @@ def get_bprop_relu(self): return bprop +@bprop_getters.register(G.ReluGrad) +def get_bprop_relu_grad(self): + """Grad definition for `ReLUGrad` operation.""" + input_grad = G.ReluGrad() + + def bprop(grad, y, out, dout): + dgrad = input_grad(dout, y) + return dgrad, zeros_like(y) + + return bprop + + @bprop_getters.register(P.ReLU6) def get_bprop_relu6(self): """Grad definition for `ReLU6` operation.""" @@ -501,9 +513,9 @@ def get_bprop_sigmoid_grad(self): sigmoid_grad = G.SigmoidGrad() def bprop(y, grad, out, dout): - ddy = dout * grad * (1. - 2 * y) - d2x = sigmoid_grad(y, dout) - return (ddy, d2x) + dy = dout * grad * (1. - 2 * y) + dgrad = sigmoid_grad(y, dout) + return dy, dgrad return bprop @@ -598,6 +610,19 @@ def get_bprop_tanh(self): return bprop +@bprop_getters.register(G.TanhGrad) +def get_bprop_tanh_grad(self): + """Grad definition for `TanhGrad` operation.""" + tanh_grad = G.TanhGrad() + + def bprop(y, grad, out, dout): + dy = dout * -2.0 * grad * y + dgrad = tanh_grad(y, dout) + return dy, dgrad + + return bprop + + @bprop_getters.register(P.Gelu) def get_bprop_gelu(self): """Grad definition for `Gelu` operation.""" diff --git a/tests/st/ops/gpu/test_acosh_grad_grad_op.py b/tests/st/ops/gpu/test_acosh_grad_grad_op.py new file mode 100644 index 0000000000..001ede6b80 --- /dev/null +++ b/tests/st/ops/gpu/test_acosh_grad_grad_op.py @@ -0,0 +1,86 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops.operations import _grad_ops as G +from mindspore.ops import composite as C + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + +class NetAcoshGrad(nn.Cell): + def __init__(self): + super(NetAcoshGrad, self).__init__() + self.acosh_grad = G.AcoshGrad() + + def construct(self, y, grad): + return self.acosh_grad(y, grad) + + +class NetAcoshGradGrad(nn.Cell): + def __init__(self, forward_net): + super(NetAcoshGradGrad, self).__init__() + self.forward_net = forward_net + self.gradOps = C.GradOperation(get_all=True, sens_param=True) + + def construct(self, y, grad, dout): + backward_net = self.gradOps(self.forward_net) + return backward_net(y, grad, dout) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def acosh_grad_grad_base(dtype, loss): + np.random.seed(1) + shape = (4, 2) + y_np = (np.random.rand(*shape) * 10).astype(dtype) + grad_np = (np.random.rand(*shape) * 20 - 10).astype(dtype) + dout_np = (np.random.rand(*shape) * 20 - 10).astype(dtype) + + y_np_32 = y_np.astype(np.float32) + grad_np_32 = grad_np.astype(np.float32) + dout_np_32 = dout_np.astype(np.float32) + out_np_32 = grad_np_32 / np.sinh(y_np_32) + dy_np = (dout_np_32 * out_np_32 * (-1.0) / np.tanh(y_np_32)).astype(dtype) + dgrad_np = (dout_np_32 / np.sinh(y_np_32)).astype(dtype) + + y_ms = Tensor(y_np) + grad_ms = Tensor(grad_np) + dout_ms = Tensor(dout_np) + forward_net = NetAcoshGrad() + net = NetAcoshGradGrad(forward_net) + dy_ms, dgrad_ms = net(y_ms, grad_ms, dout_ms) + + assert np.allclose(dy_ms.asnumpy(), dy_np, loss, loss) + assert np.allclose(dgrad_ms.asnumpy(), dgrad_np, loss, loss) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_acosh_grad_grad_float16(): + acosh_grad_grad_base(np.float16, 2e-3) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_acosh_grad_grad_float32(): + acosh_grad_grad_base(np.float32, 1e-4) diff --git a/tests/st/ops/gpu/test_asinh_grad_grad_op.py b/tests/st/ops/gpu/test_asinh_grad_grad_op.py new file mode 100644 index 0000000000..3dae2c8306 --- /dev/null +++ b/tests/st/ops/gpu/test_asinh_grad_grad_op.py @@ -0,0 +1,86 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops.operations import _grad_ops as G +from mindspore.ops import composite as C + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + +class NetAsinhGrad(nn.Cell): + def __init__(self): + super(NetAsinhGrad, self).__init__() + self.asinh_grad = G.AsinhGrad() + + def construct(self, y, grad): + return self.asinh_grad(y, grad) + + +class NetAsinhGradGrad(nn.Cell): + def __init__(self, forward_net): + super(NetAsinhGradGrad, self).__init__() + self.forward_net = forward_net + self.gradOps = C.GradOperation(get_all=True, sens_param=True) + + def construct(self, y, grad, dout): + backward_net = self.gradOps(self.forward_net) + return backward_net(y, grad, dout) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def asinh_grad_grad_base(dtype, loss): + np.random.seed(1) + shape = (4, 2) + y_np = (np.random.rand(*shape) * 20 - 10).astype(dtype) + grad_np = (np.random.rand(*shape) * 20 - 10).astype(dtype) + dout_np = (np.random.rand(*shape) * 20 - 10).astype(dtype) + + y_np_32 = y_np.astype(np.float32) + grad_np_32 = grad_np.astype(np.float32) + dout_np_32 = dout_np.astype(np.float32) + out_np_32 = grad_np_32 / np.cosh(y_np_32) + dy_np = (dout_np_32 * out_np_32 * (-1.0) * np.tanh(y_np_32)).astype(dtype) + dgrad_np = (dout_np_32 / np.cosh(y_np_32)).astype(dtype) + + y_ms = Tensor(y_np) + grad_ms = Tensor(grad_np) + dout_ms = Tensor(dout_np) + forward_net = NetAsinhGrad() + net = NetAsinhGradGrad(forward_net) + dy_ms, dgrad_ms = net(y_ms, grad_ms, dout_ms) + + assert np.allclose(dy_ms.asnumpy(), dy_np, loss, loss) + assert np.allclose(dgrad_ms.asnumpy(), dgrad_np, loss, loss) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_asinh_grad_grad_float16(): + asinh_grad_grad_base(np.float16, 1e-3) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_asinh_grad_grad_float32(): + asinh_grad_grad_base(np.float32, 1e-4) diff --git a/tests/st/ops/gpu/test_atan_grad_grad_op.py b/tests/st/ops/gpu/test_atan_grad_grad_op.py new file mode 100644 index 0000000000..d0373e5207 --- /dev/null +++ b/tests/st/ops/gpu/test_atan_grad_grad_op.py @@ -0,0 +1,87 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops.operations import _grad_ops as G +from mindspore.ops import composite as C + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + +class NetAtanGrad(nn.Cell): + def __init__(self): + super(NetAtanGrad, self).__init__() + self.atan_grad = G.AtanGrad() + + def construct(self, x, grad): + return self.atan_grad(x, grad) + + +class NetAtanGradGrad(nn.Cell): + def __init__(self, forward_net): + super(NetAtanGradGrad, self).__init__() + self.forward_net = forward_net + self.gradOps = C.GradOperation(get_all=True, sens_param=True) + + def construct(self, x, grad, dout): + backward_net = self.gradOps(self.forward_net) + return backward_net(x, grad, dout) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def atan_grad_grad_base(dtype, loss): + np.random.seed(1) + shape = (4, 2) + x_np = (np.random.rand(*shape) * 20 - 10).astype(dtype) + grad_np = (np.random.rand(*shape) * 20 - 10).astype(dtype) + dout_np = (np.random.rand(*shape) * 20 - 10).astype(dtype) + + x_np_32 = x_np.astype(np.float32) + grad_np_32 = grad_np.astype(np.float32) + dout_np_32 = dout_np.astype(np.float32) + out_np_32 = grad_np_32 / (1 + x_np_32 * x_np_32) + dgrad_np_32 = dout_np_32 / (1 + x_np_32 * x_np_32) + dx_np = (out_np_32 * dgrad_np_32 * (-2.0) * x_np_32).astype(dtype) + dgrad_np = dgrad_np_32.astype(dtype) + + x_ms = Tensor(x_np) + grad_ms = Tensor(grad_np) + dout_ms = Tensor(dout_np) + forward_net = NetAtanGrad() + net = NetAtanGradGrad(forward_net) + dx_ms, dgrad_ms = net(x_ms, grad_ms, dout_ms) + + assert np.allclose(dx_ms.asnumpy(), dx_np, loss, loss) + assert np.allclose(dgrad_ms.asnumpy(), dgrad_np, loss, loss) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_atan_grad_grad_float16(): + atan_grad_grad_base(np.float16, 1e-3) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_atan_grad_grad_float32(): + atan_grad_grad_base(np.float32, 1e-4) diff --git a/tests/st/ops/gpu/test_relu_grad_grad_op.py b/tests/st/ops/gpu/test_relu_grad_grad_op.py new file mode 100644 index 0000000000..e8f8b0c210 --- /dev/null +++ b/tests/st/ops/gpu/test_relu_grad_grad_op.py @@ -0,0 +1,115 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops.operations import _grad_ops as G +from mindspore.ops import composite as C + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + +class NetReluGrad(nn.Cell): + def __init__(self): + super(NetReluGrad, self).__init__() + self.reluGrad = G.ReluGrad() + + def construct(self, grad, y): + return self.reluGrad(grad, y) + + +class NetReluGradGrad(nn.Cell): + def __init__(self, forward_net): + super(NetReluGradGrad, self).__init__() + self.forward_net = forward_net + self.gradOps = C.GradOperation(get_all=True, sens_param=True) + + def construct(self, grad, y, dout): + backward_net = self.gradOps(self.forward_net) + return backward_net(grad, y, dout) + + +def relu_grad_grad_base(dtype, loss): + np.random.seed(1) + shape = (4, 2) + y_np = (np.random.rand(*shape) * 20 - 10).astype(dtype) + y_np[y_np < 0] = 0 + grad_np = (np.random.rand(*shape) * 20 - 10).astype(dtype) + dout_np = (np.random.rand(*shape) * 20 - 10).astype(dtype) + + y_np_32 = y_np.astype(np.float32) + dout_np_32 = dout_np.astype(np.float32) + dgrad_np = ((y_np_32 > 0) * dout_np_32).astype(dtype) + + y_ms = Tensor(y_np) + grad_ms = Tensor(grad_np) + dout_ms = Tensor(dout_np) + relu_grad = NetReluGrad() + net = NetReluGradGrad(relu_grad) + dgrad_ms, _ = net(grad_ms, y_ms, dout_ms) + + assert np.allclose(dgrad_ms.asnumpy(), dgrad_np, loss, loss) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_relu_grad_grad_float16(): + relu_grad_grad_base(np.float16, 1e-3) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_relu_grad_grad_float32(): + relu_grad_grad_base(np.float32, 1e-4) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_relu_grad_grad_float64(): + relu_grad_grad_base(np.float64, 1e-5) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_relu_grad_grad_int8(): + relu_grad_grad_base(np.int8, 1e-5) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_relu_grad_grad_int16(): + relu_grad_grad_base(np.int16, 1e-5) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_relu_grad_grad_int32(): + relu_grad_grad_base(np.int32, 1e-5) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_relu_grad_grad_int64(): + relu_grad_grad_base(np.int64, 1e-5) diff --git a/tests/st/ops/gpu/test_sqrt_grad_grad_op.py b/tests/st/ops/gpu/test_sqrt_grad_grad_op.py new file mode 100644 index 0000000000..0f7363a555 --- /dev/null +++ b/tests/st/ops/gpu/test_sqrt_grad_grad_op.py @@ -0,0 +1,87 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops.operations import _grad_ops as G +from mindspore.ops import composite as C + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + +class NetSqrtGrad(nn.Cell): + def __init__(self): + super(NetSqrtGrad, self).__init__() + self.sqrt_grad = G.SqrtGrad() + + def construct(self, y, grad): + return self.sqrt_grad(y, grad) + + +class NetSqrtGradGrad(nn.Cell): + def __init__(self, forward_net): + super(NetSqrtGradGrad, self).__init__() + self.forward_net = forward_net + self.gradOps = C.GradOperation(get_all=True, sens_param=True) + + def construct(self, y, grad, dout): + backward_net = self.gradOps(self.forward_net) + return backward_net(y, grad, dout) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def sqrt_grad_grad_base(dtype, loss): + np.random.seed(1) + shape = (4, 2) + y_np = (np.random.rand(*shape) * 10).astype(dtype) + grad_np = (np.random.rand(*shape) * 20 - 10).astype(dtype) + dout_np = (np.random.rand(*shape) * 20 - 10).astype(dtype) + + y_np_32 = y_np.astype(np.float32) + grad_np_32 = grad_np.astype(np.float32) + dout_np_32 = dout_np.astype(np.float32) + gy_np_32 = dout_np_32 / y_np_32 + out_np_32 = 0.5 * grad_np_32 / y_np_32 + dy_np = (-gy_np_32 * out_np_32).astype(dtype) + dgrad_np = (0.5 * gy_np_32).astype(dtype) + + y_ms = Tensor(y_np) + grad_ms = Tensor(grad_np) + dout_ms = Tensor(dout_np) + forward_net = NetSqrtGrad() + net = NetSqrtGradGrad(forward_net) + dy_ms, dgrad_ms = net(y_ms, grad_ms, dout_ms) + + assert np.allclose(dy_ms.asnumpy(), dy_np, loss, loss) + assert np.allclose(dgrad_ms.asnumpy(), dgrad_np, loss, loss) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sqrt_grad_grad_float16(): + sqrt_grad_grad_base(np.float16, 1e-3) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sqrt_grad_grad_float32(): + sqrt_grad_grad_base(np.float32, 1e-4) diff --git a/tests/st/ops/gpu/test_tanh_grad_grad_op.py b/tests/st/ops/gpu/test_tanh_grad_grad_op.py new file mode 100644 index 0000000000..f702820676 --- /dev/null +++ b/tests/st/ops/gpu/test_tanh_grad_grad_op.py @@ -0,0 +1,85 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops.operations import _grad_ops as G +from mindspore.ops import composite as C + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + +class NetTanhGrad(nn.Cell): + def __init__(self): + super(NetTanhGrad, self).__init__() + self.tanh_grad = G.TanhGrad() + + def construct(self, y, grad): + return self.tanh_grad(y, grad) + + +class NetTanhGradGrad(nn.Cell): + def __init__(self, forward_net): + super(NetTanhGradGrad, self).__init__() + self.forward_net = forward_net + self.gradOps = C.GradOperation(get_all=True, sens_param=True) + + def construct(self, y, grad, dout): + backward_net = self.gradOps(self.forward_net) + return backward_net(y, grad, dout) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def tanh_grad_grad_base(dtype, loss): + np.random.seed(1) + shape = (4, 2) + y_np = (np.random.rand(*shape) * 2 - 1).astype(dtype) + grad_np = (np.random.rand(*shape) * 20 - 10).astype(dtype) + dout_np = (np.random.rand(*shape) * 20 - 10).astype(dtype) + + y_np_32 = y_np.astype(np.float32) + grad_np_32 = grad_np.astype(np.float32) + dout_np_32 = dout_np.astype(np.float32) + dy_np = (dout_np_32 * grad_np_32 * (-2.0) * y_np_32).astype(dtype) + dgrad_np = (dout_np_32 * (1 - y_np_32 * y_np_32)).astype(dtype) + + y_ms = Tensor(y_np) + grad_ms = Tensor(grad_np) + dout_ms = Tensor(dout_np) + forward_net = NetTanhGrad() + net = NetTanhGradGrad(forward_net) + dy_ms, dgrad_ms = net(y_ms, grad_ms, dout_ms) + + assert np.allclose(dy_ms.asnumpy(), dy_np, loss, loss) + assert np.allclose(dgrad_ms.asnumpy(), dgrad_np, loss, loss) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_tanh_grad_grad_float16(): + tanh_grad_grad_base(np.float16, 1e-3) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_tanh_grad_grad_float32(): + tanh_grad_grad_base(np.float32, 1e-4)