|
|
@ -168,21 +168,21 @@ def _chunk_tensor_by_strategy(np_tensor, strategy):
|
|
|
|
raise ValueError("The length of np_tensor does not match the length of strategy!")
|
|
|
|
raise ValueError("The length of np_tensor does not match the length of strategy!")
|
|
|
|
return _chunk_tensor(np_tensor, strategy, len(strategy))
|
|
|
|
return _chunk_tensor(np_tensor, strategy, len(strategy))
|
|
|
|
|
|
|
|
|
|
|
|
def _get_seed(dev_mat, tensor_map):
|
|
|
|
def _get_slice_index(dev_mat, tensor_map):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Get the random seed for current slice.
|
|
|
|
Get the slice index for current slice.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
dev_mat (list): The device matrix of devices.
|
|
|
|
dev_mat (list): The device matrix of devices.
|
|
|
|
tensor_map (list): The split strategy of tensor.
|
|
|
|
tensor_map (list): The split strategy of tensor.
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
|
Integer, the local random seed for this device.
|
|
|
|
Integer, the slice index for slice on this device.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
rank = get_rank()
|
|
|
|
rank = get_rank()
|
|
|
|
tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
|
|
|
|
tensor_strategy = _get_tensor_strategy(dev_mat, tensor_map)
|
|
|
|
tensor_slice_seed = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank)
|
|
|
|
tensor_slice_index = _get_tensor_slice_index(dev_mat, tensor_strategy, tensor_map, rank)
|
|
|
|
return tensor_slice_seed
|
|
|
|
return tensor_slice_index
|
|
|
|
|
|
|
|
|
|
|
|
def _load_tensor(tensor, dev_mat, tensor_map):
|
|
|
|
def _load_tensor(tensor, dev_mat, tensor_map):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|