remove global grad ops

pull/5011/head
panyifeng 5 years ago
parent eb2437d517
commit 637e812347

@ -21,7 +21,6 @@ Pre-defined combination of operators.
from .base import GradOperation, HyperMap, Map, MultitypeFuncGraph, add_flags, \ from .base import GradOperation, HyperMap, Map, MultitypeFuncGraph, add_flags, \
grad, grad_all, grad_all_with_sens, grad_by_list, grad_by_list_with_sens, grad_with_sens, \
core, env_get, tail, zip_operation core, env_get, tail, zip_operation
from .clip_ops import clip_by_value from .clip_ops import clip_by_value
from .multitype_ops.add_impl import hyper_add from .multitype_ops.add_impl import hyper_add
@ -31,12 +30,6 @@ from .random_ops import set_seed, normal, uniform, gamma, poisson, multinomial
__all__ = [ __all__ = [
'grad',
'grad_by_list_with_sens',
'grad_all',
'grad_by_list',
'grad_all_with_sens',
'grad_with_sens',
'env_get', 'env_get',
'core', 'core',
'add_flags', 'add_flags',

@ -163,14 +163,6 @@ class GradOperation(GradOperation_):
return self.grad_fn return self.grad_fn
grad = GradOperation('grad')
grad_all = GradOperation('get_all', get_all=True)
grad_by_list = GradOperation('get_by_list', get_by_list=True)
grad_with_sens = GradOperation('grad_with_sens', sens_param=True)
grad_all_with_sens = GradOperation('grad_all_with_sens', get_all=True, sens_param=True)
grad_by_list_with_sens = GradOperation('grad_by_list_with_sens', get_by_list=True, sens_param=True)
class MultitypeFuncGraph(MultitypeFuncGraph_): class MultitypeFuncGraph(MultitypeFuncGraph_):
""" """
Generate multiply graph. Generate multiply graph.

@ -268,6 +268,7 @@ class HookBackward(PrimitiveWithInfer):
>>> def hook_fn(grad_out): >>> def hook_fn(grad_out):
>>> print(grad_out) >>> print(grad_out)
>>> >>>
>>> grad_all = GradOperation('get_all', get_all=True)
>>> hook = P.HookBackward(hook_fn) >>> hook = P.HookBackward(hook_fn)
>>> >>>
>>> def hook_test(x, y): >>> def hook_test(x, y):
@ -277,7 +278,7 @@ class HookBackward(PrimitiveWithInfer):
>>> return z >>> return z
>>> >>>
>>> def backward(x, y): >>> def backward(x, y):
>>> return C.grad_all(hook_test)(x, y) >>> return grad_all(hook_test)(x, y)
>>> >>>
>>> backward(1, 2) >>> backward(1, 2)
""" """

@ -23,6 +23,9 @@ from mindspore import Tensor
from mindspore.common.api import _executor from mindspore.common.api import _executor
grad_all_with_sens = C.GradOperation('grad_all_with_sens', get_all=True, sens_param=True)
class InputBackward(nn.Cell): class InputBackward(nn.Cell):
""" InputBackward definition """ """ InputBackward definition """
@ -30,7 +33,7 @@ class InputBackward(nn.Cell):
super(InputBackward, self).__init__() super(InputBackward, self).__init__()
self.network = network self.network = network
self.network.set_train() self.network.set_train()
self.grad = C.grad_all_with_sens self.grad = grad_all_with_sens
self.c1 = c1 self.c1 = c1
self.c2 = c2 self.c2 = c2

@ -26,6 +26,9 @@ from mindspore.common.api import _executor
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
grad_all_with_sens = C.GradOperation('grad_all_with_sens', get_all=True, sens_param=True)
batch_size = 1 batch_size = 1
channel = 1 channel = 1
height = 32 height = 32
@ -38,7 +41,7 @@ class LeNetGrad(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(LeNetGrad, self).__init__() super(LeNetGrad, self).__init__()
self.grad_op = C.grad_all_with_sens self.grad_op = grad_all_with_sens
self.network = network self.network = network
def construct(self, x, sens): def construct(self, x, sens):

@ -28,6 +28,10 @@ from mindspore.ops import operations as P
# context.set_context(save_graphs=True) # context.set_context(save_graphs=True)
grad_by_list = C.GradOperation('get_by_list', get_by_list=True)
grad_all = C.GradOperation('get_all', get_all=True)
def test_while_forward(): def test_while_forward():
class MyWhileNet(nn.Cell): class MyWhileNet(nn.Cell):
def __init__(self): def __init__(self):
@ -70,7 +74,7 @@ def test_while_grad():
self.net = net self.net = net
def construct(self, *inputs): def construct(self, *inputs):
return C.grad_all(self.net)(*inputs) return grad_all(self.net)(*inputs)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet() while_net = MyWhileNet()
@ -157,7 +161,7 @@ def test_while_with_param_grad():
self.weights = ParameterTuple(net.trainable_params()) self.weights = ParameterTuple(net.trainable_params())
def construct(self, a, b, c): def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c) return grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet() while_net = MyWhileNet()
@ -222,7 +226,7 @@ def test_while_opt_endless():
self.net = net self.net = net
def construct(self, *inputs): def construct(self, *inputs):
return C.grad_all(self.net)(*inputs) return grad_all(self.net)(*inputs)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet() while_net = MyWhileNet()
@ -285,7 +289,7 @@ def test_while_with_param_grad_with_const_branch():
self.weights = ParameterTuple(net.trainable_params()) self.weights = ParameterTuple(net.trainable_params())
def construct(self, a, b, c): def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c) return grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet() while_net = MyWhileNet()
@ -325,7 +329,7 @@ def test_for_while_with_param_grad_with_const_branch():
self.weights = ParameterTuple(net.trainable_params()) self.weights = ParameterTuple(net.trainable_params())
def construct(self, a, b, c): def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c) return grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet() while_net = MyWhileNet()
@ -362,7 +366,7 @@ def test_for_while_with_param_grad_basic():
self.weights = ParameterTuple(net.trainable_params()) self.weights = ParameterTuple(net.trainable_params())
def construct(self, a, b, c): def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c) return grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet() while_net = MyWhileNet()
@ -399,7 +403,7 @@ def test_for_while_with_param_grad_normal():
self.weights = ParameterTuple(net.trainable_params()) self.weights = ParameterTuple(net.trainable_params())
def construct(self, a, b, c): def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c) return grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet() while_net = MyWhileNet()
@ -433,7 +437,7 @@ def test_while_with_param_basic_grad():
self.weights = ParameterTuple(net.trainable_params()) self.weights = ParameterTuple(net.trainable_params())
def construct(self, a, b, c): def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c) return grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet() while_net = MyWhileNet()
@ -467,7 +471,7 @@ def test_while_with_param_basic_grad_mul():
self.weights = ParameterTuple(net.trainable_params()) self.weights = ParameterTuple(net.trainable_params())
def construct(self, a, b, c): def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c) return grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet() while_net = MyWhileNet()
@ -502,7 +506,7 @@ def test_while_with_param_basic_grad_two():
self.weights = ParameterTuple(net.trainable_params()) self.weights = ParameterTuple(net.trainable_params())
def construct(self, a, b, c): def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c) return grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet() while_net = MyWhileNet()
@ -538,7 +542,7 @@ def test_while_with_param_basic_grad_three():
self.weights = ParameterTuple(net.trainable_params()) self.weights = ParameterTuple(net.trainable_params())
def construct(self, a, b, c): def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c) return grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet() while_net = MyWhileNet()
@ -575,7 +579,7 @@ def test_while_if_with_param_grad():
self.weights = ParameterTuple(net.trainable_params()) self.weights = ParameterTuple(net.trainable_params())
def construct(self, a, b, c): def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c) return grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet() while_net = MyWhileNet()
@ -608,7 +612,7 @@ def test_while_with_param_grad_not_enter_while():
self.weights = ParameterTuple(net.trainable_params()) self.weights = ParameterTuple(net.trainable_params())
def construct(self, a, b, c): def construct(self, a, b, c):
return C.grad_by_list(self.net, self.weights)(a, b, c) return grad_by_list(self.net, self.weights)(a, b, c)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
while_net = MyWhileNet() while_net = MyWhileNet()
@ -670,7 +674,7 @@ def test_with_param_if_by_if_grad_inputs():
self.net = net self.net = net
def construct(self, *inputs): def construct(self, *inputs):
return C.grad_all(self.net)(*inputs) return grad_all(self.net)(*inputs)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
if_net = MyIfByIfNet() if_net = MyIfByIfNet()
@ -704,7 +708,7 @@ def test_with_param_if_by_if_grad_parameter():
self.weights = ParameterTuple(net.trainable_params()) self.weights = ParameterTuple(net.trainable_params())
def construct(self, *inputs): def construct(self, *inputs):
return C.grad_by_list(self.net, self.weights)(*inputs) return grad_by_list(self.net, self.weights)(*inputs)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
if_net = MyIfByIfNet() if_net = MyIfByIfNet()
@ -736,7 +740,7 @@ def test_with_param_if_by_if_grad_param_excute_null():
self.weights = ParameterTuple(net.trainable_params()) self.weights = ParameterTuple(net.trainable_params())
def construct(self, *inputs): def construct(self, *inputs):
return C.grad_by_list(self.net, self.weights)(*inputs) return grad_by_list(self.net, self.weights)(*inputs)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
if_net = MyIfByIfNet() if_net = MyIfByIfNet()
@ -770,7 +774,7 @@ def test_if_by_if_return_inside_grad():
self.weights = ParameterTuple(net.trainable_params()) self.weights = ParameterTuple(net.trainable_params())
def construct(self, *inputs): def construct(self, *inputs):
return C.grad_by_list(self.net, self.weights)(*inputs) return grad_by_list(self.net, self.weights)(*inputs)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True)
if_net = MyIfByIfNet() if_net = MyIfByIfNet()

