Merge pull request !4851 from caozhou/modified_indent
pull/4851/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 61eff67ba5

@ -31,7 +31,8 @@ from mindspore.common.api import _executor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore._checkparam import check_input_data from mindspore._checkparam import check_input_data
__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "parse_print"] __all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "parse_print",
"build_searched_strategy", "merge_sliced_parameter"]
tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16, tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16,
"Int32": mstype.int32, "Uint32": mstype.uint32, "Int64": mstype.int64, "Uint64": mstype.uint64, "Int32": mstype.int32, "Uint32": mstype.uint32, "Int64": mstype.int64, "Uint64": mstype.uint64,
@ -578,16 +579,16 @@ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
Merge data slices to one tensor with whole data when strategy is not None. Merge data slices to one tensor with whole data when strategy is not None.
Args: Args:
sliced_data (list[numpy.ndarray]): data slices in order of rank_id. sliced_data (list[numpy.ndarray]): Data slices in order of rank_id.
parameter_name (str): name of parameter. parameter_name (str): Name of parameter.
strategy (dict): parameter slice strategy. strategy (dict): Parameter slice strategy.
is_even (bool): slice manner that True represents slicing evenly and False represents slicing unevenly. is_even (bool): Slice manner that True represents slicing evenly and False represents slicing unevenly.
Returns: Returns:
Tensor, the merged Tensor which has the whole data. Tensor, the merged Tensor which has the whole data.
Raises: Raises:
ValueError: failed to merge. ValueError: Failed to merge.
""" """
layout = strategy.get(parameter_name) layout = strategy.get(parameter_name)
try: try:
@ -661,17 +662,17 @@ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
def build_searched_strategy(strategy_filename): def build_searched_strategy(strategy_filename):
""" """
build strategy of every parameter in network. Build strategy of every parameter in network.
Args: Args:
strategy_filename (str): name of strategy file. strategy_filename (str): Name of strategy file.
Returns: Returns:
Dictionary, whose key is parameter name and value is slice strategy of this parameter. Dictionary, whose key is parameter name and value is slice strategy of this parameter.
Raises: Raises:
ValueError: strategy file is incorrect. ValueError: Strategy file is incorrect.
TypeError: strategy_filename is not str. TypeError: Strategy_filename is not str.
Examples: Examples:
>>> strategy_filename = "./strategy_train.ckpt" >>> strategy_filename = "./strategy_train.ckpt"
@ -710,20 +711,20 @@ def merge_sliced_parameter(sliced_parameters, strategy=None):
Merge parameter slices to one whole parameter. Merge parameter slices to one whole parameter.
Args: Args:
sliced_parameters (list[Parameter]): parameter slices in order of rank_id. sliced_parameters (list[Parameter]): Parameter slices in order of rank_id.
strategy (dict): parameter slice strategy. Default: None. strategy (dict): Parameter slice strategy. Default: None.
If strategy is None, just merge parameter slices in 0 axis order. If strategy is None, just merge parameter slices in 0 axis order.
- key (str): parameter name.
- value (<class 'node_strategy_pb2.ParallelLayouts'>): slice strategy of this parameter. - key (str): Parameter name.
- value (<class 'node_strategy_pb2.ParallelLayouts'>): Slice strategy of this parameter.
Returns: Returns:
Parameter, the merged parameter which has the whole data. Parameter, the merged parameter which has the whole data.
Raises: Raises:
ValueError: failed to merge. ValueError: Failed to merge.
TypeError: the sliced_parameters is incorrect or strategy is not dict. TypeError: The sliced_parameters is incorrect or strategy is not dict.
KeyError: the parameter name is not in keys of strategy. KeyError: The parameter name is not in keys of strategy.
Examples: Examples:
>>> strategy = build_searched_strategy("./strategy_train.ckpt") >>> strategy = build_searched_strategy("./strategy_train.ckpt")

Loading…
Cancel
Save