|
|
|
@ -219,7 +219,7 @@ class OneInputBprop(nn.Cell):
|
|
|
|
|
return self.op(x)
|
|
|
|
|
|
|
|
|
|
def bprop(self, x, out, dout):
|
|
|
|
|
return 5 * x,
|
|
|
|
|
return (5 * x,)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_grad_one_input_bprop():
|
|
|
|
@ -349,7 +349,7 @@ class MulAddWithWrongOutputNum(nn.Cell):
|
|
|
|
|
return 2 * x + y
|
|
|
|
|
|
|
|
|
|
def bprop(self, x, y, out, dout):
|
|
|
|
|
return 2 * dout,
|
|
|
|
|
return (2 * dout,)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_grad_mul_add_with_wrong_output_num():
|
|
|
|
@ -380,7 +380,7 @@ def test_grad_mul_add_with_wrong_output_type():
|
|
|
|
|
class MulAddWithWrongOutputShape(nn.Cell):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(MulAddWithWrongOutputShape, self).__init__()
|
|
|
|
|
self.ones = Tensor(np.ones([2, ]))
|
|
|
|
|
self.ones = Tensor(np.ones([2,]))
|
|
|
|
|
|
|
|
|
|
def construct(self, x, y):
|
|
|
|
|
return 2 * x + y
|
|
|
|
|