@ -25,12 +25,15 @@ from mindspore.common.api import _executor
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
grad_all_with_sens = C.GradOperation('grad_all_with_sens', get_all=True, sens_param=True)
class MeanAggregatorGrad(nn.Cell): class MeanAggregatorGrad(nn.Cell):
"""Backward of MeanAggregator""" """Backward of MeanAggregator"""
def __init__(self, network): def __init__(self, network):
super(MeanAggregatorGrad, self).__init__() super(MeanAggregatorGrad, self).__init__()
self.grad_op = C.grad_all_with_sens self.grad_op = grad_all_with_sens
self.network = network self.network = network
def construct(self, x, sens): def construct(self, x, sens):

@ -28,6 +28,10 @@ from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
grad_all = C.GradOperation('get_all', get_all=True)
class MulAdd(nn.Cell): class MulAdd(nn.Cell):
def construct(self, x, y): def construct(self, x, y):
return 2 * x + y return 2 * x + y
@ -43,7 +47,7 @@ def test_grad_mul_add():
mul_add = MulAdd() mul_add = MulAdd()
x = Tensor(1, dtype=ms.int32) x = Tensor(1, dtype=ms.int32)
y = Tensor(2, dtype=ms.int32) y = Tensor(2, dtype=ms.int32)
assert C.grad_all(mul_add)(x, y) == (2, 4) assert grad_all(mul_add)(x, y) == (2, 4)
class InlineMulADD(nn.Cell): class InlineMulADD(nn.Cell):
@ -62,7 +66,7 @@ def test_grad_inline_mul_add():
inline_mul_add = InlineMulADD() inline_mul_add = InlineMulADD()
x = Tensor(1, dtype=ms.int32) x = Tensor(1, dtype=ms.int32)
y = Tensor(2, dtype=ms.int32) y = Tensor(2, dtype=ms.int32)
assert C.grad_all(inline_mul_add)(x, y) == (3, 6) assert grad_all(inline_mul_add)(x, y) == (3, 6)
class WithParameter(nn.Cell): class WithParameter(nn.Cell):
@ -84,7 +88,7 @@ class WithParameter(nn.Cell):
def test_with_param(): def test_with_param():
with_param = WithParameter() with_param = WithParameter()
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
C.grad_all(with_param)(1, 2) grad_all(with_param)(1, 2)
class WithNoBprop(nn.Cell): class WithNoBprop(nn.Cell):
@ -98,7 +102,7 @@ def test_with_no_bprop():
with_no_bprop = WithNoBprop() with_no_bprop = WithNoBprop()
x = Tensor(1, dtype=ms.int32) x = Tensor(1, dtype=ms.int32)
y = Tensor(2, dtype=ms.int32) y = Tensor(2, dtype=ms.int32)
assert C.grad_all(with_no_bprop)(x, y) == (2, 1) assert grad_all(with_no_bprop)(x, y) == (2, 1)
@pytest.mark.level0 @pytest.mark.level0
@pytest.mark.platform_x86_ascend_training @pytest.mark.platform_x86_ascend_training
@ -118,10 +122,10 @@ def test_grad_in_bprop_1():
self.f = GradInBprop_1() self.f = GradInBprop_1()
def construct(self, x, y): def construct(self, x, y):
return self.f(x, y), C.grad_all(self.f)(x, y) return self.f(x, y), grad_all(self.f)(x, y)
def bprop(self, x, y, out, dout): def bprop(self, x, y, out, dout):
grads = C.grad_all(self.f)(x, y) grads = grad_all(self.f)(x, y)
return out[1][0], grads[1] return out[1][0], grads[1]
class GradInBprop_3(nn.Cell): class GradInBprop_3(nn.Cell):
@ -133,8 +137,8 @@ def test_grad_in_bprop_1():
return self.f(x, y) return self.f(x, y)
grad_in_bprop = GradInBprop_3() grad_in_bprop = GradInBprop_3()
grads = C.grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)), grads = grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)),
Tensor(np.ones([2, 2]).astype(np.float32))) Tensor(np.ones([2, 2]).astype(np.float32)))
assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all() assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all()
assert (grads[1].asnumpy() == np.zeros([2, 2]).astype(np.float32)).all() assert (grads[1].asnumpy() == np.zeros([2, 2]).astype(np.float32)).all()
@ -159,10 +163,10 @@ def test_grad_in_bprop_2():
self.f = GradInBprop_1() self.f = GradInBprop_1()
def construct(self, x, y): def construct(self, x, y):
return self.f(x, y), C.grad_all(self.f)(x, y) return self.f(x, y), grad_all(self.f)(x, y)
def bprop(self, x, y, out, dout): def bprop(self, x, y, out, dout):
grads = C.grad_all(self.f)(x, y) grads = grad_all(self.f)(x, y)
return out[1][0], grads[1] return out[1][0], grads[1]
class GradInBprop_3(nn.Cell): class GradInBprop_3(nn.Cell):
@ -174,8 +178,8 @@ def test_grad_in_bprop_2():
return self.f(x, y) return self.f(x, y)
grad_in_bprop = GradInBprop_3() grad_in_bprop = GradInBprop_3()
grads = C.grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)), grads = grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)),
Tensor(np.ones([2, 2]).astype(np.float32))) Tensor(np.ones([2, 2]).astype(np.float32)))
assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all() assert (grads[0].asnumpy() == np.ones([2, 2]).astype(np.float32)).all()
assert (grads[1].asnumpy() == np.array([[2, 2], [2, 2]]).astype(np.float32)).all() assert (grads[1].asnumpy() == np.array([[2, 2], [2, 2]]).astype(np.float32)).all()
@ -197,10 +201,10 @@ def test_grad_in_bprop_3():
self.f = GradInBprop_1() self.f = GradInBprop_1()
def construct(self, x, y): def construct(self, x, y):
return self.f(x, y), C.grad_all(self.f)(x, y) return self.f(x, y), grad_all(self.f)(x, y)
def bprop(self, x, y, out, dout): def bprop(self, x, y, out, dout):
grads = C.grad_all(self.f)(x, y) grads = grad_all(self.f)(x, y)
return out[1][0], grads[1] return out[1][0], grads[1]
class GradInBprop_3(nn.Cell): class GradInBprop_3(nn.Cell):
@ -215,8 +219,8 @@ def test_grad_in_bprop_3():
return x + y + y + out[0], x + x + y + y + dout[0] return x + y + y + out[0], x + x + y + y + dout[0]
grad_in_bprop = GradInBprop_3() grad_in_bprop = GradInBprop_3()
grads = C.grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)), grads = grad_all(grad_in_bprop)(Tensor(np.ones([2, 2]).astype(np.float32)),
Tensor(np.ones([2, 2]).astype(np.float32))) Tensor(np.ones([2, 2]).astype(np.float32)))
assert (grads[0].asnumpy() == np.array([[4, 4], [4, 4]]).astype(np.float32)).all() assert (grads[0].asnumpy() == np.array([[4, 4], [4, 4]]).astype(np.float32)).all()
assert (grads[1].asnumpy() == np.array([[5, 5], [5, 5]]).astype(np.float32)).all() assert (grads[1].asnumpy() == np.array([[5, 5], [5, 5]]).astype(np.float32)).all()
@ -238,7 +242,7 @@ class OneInputBprop(nn.Cell):
def test_grad_one_input_bprop(): def test_grad_one_input_bprop():
net = OneInputBprop() net = OneInputBprop()
input1 = Tensor(np.ones([2, 2]).astype(np.float32)) input1 = Tensor(np.ones([2, 2]).astype(np.float32))
grad = C.grad_all(net)(input1) grad = grad_all(net)(input1)
assert (grad[0].asnumpy() == np.array([5, 5]).astype(np.float32)).all() assert (grad[0].asnumpy() == np.array([5, 5]).astype(np.float32)).all()
@ -253,10 +257,10 @@ class InlineBpropTwoInput(nn.Cell):
self.f = TwoInput() self.f = TwoInput()
def construct(self, x, y): def construct(self, x, y):
return self.f(x, y), C.grad_all(self.f)(x, y) return self.f(x, y), grad_all(self.f)(x, y)
def bprop(self, x, y, out, dout): def bprop(self, x, y, out, dout):
grads = C.grad_all(self.f)(x, y) grads = grad_all(self.f)(x, y)
return grads[0] * 2, grads[1] * 2 return grads[0] * 2, grads[1] * 2
@pytest.mark.level0 @pytest.mark.level0
@ -266,7 +270,7 @@ def test_grad_inline_bprop_two_input():
net = InlineBpropTwoInput() net = InlineBpropTwoInput()
input1 = Tensor(np.ones([2, 2]).astype(np.float32)) input1 = Tensor(np.ones([2, 2]).astype(np.float32))
input2 = Tensor(np.ones([2, 2]).astype(np.float32)) input2 = Tensor(np.ones([2, 2]).astype(np.float32))
grads = C.grad_all(net)(input1, input2) grads = grad_all(net)(input1, input2)
assert (grads[0].asnumpy() == np.array([2, 2]).astype(np.float32)).all() assert (grads[0].asnumpy() == np.array([2, 2]).astype(np.float32)).all()
assert (grads[1].asnumpy() == np.array([2, 2]).astype(np.float32)).all() assert (grads[1].asnumpy() == np.array([2, 2]).astype(np.float32)).all()
assert len(grads) == 2 assert len(grads) == 2
@ -328,7 +332,7 @@ def test_grad_inline_bprop_multi_input():
input1 = Tensor(np.ones([2, 2]).astype(np.float32)) input1 = Tensor(np.ones([2, 2]).astype(np.float32))
input2 = Tensor(np.ones([2, 2]).astype(np.float32)) input2 = Tensor(np.ones([2, 2]).astype(np.float32))
net.init_parameters_data() net.init_parameters_data()
grads = C.grad_all(net)(input1, input2) grads = grad_all(net)(input1, input2)
assert (grads[0].asnumpy() == np.array([[12, 12], [12, 12]]).astype(np.float32)).all() assert (grads[0].asnumpy() == np.array([[12, 12], [12, 12]]).astype(np.float32)).all()
assert (grads[1].asnumpy() == np.array([[19, 19], [19, 19]]).astype(np.float32)).all() assert (grads[1].asnumpy() == np.array([[19, 19], [19, 19]]).astype(np.float32)).all()
assert len(grads) == 2 assert len(grads) == 2
@ -378,7 +382,7 @@ def test_grad_mul_add_with_wrong_output_num():
context.set_context(check_bprop=True) context.set_context(check_bprop=True)
mul_add = MulAddWithWrongOutputNum() mul_add = MulAddWithWrongOutputNum()
with pytest.raises(TypeError): with pytest.raises(TypeError):
C.grad_all(mul_add)(1, 2) grad_all(mul_add)(1, 2)
class MulAddWithWrongOutputType(nn.Cell): class MulAddWithWrongOutputType(nn.Cell):
@ -395,7 +399,7 @@ def test_grad_mul_add_with_wrong_output_type():
context.set_context(check_bprop=True) context.set_context(check_bprop=True)
mul_add = MulAddWithWrongOutputType() mul_add = MulAddWithWrongOutputType()
with pytest.raises(TypeError): with pytest.raises(TypeError):
C.grad_all(mul_add)(1, Tensor(np.ones([2, 2]))) grad_all(mul_add)(1, Tensor(np.ones([2, 2])))
class MulAddWithWrongOutputShape(nn.Cell): class MulAddWithWrongOutputShape(nn.Cell):
@ -416,4 +420,4 @@ def test_grad_mul_add_with_wrong_output_shape():
context.set_context(check_bprop=True) context.set_context(check_bprop=True)
mul_add = MulAddWithWrongOutputShape() mul_add = MulAddWithWrongOutputShape()
with pytest.raises(TypeError): with pytest.raises(TypeError):
C.grad_all(mul_add)(1, Tensor(np.ones([2, 2]))) grad_all(mul_add)(1, Tensor(np.ones([2, 2])))

