|
|
|
@ -31,7 +31,8 @@ from mindspore.common.api import _executor
|
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
|
from mindspore._checkparam import check_input_data
|
|
|
|
|
|
|
|
|
|
__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "parse_print"]
|
|
|
|
|
__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "parse_print",
|
|
|
|
|
"build_searched_strategy", "merge_sliced_parameter"]
|
|
|
|
|
|
|
|
|
|
tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16,
|
|
|
|
|
"Int32": mstype.int32, "Uint32": mstype.uint32, "Int64": mstype.int64, "Uint64": mstype.uint64,
|
|
|
|
@ -575,19 +576,19 @@ def parse_print(print_file_name):
|
|
|
|
|
|
|
|
|
|
def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
|
|
|
|
|
"""
|
|
|
|
|
Merge data slices to one tensor with whole data when strategy is not None.
|
|
|
|
|
Merge data slices to one tensor with whole data when strategy is not None.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
sliced_data (list[numpy.ndarray]): data slices in order of rank_id.
|
|
|
|
|
parameter_name (str): name of parameter.
|
|
|
|
|
strategy (dict): parameter slice strategy.
|
|
|
|
|
is_even (bool): slice manner that True represents slicing evenly and False represents slicing unevenly.
|
|
|
|
|
Args:
|
|
|
|
|
sliced_data (list[numpy.ndarray]): Data slices in order of rank_id.
|
|
|
|
|
parameter_name (str): Name of parameter.
|
|
|
|
|
strategy (dict): Parameter slice strategy.
|
|
|
|
|
is_even (bool): Slice manner that True represents slicing evenly and False represents slicing unevenly.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor, the merged Tensor which has the whole data.
|
|
|
|
|
Returns:
|
|
|
|
|
Tensor, the merged Tensor which has the whole data.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: failed to merge.
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: Failed to merge.
|
|
|
|
|
"""
|
|
|
|
|
layout = strategy.get(parameter_name)
|
|
|
|
|
try:
|
|
|
|
@ -661,17 +662,17 @@ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even):
|
|
|
|
|
|
|
|
|
|
def build_searched_strategy(strategy_filename):
|
|
|
|
|
"""
|
|
|
|
|
build strategy of every parameter in network.
|
|
|
|
|
Build strategy of every parameter in network.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
strategy_filename (str): name of strategy file.
|
|
|
|
|
strategy_filename (str): Name of strategy file.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Dictionary, whose key is parameter name and value is slice strategy of this parameter.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: strategy file is incorrect.
|
|
|
|
|
TypeError: strategy_filename is not str.
|
|
|
|
|
ValueError: Strategy file is incorrect.
|
|
|
|
|
TypeError: Strategy_filename is not str.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> strategy_filename = "./strategy_train.ckpt"
|
|
|
|
@ -707,32 +708,32 @@ def build_searched_strategy(strategy_filename):
|
|
|
|
|
|
|
|
|
|
def merge_sliced_parameter(sliced_parameters, strategy=None):
|
|
|
|
|
"""
|
|
|
|
|
Merge parameter slices to one whole parameter.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
sliced_parameters (list[Parameter]): parameter slices in order of rank_id.
|
|
|
|
|
strategy (dict): parameter slice strategy. Default: None.
|
|
|
|
|
|
|
|
|
|
If strategy is None, just merge parameter slices in 0 axis order.
|
|
|
|
|
- key (str): parameter name.
|
|
|
|
|
- value (<class 'node_strategy_pb2.ParallelLayouts'>): slice strategy of this parameter.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Parameter, the merged parameter which has the whole data.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: failed to merge.
|
|
|
|
|
TypeError: the sliced_parameters is incorrect or strategy is not dict.
|
|
|
|
|
KeyError: the parameter name is not in keys of strategy.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> strategy = build_searched_strategy("./strategy_train.ckpt")
|
|
|
|
|
>>> sliced_parameters = [\
|
|
|
|
|
Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])), "network.embedding_table"), \
|
|
|
|
|
Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])), "network.embedding_table"), \
|
|
|
|
|
Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])), "network.embedding_tabel"), \
|
|
|
|
|
Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])), "network.embedding_table")]
|
|
|
|
|
>>> merged_parameter = merge_sliced_parameter(sliced_parameters, strategy)
|
|
|
|
|
Merge parameter slices to one whole parameter.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
sliced_parameters (list[Parameter]): Parameter slices in order of rank_id.
|
|
|
|
|
strategy (dict): Parameter slice strategy. Default: None.
|
|
|
|
|
If strategy is None, just merge parameter slices in 0 axis order.
|
|
|
|
|
|
|
|
|
|
- key (str): Parameter name.
|
|
|
|
|
- value (<class 'node_strategy_pb2.ParallelLayouts'>): Slice strategy of this parameter.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Parameter, the merged parameter which has the whole data.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: Failed to merge.
|
|
|
|
|
TypeError: The sliced_parameters is incorrect or strategy is not dict.
|
|
|
|
|
KeyError: The parameter name is not in keys of strategy.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> strategy = build_searched_strategy("./strategy_train.ckpt")
|
|
|
|
|
>>> sliced_parameters = [\
|
|
|
|
|
Parameter(Tensor(np.array([0.00023915, 0.00013939, -0.00098059])), "network.embedding_table"), \
|
|
|
|
|
Parameter(Tensor(np.array([0.00015815, 0.00015458, -0.00012125])), "network.embedding_table"), \
|
|
|
|
|
Parameter(Tensor(np.array([0.00042165, 0.00029692, -0.00007941])), "network.embedding_tabel"), \
|
|
|
|
|
Parameter(Tensor(np.array([0.00084451, 0.00089960, -0.00010431])), "network.embedding_table")]
|
|
|
|
|
>>> merged_parameter = merge_sliced_parameter(sliced_parameters, strategy)
|
|
|
|
|
"""
|
|
|
|
|
if not isinstance(sliced_parameters, list):
|
|
|
|
|
raise TypeError(f"The sliced_parameters should be list, but got {type(sliced_parameters)}.")
|
|
|
|
|