!9577 support distributed predict

From: @gong_zi_yan
Reviewed-by: @caozhou_huawei,@yao_yf,@stsuteng,@zh_qh
Signed-off-by: @stsuteng
pull/9577/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit ec3983b77d

@ -15,7 +15,8 @@
"""Utils of auto parallel"""
import numpy as np
from mindspore import log as logger
from mindspore import context, log as logger
from mindspore.context import ParallelMode
from mindspore._c_expression import reset_op_id
from mindspore.common.tensor import Tensor
from mindspore.common.dtype import dtype_to_nptype
@ -193,3 +194,70 @@ def _get_python_op(op_name, op_path, instance_name, arglist):
def _reset_op_id():
"""Reset op id."""
reset_op_id()
def _parallel_predict_check():
"""validate parallel model prediction"""
if _get_parallel_mode() in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
if not context.get_auto_parallel_context("full_batch"):
raise RuntimeError('Model prediction only supports full batch dataset. Please set "full_batch" with True.')
if context.get_auto_parallel_context("enable_parallel_optimizer"):
raise RuntimeError('Model prediction does not support parallel optimizer. Please set'
'"enable_parallel_optimizer" with False.')
def _check_similar_layout(tensor_layout1, tensor_layout2):
"""check if two tensor layouts are same"""
if tensor_layout1[1] != tensor_layout2[1]:
return False
for i in tensor_layout1[1]:
if i == -1:
continue
if tensor_layout1[0][-1-i] != tensor_layout2[0][-1-i]:
return False
return True
def _remove_repeated_slices(tensor_layout):
"""generate unrepeated tensor layout"""
import copy
new_tensor_layout = copy.deepcopy(tensor_layout)
dev_mat = tensor_layout[0][:]
tensor_map = tensor_layout[1]
for dim in range(len(dev_mat)):
if dim not in tensor_map:
dev_mat[-1-dim] = 1
new_tensor_layout[0] = dev_mat
return new_tensor_layout
def _infer_rank_list(train_map, predict_map=None):
"""infer checkpoint slices to be loaded"""
ret = {}
for param_name in train_map:
train_layout = train_map[param_name]
new_train_layout = _remove_repeated_slices(train_layout)
train_dev_mat = train_layout[0]
dev_num = np.array(train_dev_mat).prod()
array = np.arange(dev_num).reshape(train_dev_mat)
index = ()
for i in new_train_layout[0]:
if i == 1:
index = index + (0,)
else:
index = index + (slice(None),)
rank_list = array[index].flatten()
if not predict_map:
ret[param_name] = rank_list
continue
if param_name not in predict_map:
logger.warning("predict_map does not contain %s", param_name)
continue
predict_layout = predict_map[param_name]
# optimization pass
if _check_similar_layout(train_layout, predict_layout):
dev_rank = _get_global_rank()
ret[param_name] = [rank_list[dev_rank]]
else:
ret[param_name] = rank_list
return ret

@ -26,7 +26,7 @@ from .._checkparam import check_input_data, check_output_data, Validator
from .callback import _InternalCallbackParam, RunContext, _CallbackManager, Callback
from .. import context
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _parallel_predict_check
from ..parallel._ps_context import _is_role_pserver, _is_role_sched
from ..nn.metrics import Loss
from .. import nn
@ -736,10 +736,46 @@ class Model:
"""
self._predict_network.set_train(False)
check_input_data(*predict_data, data_class=Tensor)
_parallel_predict_check()
result = self._predict_network(*predict_data)
check_output_data(result)
return result
def infer_predict_layout(self, *predict_data):
"""
Generate parameter layout for the predict network in auto or semi auto parallel mode.
Data could be a single tensor, a list of tensor, or a tuple of tensor.
Note:
Batch data should be put together in one tensor.
Args:
predict_data (Tensor): Tensor of predict data. can be array, list or tuple.
Returns:
parameter_layout_dict (dict): Parameter layout dictionary used for load distributed checkpoint
Examples:
>>> input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), mindspore.float32)
>>> model = Model(Net())
>>> model.infer_predict_layout(input_data)
"""
if context.get_context("mode") != context.GRAPH_MODE:
raise RuntimeError('infer predict layout only supports GRAPH MODE currently.')
# remove this restriction after support inferring repeated strategy
if _get_parallel_mode() not in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
raise RuntimeError('infer predict layout only supports semi auto parallel and auto parallel mode.')
_parallel_predict_check()
check_input_data(*predict_data, data_class=Tensor)
predict_net = self._predict_network
# Unlike the cases in build_train_network() and build_eval_network(), 'multi_subgraphs' is not set
predict_net.set_auto_parallel()
predict_net.set_train(False)
predict_net.compile(*predict_data)
return predict_net.parameter_layout_dict
__all__ = ["Model"]

@ -26,7 +26,7 @@ init()
def test_train_32k_8p(batch_size=32, num_classes=32768):
dev_num = 8
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num)
context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, device_num=dev_num, full_batch=True)
set_algo_parameters(elementwise_op_strategy_follow=True)
np.random.seed(6)
input_np = Tensor(np.ones([batch_size, 3, 224, 224]).astype(np.float32))

@ -0,0 +1,73 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test distribute predict """
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor, Model
from mindspore.ops import operations as P
from mindspore import context
class Net(nn.Cell):
"""Net definition"""
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Dense(128, 768, activation='relu')
self.fc2 = nn.Dense(128, 768, activation='relu')
self.fc3 = nn.Dense(128, 768, activation='relu')
self.fc4 = nn.Dense(768, 768, activation='relu')
self.relu4 = nn.ReLU()
self.relu5 = nn.ReLU()
self.transpose = P.Transpose()
self.matmul1 = P.MatMul()
self.matmul2 = P.MatMul()
def construct(self, x):
q = self.fc1(x)
k = self.fc2(x)
v = self.fc3(x)
k = self.transpose(k, (1, 0))
c = self.relu4(self.matmul1(q, k))
s = self.relu5(self.matmul2(c, v))
s = self.fc4(s)
return s
def test_distribute_predict():
context.set_context(mode=context.GRAPH_MODE)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, full_batch=True)
inputs = Tensor(np.ones([32, 128]).astype(np.float32))
net = Net()
model = Model(net)
predict_map = model.infer_predict_layout(inputs)
output = model.predict(inputs)
context.reset_auto_parallel_context()
return predict_map, output
def test_edge_case():
context.set_context(mode=context.GRAPH_MODE)
inputs = Tensor(np.ones([32, 48]).astype(np.float32))
net = Net()
model = Model(net)
with pytest.raises(RuntimeError):
model.infer_predict_layout(inputs)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
with pytest.raises(RuntimeError):
model.infer_predict_layout(inputs)
context.set_auto_parallel_context(full_batch=True, enable_parallel_optimizer=True)
with pytest.raises(RuntimeError):
model.predict(inputs)
Loading…
Cancel
Save