@ -22,6 +22,10 @@ from mindspore import Tensor
from mindspore.ops import composite as C from mindspore.ops import composite as C
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
grad_with_sens = C.GradOperation('grad_with_sens', sens_param=True)
class Net(nn.Cell): class Net(nn.Cell):
"""Net definition""" """Net definition"""
@ -52,6 +56,6 @@ def test_grad_net():
x = np.array([1.0, 4.0, 9.0]).astype(np.float32) x = np.array([1.0, 4.0, 9.0]).astype(np.float32)
sens = np.array([1.0, 1.0, 1.0]).astype(np.float32) sens = np.array([1.0, 1.0, 1.0]).astype(np.float32)
square = Net() square = Net()
dx = C.grad_with_sens(square)(Tensor(x), Tensor(sens)) dx = grad_with_sens(square)(Tensor(x), Tensor(sens))
expect = np.array([2.0, 8.0, 18.0]).astype(np.float32) expect = np.array([2.0, 8.0, 18.0]).astype(np.float32)
assert (dx.asnumpy() == expect).all() assert (dx.asnumpy() == expect).all()

@ -30,6 +30,9 @@ from mindspore.common.initializer import TruncatedNormal
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
grad_all = C.GradOperation('get_all', get_all=True)
def weight_variable(): def weight_variable():
"""weight initial""" """weight initial"""
return TruncatedNormal(0.02) return TruncatedNormal(0.02)
@ -121,9 +124,6 @@ class test_custom_cell_base():
class MulAdd(nn.Cell): class MulAdd(nn.Cell):
def __init__(self):
super(MulAdd, self).__init__()
def construct(self, x, y): def construct(self, x, y):
return 2 * x + y return 2 * x + y
@ -181,8 +181,8 @@ def test_pynative_custom_bprop_and_Cell_MulAdd():
custom_cell = test_custom_cell_base() custom_cell = test_custom_cell_base()
mul_add = custom_cell.test_custom_cell_function(MulAdd()) mul_add = custom_cell.test_custom_cell_function(MulAdd())
mul_add.bprop_debug = True mul_add.bprop_debug = True
C.grad_all(mul_add)(Tensor(1, mstype.float32), Tensor(2, mstype.float32)) grad_all(mul_add)(Tensor(1, mstype.float32), Tensor(2, mstype.float32))
assert C.grad_all(mul_add)(Tensor(1, mstype.float32), Tensor(2, mstype.float32)) == \ assert grad_all(mul_add)(Tensor(1, mstype.float32), Tensor(2, mstype.float32)) == \
(Tensor(1.0, mstype.float32), Tensor(2.0, mstype.float32)) (Tensor(1.0, mstype.float32), Tensor(2.0, mstype.float32))
@ -194,5 +194,5 @@ def test_pynative_custom_bprop_and_Cell_Ms_Cell():
custom_cell = test_custom_cell_base() custom_cell = test_custom_cell_base()
ms_Cell = custom_cell.test_custom_cell_function(Ms_Cell()) ms_Cell = custom_cell.test_custom_cell_function(Ms_Cell())
ms_Cell.bprop_debug = True ms_Cell.bprop_debug = True
assert C.grad_all(ms_Cell)(Tensor(1, mstype.float32)) == (Tensor(1.0, mstype.float32),) assert grad_all(ms_Cell)(Tensor(1, mstype.float32)) == (Tensor(1.0, mstype.float32),)

