diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 91f976cb0b..070ce68545 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -31,7 +31,8 @@ from mindspore.common.api import _executor from mindspore.common import dtype as mstype from mindspore._checkparam import check_input_data -__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "parse_print"] +__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "parse_print", + "build_searched_strategy", "merge_sliced_parameter"] tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16, "Int32": mstype.int32, "Uint32": mstype.uint32, "Int64": mstype.int64, "Uint64": mstype.uint64, @@ -575,19 +576,19 @@ def parse_print(print_file_name): def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even): """ - Merge data slices to one tensor with whole data when strategy is not None. + Merge data slices to one tensor with whole data when strategy is not None. - Args: - sliced_data (list[numpy.ndarray]): data slices in order of rank_id. - parameter_name (str): name of parameter. - strategy (dict): parameter slice strategy. - is_even (bool): slice manner that True represents slicing evenly and False represents slicing unevenly. + Args: + sliced_data (list[numpy.ndarray]): Data slices in order of rank_id. + parameter_name (str): Name of parameter. + strategy (dict): Parameter slice strategy. + is_even (bool): Slice manner that True represents slicing evenly and False represents slicing unevenly. - Returns: - Tensor, the merged Tensor which has the whole data. + Returns: + Tensor, the merged Tensor which has the whole data. - Raises: - ValueError: failed to merge. + Raises: + ValueError: Failed to merge. """ layout = strategy.get(parameter_name) try: @@ -661,17 +662,17 @@ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even): def build_searched_strategy(strategy_filename): """ - build strategy of every parameter in network. + Build strategy of every parameter in network. Args: - strategy_filename (str): name of strategy file. + strategy_filename (str): Name of strategy file. Returns: Dictionary, whose key is parameter name and value is slice strategy of this parameter. Raises: - ValueError: strategy file is incorrect. - TypeError: strategy_filename is not str. + ValueError: Strategy file is incorrect. + TypeError: Strategy_filename is not str. Examples: >>> strategy_filename = "./strategy_train.ckpt" @@ -707,32 +708,32 @@ def build_searched_strategy(strategy_filename): def merge_sliced_parameter(sliced_parameters, strategy=None): """ - Merge parameter slices to one whole parameter. - - Args: - sliced_parameters (list[Parameter]): parameter slices in order of rank_id. - strategy (dict): parameter slice strategy. Default: None. - - If strategy is None, just merge parameter slices in 0 axis order. - - key (str): parameter name. - - value (): slice strategy of this parameter. - - Returns: - Parameter, the merged parameter which has the whole data. - - Raises: - ValueError: failed to merge. - TypeError: the sliced_parameters is incorrect or strategy is not dict. - KeyError: the parameter name is not in keys of strategy. - - Examples: - >>> strategy = build_searched_strategy("./strategy_train.ckpt") - >>> sliced_parameters = [\ - Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])), "network.embedding_table"), \ - Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])), "network.embedding_table"), \ - Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])), "network.embedding_tabel"), \ - Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])), "network.embedding_table")] - >>> merged_parameter = merge_sliced_parameter(sliced_parameters, strategy) + Merge parameter slices to one whole parameter. + + Args: + sliced_parameters (list[Parameter]): Parameter slices in order of rank_id. + strategy (dict): Parameter slice strategy. Default: None. + If strategy is None, just merge parameter slices in 0 axis order. + + - key (str): Parameter name. + - value (): Slice strategy of this parameter. + + Returns: + Parameter, the merged parameter which has the whole data. + + Raises: + ValueError: Failed to merge. + TypeError: The sliced_parameters is incorrect or strategy is not dict. + KeyError: The parameter name is not in keys of strategy. + + Examples: + >>> strategy = build_searched_strategy("./strategy_train.ckpt") + >>> sliced_parameters = [\ + Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])), "network.embedding_table"), \ + Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])), "network.embedding_table"), \ + Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])), "network.embedding_tabel"), \ + Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])), "network.embedding_table")] + >>> merged_parameter = merge_sliced_parameter(sliced_parameters, strategy) """ if not isinstance(sliced_parameters, list): raise TypeError(f"The sliced_parameters should be list, but got {type(sliced_parameters)}.")