|
|
|
@ -77,8 +77,8 @@ class Bprop(Cell):
|
|
|
|
|
self.grad = grad_op
|
|
|
|
|
self.with_sens = False
|
|
|
|
|
self.sens = sens
|
|
|
|
|
if sens:
|
|
|
|
|
self.sens = Tensor(sens, dtype=mstype.float32)
|
|
|
|
|
if not sens is None:
|
|
|
|
|
self.sens = sens if isinstance(sens, Tensor) else Tensor(sens, dtype=mstype.float32)
|
|
|
|
|
self.with_sens = True
|
|
|
|
|
|
|
|
|
|
def construct(self, *inputs):
|
|
|
|
@ -108,7 +108,7 @@ def test_all_var_args_grad_with_sens():
|
|
|
|
|
|
|
|
|
|
x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
|
|
|
|
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
|
|
|
|
sens = Tensor(1.0, dtype=mstype.float32)
|
|
|
|
|
sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
|
|
|
|
net = VarNet(SecondNet())
|
|
|
|
|
grad_net = GradNet(net)
|
|
|
|
|
_ = grad_net(x, y, sens)
|
|
|
|
@ -160,7 +160,7 @@ def test_grad_all_var_args_with_sens():
|
|
|
|
|
|
|
|
|
|
x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
|
|
|
|
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
|
|
|
|
sens = Tensor(1.0, dtype=mstype.float32)
|
|
|
|
|
sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
|
|
|
|
net = VarNet(SecondNet())
|
|
|
|
|
grad_net = GradNet(net)
|
|
|
|
|
_ = grad_net(x, y, sens)
|
|
|
|
@ -178,7 +178,7 @@ def test_grad_var_args_with_sens():
|
|
|
|
|
|
|
|
|
|
x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
|
|
|
|
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
|
|
|
|
sens = Tensor(1.0, dtype=mstype.float32)
|
|
|
|
|
sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
|
|
|
|
net = VarNet(SecondNet())
|
|
|
|
|
grad_net = GradNet(net)
|
|
|
|
|
_ = grad_net(x, y, sens)
|
|
|
|
@ -237,7 +237,7 @@ def test_var_args_grad():
|
|
|
|
|
|
|
|
|
|
x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
|
|
|
|
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
|
|
|
|
sens = Tensor(1.0, dtype=mstype.float32)
|
|
|
|
|
sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
|
|
|
|
net = VarNet(SecondNet())
|
|
|
|
|
grad_net = GradNet(net)
|
|
|
|
|
_ = grad_net(x, y, sens)
|
|
|
|
@ -285,14 +285,14 @@ def test_grad_within_if_else():
|
|
|
|
|
self.net = net
|
|
|
|
|
grad_op = C.GradOperation(
|
|
|
|
|
name='grad', get_all=False, get_by_list=True, sens_param=True)
|
|
|
|
|
self.grad = Bprop(self.net, True, self.weights, grad_op, 1.0)
|
|
|
|
|
sens = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
|
|
|
|
self.grad = Bprop(self.net, True, self.weights, grad_op, sens)
|
|
|
|
|
|
|
|
|
|
def construct(self, *inputs):
|
|
|
|
|
return self.grad(*inputs)
|
|
|
|
|
|
|
|
|
|
x = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
|
|
|
|
y = Tensor(np.ones([3, 4, 5]), dtype=mstype.float32)
|
|
|
|
|
_ = Tensor(1.0, dtype=mstype.float32)
|
|
|
|
|
net = VarNet(SecondNet())
|
|
|
|
|
grad_net = GradNet(net)
|
|
|
|
|
out = grad_net(x, y)
|
|
|
|
|