|
|
|
@ -149,28 +149,46 @@ def monkey_patch_math_varbase():
|
|
|
|
|
reverse=False,
|
|
|
|
|
scalar_method=None):
|
|
|
|
|
def __impl__(self, other_var):
|
|
|
|
|
# tensor and ComplexVariable opetator
|
|
|
|
|
# 0. check tensor and ComplexVariable opetator
|
|
|
|
|
if isinstance(other_var, ComplexVariable):
|
|
|
|
|
# need import paddle in closure
|
|
|
|
|
import paddle
|
|
|
|
|
math_op = getattr(paddle.incubate.complex.tensor, op_type)
|
|
|
|
|
return math_op(self, other_var)
|
|
|
|
|
|
|
|
|
|
# FIXME(zjl): elementwise_div between integers cannot be converted to scale,
|
|
|
|
|
# which may lose accuracy. This is a hot fix for release 1.6.
|
|
|
|
|
if scalar_method is not None and not (
|
|
|
|
|
op_type == 'elementwise_div' and
|
|
|
|
|
self.dtype in _supported_int_dtype_):
|
|
|
|
|
if isinstance(other_var, float):
|
|
|
|
|
if self.dtype in _supported_int_dtype_:
|
|
|
|
|
assert other_var == int(other_var), \
|
|
|
|
|
"float value {} cannot convert to integer".format(other_var)
|
|
|
|
|
# 1. scalar exists cases
|
|
|
|
|
# we need combine the tensor.dtype and scalar.dtype, cast correct object
|
|
|
|
|
if isinstance(other_var, float):
|
|
|
|
|
# in all cases(+, -, *, /, **, //, %), we need cast tensor.dtype to float
|
|
|
|
|
if self.dtype in _supported_int_dtype_:
|
|
|
|
|
self = astype(self, 'float32')
|
|
|
|
|
# here use `scale` replace `elementwise` to get better performance
|
|
|
|
|
# but only +, -, *, / can use this method
|
|
|
|
|
if scalar_method is not None:
|
|
|
|
|
return scalar_method(self, other_var)
|
|
|
|
|
elif isinstance(other_var, int):
|
|
|
|
|
return scalar_method(self, float(other_var))
|
|
|
|
|
elif isinstance(other_var, int):
|
|
|
|
|
# in all cases(+, -, *, /, **, //, %), we can cast it to float
|
|
|
|
|
# because the output tensor.dtype depend on the type of input tensor
|
|
|
|
|
other_var = float(other_var)
|
|
|
|
|
# division is a special case
|
|
|
|
|
# NOTE(chenweihang): because we cast tensor to float32 instead float64,
|
|
|
|
|
# the division result can only guarantee the numerical accuracy of 6 digits
|
|
|
|
|
# after the decimal point. The result of numpy calculation is of float64 type,
|
|
|
|
|
# so the calculation result here and the calculation result of numpy are
|
|
|
|
|
# different after 6 decimal point. If necessary, we can also use float64 here.
|
|
|
|
|
# torch's behavior here is consistent with ours
|
|
|
|
|
if op_type == 'elementwise_div' and self.dtype in _supported_int_dtype_:
|
|
|
|
|
self = astype(self, 'float32')
|
|
|
|
|
# here use `scale` replace `elementwise` to get better performance
|
|
|
|
|
# but only +, -, *, / can use this method
|
|
|
|
|
if scalar_method is not None:
|
|
|
|
|
return scalar_method(self, other_var)
|
|
|
|
|
else:
|
|
|
|
|
# do nothing
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
# 2. create varbase for scalar
|
|
|
|
|
lhs_dtype = self.dtype
|
|
|
|
|
|
|
|
|
|
if not isinstance(other_var, core.VarBase):
|
|
|
|
|
if reverse:
|
|
|
|
|
other_var = create_tensor(
|
|
|
|
@ -179,6 +197,7 @@ def monkey_patch_math_varbase():
|
|
|
|
|
# add fill_op
|
|
|
|
|
other_var = create_scalar(value=other_var, dtype=lhs_dtype)
|
|
|
|
|
|
|
|
|
|
# 3. unify right var type to left var
|
|
|
|
|
rhs_dtype = other_var.dtype
|
|
|
|
|
if lhs_dtype != rhs_dtype:
|
|
|
|
|
other_var = astype(other_var, lhs_dtype)
|
|
|
|
@ -187,6 +206,7 @@ def monkey_patch_math_varbase():
|
|
|
|
|
self = other_var
|
|
|
|
|
other_var = tmp
|
|
|
|
|
|
|
|
|
|
# 4. calculation
|
|
|
|
|
axis = -1
|
|
|
|
|
math_op = getattr(core.ops, op_type)
|
|
|
|
|
return math_op(self, other_var, 'axis', axis)
|
|
|
|
|