Merge pull request !4851 from caozhou/modified_indent
pull/4851/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 61eff67ba5

@ -31,7 +31,8 @@ from mindspore.common.api import _executor
from mindspore.common import dtype as mstype
from mindspore._checkparam import check_input_data
__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "parse_print"]
__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "parse_print",
"build_searched_strategy", "merge_sliced_parameter"]
tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16,
"Int32": mstype.int32, "Uint32": mstype.uint32, "Int64": mstype.int64, "Uint64": mstype.uint64,
@ -575,19 +576,19 @@ def parse_print(print_file_name):
def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
"""
Merge data slices to one tensor with whole data when strategy is not None.
Merge data slices to one tensor with whole data when strategy is not None.
Args:
sliced_data (list[numpy.ndarray]): data slices in order of rank_id.
parameter_name (str): name of parameter.
strategy (dict): parameter slice strategy.
is_even (bool): slice manner that True represents slicing evenly and False represents slicing unevenly.
Args:
sliced_data (list[numpy.ndarray]): Data slices in order of rank_id.
parameter_name (str): Name of parameter.
strategy (dict): Parameter slice strategy.
is_even (bool): Slice manner that True represents slicing evenly and False represents slicing unevenly.
Returns:
Tensor, the merged Tensor which has the whole data.
Returns:
Tensor, the merged Tensor which has the whole data.
Raises:
ValueError: failed to merge.
Raises:
ValueError: Failed to merge.
"""
layout = strategy.get(parameter_name)
try:
@ -661,17 +662,17 @@ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
def build_searched_strategy(strategy_filename):
"""
build strategy of every parameter in network.
Build strategy of every parameter in network.
Args:
strategy_filename (str): name of strategy file.
strategy_filename (str): Name of strategy file.
Returns:
Dictionary, whose key is parameter name and value is slice strategy of this parameter.
Raises:
ValueError: strategy file is incorrect.
TypeError: strategy_filename is not str.
ValueError: Strategy file is incorrect.
TypeError: Strategy_filename is not str.
Examples:
>>> strategy_filename = "./strategy_train.ckpt"
@ -707,32 +708,32 @@ def build_searched_strategy(strategy_filename):
def merge_sliced_parameter(sliced_parameters, strategy=None):
"""
Merge parameter slices to one whole parameter.
Args:
sliced_parameters (list[Parameter]): parameter slices in order of rank_id.
strategy (dict): parameter slice strategy. Default: None.
If strategy is None, just merge parameter slices in 0 axis order.
- key (str): parameter name.
- value (<class 'node_strategy_pb2.ParallelLayouts'>): slice strategy of this parameter.
Returns:
Parameter, the merged parameter which has the whole data.
Raises:
ValueError: failed to merge.
TypeError: the sliced_parameters is incorrect or strategy is not dict.
KeyError: the parameter name is not in keys of strategy.
Examples:
>>> strategy = build_searched_strategy("./strategy_train.ckpt")
>>> sliced_parameters = [\
Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])), "network.embedding_table"), \
Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])), "network.embedding_table"), \
Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])), "network.embedding_tabel"), \
Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])), "network.embedding_table")]
>>> merged_parameter = merge_sliced_parameter(sliced_parameters, strategy)
Merge parameter slices to one whole parameter.
Args:
sliced_parameters (list[Parameter]): Parameter slices in order of rank_id.
strategy (dict): Parameter slice strategy. Default: None.
If strategy is None, just merge parameter slices in 0 axis order.
- key (str): Parameter name.
- value (<class 'node_strategy_pb2.ParallelLayouts'>): Slice strategy of this parameter.
Returns:
Parameter, the merged parameter which has the whole data.
Raises:
ValueError: Failed to merge.
TypeError: The sliced_parameters is incorrect or strategy is not dict.
KeyError: The parameter name is not in keys of strategy.
Examples:
>>> strategy = build_searched_strategy("./strategy_train.ckpt")
>>> sliced_parameters = [\
Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])), "network.embedding_table"), \
Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])), "network.embedding_table"), \
Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])), "network.embedding_tabel"), \
Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])), "network.embedding_table")]
>>> merged_parameter = merge_sliced_parameter(sliced_parameters, strategy)
"""
if not isinstance(sliced_parameters, list):
raise TypeError(f"The sliced_parameters should be list, but got {type(sliced_parameters)}.")

Loading…
Cancel
Save