diff --git a/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc b/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc index cb487a5fe6..0a0b220093 100644 --- a/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc +++ b/mindspore/ccsrc/transform/express_ir/mindir_exporter.cc @@ -192,7 +192,6 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, mind_ir::G mind_ir::TensorProto *parameter_proto = graph_proto->add_parameter(); parameter_proto->set_name(param_name); SetParamToTensorProto(param, parameter_proto); - auto tensor = std::dynamic_pointer_cast(param->default_param()); } else { mind_ir::ValueInfoProto *input_proto = graph_proto->add_input(); input_proto->set_name(param_name); diff --git a/mindspore/train/dataset_helper.py b/mindspore/train/dataset_helper.py index 01f1ff1c3c..534ae2176e 100644 --- a/mindspore/train/dataset_helper.py +++ b/mindspore/train/dataset_helper.py @@ -183,7 +183,7 @@ class DatasetHelper: >>> train_dataset = create_custom_dataset() >>> set_helper = DatasetHelper(train_dataset, dataset_sink_mode=False) >>> for next_element in set_helper: - >>> print(next_element) + ... print(next_element) """ def __init__(self, dataset, dataset_sink_mode=True, sink_size=-1, epoch_num=1): diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 52b59ed2fc..45936ca51c 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -889,7 +889,6 @@ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even): raise ValueError(f"The sliced_parameters length should be equal to device_count. " f"the sliced_parameters length is {len(sliced_data)} but device_count is {device_count}.") - merged_tensor = None if not param_split_shape: if not is_even: raise ValueError("The shape of every parameter in sliced_parameters should be the same " @@ -1052,7 +1051,6 @@ def merge_sliced_parameter(sliced_parameters, strategy=None): layerwise_parallel = sliced_parameters[0].layerwise_parallel requires_grad = sliced_parameters[0].requires_grad sliced_data = [parameter.data.asnumpy() for parameter in sliced_parameters] - merged_parameter = None if not strategy: merged_tensor = Tensor(np.concatenate(sliced_data)) @@ -1121,7 +1119,7 @@ def load_distributed_checkpoint(network, checkpoint_filenames, predict_strategy= param_rank = rank_list[param.name][0] skip_merge_split = rank_list[param.name][1] for rank in param_rank: - sliced_param = _load_single_param(checkpoint_filenames[rank], param.name) + sliced_param = load_checkpoint(checkpoint_filenames[rank])[param.name] sliced_params.append(sliced_param) if skip_merge_split: split_param = sliced_params[0] @@ -1213,59 +1211,3 @@ def _merge_and_split(sliced_params, train_strategy, predict_strategy): layerwise_parallel = merged_param.layerwise_parallel split_param = Parameter(split_tensor, param_name, requires_grad, layerwise_parallel) return split_param - - -def _load_single_param(ckpt_file_name, param_name): - """Load a parameter from checkpoint.""" - checkpoint_list = Checkpoint() - - try: - with open(ckpt_file_name, "rb") as f: - pb_content = f.read() - checkpoint_list.ParseFromString(pb_content) - except BaseException as e: - logger.error("Failed to read the checkpoint file `%s` during load single parameter," - " please check the correct of the file.", ckpt_file_name) - raise ValueError(e.__str__()) - - parameter = None - try: - param_data_list = [] - for element_id, element in enumerate(checkpoint_list.value): - if element.tag != param_name: - continue - data = element.tensor.tensor_content - data_type = element.tensor.tensor_type - np_type = tensor_to_np_type[data_type] - ms_type = tensor_to_ms_type[data_type] - element_data = np.frombuffer(data, np_type) - param_data_list.append(element_data) - if (element_id == len(checkpoint_list.value) - 1) or \ - (element.tag != checkpoint_list.value[element_id + 1].tag): - param_data = np.concatenate((param_data_list), axis=0) - param_data_list.clear() - dims = element.tensor.dims - if dims == [0]: - if 'Float' in data_type: - param_data = float(param_data[0]) - elif 'Int' in data_type: - param_data = int(param_data[0]) - parameter = Parameter(Tensor(param_data, ms_type), name=element.tag) - elif dims == [1]: - parameter = Parameter(Tensor(param_data, ms_type), name=element.tag) - else: - param_dim = [] - for dim in dims: - param_dim.append(dim) - param_value = param_data.reshape(param_dim) - parameter = Parameter(Tensor(param_value, ms_type), name=element.tag) - break - - except BaseException as e: - logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name) - raise RuntimeError(e.__str__()) - - if parameter is None: - raise ValueError(f"There is no parameter named {param_name} in this checkpoint file {ckpt_file_name}, " - f"please check parameter name or checkpoint file.") - return parameter