@ -29,6 +29,9 @@ from mindspore.ops import operations as P
np.random.seed(1) np.random.seed(1)
grad_by_list = C.GradOperation('get_by_list', get_by_list=True)
def weight_variable(): def weight_variable():
"""weight initial""" """weight initial"""
return TruncatedNormal(0.02) return TruncatedNormal(0.02)
@ -122,7 +125,7 @@ class GradWrap(nn.Cell):
def construct(self, x, label): def construct(self, x, label):
weights = self.weights weights = self.weights
return C.grad_by_list(self.network, weights)(x, label) return grad_by_list(self.network, weights)(x, label)
@pytest.mark.level0 @pytest.mark.level0

@ -40,6 +40,9 @@ np.random.seed(1)
ds.config.set_seed(1) ds.config.set_seed(1)
grad_by_list = CP.GradOperation('get_by_list', get_by_list=True)
def weight_variable(shape): def weight_variable(shape):
return initializer('XavierUniform', shape=shape, dtype=mstype.float32) return initializer('XavierUniform', shape=shape, dtype=mstype.float32)
@ -389,7 +392,7 @@ class GradWrap(Cell):
def construct(self, x, label): def construct(self, x, label):
weights = self.weights weights = self.weights
return CP.grad_by_list(self.network, weights)(x, label) return grad_by_list(self.network, weights)(x, label)
@pytest.mark.level0 @pytest.mark.level0

