|
|
|
@ -397,6 +397,22 @@ def get_bprop_xlogy(self):
|
|
|
|
|
|
|
|
|
|
return bprop
|
|
|
|
|
|
|
|
|
|
@bprop_getters.register(P.SquareSumAll)
|
|
|
|
|
def get_bprop_square_sum_all(self):
|
|
|
|
|
"""Grad definition for `Square` operation."""
|
|
|
|
|
mul_func = P.Mul()
|
|
|
|
|
fill_func = P.Fill()
|
|
|
|
|
dtype = P.DType()
|
|
|
|
|
|
|
|
|
|
def bprop(x, y, out, dout):
|
|
|
|
|
temp_x = mul_func(dout[0], x)
|
|
|
|
|
temp_y = mul_func(dout[1], y)
|
|
|
|
|
dx = mul_func(fill_func(dtype(temp_x), shape_op(x), 2.0), temp_x)
|
|
|
|
|
dy = mul_func(fill_func(dtype(temp_y), shape_op(y), 2.0), temp_y)
|
|
|
|
|
return (dx, dy)
|
|
|
|
|
|
|
|
|
|
return bprop
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@bprop_getters.register(P.Sqrt)
|
|
|
|
|
def get_bprop_sqrt(self):
|
|
|
|
|