add parallel mode for cell

pull/958/head
yangzhenzhang 5 years ago
parent 7c64048d76
commit 8c9730b3c5

@ -35,8 +35,8 @@ bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &opti
// assume no change to graph
bool changes = false;
// control whether use model_parallel mode
if (((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) || (!enable_all_reduce_fusion) ||
(root->has_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY))) {
if (!root->has_flag(AUTO_PARALLEL) || ((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) ||
(!enable_all_reduce_fusion) || (root->has_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY))) {
return changes;
}
#if defined(_WIN32) || defined(_WIN64)

@ -121,7 +121,8 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
// assume no change to graph
bool changes = false;
// control whether use model_parallel mode
if ((parallel_mode != AUTO_PARALLEL) || root->flags()[AUTO_PARALLEL_RUN_ONCE_ONLY]) {
if (!root->has_flag(AUTO_PARALLEL) || (parallel_mode != AUTO_PARALLEL) ||
root->has_flag(AUTO_PARALLEL_RUN_ONCE_ONLY)) {
return changes;
}
// check whether strategy_search_mode is valid

@ -2220,7 +2220,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
// assume no change to graph
bool changes = false;
// control whether use model_parallel mode
if (((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) ||
if (!root->has_flag(AUTO_PARALLEL) || ((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) ||
(root->has_flag(SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY))) {
return changes;
}

@ -281,7 +281,7 @@ void ExecutorPy::SaveCompiledGraph(const std::string &phase_s) {
MS_LOG(INFO) << "Save compiled func graph(" << func_graph->ToString() << ") phase(" << phase_s << ")!";
info_[phase_s]->func_graph = func_graph;
if ((func_graph != nullptr) &&
if ((func_graph != nullptr) && func_graph->has_flag(parallel::AUTO_PARALLEL) &&
((parallel_mode == parallel::AUTO_PARALLEL) || (parallel_mode == parallel::SEMI_AUTO_PARALLEL))) {
MS_LOG(DEBUG) << "Save model parallel parameter layout graph!";
func_graph = info_[phase_s]->resource->results()[kStepParallelGraph].cast<FuncGraphPtr>();

@ -20,7 +20,6 @@ from collections import OrderedDict
from functools import wraps
from mindspore import context
from mindspore import log as logger
from mindspore.parallel._utils import _get_parallel_mode
from .._c_expression import generate_key, Executor_, Tensor, MetaTensor
from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_backend
from .tensor import Tensor as MsTensor
@ -327,7 +326,7 @@ class _Executor:
raise TypeError('Parameters need OrderedDict type, but got {}'.
format(type(params)))
def compile(self, obj, *args, phase='predict', params=None, do_convert=True):
def compile(self, obj, *args, phase='predict', params=None, do_convert=True, auto_parallel_mode=False):
"""
Compiles graph.
@ -337,6 +336,7 @@ class _Executor:
phase (str): The name of compile phase. Default: 'predict'.
params (OrderedDict): The parameters dictionary used for init data graph. Default: None.
do_convert (bool): When set to True, convert ME graph to GE graph after compiling graph.
auto_parallel_mode: When set to True, use auto parallel mode to compile graph.
Return:
Str, the full phase of the cell.
@ -370,8 +370,9 @@ class _Executor:
logger.error("%r graph compile failed.", phase)
if not do_convert:
return phase, True
if not enable_debug_runtime or enable_ge:
if _get_parallel_mode() in ["auto_parallel", "semi_auto_parallel"]:
if auto_parallel_mode:
obj.parameter_layout_dict = self._executor.get_parameter_layout(phase)
obj.load_parameter_slice(params)

@ -25,7 +25,6 @@ from ..common.parameter import Parameter, ParameterTuple
from .._c_expression import init_backend
from ..ops.primitive import Primitive
from ..parallel._tensor import _load_tensor_by_layout
from ..parallel._utils import _get_parallel_mode
from ..common.tensor import Tensor
@ -71,8 +70,7 @@ class Cell:
gc.collect()
self._construct_inputs_num = 0
self._construct_inputs_names = []
if _get_parallel_mode() in ["auto_parallel", "semi_auto_parallel"]:
self._get_construct_inputs_number_and_name()
self._auto_parallel_mode = False
self._parallel_inputs_run = None
if flags:
self.add_flags(**flags)
@ -298,9 +296,10 @@ class Cell:
Returns:
Object, the result of executing.
"""
_, compile_flag = _executor.compile(self, *inputs, phase=self.phase)
_, compile_flag = _executor.compile(self, *inputs, phase=self.phase,
auto_parallel_mode=self._auto_parallel_mode)
if _get_parallel_mode() in ["auto_parallel", "semi_auto_parallel"]:
if self._auto_parallel_mode:
if inputs and isinstance(inputs[0], Tensor) and inputs[0].virtual_flag and (not compile_flag):
parallel_inputs_run = self._parallel_inputs_run
else:
@ -665,3 +664,15 @@ class Cell:
"""
self.add_flags_recursive(broadcast_flag=mode)
return self
def set_auto_parallel(self):
"""
Set the cell to auto parallel mode.
Note:
If a cell needs to use auto parallel or semi auto parallel mode for training, evaluation or prediction,
this interface needs to be called for the cell.
"""
self._auto_parallel_mode = True
self.add_flags(auto_parallel=True)
self._get_construct_inputs_number_and_name()

@ -16,8 +16,7 @@
from mindspore._c_expression import reset_op_id
from mindspore.communication.management import get_group_size, get_rank
from mindspore.parallel._auto_parallel_context import auto_parallel_context, _set_auto_parallel_context,\
_reset_auto_parallel_context
from mindspore.parallel._auto_parallel_context import auto_parallel_context
def _get_parallel_mode():
@ -108,102 +107,6 @@ def _parameter_broadcast_check(parallel_mode, parameter_broadcast):
.format(parallel_mode, parameter_broadcast))
_parallel_mode = None
_device_num = None
_global_rank = None
_parameter_broadcast = None
_mirror_mean = None
_cast_before_mirror = None
_loss_repeated_mean = None
_communication_backend = None
_has_checkpointed = False
_enable_all_reduce_fusion = None
def _checkpoint_auto_parallel_context():
"""checkpoint auto parallel context"""
global _has_checkpointed
if _has_checkpointed is True:
return
global _parallel_mode
global _device_num
global _global_rank
global _parameter_broadcast
global _mirror_mean
global _cast_before_mirror
global _loss_repeated_mean
global _communication_backend
global _enable_all_reduce_fusion
_parallel_mode = auto_parallel_context().get_parallel_mode()
_device_num = _get_device_num()
_global_rank = _get_global_rank()
_parameter_broadcast = auto_parallel_context().get_parameter_broadcast()
_mirror_mean = auto_parallel_context().get_mirror_mean()
_cast_before_mirror = auto_parallel_context().get_cast_before_mirror()
_loss_repeated_mean = auto_parallel_context().get_loss_repeated_mean()
_communication_backend = auto_parallel_context().get_communication_backend()
_enable_all_reduce_fusion = auto_parallel_context().get_enable_all_reduce_fusion()
_has_checkpointed = True
def _restore_auto_parallel_context():
"""restore auto parallel context"""
global _parallel_mode
global _device_num
global _global_rank
global _parameter_broadcast
global _mirror_mean
global _cast_before_mirror
global _loss_repeated_mean
global _communication_backend
global _enable_all_reduce_fusion
_set_auto_parallel_context(parallel_mode=_parallel_mode, device_num=_device_num, global_rank=_global_rank,
parameter_broadcast=_parameter_broadcast, mirror_mean=_mirror_mean,
cast_before_mirror=_cast_before_mirror, loss_repeated_mean=_loss_repeated_mean)
auto_parallel_context().set_communication_backend(_communication_backend)
auto_parallel_context().set_enable_all_reduce_fusion(_enable_all_reduce_fusion)
def _reset_checkpoint_auto_parallel_context():
"""reset the _has_checkpointed"""
global _has_checkpointed
_has_checkpointed = False
def _callback_wrapper(list_callback, run_context, callback_type):
"""
reset the context for callback of model train
Raises:
ValueError: If the type keyword is not recognized
"""
_callback_func_map = {
"begin": list_callback.begin,
"epoch_begin": list_callback.epoch_begin,
"step_begin": list_callback.step_begin,
"step_end": list_callback.step_end,
"epoch_end": list_callback.epoch_end,
"end": list_callback.end}
if callback_type not in _callback_func_map:
raise ValueError("Get type keyword %s is not recognized!" % callback_type)
func = _callback_func_map[callback_type]
if callback_type == "begin":
_reset_checkpoint_auto_parallel_context()
_checkpoint_auto_parallel_context()
global _parallel_mode
if _parallel_mode == "stand_alone":
func(run_context)
return
_reset_auto_parallel_context()
func(run_context)
_restore_auto_parallel_context()
PARAMETER_CLONED_INDEX = 0

@ -22,7 +22,7 @@ from .._checkparam import check_input_data, check_output_data, check_int_positiv
from .callback import _InternalCallbackParam, RunContext, _build_callbacks
from .. import context
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _callback_wrapper
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
from ..nn.metrics import Loss
from .. import nn
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
@ -144,6 +144,9 @@ class Model:
elif self._loss_fn:
network = nn.WithLossCell(network, self._loss_fn)
# If need to check if loss_fn is not None, but optimizer is None
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
network.set_auto_parallel()
return network
def _build_eval_network(self, metrics, eval_network, eval_indexes):
@ -165,11 +168,15 @@ class Model:
self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level == "O2")
self._eval_indexes = [0, 1, 2]
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
self._eval_network.set_auto_parallel()
def _build_predict_network(self):
"""Build the network for prediction."""
self._predict_network = self._network
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
self._predict_network = _VirtualDatasetCell(self._network)
self._predict_network.set_auto_parallel()
def _clear_metrics(self):
"""Clear metrics local values."""
@ -287,28 +294,28 @@ class Model:
cb_params.cur_step_num = 0
loop_size = dataset_helper.loop_size()
run_context = RunContext(cb_params)
_callback_wrapper(list_callback, run_context, "begin")
list_callback.begin(run_context)
# used to stop training for early stop, such as stopAtTIme or stopATStep
should_stop = False
for i in range(epoch):
cb_params.cur_epoch_num = i + 1
_callback_wrapper(list_callback, run_context, "epoch_begin")
list_callback.epoch_begin(run_context)
# for data sink dataset_helper only iter once, other wise iter epoch_size times.
for inputs in dataset_helper:
cb_params.cur_step_num += loop_size
_callback_wrapper(list_callback, run_context, "step_begin")
list_callback.step_begin(run_context)
outputs = self._train_network(*inputs)
cb_params.net_outputs = outputs
_callback_wrapper(list_callback, run_context, "step_end")
list_callback.step_end(run_context)
_callback_wrapper(list_callback, run_context, "epoch_end")
list_callback.epoch_end(run_context)
should_stop = should_stop or run_context.get_stop_requested()
if should_stop:
break
_callback_wrapper(list_callback, run_context, "end")
list_callback.end(run_context)
def _train_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
"""
@ -327,14 +334,14 @@ class Model:
dataset_helper = DatasetHelper(train_dataset, dataset_sink_mode=False)
cb_params.cur_step_num = 0
run_context = RunContext(cb_params)
_callback_wrapper(list_callback, run_context, "begin")
list_callback.begin(run_context)
# used to stop training for early stop, such as stopAtTIme or stopATStep
should_stop = False
for i in range(epoch):
cb_params.cur_epoch_num = i + 1
_callback_wrapper(list_callback, run_context, "epoch_begin")
list_callback.epoch_begin(run_context)
for next_element in dataset_helper:
len_element = len(next_element)
@ -342,7 +349,7 @@ class Model:
raise ValueError("when loss_fn is not None, train_dataset should"
"return two elements, but got {}".format(len_element))
cb_params.cur_step_num += 1
_callback_wrapper(list_callback, run_context, "step_begin")
list_callback.step_begin(run_context)
overflow = False
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
@ -356,19 +363,19 @@ class Model:
overflow = np.all(overflow.asnumpy())
self._loss_scale_manager.update_loss_scale(overflow)
_callback_wrapper(list_callback, run_context, "step_end")
list_callback.step_end(run_context)
should_stop = should_stop or run_context.get_stop_requested()
if should_stop:
break
train_dataset.reset()
_callback_wrapper(list_callback, run_context, "epoch_end")
list_callback.epoch_end(run_context)
should_stop = should_stop or run_context.get_stop_requested()
if should_stop:
break
_callback_wrapper(list_callback, run_context, "end")
list_callback.end(run_context)
def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True):
"""

@ -92,6 +92,7 @@ class AddReluFactory:
def forward_mindspore_parallel_impl(self):
net = AddRelu(strategy0=self.strategy0, strategy1=self.strategy1)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net.set_auto_parallel()
x = Tensor(self.input_np1)
y = Tensor(self.input_np2, ms.float32)
inputs_x = self.get_parallel_blocks(self.input_np1, self.strategy0[1])
@ -118,6 +119,7 @@ class AddReluFactory:
net = AddRelu(strategy0=self.strategy0, strategy1=self.strategy1)
grad_net = Grad(net)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
grad_net.set_auto_parallel()
grad_net.set_train()
inputs_x = self.get_parallel_blocks(self.input_np1, self.strategy0[1])
x1 = Tensor(inputs_x[self.x_id])

@ -249,6 +249,7 @@ class Conv2dFactory:
padding=self.padding, dilation=self.dilation,
group=self.group, has_bias=False, weight_init=weight, strategy=(self.strategy0[0], self.strategy0[1], self.strategy0[1]))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net.set_auto_parallel()
out = net(x, y, parallel_inputs_compile=[x, y], parallel_inputs_run=[x1, y1])
return out.asnumpy()
@ -307,7 +308,8 @@ class Conv2dFactory:
grad_net = Grad(net)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
grad_net.set_train()
grad_net.set_train()
grad_net.set_auto_parallel()
out_grad = grad_net(x, y, output_grad, parallel_inputs_compile=[x, y, output_grad1], parallel_inputs_run=[x1, y1, output_grad1])
return out_grad

@ -95,6 +95,7 @@ class DropoutFactory:
x1 = Tensor(inputs_x[self.x_id])
net = Net(0.4, 0, 0, strategy=self.strategy0)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net.set_auto_parallel()
out = net(x, parallel_inputs_compile=[x], parallel_inputs_run=[x1])
return out.asnumpy()

@ -118,6 +118,7 @@ class L2normalizeFactory:
y1 = Tensor(inputs_y[self.y_id])
net = L2normalize(self.axis, self.epsilon, strategy0=self.strategy0, strategy1=self.strategy1)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net.set_auto_parallel()
out = net(x, y, parallel_inputs_compile=[x, y], parallel_inputs_run=[x1, y1])
return out.asnumpy()
@ -144,6 +145,7 @@ class L2normalizeFactory:
net = L2normalize(self.axis, self.epsilon, strategy0=self.strategy0, strategy1=self.strategy1)
grad_net = Grad(net)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
grad_net.set_auto_parallel()
grad_net.set_train()
input_grad = grad_net(x, y, output_grad, parallel_inputs_compile=[x, y, output_grad1], parallel_inputs_run=[x1, y1, output_grad1])
return input_grad

@ -140,6 +140,7 @@ class AddReluFactory:
net_with_loss = NetWithLoss(net, strategy2=self.strategy2)
grad_net = Grad(net_with_loss)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
grad_net.set_auto_parallel()
grad_net.set_train()
input_grads = []
for i in range(0, 3):

@ -229,6 +229,7 @@ class BatchmatmulFactory:
y1 = Tensor(ys[self.y_id]) #需要从设备矩阵推导
z1 = Tensor(zs[self.x_id])
matmul.set_train()
matmul.set_auto_parallel()
out_me = matmul(x, y, z, parallel_inputs_compile=[x, y, z], parallel_inputs_run=[x1, y1, z1])
return out_me.asnumpy()
@ -267,6 +268,7 @@ class BatchmatmulFactory:
out_grad1 = Tensor(out_grads[self.out_id])
net_me = Grad(matmul)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net_me.set_auto_parallel()
net_me.set_train()
out_grad = net_me(x, y, z, out_grad_me, parallel_inputs_compile = [x, y, z, out_grad1], parallel_inputs_run = [x1, y1, z1, out_grad1])

@ -119,6 +119,7 @@ class MaxFactory:
y1 = Tensor(ys[self.y_id])
net = Max(axis=self.axis, keep_dims=self.keep_dims, strategy0=self.strategy0, strategy1=self.strategy1)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net.set_auto_parallel()
out = net(x, y, parallel_inputs_compile=[x, y], parallel_inputs_run=[x1, y1])
return out.asnumpy()
@ -144,6 +145,7 @@ class MaxFactory:
net = Max(axis=self.axis, keep_dims=self.keep_dims, strategy0=self.strategy0, strategy1=self.strategy1)
grad_net = Grad(net)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
grad_net.set_auto_parallel()
grad_net.set_train()
input_grad = grad_net(x, y, out_grad, parallel_inputs_compile=[x, y, out_grad], parallel_inputs_run=[x1, y1, out_grad])
return input_grad

@ -93,6 +93,7 @@ class MulSoftmaxFactory:
def forward_mindspore_parallel_impl(self):
net = MulSoftmax(strategy0=self.strategy0, strategy1=self.strategy1)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net.set_auto_parallel()
x = Tensor(self.input_np1)
y = Tensor(self.input_np2, ms.float32)
inputs_x = self.get_parallel_blocks(self.input_np1, self.strategy0[1])
@ -120,6 +121,7 @@ class MulSoftmaxFactory:
grad_net = Grad(net)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
grad_net.set_train()
grad_net.set_auto_parallel()
inputs_x = self.get_parallel_blocks(self.input_np1, self.strategy0[1])
x1 = Tensor(inputs_x[self.x_id])
y1 = Tensor(self.input_np2, ms.float32)

@ -113,6 +113,7 @@ class OneHotFactory:
on_value=self.on_value,
off_value=self.off_value, strategy=self.strategy0)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net.set_auto_parallel()
out = net(x, parallel_inputs_compile=[x], parallel_inputs_run=[x1])
return out.asnumpy()

@ -86,6 +86,7 @@ class PReLUFactory:
def forward_mindspore_parallel_impl(self):
net = PReLU(channel=self.channel, w=self.weight, strategy_=self.strategy, strategy1_=(self.strategy[0], self.strategy[1], self.strategy[1]))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net.set_auto_parallel()
x = Tensor(self.input_np)
z = Tensor(np.zeros(self.input_np.shape), ms.float32)
w = Tensor(self.weight)
@ -122,6 +123,7 @@ class PReLUFactory:
net = PReLU(channel=self.channel, w=self.weight, strategy_=self.strategy, strategy1_=(self.strategy[0], self.strategy[1], self.strategy[1]))
grad_net = Grad(net)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
grad_net.set_auto_parallel()
grad_net.set_train()
inputs = self.get_parallel_blocks(self.input_np, self.strategy[1])

@ -176,6 +176,7 @@ class ReduceMeanFactory:
y1 = Tensor(inputs_y[self.y_id])
net = ReduceMean(keep_dims=self.keep_dims, axis=self.axis, strategy0=self.strategy0, strategy1=self.strategy1)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net.set_auto_parallel()
out = net(x, y, parallel_inputs_compile=[x, y], parallel_inputs_run=[x1, y1])
return out.asnumpy()
@ -202,6 +203,7 @@ class ReduceMeanFactory:
net = ReduceMean(keep_dims=self.keep_dims, axis=self.axis, strategy0=self.strategy0, strategy1=self.strategy1)
grad_net = Grad(net)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
grad_net.set_auto_parallel()
grad_net.set_train()
input_grad = grad_net(x, y, output_grad, parallel_inputs_compile=[x, y, output_grad1],
parallel_inputs_run=[x1, y1, output_grad1])

@ -121,6 +121,7 @@ class ReshapeFactory:
y1 = Tensor(inputs_y[self.y_id])
net = Reshape(self.target_shape, strategy0=self.strategy0, strategy1=self.strategy1)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net.set_auto_parallel()
out = net(x, y, parallel_inputs_compile=[x, y], parallel_inputs_run=[x1, y1])
return out.asnumpy()
@ -147,6 +148,7 @@ class ReshapeFactory:
net = Reshape(self.target_shape, strategy0=self.strategy0, strategy1=self.strategy1)
grad_net = Grad(net)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
grad_net.set_auto_parallel()
grad_net.set_train()
input_grad = grad_net(x, y, output_grad, parallel_inputs_compile=[x, y, output_grad1], parallel_inputs_run=[x1, y1, output_grad1])
return input_grad

@ -148,6 +148,7 @@ class TransposeFactory:
y1 = Tensor(inputs_y[self.y_id])
net = Net(self.perm_in, strategy0=self.strategy0, strategy1=self.strategy1)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net.set_auto_parallel()
out = net(x, y, parallel_inputs_compile=[x, y], parallel_inputs_run=[x1, y1])
return out.asnumpy()
@ -174,6 +175,7 @@ class TransposeFactory:
net = Net(self.perm_in, strategy0=self.strategy0, strategy1=self.strategy1)
grad_net = Grad(net)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
grad_net.set_auto_parallel()
grad_net.set_train()
input_grad = grad_net(x, y, output_grad, parallel_inputs_compile=[x, y, output_grad1], parallel_inputs_run=[x1, y1, output_grad1])
return input_grad

@ -49,6 +49,12 @@ class Grad(nn.Cell):
def construct(self, x, y):
return C.grad_all(self.network)(x, y)
def compile(net, x, y):
net.set_auto_parallel()
_executor.compile(net, x, y)
def test_add_relu_stride_slice():
context.set_auto_parallel_context(device_num=8, global_rank=7)
@ -59,7 +65,7 @@ def test_add_relu_stride_slice():
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
y = Tensor(np.ones([128, 32]), dtype=ms.float32)
_executor.compile(net, x, y)
compile(net, x, y)
def test_add_relu_all_gather():
context.set_auto_parallel_context(device_num=8, global_rank=7)
@ -71,4 +77,4 @@ def test_add_relu_all_gather():
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
y = Tensor(np.ones([128, 32]), dtype=ms.float32)
_executor.compile(net, x, y)
compile(net, x, y)

@ -42,6 +42,11 @@ class GradWrap(nn.Cell):
return C.grad_all(self.network)(x, y, b)
def compile(net, x, y, b):
net.set_auto_parallel()
_executor.compile(net, x, y, b)
def test_matmul_sub():
class Net(nn.Cell):
def __init__(self, strategy1, strategy2):
@ -64,7 +69,7 @@ def test_matmul_sub():
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
_executor.compile(net, x, y, b)
compile(net, x, y, b)
def test_matmul_add():
@ -88,7 +93,7 @@ def test_matmul_add():
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
_executor.compile(net, x, y, b)
compile(net, x, y, b)
def test_matmul_mul():
@ -112,7 +117,7 @@ def test_matmul_mul():
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
_executor.compile(net, x, y, b)
compile(net, x, y, b)
def test_matmul_div():
@ -136,7 +141,7 @@ def test_matmul_div():
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
_executor.compile(net, x, y, b)
compile(net, x, y, b)
def test_matmul_greater():
class Net(nn.Cell):
@ -159,7 +164,7 @@ def test_matmul_greater():
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
_executor.compile(net, x, y, b)
compile(net, x, y, b)
def test_matmul_add_broadcast():
class Net(nn.Cell):
@ -182,7 +187,7 @@ def test_matmul_add_broadcast():
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64]), dtype=ms.float32)
_executor.compile(net, x, y, b)
compile(net, x, y, b)
def test_matmul_add_broadcast2():
@ -206,7 +211,7 @@ def test_matmul_add_broadcast2():
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
_executor.compile(net, x, y, b)
compile(net, x, y, b)
def test_matmul_sub_broadcast():
@ -230,7 +235,7 @@ def test_matmul_sub_broadcast():
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64]), dtype=ms.float32)
_executor.compile(net, x, y, b)
compile(net, x, y, b)
def test_matmul_sub_broadcast2():
@ -254,7 +259,7 @@ def test_matmul_sub_broadcast2():
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
_executor.compile(net, x, y, b)
compile(net, x, y, b)
def test_matmul_mul_broadcast():
@ -278,7 +283,7 @@ def test_matmul_mul_broadcast():
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64]), dtype=ms.float32)
_executor.compile(net, x, y, b)
compile(net, x, y, b)
def test_matmul_mul_broadcast2():
@ -302,7 +307,7 @@ def test_matmul_mul_broadcast2():
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
_executor.compile(net, x, y, b)
compile(net, x, y, b)
def test_matmul_div_broadcast():
@ -326,7 +331,7 @@ def test_matmul_div_broadcast():
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64]), dtype=ms.float32)
_executor.compile(net, x, y, b)
compile(net, x, y, b)
def test_matmul_div_broadcast2():
@ -350,7 +355,7 @@ def test_matmul_div_broadcast2():
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
_executor.compile(net, x, y, b)
compile(net, x, y, b)
def test_matmul_greater_broadcast():
class Net(nn.Cell):
@ -373,7 +378,7 @@ def test_matmul_greater_broadcast():
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64]), dtype=ms.float32)
_executor.compile(net, x, y, b)
compile(net, x, y, b)
def test_matmul_greater_broadcast2():
@ -397,7 +402,7 @@ def test_matmul_greater_broadcast2():
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
_executor.compile(net, x, y, b)
compile(net, x, y, b)
def test_matmul_floordiv():
class Net(nn.Cell):
@ -420,7 +425,7 @@ def test_matmul_floordiv():
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 64]), dtype=ms.float32)
_executor.compile(net, x, y, b)
compile(net, x, y, b)
def test_matmul_floordiv_broadcast():
@ -444,7 +449,7 @@ def test_matmul_floordiv_broadcast():
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64]), dtype=ms.float32)
_executor.compile(net, x, y, b)
compile(net, x, y, b)
def test_matmul_floordiv_broadcast2():
@ -468,7 +473,7 @@ def test_matmul_floordiv_broadcast2():
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
_executor.compile(net, x, y, b)
compile(net, x, y, b)
def test_assign_sub():
@ -495,4 +500,4 @@ def test_assign_sub():
x = Tensor(np.ones([128, 32]), dtype=ms.float32)
y = Tensor(np.ones([128, 32]), dtype=ms.float32)
z = Tensor(np.ones([128, 32]), dtype=ms.float32)
_executor.compile(net, x, y, z)
compile(net, x, y, z)

@ -66,4 +66,5 @@ def test_auto_parallel_bn_with_prelu():
net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel")
net.set_auto_parallel()
_executor.compile(net, x)

@ -43,6 +43,12 @@ class GradWrap(nn.Cell):
def construct(self, x, y, b):
return C.grad_all(self.network)(x, y, b)
def compile(net, x, y, b, phase):
net.set_auto_parallel()
_executor.compile(net, x, y, b, phase=phase)
def test_auto_parallel_arithmetic():
class Net(nn.Cell):
def __init__(self):
@ -63,7 +69,7 @@ def test_auto_parallel_arithmetic():
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 128]), dtype=ms.float32)
b = Tensor(np.ones([64, 128]), dtype=ms.float32)
_executor.compile(net, x, y, b, phase='train')
compile(net, x, y, b, phase='train')
strategies = _executor._get_strategy(net)
expected_strategies = {'Default/network-Net/FloorDiv-op0': [[2, 4], [2, 4]],
'Default/network-Net/MatMul-op1': [[2, 1], [1, 4]]}
@ -89,7 +95,7 @@ def test_auto_parallel_arithmetic_broadcast_both():
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 1]), dtype=ms.float32)
b = Tensor(np.ones([1, 64]), dtype=ms.float32)
_executor.compile(net, x, y, b, phase='train')
compile(net, x, y, b, phase='train')
strategies = _executor._get_strategy(net)
expected_strategies = {'Default/network-Net/FloorDiv-op0': [[8, 1], [1, 1]],
'Default/network-Net/MatMul-op1': [[8, 1], [1, 1]]}
@ -116,7 +122,7 @@ def test_auto_parallel_arithmetic_broadcast_right():
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 32]), dtype=ms.float32)
b = Tensor(np.ones([32]), dtype=ms.float32)
_executor.compile(net, x, y, b, phase='train')
compile(net, x, y, b, phase='train')
strategies = _executor._get_strategy(net)
expected_strategies = {'Default/network-Net/FloorDiv-op0': [[4, 2], [2]],
'Default/network-Net/MatMul-op1': [[4, 1], [1, 2]]}
@ -143,7 +149,7 @@ def test_auto_parallel_arithmetic_broadcast_left():
x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 32]), dtype=ms.float32)
b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
_executor.compile(net, x, y, b, phase="train")
compile(net, x, y, b, phase="train")
strategies = _executor._get_strategy(net)
expected_strategies = {'Default/network-Net/FloorDiv-op0': [[4, 2], [1, 4, 2]],
'Default/network-Net/MatMul-op1': [[4, 1], [1, 2]]}

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save