@ -24,6 +24,9 @@ from mindspore.common.parameter import ParameterTuple
from mindspore.ops import composite as C from mindspore.ops import composite as C
grad_by_list_with_sens = C.GradOperation('grad_by_list_with_sens', get_by_list=True, sens_param=True)
def setup_module(): def setup_module():
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
@ -319,9 +322,6 @@ def test_setitem_by_mixed_tensors_2():
class TensorGetItemByMixedTensorsTypeError(Cell): class TensorGetItemByMixedTensorsTypeError(Cell):
def __init__(self):
super(TensorGetItemByMixedTensorsTypeError, self).__init__()
def construct(self, x, index_0, index_1): def construct(self, x, index_0, index_1):
ret = x[index_0, index_1, 0:3, ..., 0:5, [1, 2, 3, 4]] ret = x[index_0, index_1, 0:3, ..., 0:5, [1, 2, 3, 4]]
return ret return ret
@ -667,7 +667,7 @@ def test_setitem_grad():
self.weights = ParameterTuple(net.trainable_params()) self.weights = ParameterTuple(net.trainable_params())
def construct(self, x, y, sens): def construct(self, x, y, sens):
return C.grad_by_list_with_sens(self.net, self.weights)(x, y, sens) return grad_by_list_with_sens(self.net, self.weights)(x, y, sens)
net = GradNet(Net()) net = GradNet(Net())
x = Tensor(np.ones([4, 4, 5]).astype(np.float32), mstype.float32) x = Tensor(np.ones([4, 4, 5]).astype(np.float32), mstype.float32)
y = Tensor(np.array([3]).astype(np.float32), mstype.float32) y = Tensor(np.array([3]).astype(np.float32), mstype.float32)
@ -676,27 +676,18 @@ def test_setitem_grad():
class TensorAssignWithSliceError1(Cell): class TensorAssignWithSliceError1(Cell):
def __init__(self):
super(TensorAssignWithSliceError1, self).__init__()
def construct(self, a, b): def construct(self, a, b):
a[1:3:-1, ::] = b a[1:3:-1, ::] = b
return a return a
class TensorAssignWithSliceError2(Cell): class TensorAssignWithSliceError2(Cell):
def __init__(self):
super(TensorAssignWithSliceError2, self).__init__()
def construct(self, a, b): def construct(self, a, b):
a[1:3:-1] = b a[1:3:-1] = b
return a return a
class TensorAssignWithSlice2(Cell): class TensorAssignWithSlice2(Cell):
def __init__(self):
super(TensorAssignWithSlice2, self).__init__()
def construct(self, a, b, ck): def construct(self, a, b, ck):
a[1:5] = b a[1:5] = b
a[3:4] = 5 a[3:4] = 5
@ -864,18 +855,12 @@ def test_tensor_assign_exception():
class TensorAssignWithTupleEllipsis2(Cell): class TensorAssignWithTupleEllipsis2(Cell):
def __init__(self):
super(TensorAssignWithTupleEllipsis2, self).__init__()
def construct(self, a, b): def construct(self, a, b):
a[1:, ..., ::] = b a[1:, ..., ::] = b
return a return a
class TensorAssignWithTupleEllipsis(Cell): class TensorAssignWithTupleEllipsis(Cell):
def __init__(self):
super(TensorAssignWithTupleEllipsis, self).__init__()
def construct(self, a, b): def construct(self, a, b):
a[:2, ...] = 1.0 a[:2, ...] = 1.0
a[1:, ...] = b a[1:, ...] = b
@ -883,9 +868,6 @@ class TensorAssignWithTupleEllipsis(Cell):
class TensorAssignWithEllipsis(Cell): class TensorAssignWithEllipsis(Cell):
def __init__(self):
super(TensorAssignWithEllipsis, self).__init__()
def construct(self, a, b): def construct(self, a, b):
a[...] = 1 a[...] = 1
a[...] = b a[...] = b
@ -893,9 +875,6 @@ class TensorAssignWithEllipsis(Cell):
class TensorAssignWithInteger(Cell): class TensorAssignWithInteger(Cell):
def __init__(self):
super(TensorAssignWithInteger, self).__init__()
def construct(self, a, b, ck): def construct(self, a, b, ck):
a[1] = 1 a[1] = 1
a[0] = b a[0] = b
@ -904,9 +883,6 @@ class TensorAssignWithInteger(Cell):
class TensorAssignWithTupleInteger(Cell): class TensorAssignWithTupleInteger(Cell):
def __init__(self):
super(TensorAssignWithTupleInteger, self).__init__()
def construct(self, a, b, ck): def construct(self, a, b, ck):
a[(1)] = 1 a[(1)] = 1
a[(1)] = b a[(1)] = b
@ -930,9 +906,6 @@ class TensorAssignWithBoolTensorIndex(Cell):
class TensorAssignWithBoolTensorIndexError(Cell): class TensorAssignWithBoolTensorIndexError(Cell):
def __init__(self):
super(TensorAssignWithBoolTensorIndexError, self).__init__()
def construct(self, a, b, c, u_tensor): def construct(self, a, b, c, u_tensor):
a[b][c] = u_tensor a[b][c] = u_tensor
return a return a
@ -955,9 +928,6 @@ class TensorAssignWithBoolTensorIndex2(Cell):
class TensorAssignWithBoolTensorIndex2Error(Cell): class TensorAssignWithBoolTensorIndex2Error(Cell):
def __init__(self):
super(TensorAssignWithBoolTensorIndex2Error, self).__init__()
def construct(self, a, u_tensor): def construct(self, a, u_tensor):
a[a > 8][a > 5] = u_tensor a[a > 8][a > 5] = u_tensor
return a return a

