|
|
|
@ -19,7 +19,7 @@ __all__ = ['monkey_patch_variable']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def monkey_patch_variable():
|
|
|
|
|
def new_name():
|
|
|
|
|
def unique_tmp_name():
|
|
|
|
|
return unique_name("tmp")
|
|
|
|
|
|
|
|
|
|
def safe_get_dtype(var):
|
|
|
|
@ -29,21 +29,9 @@ def monkey_patch_variable():
|
|
|
|
|
raise ValueError("Cannot get data type from %s", var.name)
|
|
|
|
|
return dtype
|
|
|
|
|
|
|
|
|
|
def create_scalar(block, value, dtype):
|
|
|
|
|
value = float(value)
|
|
|
|
|
tmp_name = new_name()
|
|
|
|
|
var = block.create_var(name=tmp_name, shape=[1], dtype=dtype)
|
|
|
|
|
block.append_op(
|
|
|
|
|
type="fill",
|
|
|
|
|
outputs={"Out": [var]},
|
|
|
|
|
attrs={"value": [value],
|
|
|
|
|
"shape": [1],
|
|
|
|
|
"dtype": dtype})
|
|
|
|
|
return var
|
|
|
|
|
|
|
|
|
|
def create_tensor(block, value, dtype, shape):
|
|
|
|
|
value = float(value)
|
|
|
|
|
tmp_name = new_name()
|
|
|
|
|
tmp_name = unique_tmp_name()
|
|
|
|
|
var = block.create_var(name=tmp_name, shape=shape, dtype=dtype)
|
|
|
|
|
block.append_op(
|
|
|
|
|
type="fill_constant",
|
|
|
|
@ -53,10 +41,13 @@ def monkey_patch_variable():
|
|
|
|
|
'value': value})
|
|
|
|
|
return var
|
|
|
|
|
|
|
|
|
|
def create_scalar(block, value, dtype):
|
|
|
|
|
return create_tensor(block, value, dtype, shape=[1])
|
|
|
|
|
|
|
|
|
|
def create_tensor_with_batchsize(ref_var, value, dtype):
|
|
|
|
|
assert isinstance(ref_var, Variable)
|
|
|
|
|
value = float(value)
|
|
|
|
|
tmp_name = new_name()
|
|
|
|
|
tmp_name = unique_tmp_name()
|
|
|
|
|
var = ref_var.block.create_var(name=tmp_name, dtype=dtype)
|
|
|
|
|
ref_var.block.append_op(
|
|
|
|
|
type='fill_constant_batch_size_like',
|
|
|
|
@ -68,7 +59,7 @@ def monkey_patch_variable():
|
|
|
|
|
|
|
|
|
|
def astype(self, dtype):
|
|
|
|
|
"""
|
|
|
|
|
Cast a variable to data type.
|
|
|
|
|
Cast a variable to a specified data type.
|
|
|
|
|
NOTE: The variable must be a Tensor
|
|
|
|
|
Args:
|
|
|
|
|
self(Variable): The source variable
|
|
|
|
@ -77,7 +68,7 @@ def monkey_patch_variable():
|
|
|
|
|
Returns:
|
|
|
|
|
Variable with new dtype
|
|
|
|
|
"""
|
|
|
|
|
tmp_name = new_name()
|
|
|
|
|
tmp_name = unique_tmp_name()
|
|
|
|
|
out = self.block.create_var(name=tmp_name, dtype=dtype)
|
|
|
|
|
self.block.append_op(
|
|
|
|
|
type="cast",
|
|
|
|
@ -120,7 +111,7 @@ def monkey_patch_variable():
|
|
|
|
|
self = other_var
|
|
|
|
|
other_var = tmp
|
|
|
|
|
|
|
|
|
|
tmp_name = new_name()
|
|
|
|
|
tmp_name = unique_tmp_name()
|
|
|
|
|
out = self.block.create_var(name=tmp_name, dtype=lhs_dtype)
|
|
|
|
|
self.block.append_op(
|
|
|
|
|
type=op_type,
|
|
|
|
|