|
|
|
@ -93,7 +93,7 @@ def _update_param(param, new_param):
|
|
|
|
|
if isinstance(param.data, Tensor) and not isinstance(new_param.data, Tensor):
|
|
|
|
|
if param.data.shape != (1,) and param.data.shape != ():
|
|
|
|
|
logger.error("Failed to combine the net and the parameters for param %s.", param.name)
|
|
|
|
|
msg = ("Net parameters {} shape({}) is not (1,), inconsitent with parameter_dict's(scalar)."
|
|
|
|
|
msg = ("Net parameters {} shape({}) is not (1,), inconsistent with parameter_dict's(scalar)."
|
|
|
|
|
.format(param.name, param.data.shape))
|
|
|
|
|
raise RuntimeError(msg)
|
|
|
|
|
param.set_data(initializer(new_param.data, param.data.shape, param.data.dtype))
|
|
|
|
@ -244,31 +244,7 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
|
|
>>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
|
|
|
|
|
>>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1")
|
|
|
|
|
"""
|
|
|
|
|
if not isinstance(ckpt_file_name, str):
|
|
|
|
|
raise ValueError("The ckpt_file_name must be string.")
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(ckpt_file_name):
|
|
|
|
|
raise ValueError("The checkpoint file is not exist.")
|
|
|
|
|
|
|
|
|
|
if ckpt_file_name[-5:] != ".ckpt":
|
|
|
|
|
raise ValueError("Please input the correct checkpoint file name.")
|
|
|
|
|
|
|
|
|
|
if os.path.getsize(ckpt_file_name) == 0:
|
|
|
|
|
raise ValueError("The checkpoint file may be empty, please make sure enter the correct file name.")
|
|
|
|
|
|
|
|
|
|
if filter_prefix is not None:
|
|
|
|
|
if not isinstance(filter_prefix, (str, list, tuple)):
|
|
|
|
|
raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str] "
|
|
|
|
|
f"when filter_prefix is not None, but got {str(type(filter_prefix))}.")
|
|
|
|
|
if isinstance(filter_prefix, str):
|
|
|
|
|
filter_prefix = (filter_prefix,)
|
|
|
|
|
if not filter_prefix:
|
|
|
|
|
raise ValueError("The filter_prefix can't be empty when filter_prefix is list or tuple.")
|
|
|
|
|
for index, prefix in enumerate(filter_prefix):
|
|
|
|
|
if not isinstance(prefix, str):
|
|
|
|
|
raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str], "
|
|
|
|
|
f"but got {str(type(prefix))} at index {index}.")
|
|
|
|
|
|
|
|
|
|
ckpt_file_name, filter_prefix = _check_checkpoint_param(ckpt_file_name, filter_prefix)
|
|
|
|
|
logger.info("Execute the process of loading checkpoint files.")
|
|
|
|
|
checkpoint_list = Checkpoint()
|
|
|
|
|
|
|
|
|
@ -297,7 +273,6 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
|
|
param_data = np.concatenate((param_data_list), axis=0)
|
|
|
|
|
param_data_list.clear()
|
|
|
|
|
dims = element.tensor.dims
|
|
|
|
|
|
|
|
|
|
if dims == [0]:
|
|
|
|
|
if 'Float' in data_type:
|
|
|
|
|
param_data = float(param_data[0])
|
|
|
|
@ -328,6 +303,32 @@ def load_checkpoint(ckpt_file_name, net=None, strict_load=False, filter_prefix=N
|
|
|
|
|
return parameter_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _check_checkpoint_param(ckpt_file_name, filter_prefix=None):
|
|
|
|
|
"""Check function load_checkpoint's parameter."""
|
|
|
|
|
if not isinstance(ckpt_file_name, str):
|
|
|
|
|
raise ValueError("The ckpt_file_name must be string.")
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(ckpt_file_name):
|
|
|
|
|
raise ValueError("The checkpoint file is not exist.")
|
|
|
|
|
|
|
|
|
|
if ckpt_file_name[-5:] != ".ckpt":
|
|
|
|
|
raise ValueError("Please input the correct checkpoint file name.")
|
|
|
|
|
|
|
|
|
|
if filter_prefix is not None:
|
|
|
|
|
if not isinstance(filter_prefix, (str, list, tuple)):
|
|
|
|
|
raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str] "
|
|
|
|
|
f"when filter_prefix is not None, but got {str(type(filter_prefix))}.")
|
|
|
|
|
if isinstance(filter_prefix, str):
|
|
|
|
|
filter_prefix = (filter_prefix,)
|
|
|
|
|
if not filter_prefix:
|
|
|
|
|
raise ValueError("The filter_prefix can't be empty when filter_prefix is list or tuple.")
|
|
|
|
|
for index, prefix in enumerate(filter_prefix):
|
|
|
|
|
if not isinstance(prefix, str):
|
|
|
|
|
raise TypeError(f"The type of filter_prefix must be str, list[str] or tuple[str], "
|
|
|
|
|
f"but got {str(type(prefix))} at index {index}.")
|
|
|
|
|
return ckpt_file_name, filter_prefix
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_param_into_net(net, parameter_dict, strict_load=False):
|
|
|
|
|
"""
|
|
|
|
|
Loads parameters into network.
|
|
|
|
@ -560,13 +561,15 @@ def _export(net, file_name, file_format, *inputs):
|
|
|
|
|
if file_format == 'AIR':
|
|
|
|
|
phase_name = 'export.air'
|
|
|
|
|
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name)
|
|
|
|
|
file_name += ".air"
|
|
|
|
|
if not file_name.endswith('.air'):
|
|
|
|
|
file_name += ".air"
|
|
|
|
|
_executor.export(file_name, graph_id)
|
|
|
|
|
elif file_format == 'ONNX':
|
|
|
|
|
phase_name = 'export.onnx'
|
|
|
|
|
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
|
|
|
|
|
onnx_stream = _executor._get_func_graph_proto(net, graph_id)
|
|
|
|
|
file_name += ".onnx"
|
|
|
|
|
if not file_name.endswith('.onnx'):
|
|
|
|
|
file_name += ".onnx"
|
|
|
|
|
with open(file_name, 'wb') as f:
|
|
|
|
|
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
|
|
|
|
|
f.write(onnx_stream)
|
|
|
|
@ -574,7 +577,8 @@ def _export(net, file_name, file_format, *inputs):
|
|
|
|
|
phase_name = 'export.mindir'
|
|
|
|
|
graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False)
|
|
|
|
|
onnx_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir')
|
|
|
|
|
file_name += ".mindir"
|
|
|
|
|
if not file_name.endswith('.mindir'):
|
|
|
|
|
file_name += ".mindir"
|
|
|
|
|
with open(file_name, 'wb') as f:
|
|
|
|
|
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
|
|
|
|
|
f.write(onnx_stream)
|
|
|
|
|