@ -31,6 +31,9 @@ from tests.mindspore_test_framework.pipeline.forward.compile_forward \
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
grad_all = C.GradOperation('get_all', get_all=True)
def test_list_equal(): def test_list_equal():
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self, z: list): def __init__(self, z: list):
@ -303,7 +306,7 @@ def test_grad_make_list():
self.net = net self.net = net
def construct(self, *inputs): def construct(self, *inputs):
return C.grad_all(self.net)(*inputs) return grad_all(self.net)(*inputs)
while_net = MyWhileNet() while_net = MyWhileNet()
net = GradNet(while_net) net = GradNet(while_net)

@ -18,8 +18,11 @@ import numpy as np
from mindspore import Parameter, ParameterTuple, Tensor from mindspore import Parameter, ParameterTuple, Tensor
from mindspore.nn import Cell from mindspore.nn import Cell
from mindspore.nn.optim import Optimizer from mindspore.nn.optim import Optimizer
from mindspore.ops.composite import grad_by_list
from mindspore.ops.operations import BiasAdd, MatMul from mindspore.ops.operations import BiasAdd, MatMul
import mindspore.ops.composite as C
grad_by_list = C.GradOperation('get_by_list', get_by_list=True)
class Net(Cell): class Net(Cell):

@ -28,6 +28,9 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \
import pipeline_for_compile_forward_ge_graph_for_case_by_case_config import pipeline_for_compile_forward_ge_graph_for_case_by_case_config
grad_by_list_with_sens = C.GradOperation('grad_by_list_with_sens', get_by_list=True, sens_param=True)
class DisOrderTest1(nn.Cell): class DisOrderTest1(nn.Cell):
""" DisOrderTest1 definition """ """ DisOrderTest1 definition """
@ -72,7 +75,7 @@ class GradNetWrap(nn.Cell):
self.weights = ParameterTuple(net.get_parameters()) self.weights = ParameterTuple(net.get_parameters())
def construct(self, x, sens): def construct(self, x, sens):
return C.grad_by_list_with_sens(self.net, self.weights)(x, sens) return grad_by_list_with_sens(self.net, self.weights)(x, sens)
test_case_ops = [ test_case_ops = [

@ -30,6 +30,11 @@ from mindspore.common import ms_function
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE)
grad_by_list = C.GradOperation('get_by_list', get_by_list=True)
grad_all = C.GradOperation('get_all', get_all=True)
grad_all_with_sens = C.GradOperation('grad_all_with_sens', get_all=True, sens_param=True)
def cond_data_test(x_init, y_init): def cond_data_test(x_init, y_init):
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self): def __init__(self):
@ -401,9 +406,9 @@ def test_switch_layer():
index = Tensor(0, dtype=mstype.int32) index = Tensor(0, dtype=mstype.int32)
net = SwitchLayerCell() net = SwitchLayerCell()
net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
C.grad_by_list(net, ParameterTuple(net.trainable_params()))(index, grad_by_list(net, ParameterTuple(net.trainable_params()))(index,
Tensor(np.full([128, 96], 0.6, dtype=np.float32))) Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
def test_index_to_switch_layer(): def test_index_to_switch_layer():
@ -439,9 +444,9 @@ def test_index_to_switch_layer():
index = Tensor(0, dtype=mstype.int32) index = Tensor(0, dtype=mstype.int32)
net = SwitchLayerCell() net = SwitchLayerCell()
net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
C.grad_by_list(net, ParameterTuple(net.trainable_params()))(index, grad_by_list(net, ParameterTuple(net.trainable_params()))(index,
Tensor(np.full([128, 96], 0.6, dtype=np.float32))) Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
def test_parser_switch_layer_switch_in_bprop(): def test_parser_switch_layer_switch_in_bprop():
@ -477,7 +482,7 @@ def test_parser_switch_layer_switch_in_bprop():
input1 = Tensor(np.ones([2, 2]).astype(np.float32)) input1 = Tensor(np.ones([2, 2]).astype(np.float32))
grad = Tensor(np.random.randn(2, 2).astype(np.float32)) grad = Tensor(np.random.randn(2, 2).astype(np.float32))
i = Tensor(1, mstype.int32) i = Tensor(1, mstype.int32)
grad_net = C.grad_all_with_sens(net) grad_net = grad_all_with_sens(net)
grad_net(i, input1, grad) grad_net(i, input1, grad)
@ -520,7 +525,7 @@ def test_parser_switch_layer_inputs_tuple():
input2 = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32)) input2 = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
i = Tensor(1, mstype.int32) i = Tensor(1, mstype.int32)
grad = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32)) grad = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
back_net = C.grad_all_with_sens(net) back_net = grad_all_with_sens(net)
back_out = back_net(i, input1, input2, grad) back_out = back_net(i, input1, input2, grad)
@ -539,9 +544,9 @@ def test_switch_layer_with_single_prim():
index = Tensor(0, dtype=mstype.int32) index = Tensor(0, dtype=mstype.int32)
net = SwitchLayerCell() net = SwitchLayerCell()
net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
C.grad_by_list(net, ParameterTuple(net.trainable_params()))(index, grad_by_list(net, ParameterTuple(net.trainable_params()))(index,
Tensor(np.full([128, 96], 0.6, dtype=np.float32))) Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32)))
def test_switch_layer_env_eliminate(): def test_switch_layer_env_eliminate():

