From bdbdc291f592146e6fbc25806a50a6eb082bf0c1 Mon Sep 17 00:00:00 2001 From: Daniel Date: Thu, 17 Dec 2020 15:22:26 +0800 Subject: [PATCH] high grad Signed-off-by: Daniel --- tests/st/high_grad/test_highgrad_param.py | 75 +++++++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 tests/st/high_grad/test_highgrad_param.py diff --git a/tests/st/high_grad/test_highgrad_param.py b/tests/st/high_grad/test_highgrad_param.py new file mode 100644 index 0000000000..a91f87c40d --- /dev/null +++ b/tests/st/high_grad/test_highgrad_param.py @@ -0,0 +1,75 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Test high order grad with respect to parameter first, then input.""" + +import pytest +import numpy as np +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor, context +from mindspore import ParameterTuple, Parameter + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.mul = ops.Mul() + weight_np = np.array([2, 2]).astype(np.float32) + self.weight = Parameter(Tensor(weight_np), name="weight", requires_grad=True) + + def construct(self, x): + x_square = self.mul(x, x) + x_square_z = self.mul(x_square, self.weight) + output = self.mul(x_square_z, self.weight) + return output + + +class Grad(nn.Cell): + def __init__(self, network): + super(Grad, self).__init__() + self.grad = ops.GradOperation(get_by_list=True, sens_param=False) + self.network = network + self.params = ParameterTuple(network.trainable_params()) + + def construct(self, x): + output = self.grad(self.network, self.params)(x) + return output + + +class GradSec(nn.Cell): + def __init__(self, network): + super(GradSec, self).__init__() + self.grad = ops.GradOperation(get_all=True, sens_param=False) + self.network = network + + def construct(self, x): + output = self.grad(self.network)(x) + return output + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_x86_cpu_training +@pytest.mark.env_onecard +def test_sit_high_order_grad_params(): + context.set_context(mode=context.GRAPH_MODE) + x = Tensor(np.array([1, 1]).astype(np.float32)) + net = Net() + first_grad = Grad(net) + second_grad = GradSec(first_grad) + grad = second_grad(x) + assert (grad[0].asnumpy() == np.array([8, 8])).all()