|
|
@ -35,11 +35,7 @@ def monkey_patch_math_varbase():
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def safe_get_dtype(var):
|
|
|
|
def safe_get_dtype(var):
|
|
|
|
try:
|
|
|
|
return var.dtype
|
|
|
|
dtype = var.dtype
|
|
|
|
|
|
|
|
except:
|
|
|
|
|
|
|
|
raise ValueError("Cannot get data type from %s", var.name)
|
|
|
|
|
|
|
|
return dtype
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@no_grad
|
|
|
|
@no_grad
|
|
|
|
def create_tensor(value, dtype, shape):
|
|
|
|
def create_tensor(value, dtype, shape):
|
|
|
@ -117,6 +113,9 @@ def monkey_patch_math_varbase():
|
|
|
|
outs = core.ops.scale(inputs, attrs)
|
|
|
|
outs = core.ops.scale(inputs, attrs)
|
|
|
|
return outs['Out'][0]
|
|
|
|
return outs['Out'][0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _neg_(var):
|
|
|
|
|
|
|
|
return _scalar_elementwise_op_(var, -1.0, 0.0)
|
|
|
|
|
|
|
|
|
|
|
|
def _scalar_elementwise_add_(var, value):
|
|
|
|
def _scalar_elementwise_add_(var, value):
|
|
|
|
return _scalar_elementwise_op_(var, 1.0, value)
|
|
|
|
return _scalar_elementwise_op_(var, 1.0, value)
|
|
|
|
|
|
|
|
|
|
|
@ -217,6 +216,7 @@ def monkey_patch_math_varbase():
|
|
|
|
|
|
|
|
|
|
|
|
setattr(core.VarBase, method_name,
|
|
|
|
setattr(core.VarBase, method_name,
|
|
|
|
_elemwise_method_creator_(method_name, op_type, reverse,
|
|
|
|
_elemwise_method_creator_(method_name, op_type, reverse,
|
|
|
|
scalar_method))
|
|
|
|
scalar_method)),
|
|
|
|
|
|
|
|
# b = -a
|
|
|
|
|
|
|
|
core.VarBase.__neg__ = _neg_
|
|
|
|
core.VarBase.astype = astype
|
|
|
|
core.VarBase.astype = astype
|
|
|
|