@ -38,6 +38,8 @@ context.set_context(mode=context.GRAPH_MODE)
# W0613: unused-argument # W0613: unused-argument
# W0231: super-init-not-called # W0231: super-init-not-called
grad = C.GradOperation('grad')
def test_multiply(): def test_multiply():
""" test_multiply """ """ test_multiply """
input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]])) input_x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]))
@ -200,7 +202,7 @@ class GradWrap(nn.Cell):
self.network = network self.network = network
def construct(self, x, y, b): def construct(self, x, y, b):
return C.grad(self.network)(x, y, b) return grad(self.network)(x, y, b)
class MatMulNet(nn.Cell): class MatMulNet(nn.Cell):
@ -236,7 +238,7 @@ class GradWrapSub(nn.Cell):
self.network = network self.network = network
def construct(self, x, y): def construct(self, x, y):
return C.grad(self.network)(x, y) return grad(self.network)(x, y)
class SubNet(nn.Cell): class SubNet(nn.Cell):
@ -315,7 +317,7 @@ class GradWrapCumSum(nn.Cell):
self.network = network self.network = network
def construct(self, input_): def construct(self, input_):
return C.grad(self.network)(input_) return grad(self.network)(input_)
class NetCumSum(nn.Cell): class NetCumSum(nn.Cell):

@ -34,6 +34,9 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \
run_opt = C.MultitypeFuncGraph("run_opt") run_opt = C.MultitypeFuncGraph("run_opt")
grad_by_list = C.GradOperation('get_by_list', get_by_list=True)
@run_opt.register("Function", "Tensor", "Tensor", "Tensor", @run_opt.register("Function", "Tensor", "Tensor", "Tensor",
"Tensor", "Tensor", "Tensor", "Tensor",
"Tensor") "Tensor")
@ -83,7 +86,7 @@ class TrainStepWrap(nn.Cell):
def construct(self, x, label): def construct(self, x, label):
weights = self.weights weights = self.weights
grads = C.grad_by_list(self.network, weights)(x, label) grads = grad_by_list(self.network, weights)(x, label)
return self.optimizer(grads) return self.optimizer(grads)

