From 906f3fa22230f102bce0c926578852f67f63d122 Mon Sep 17 00:00:00 2001 From: chenfei Date: Fri, 19 Mar 2021 15:52:03 +0800 Subject: [PATCH] add test case for while grad --- tests/st/control/test_while_grad.py | 48 +++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 tests/st/control/test_while_grad.py diff --git a/tests/st/control/test_while_grad.py b/tests/st/control/test_while_grad.py new file mode 100644 index 0000000000..870d46a1ef --- /dev/null +++ b/tests/st/control/test_while_grad.py @@ -0,0 +1,48 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import pytest +from mindspore.ops import composite as C +import mindspore.common.dtype as mstype +import mindspore.nn as nn +import mindspore.context as context +from mindspore.common.tensor import Tensor + +class Net(nn.Cell): + def construct(self, x, y): + while x < y: + x = x * x + 1 + return x + + +class GradNet(nn.Cell): + def __init__(self, net): + super().__init__() + self.net = net + self.grad_op = C.GradOperation(get_all=True) + + def construct(self, x, y): + gradient_function = self.grad_op(self.net) + return gradient_function(x, y) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_while_grad(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) + x = Tensor([2.0], dtype=mstype.float32) + y = Tensor([2.0], dtype=mstype.float32) + GradNet(Net())(x, y)