|
|
|
@ -148,10 +148,15 @@ def assign(input, output):
|
|
|
|
|
dtype = convert_np_dtype_to_dtype_(input.dtype)
|
|
|
|
|
if dtype == DataType.FP32:
|
|
|
|
|
value_name = "fp32_values"
|
|
|
|
|
values = [float(v) for v in input.flat]
|
|
|
|
|
elif dtype == DataType.INT32:
|
|
|
|
|
value_name = "int32_values"
|
|
|
|
|
values = [int(v) for v in input.flat]
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("Unsupported dtype %s", input.dtype)
|
|
|
|
|
if input.size > 1024 * 1024:
|
|
|
|
|
raise ValueError("The size of input is too big. Please consider "
|
|
|
|
|
"saving it to file and 'load_op' to load it")
|
|
|
|
|
|
|
|
|
|
helper.append_op(
|
|
|
|
|
type='assign_value',
|
|
|
|
@ -159,7 +164,7 @@ def assign(input, output):
|
|
|
|
|
attrs={
|
|
|
|
|
'dtype': dtype,
|
|
|
|
|
'shape': list(input.shape),
|
|
|
|
|
value_name: [float(v) for v in input.flat]
|
|
|
|
|
value_name: values
|
|
|
|
|
})
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError("Wrong type for assign input: %s" % type(input))
|
|
|
|
|