@ -45,6 +45,10 @@ def conv1x1(in_channels, out_channels, stride=1, padding=0):
kernel_size=1, stride=stride, padding=padding) kernel_size=1, stride=stride, padding=padding)
grad = C.GradOperation('grad')
grad_all_with_sens = C.GradOperation('grad_all_with_sens', get_all=True, sens_param=True)
class ResidualBlock(nn.Cell): class ResidualBlock(nn.Cell):
""" """
residual Block residual Block
@ -169,7 +173,7 @@ class SoftMaxGrad(nn.Cell):
self.network = network self.network = network
def construct(self, x): def construct(self, x):
return C.grad(self.network)(x) return grad(self.network)(x)
class DropoutGrad(nn.Cell): class DropoutGrad(nn.Cell):
@ -180,7 +184,7 @@ class DropoutGrad(nn.Cell):
self.network = network self.network = network
def construct(self, x): def construct(self, x):
return C.grad(self.network)(x) return grad(self.network)(x)
class ScalarSummaryNet(nn.Cell): class ScalarSummaryNet(nn.Cell):
@ -255,7 +259,7 @@ class Grad(nn.Cell):
self.network.set_train() self.network.set_train()
def construct(self, x, label): def construct(self, x, label):
return C.grad(self.network)(x, label) return grad(self.network)(x, label)
class BatchnormNet(nn.Cell): class BatchnormNet(nn.Cell):
@ -418,7 +422,7 @@ class GradWrapUnfold(nn.Cell):
self.sens = Tensor(np.ones([1, 4, 2, 2], np.float32)) self.sens = Tensor(np.ones([1, 4, 2, 2], np.float32))
def construct(self, x): def construct(self, x):
return C.grad_all_with_sens(self.network)(x, self.sens) return grad_all_with_sens(self.network)(x, self.sens)
class UnfoldNetValid(nn.Cell): class UnfoldNetValid(nn.Cell):

@ -34,12 +34,16 @@ from ....mindspore_test_framework.pipeline.forward.compile_forward \
from ....mindspore_test_framework.pipeline.gradient.compile_gradient \ from ....mindspore_test_framework.pipeline.gradient.compile_gradient \
import pipeline_for_compile_grad_ge_graph_for_case_by_case_config import pipeline_for_compile_grad_ge_graph_for_case_by_case_config
grad_all_with_sens = C.GradOperation('grad_all_with_sens', get_all=True, sens_param=True)
class InputBackward(nn.Cell): class InputBackward(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(InputBackward, self).__init__() super(InputBackward, self).__init__()
self.network = network self.network = network
self.network.set_train() self.network.set_train()
self.grad = C.grad_all_with_sens self.grad = grad_all_with_sens
def construct(self, x1, x2, x3, sens): def construct(self, x1, x2, x3, sens):
return self.grad(self.network)(x1, x2, x3, sens) return self.grad(self.network)(x1, x2, x3, sens)

@ -24,6 +24,9 @@ from mindspore.parallel._utils import _set_has_initializer
from tests.ut.python.ops.test_math_ops import VirtualLoss from tests.ut.python.ops.test_math_ops import VirtualLoss
grad_all = C.GradOperation('get_all', get_all=True)
class AddRelu(nn.Cell): class AddRelu(nn.Cell):
def __init__(self, strategy0=None, strategy1=None): def __init__(self, strategy0=None, strategy1=None):
super(AddRelu, self).__init__() super(AddRelu, self).__init__()
@ -52,7 +55,7 @@ class Grad(nn.Cell):
self.network = network self.network = network
def construct(self, x, y): def construct(self, x, y):
return C.grad_all(self.network)(x, y) return grad_all(self.network)(x, y)
def compile_net(net, x, y): def compile_net(net, x, y):

@ -24,6 +24,9 @@ from mindspore.parallel._utils import _set_has_initializer
from tests.ut.python.ops.test_math_ops import VirtualLoss from tests.ut.python.ops.test_math_ops import VirtualLoss
grad_all = C.GradOperation('get_all', get_all=True)
class NetWithLoss(nn.Cell): class NetWithLoss(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(NetWithLoss, self).__init__() super(NetWithLoss, self).__init__()
@ -41,7 +44,7 @@ class GradWrap(nn.Cell):
self.network = network self.network = network
def construct(self, x, y, b): def construct(self, x, y, b):
return C.grad_all(self.network)(x, y, b) return grad_all(self.network)(x, y, b)
def compile_net(net, x, y, b): def compile_net(net, x, y, b):
@ -516,7 +519,7 @@ def test_assign_sub():
self.network = network self.network = network
def construct(self, x): def construct(self, x):
return C.grad_all(self.network)(x) return grad_all(self.network)(x)
def compile_sub_net(net, x): def compile_sub_net(net, x):
net.set_auto_parallel() net.set_auto_parallel()

@ -27,6 +27,9 @@ from mindspore.common.parameter import Parameter
from tests.ut.python.ops.test_math_ops import VirtualLoss from tests.ut.python.ops.test_math_ops import VirtualLoss
grad_all = C.GradOperation('get_all', get_all=True)
class NetWithLoss(nn.Cell): class NetWithLoss(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(NetWithLoss, self).__init__() super(NetWithLoss, self).__init__()
@ -44,7 +47,7 @@ class GradWrap(nn.Cell):
self.network = network self.network = network
def construct(self, x): def construct(self, x):
return C.grad_all(self.network)(x) return grad_all(self.network)(x)
def compile_net(net, x): def compile_net(net, x):

@ -23,6 +23,9 @@ from mindspore.ops import composite as C
from tests.ut.python.ops.test_math_ops import VirtualLoss from tests.ut.python.ops.test_math_ops import VirtualLoss
grad_all = C.GradOperation('get_all', get_all=True)
class NetWithLoss(nn.Cell): class NetWithLoss(nn.Cell):
def __init__(self, network): def __init__(self, network):
super(NetWithLoss, self).__init__() super(NetWithLoss, self).__init__()
@ -45,7 +48,7 @@ class GradWrap(nn.Cell):
self.network = network self.network = network
def construct(self, x): def construct(self, x):
return C.grad_all(self.network)(x) return grad_all(self.network)(x)
# model_parallel test # model_parallel test

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save