add parallel mode for cell

pull/958/head
yangzhenzhang 5 years ago
parent 7c64048d76
commit 8c9730b3c5

@ -35,8 +35,8 @@ bool StepAllreduceFusion(const FuncGraphPtr &root, const opt::OptimizerPtr &opti
// assume no change to graph // assume no change to graph
bool changes = false; bool changes = false;
// control whether use model_parallel mode // control whether use model_parallel mode
if (((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) || (!enable_all_reduce_fusion) || if (!root->has_flag(AUTO_PARALLEL) || ((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) ||
(root->has_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY))) { (!enable_all_reduce_fusion) || (root->has_flag(ALLREDUCE_FUSION_RUN_ONCE_ONLY))) {
return changes; return changes;
} }
#if defined(_WIN32) || defined(_WIN64) #if defined(_WIN32) || defined(_WIN64)

@ -121,7 +121,8 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
// assume no change to graph // assume no change to graph
bool changes = false; bool changes = false;
// control whether use model_parallel mode // control whether use model_parallel mode
if ((parallel_mode != AUTO_PARALLEL) || root->flags()[AUTO_PARALLEL_RUN_ONCE_ONLY]) { if (!root->has_flag(AUTO_PARALLEL) || (parallel_mode != AUTO_PARALLEL) ||
root->has_flag(AUTO_PARALLEL_RUN_ONCE_ONLY)) {
return changes; return changes;
} }
// check whether strategy_search_mode is valid // check whether strategy_search_mode is valid

@ -2220,7 +2220,7 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
// assume no change to graph // assume no change to graph
bool changes = false; bool changes = false;
// control whether use model_parallel mode // control whether use model_parallel mode
if (((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) || if (!root->has_flag(AUTO_PARALLEL) || ((parallel_mode != AUTO_PARALLEL) && (parallel_mode != SEMI_AUTO_PARALLEL)) ||
(root->has_flag(SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY))) { (root->has_flag(SEMI_AUTO_PARALLEL_RUN_ONCE_ONLY))) {
return changes; return changes;
} }

@ -281,7 +281,7 @@ void ExecutorPy::SaveCompiledGraph(const std::string &phase_s) {
MS_LOG(INFO) << "Save compiled func graph(" << func_graph->ToString() << ") phase(" << phase_s << ")!"; MS_LOG(INFO) << "Save compiled func graph(" << func_graph->ToString() << ") phase(" << phase_s << ")!";
info_[phase_s]->func_graph = func_graph; info_[phase_s]->func_graph = func_graph;
if ((func_graph != nullptr) && if ((func_graph != nullptr) && func_graph->has_flag(parallel::AUTO_PARALLEL) &&
((parallel_mode == parallel::AUTO_PARALLEL) || (parallel_mode == parallel::SEMI_AUTO_PARALLEL))) { ((parallel_mode == parallel::AUTO_PARALLEL) || (parallel_mode == parallel::SEMI_AUTO_PARALLEL))) {
MS_LOG(DEBUG) << "Save model parallel parameter layout graph!"; MS_LOG(DEBUG) << "Save model parallel parameter layout graph!";
func_graph = info_[phase_s]->resource->results()[kStepParallelGraph].cast<FuncGraphPtr>(); func_graph = info_[phase_s]->resource->results()[kStepParallelGraph].cast<FuncGraphPtr>();

@ -20,7 +20,6 @@ from collections import OrderedDict
from functools import wraps from functools import wraps
from mindspore import context from mindspore import context
from mindspore import log as logger from mindspore import log as logger
from mindspore.parallel._utils import _get_parallel_mode
from .._c_expression import generate_key, Executor_, Tensor, MetaTensor from .._c_expression import generate_key, Executor_, Tensor, MetaTensor
from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_backend from .._c_expression import verify_inputs_signature, init_exec_dataset, _set_dataset_mode_config, init_backend
from .tensor import Tensor as MsTensor from .tensor import Tensor as MsTensor
@ -327,7 +326,7 @@ class _Executor:
raise TypeError('Parameters need OrderedDict type, but got {}'. raise TypeError('Parameters need OrderedDict type, but got {}'.
format(type(params))) format(type(params)))
def compile(self, obj, *args, phase='predict', params=None, do_convert=True): def compile(self, obj, *args, phase='predict', params=None, do_convert=True, auto_parallel_mode=False):
""" """
Compiles graph. Compiles graph.
@ -337,6 +336,7 @@ class _Executor:
phase (str): The name of compile phase. Default: 'predict'. phase (str): The name of compile phase. Default: 'predict'.
params (OrderedDict): The parameters dictionary used for init data graph. Default: None. params (OrderedDict): The parameters dictionary used for init data graph. Default: None.
do_convert (bool): When set to True, convert ME graph to GE graph after compiling graph. do_convert (bool): When set to True, convert ME graph to GE graph after compiling graph.
auto_parallel_mode: When set to True, use auto parallel mode to compile graph.
Return: Return:
Str, the full phase of the cell. Str, the full phase of the cell.
@ -370,8 +370,9 @@ class _Executor:
logger.error("%r graph compile failed.", phase) logger.error("%r graph compile failed.", phase)
if not do_convert: if not do_convert:
return phase, True return phase, True
if not enable_debug_runtime or enable_ge: if not enable_debug_runtime or enable_ge:
if _get_parallel_mode() in ["auto_parallel", "semi_auto_parallel"]: if auto_parallel_mode:
obj.parameter_layout_dict = self._executor.get_parameter_layout(phase) obj.parameter_layout_dict = self._executor.get_parameter_layout(phase)
obj.load_parameter_slice(params) obj.load_parameter_slice(params)

@ -25,7 +25,6 @@ from ..common.parameter import Parameter, ParameterTuple
from .._c_expression import init_backend from .._c_expression import init_backend
from ..ops.primitive import Primitive from ..ops.primitive import Primitive
from ..parallel._tensor import _load_tensor_by_layout from ..parallel._tensor import _load_tensor_by_layout
from ..parallel._utils import _get_parallel_mode
from ..common.tensor import Tensor from ..common.tensor import Tensor
@ -71,8 +70,7 @@ class Cell:
gc.collect() gc.collect()
self._construct_inputs_num = 0 self._construct_inputs_num = 0
self._construct_inputs_names = [] self._construct_inputs_names = []
if _get_parallel_mode() in ["auto_parallel", "semi_auto_parallel"]: self._auto_parallel_mode = False
self._get_construct_inputs_number_and_name()
self._parallel_inputs_run = None self._parallel_inputs_run = None
if flags: if flags:
self.add_flags(**flags) self.add_flags(**flags)
@ -298,9 +296,10 @@ class Cell:
Returns: Returns:
Object, the result of executing. Object, the result of executing.
""" """
_, compile_flag = _executor.compile(self, *inputs, phase=self.phase) _, compile_flag = _executor.compile(self, *inputs, phase=self.phase,
auto_parallel_mode=self._auto_parallel_mode)
if _get_parallel_mode() in ["auto_parallel", "semi_auto_parallel"]: if self._auto_parallel_mode:
if inputs and isinstance(inputs[0], Tensor) and inputs[0].virtual_flag and (not compile_flag): if inputs and isinstance(inputs[0], Tensor) and inputs[0].virtual_flag and (not compile_flag):
parallel_inputs_run = self._parallel_inputs_run parallel_inputs_run = self._parallel_inputs_run
else: else:
@ -665,3 +664,15 @@ class Cell:
""" """
self.add_flags_recursive(broadcast_flag=mode) self.add_flags_recursive(broadcast_flag=mode)
return self return self
def set_auto_parallel(self):
"""
Set the cell to auto parallel mode.
Note:
If a cell needs to use auto parallel or semi auto parallel mode for training, evaluation or prediction,
this interface needs to be called for the cell.
"""
self._auto_parallel_mode = True
self.add_flags(auto_parallel=True)
self._get_construct_inputs_number_and_name()

@ -16,8 +16,7 @@
from mindspore._c_expression import reset_op_id from mindspore._c_expression import reset_op_id
from mindspore.communication.management import get_group_size, get_rank from mindspore.communication.management import get_group_size, get_rank
from mindspore.parallel._auto_parallel_context import auto_parallel_context, _set_auto_parallel_context,\ from mindspore.parallel._auto_parallel_context import auto_parallel_context
_reset_auto_parallel_context
def _get_parallel_mode(): def _get_parallel_mode():
@ -108,102 +107,6 @@ def _parameter_broadcast_check(parallel_mode, parameter_broadcast):
.format(parallel_mode, parameter_broadcast)) .format(parallel_mode, parameter_broadcast))
_parallel_mode = None
_device_num = None
_global_rank = None
_parameter_broadcast = None
_mirror_mean = None
_cast_before_mirror = None
_loss_repeated_mean = None
_communication_backend = None
_has_checkpointed = False
_enable_all_reduce_fusion = None
def _checkpoint_auto_parallel_context():
"""checkpoint auto parallel context"""
global _has_checkpointed
if _has_checkpointed is True:
return
global _parallel_mode
global _device_num
global _global_rank
global _parameter_broadcast
global _mirror_mean
global _cast_before_mirror
global _loss_repeated_mean
global _communication_backend
global _enable_all_reduce_fusion
_parallel_mode = auto_parallel_context().get_parallel_mode()
_device_num = _get_device_num()
_global_rank = _get_global_rank()
_parameter_broadcast = auto_parallel_context().get_parameter_broadcast()
_mirror_mean = auto_parallel_context().get_mirror_mean()
_cast_before_mirror = auto_parallel_context().get_cast_before_mirror()
_loss_repeated_mean = auto_parallel_context().get_loss_repeated_mean()
_communication_backend = auto_parallel_context().get_communication_backend()
_enable_all_reduce_fusion = auto_parallel_context().get_enable_all_reduce_fusion()
_has_checkpointed = True
def _restore_auto_parallel_context():
"""restore auto parallel context"""
global _parallel_mode
global _device_num
global _global_rank
global _parameter_broadcast
global _mirror_mean
global _cast_before_mirror
global _loss_repeated_mean
global _communication_backend
global _enable_all_reduce_fusion
_set_auto_parallel_context(parallel_mode=_parallel_mode, device_num=_device_num, global_rank=_global_rank,
parameter_broadcast=_parameter_broadcast, mirror_mean=_mirror_mean,
cast_before_mirror=_cast_before_mirror, loss_repeated_mean=_loss_repeated_mean)
auto_parallel_context().set_communication_backend(_communication_backend)
auto_parallel_context().set_enable_all_reduce_fusion(_enable_all_reduce_fusion)
def _reset_checkpoint_auto_parallel_context():
"""reset the _has_checkpointed"""
global _has_checkpointed
_has_checkpointed = False
def _callback_wrapper(list_callback, run_context, callback_type):
"""
reset the context for callback of model train
Raises:
ValueError: If the type keyword is not recognized
"""
_callback_func_map = {
"begin": list_callback.begin,
"epoch_begin": list_callback.epoch_begin,
"step_begin": list_callback.step_begin,
"step_end": list_callback.step_end,
"epoch_end": list_callback.epoch_end,
"end": list_callback.end}
if callback_type not in _callback_func_map:
raise ValueError("Get type keyword %s is not recognized!" % callback_type)
func = _callback_func_map[callback_type]
if callback_type == "begin":
_reset_checkpoint_auto_parallel_context()
_checkpoint_auto_parallel_context()
global _parallel_mode
if _parallel_mode == "stand_alone":
func(run_context)
return
_reset_auto_parallel_context()
func(run_context)
_restore_auto_parallel_context()
PARAMETER_CLONED_INDEX = 0 PARAMETER_CLONED_INDEX = 0

@ -22,7 +22,7 @@ from .._checkparam import check_input_data, check_output_data, check_int_positiv
from .callback import _InternalCallbackParam, RunContext, _build_callbacks from .callback import _InternalCallbackParam, RunContext, _build_callbacks
from .. import context from .. import context
from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \ from ..parallel._utils import _get_parallel_mode, _get_device_num, _get_global_rank, \
_get_parameter_broadcast, _device_number_check, _parameter_broadcast_check, _callback_wrapper _get_parameter_broadcast, _device_number_check, _parameter_broadcast_check
from ..nn.metrics import Loss from ..nn.metrics import Loss
from .. import nn from .. import nn
from ..nn.wrap.cell_wrapper import _VirtualDatasetCell from ..nn.wrap.cell_wrapper import _VirtualDatasetCell
@ -144,6 +144,9 @@ class Model:
elif self._loss_fn: elif self._loss_fn:
network = nn.WithLossCell(network, self._loss_fn) network = nn.WithLossCell(network, self._loss_fn)
# If need to check if loss_fn is not None, but optimizer is None # If need to check if loss_fn is not None, but optimizer is None
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
network.set_auto_parallel()
return network return network
def _build_eval_network(self, metrics, eval_network, eval_indexes): def _build_eval_network(self, metrics, eval_network, eval_indexes):
@ -165,11 +168,15 @@ class Model:
self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level == "O2") self._eval_network = nn.WithEvalCell(self._network, self._loss_fn, self._amp_level == "O2")
self._eval_indexes = [0, 1, 2] self._eval_indexes = [0, 1, 2]
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
self._eval_network.set_auto_parallel()
def _build_predict_network(self): def _build_predict_network(self):
"""Build the network for prediction.""" """Build the network for prediction."""
self._predict_network = self._network self._predict_network = self._network
if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): if self._parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
self._predict_network = _VirtualDatasetCell(self._network) self._predict_network = _VirtualDatasetCell(self._network)
self._predict_network.set_auto_parallel()
def _clear_metrics(self): def _clear_metrics(self):
"""Clear metrics local values.""" """Clear metrics local values."""
@ -287,28 +294,28 @@ class Model:
cb_params.cur_step_num = 0 cb_params.cur_step_num = 0
loop_size = dataset_helper.loop_size() loop_size = dataset_helper.loop_size()
run_context = RunContext(cb_params) run_context = RunContext(cb_params)
_callback_wrapper(list_callback, run_context, "begin") list_callback.begin(run_context)
# used to stop training for early stop, such as stopAtTIme or stopATStep # used to stop training for early stop, such as stopAtTIme or stopATStep
should_stop = False should_stop = False
for i in range(epoch): for i in range(epoch):
cb_params.cur_epoch_num = i + 1 cb_params.cur_epoch_num = i + 1
_callback_wrapper(list_callback, run_context, "epoch_begin") list_callback.epoch_begin(run_context)
# for data sink dataset_helper only iter once, other wise iter epoch_size times. # for data sink dataset_helper only iter once, other wise iter epoch_size times.
for inputs in dataset_helper: for inputs in dataset_helper:
cb_params.cur_step_num += loop_size cb_params.cur_step_num += loop_size
_callback_wrapper(list_callback, run_context, "step_begin") list_callback.step_begin(run_context)
outputs = self._train_network(*inputs) outputs = self._train_network(*inputs)
cb_params.net_outputs = outputs cb_params.net_outputs = outputs
_callback_wrapper(list_callback, run_context, "step_end") list_callback.step_end(run_context)
_callback_wrapper(list_callback, run_context, "epoch_end") list_callback.epoch_end(run_context)
should_stop = should_stop or run_context.get_stop_requested() should_stop = should_stop or run_context.get_stop_requested()
if should_stop: if should_stop:
break break
_callback_wrapper(list_callback, run_context, "end") list_callback.end(run_context)
def _train_process(self, epoch, train_dataset, list_callback=None, cb_params=None): def _train_process(self, epoch, train_dataset, list_callback=None, cb_params=None):
""" """
@ -327,14 +334,14 @@ class Model:
dataset_helper = DatasetHelper(train_dataset, dataset_sink_mode=False) dataset_helper = DatasetHelper(train_dataset, dataset_sink_mode=False)
cb_params.cur_step_num = 0 cb_params.cur_step_num = 0
run_context = RunContext(cb_params) run_context = RunContext(cb_params)
_callback_wrapper(list_callback, run_context, "begin") list_callback.begin(run_context)
# used to stop training for early stop, such as stopAtTIme or stopATStep # used to stop training for early stop, such as stopAtTIme or stopATStep
should_stop = False should_stop = False
for i in range(epoch): for i in range(epoch):
cb_params.cur_epoch_num = i + 1 cb_params.cur_epoch_num = i + 1
_callback_wrapper(list_callback, run_context, "epoch_begin") list_callback.epoch_begin(run_context)
for next_element in dataset_helper: for next_element in dataset_helper:
len_element = len(next_element) len_element = len(next_element)
@ -342,7 +349,7 @@ class Model:
raise ValueError("when loss_fn is not None, train_dataset should" raise ValueError("when loss_fn is not None, train_dataset should"
"return two elements, but got {}".format(len_element)) "return two elements, but got {}".format(len_element))
cb_params.cur_step_num += 1 cb_params.cur_step_num += 1
_callback_wrapper(list_callback, run_context, "step_begin") list_callback.step_begin(run_context)
overflow = False overflow = False
if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update(): if self._loss_scale_manager and self._loss_scale_manager.get_drop_overflow_update():
@ -356,19 +363,19 @@ class Model:
overflow = np.all(overflow.asnumpy()) overflow = np.all(overflow.asnumpy())
self._loss_scale_manager.update_loss_scale(overflow) self._loss_scale_manager.update_loss_scale(overflow)
_callback_wrapper(list_callback, run_context, "step_end") list_callback.step_end(run_context)
should_stop = should_stop or run_context.get_stop_requested() should_stop = should_stop or run_context.get_stop_requested()
if should_stop: if should_stop:
break break
train_dataset.reset() train_dataset.reset()
_callback_wrapper(list_callback, run_context, "epoch_end") list_callback.epoch_end(run_context)
should_stop = should_stop or run_context.get_stop_requested() should_stop = should_stop or run_context.get_stop_requested()
if should_stop: if should_stop:
break break
_callback_wrapper(list_callback, run_context, "end") list_callback.end(run_context)
def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True): def train(self, epoch, train_dataset, callbacks=None, dataset_sink_mode=True):
""" """

@ -92,6 +92,7 @@ class AddReluFactory:
def forward_mindspore_parallel_impl(self): def forward_mindspore_parallel_impl(self):
net = AddRelu(strategy0=self.strategy0, strategy1=self.strategy1) net = AddRelu(strategy0=self.strategy0, strategy1=self.strategy1)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net.set_auto_parallel()
x = Tensor(self.input_np1) x = Tensor(self.input_np1)
y = Tensor(self.input_np2, ms.float32) y = Tensor(self.input_np2, ms.float32)
inputs_x = self.get_parallel_blocks(self.input_np1, self.strategy0[1]) inputs_x = self.get_parallel_blocks(self.input_np1, self.strategy0[1])
@ -118,6 +119,7 @@ class AddReluFactory:
net = AddRelu(strategy0=self.strategy0, strategy1=self.strategy1) net = AddRelu(strategy0=self.strategy0, strategy1=self.strategy1)
grad_net = Grad(net) grad_net = Grad(net)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
grad_net.set_auto_parallel()
grad_net.set_train() grad_net.set_train()
inputs_x = self.get_parallel_blocks(self.input_np1, self.strategy0[1]) inputs_x = self.get_parallel_blocks(self.input_np1, self.strategy0[1])
x1 = Tensor(inputs_x[self.x_id]) x1 = Tensor(inputs_x[self.x_id])

@ -249,6 +249,7 @@ class Conv2dFactory:
padding=self.padding, dilation=self.dilation, padding=self.padding, dilation=self.dilation,
group=self.group, has_bias=False, weight_init=weight, strategy=(self.strategy0[0], self.strategy0[1], self.strategy0[1])) group=self.group, has_bias=False, weight_init=weight, strategy=(self.strategy0[0], self.strategy0[1], self.strategy0[1]))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net.set_auto_parallel()
out = net(x, y, parallel_inputs_compile=[x, y], parallel_inputs_run=[x1, y1]) out = net(x, y, parallel_inputs_compile=[x, y], parallel_inputs_run=[x1, y1])
return out.asnumpy() return out.asnumpy()
@ -307,7 +308,8 @@ class Conv2dFactory:
grad_net = Grad(net) grad_net = Grad(net)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
grad_net.set_train() grad_net.set_train()
grad_net.set_auto_parallel()
out_grad = grad_net(x, y, output_grad, parallel_inputs_compile=[x, y, output_grad1], parallel_inputs_run=[x1, y1, output_grad1]) out_grad = grad_net(x, y, output_grad, parallel_inputs_compile=[x, y, output_grad1], parallel_inputs_run=[x1, y1, output_grad1])
return out_grad return out_grad

@ -95,6 +95,7 @@ class DropoutFactory:
x1 = Tensor(inputs_x[self.x_id]) x1 = Tensor(inputs_x[self.x_id])
net = Net(0.4, 0, 0, strategy=self.strategy0) net = Net(0.4, 0, 0, strategy=self.strategy0)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net.set_auto_parallel()
out = net(x, parallel_inputs_compile=[x], parallel_inputs_run=[x1]) out = net(x, parallel_inputs_compile=[x], parallel_inputs_run=[x1])
return out.asnumpy() return out.asnumpy()

@ -118,6 +118,7 @@ class L2normalizeFactory:
y1 = Tensor(inputs_y[self.y_id]) y1 = Tensor(inputs_y[self.y_id])
net = L2normalize(self.axis, self.epsilon, strategy0=self.strategy0, strategy1=self.strategy1) net = L2normalize(self.axis, self.epsilon, strategy0=self.strategy0, strategy1=self.strategy1)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net.set_auto_parallel()
out = net(x, y, parallel_inputs_compile=[x, y], parallel_inputs_run=[x1, y1]) out = net(x, y, parallel_inputs_compile=[x, y], parallel_inputs_run=[x1, y1])
return out.asnumpy() return out.asnumpy()
@ -144,6 +145,7 @@ class L2normalizeFactory:
net = L2normalize(self.axis, self.epsilon, strategy0=self.strategy0, strategy1=self.strategy1) net = L2normalize(self.axis, self.epsilon, strategy0=self.strategy0, strategy1=self.strategy1)
grad_net = Grad(net) grad_net = Grad(net)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
grad_net.set_auto_parallel()
grad_net.set_train() grad_net.set_train()
input_grad = grad_net(x, y, output_grad, parallel_inputs_compile=[x, y, output_grad1], parallel_inputs_run=[x1, y1, output_grad1]) input_grad = grad_net(x, y, output_grad, parallel_inputs_compile=[x, y, output_grad1], parallel_inputs_run=[x1, y1, output_grad1])
return input_grad return input_grad

@ -140,6 +140,7 @@ class AddReluFactory:
net_with_loss = NetWithLoss(net, strategy2=self.strategy2) net_with_loss = NetWithLoss(net, strategy2=self.strategy2)
grad_net = Grad(net_with_loss) grad_net = Grad(net_with_loss)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
grad_net.set_auto_parallel()
grad_net.set_train() grad_net.set_train()
input_grads = [] input_grads = []
for i in range(0, 3): for i in range(0, 3):

@ -229,6 +229,7 @@ class BatchmatmulFactory:
y1 = Tensor(ys[self.y_id]) #需要从设备矩阵推导 y1 = Tensor(ys[self.y_id]) #需要从设备矩阵推导
z1 = Tensor(zs[self.x_id]) z1 = Tensor(zs[self.x_id])
matmul.set_train() matmul.set_train()
matmul.set_auto_parallel()
out_me = matmul(x, y, z, parallel_inputs_compile=[x, y, z], parallel_inputs_run=[x1, y1, z1]) out_me = matmul(x, y, z, parallel_inputs_compile=[x, y, z], parallel_inputs_run=[x1, y1, z1])
return out_me.asnumpy() return out_me.asnumpy()
@ -267,6 +268,7 @@ class BatchmatmulFactory:
out_grad1 = Tensor(out_grads[self.out_id]) out_grad1 = Tensor(out_grads[self.out_id])
net_me = Grad(matmul) net_me = Grad(matmul)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net_me.set_auto_parallel()
net_me.set_train() net_me.set_train()
out_grad = net_me(x, y, z, out_grad_me, parallel_inputs_compile = [x, y, z, out_grad1], parallel_inputs_run = [x1, y1, z1, out_grad1]) out_grad = net_me(x, y, z, out_grad_me, parallel_inputs_compile = [x, y, z, out_grad1], parallel_inputs_run = [x1, y1, z1, out_grad1])

@ -119,6 +119,7 @@ class MaxFactory:
y1 = Tensor(ys[self.y_id]) y1 = Tensor(ys[self.y_id])
net = Max(axis=self.axis, keep_dims=self.keep_dims, strategy0=self.strategy0, strategy1=self.strategy1) net = Max(axis=self.axis, keep_dims=self.keep_dims, strategy0=self.strategy0, strategy1=self.strategy1)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net.set_auto_parallel()
out = net(x, y, parallel_inputs_compile=[x, y], parallel_inputs_run=[x1, y1]) out = net(x, y, parallel_inputs_compile=[x, y], parallel_inputs_run=[x1, y1])
return out.asnumpy() return out.asnumpy()
@ -144,6 +145,7 @@ class MaxFactory:
net = Max(axis=self.axis, keep_dims=self.keep_dims, strategy0=self.strategy0, strategy1=self.strategy1) net = Max(axis=self.axis, keep_dims=self.keep_dims, strategy0=self.strategy0, strategy1=self.strategy1)
grad_net = Grad(net) grad_net = Grad(net)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
grad_net.set_auto_parallel()
grad_net.set_train() grad_net.set_train()
input_grad = grad_net(x, y, out_grad, parallel_inputs_compile=[x, y, out_grad], parallel_inputs_run=[x1, y1, out_grad]) input_grad = grad_net(x, y, out_grad, parallel_inputs_compile=[x, y, out_grad], parallel_inputs_run=[x1, y1, out_grad])
return input_grad return input_grad

@ -93,6 +93,7 @@ class MulSoftmaxFactory:
def forward_mindspore_parallel_impl(self): def forward_mindspore_parallel_impl(self):
net = MulSoftmax(strategy0=self.strategy0, strategy1=self.strategy1) net = MulSoftmax(strategy0=self.strategy0, strategy1=self.strategy1)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net.set_auto_parallel()
x = Tensor(self.input_np1) x = Tensor(self.input_np1)
y = Tensor(self.input_np2, ms.float32) y = Tensor(self.input_np2, ms.float32)
inputs_x = self.get_parallel_blocks(self.input_np1, self.strategy0[1]) inputs_x = self.get_parallel_blocks(self.input_np1, self.strategy0[1])
@ -120,6 +121,7 @@ class MulSoftmaxFactory:
grad_net = Grad(net) grad_net = Grad(net)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
grad_net.set_train() grad_net.set_train()
grad_net.set_auto_parallel()
inputs_x = self.get_parallel_blocks(self.input_np1, self.strategy0[1]) inputs_x = self.get_parallel_blocks(self.input_np1, self.strategy0[1])
x1 = Tensor(inputs_x[self.x_id]) x1 = Tensor(inputs_x[self.x_id])
y1 = Tensor(self.input_np2, ms.float32) y1 = Tensor(self.input_np2, ms.float32)

@ -113,6 +113,7 @@ class OneHotFactory:
on_value=self.on_value, on_value=self.on_value,
off_value=self.off_value, strategy=self.strategy0) off_value=self.off_value, strategy=self.strategy0)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net.set_auto_parallel()
out = net(x, parallel_inputs_compile=[x], parallel_inputs_run=[x1]) out = net(x, parallel_inputs_compile=[x], parallel_inputs_run=[x1])
return out.asnumpy() return out.asnumpy()

@ -86,6 +86,7 @@ class PReLUFactory:
def forward_mindspore_parallel_impl(self): def forward_mindspore_parallel_impl(self):
net = PReLU(channel=self.channel, w=self.weight, strategy_=self.strategy, strategy1_=(self.strategy[0], self.strategy[1], self.strategy[1])) net = PReLU(channel=self.channel, w=self.weight, strategy_=self.strategy, strategy1_=(self.strategy[0], self.strategy[1], self.strategy[1]))
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net.set_auto_parallel()
x = Tensor(self.input_np) x = Tensor(self.input_np)
z = Tensor(np.zeros(self.input_np.shape), ms.float32) z = Tensor(np.zeros(self.input_np.shape), ms.float32)
w = Tensor(self.weight) w = Tensor(self.weight)
@ -122,6 +123,7 @@ class PReLUFactory:
net = PReLU(channel=self.channel, w=self.weight, strategy_=self.strategy, strategy1_=(self.strategy[0], self.strategy[1], self.strategy[1])) net = PReLU(channel=self.channel, w=self.weight, strategy_=self.strategy, strategy1_=(self.strategy[0], self.strategy[1], self.strategy[1]))
grad_net = Grad(net) grad_net = Grad(net)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
grad_net.set_auto_parallel()
grad_net.set_train() grad_net.set_train()
inputs = self.get_parallel_blocks(self.input_np, self.strategy[1]) inputs = self.get_parallel_blocks(self.input_np, self.strategy[1])

@ -176,6 +176,7 @@ class ReduceMeanFactory:
y1 = Tensor(inputs_y[self.y_id]) y1 = Tensor(inputs_y[self.y_id])
net = ReduceMean(keep_dims=self.keep_dims, axis=self.axis, strategy0=self.strategy0, strategy1=self.strategy1) net = ReduceMean(keep_dims=self.keep_dims, axis=self.axis, strategy0=self.strategy0, strategy1=self.strategy1)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net.set_auto_parallel()
out = net(x, y, parallel_inputs_compile=[x, y], parallel_inputs_run=[x1, y1]) out = net(x, y, parallel_inputs_compile=[x, y], parallel_inputs_run=[x1, y1])
return out.asnumpy() return out.asnumpy()
@ -202,6 +203,7 @@ class ReduceMeanFactory:
net = ReduceMean(keep_dims=self.keep_dims, axis=self.axis, strategy0=self.strategy0, strategy1=self.strategy1) net = ReduceMean(keep_dims=self.keep_dims, axis=self.axis, strategy0=self.strategy0, strategy1=self.strategy1)
grad_net = Grad(net) grad_net = Grad(net)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
grad_net.set_auto_parallel()
grad_net.set_train() grad_net.set_train()
input_grad = grad_net(x, y, output_grad, parallel_inputs_compile=[x, y, output_grad1], input_grad = grad_net(x, y, output_grad, parallel_inputs_compile=[x, y, output_grad1],
parallel_inputs_run=[x1, y1, output_grad1]) parallel_inputs_run=[x1, y1, output_grad1])

@ -121,6 +121,7 @@ class ReshapeFactory:
y1 = Tensor(inputs_y[self.y_id]) y1 = Tensor(inputs_y[self.y_id])
net = Reshape(self.target_shape, strategy0=self.strategy0, strategy1=self.strategy1) net = Reshape(self.target_shape, strategy0=self.strategy0, strategy1=self.strategy1)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net.set_auto_parallel()
out = net(x, y, parallel_inputs_compile=[x, y], parallel_inputs_run=[x1, y1]) out = net(x, y, parallel_inputs_compile=[x, y], parallel_inputs_run=[x1, y1])
return out.asnumpy() return out.asnumpy()
@ -147,6 +148,7 @@ class ReshapeFactory:
net = Reshape(self.target_shape, strategy0=self.strategy0, strategy1=self.strategy1) net = Reshape(self.target_shape, strategy0=self.strategy0, strategy1=self.strategy1)
grad_net = Grad(net) grad_net = Grad(net)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
grad_net.set_auto_parallel()
grad_net.set_train() grad_net.set_train()
input_grad = grad_net(x, y, output_grad, parallel_inputs_compile=[x, y, output_grad1], parallel_inputs_run=[x1, y1, output_grad1]) input_grad = grad_net(x, y, output_grad, parallel_inputs_compile=[x, y, output_grad1], parallel_inputs_run=[x1, y1, output_grad1])
return input_grad return input_grad

@ -148,6 +148,7 @@ class TransposeFactory:
y1 = Tensor(inputs_y[self.y_id]) y1 = Tensor(inputs_y[self.y_id])
net = Net(self.perm_in, strategy0=self.strategy0, strategy1=self.strategy1) net = Net(self.perm_in, strategy0=self.strategy0, strategy1=self.strategy1)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
net.set_auto_parallel()
out = net(x, y, parallel_inputs_compile=[x, y], parallel_inputs_run=[x1, y1]) out = net(x, y, parallel_inputs_compile=[x, y], parallel_inputs_run=[x1, y1])
return out.asnumpy() return out.asnumpy()
@ -174,6 +175,7 @@ class TransposeFactory:
net = Net(self.perm_in, strategy0=self.strategy0, strategy1=self.strategy1) net = Net(self.perm_in, strategy0=self.strategy0, strategy1=self.strategy1)
grad_net = Grad(net) grad_net = Grad(net)
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
grad_net.set_auto_parallel()
grad_net.set_train() grad_net.set_train()
input_grad = grad_net(x, y, output_grad, parallel_inputs_compile=[x, y, output_grad1], parallel_inputs_run=[x1, y1, output_grad1]) input_grad = grad_net(x, y, output_grad, parallel_inputs_compile=[x, y, output_grad1], parallel_inputs_run=[x1, y1, output_grad1])
return input_grad return input_grad

@ -49,6 +49,12 @@ class Grad(nn.Cell):
def construct(self, x, y): def construct(self, x, y):
return C.grad_all(self.network)(x, y) return C.grad_all(self.network)(x, y)
def compile(net, x, y):
net.set_auto_parallel()
_executor.compile(net, x, y)
def test_add_relu_stride_slice(): def test_add_relu_stride_slice():
context.set_auto_parallel_context(device_num=8, global_rank=7) context.set_auto_parallel_context(device_num=8, global_rank=7)
@ -59,7 +65,7 @@ def test_add_relu_stride_slice():
x = Tensor(np.ones([128, 32]), dtype=ms.float32) x = Tensor(np.ones([128, 32]), dtype=ms.float32)
y = Tensor(np.ones([128, 32]), dtype=ms.float32) y = Tensor(np.ones([128, 32]), dtype=ms.float32)
_executor.compile(net, x, y) compile(net, x, y)
def test_add_relu_all_gather(): def test_add_relu_all_gather():
context.set_auto_parallel_context(device_num=8, global_rank=7) context.set_auto_parallel_context(device_num=8, global_rank=7)
@ -71,4 +77,4 @@ def test_add_relu_all_gather():
x = Tensor(np.ones([128, 32]), dtype=ms.float32) x = Tensor(np.ones([128, 32]), dtype=ms.float32)
y = Tensor(np.ones([128, 32]), dtype=ms.float32) y = Tensor(np.ones([128, 32]), dtype=ms.float32)
_executor.compile(net, x, y) compile(net, x, y)

@ -42,6 +42,11 @@ class GradWrap(nn.Cell):
return C.grad_all(self.network)(x, y, b) return C.grad_all(self.network)(x, y, b)
def compile(net, x, y, b):
net.set_auto_parallel()
_executor.compile(net, x, y, b)
def test_matmul_sub(): def test_matmul_sub():
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self, strategy1, strategy2): def __init__(self, strategy1, strategy2):
@ -64,7 +69,7 @@ def test_matmul_sub():
x = Tensor(np.ones([64, 32]), dtype=ms.float32) x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32) y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 64]), dtype=ms.float32) b = Tensor(np.ones([64, 64]), dtype=ms.float32)
_executor.compile(net, x, y, b) compile(net, x, y, b)
def test_matmul_add(): def test_matmul_add():
@ -88,7 +93,7 @@ def test_matmul_add():
x = Tensor(np.ones([64, 32]), dtype=ms.float32) x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32) y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 64]), dtype=ms.float32) b = Tensor(np.ones([64, 64]), dtype=ms.float32)
_executor.compile(net, x, y, b) compile(net, x, y, b)
def test_matmul_mul(): def test_matmul_mul():
@ -112,7 +117,7 @@ def test_matmul_mul():
x = Tensor(np.ones([64, 32]), dtype=ms.float32) x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32) y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 64]), dtype=ms.float32) b = Tensor(np.ones([64, 64]), dtype=ms.float32)
_executor.compile(net, x, y, b) compile(net, x, y, b)
def test_matmul_div(): def test_matmul_div():
@ -136,7 +141,7 @@ def test_matmul_div():
x = Tensor(np.ones([64, 32]), dtype=ms.float32) x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32) y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 64]), dtype=ms.float32) b = Tensor(np.ones([64, 64]), dtype=ms.float32)
_executor.compile(net, x, y, b) compile(net, x, y, b)
def test_matmul_greater(): def test_matmul_greater():
class Net(nn.Cell): class Net(nn.Cell):
@ -159,7 +164,7 @@ def test_matmul_greater():
x = Tensor(np.ones([64, 32]), dtype=ms.float32) x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32) y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 64]), dtype=ms.float32) b = Tensor(np.ones([64, 64]), dtype=ms.float32)
_executor.compile(net, x, y, b) compile(net, x, y, b)
def test_matmul_add_broadcast(): def test_matmul_add_broadcast():
class Net(nn.Cell): class Net(nn.Cell):
@ -182,7 +187,7 @@ def test_matmul_add_broadcast():
x = Tensor(np.ones([64, 32]), dtype=ms.float32) x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32) y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64]), dtype=ms.float32) b = Tensor(np.ones([64]), dtype=ms.float32)
_executor.compile(net, x, y, b) compile(net, x, y, b)
def test_matmul_add_broadcast2(): def test_matmul_add_broadcast2():
@ -206,7 +211,7 @@ def test_matmul_add_broadcast2():
x = Tensor(np.ones([64, 32]), dtype=ms.float32) x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 1]), dtype=ms.float32) y = Tensor(np.ones([32, 1]), dtype=ms.float32)
b = Tensor(np.ones([1, 64]), dtype=ms.float32) b = Tensor(np.ones([1, 64]), dtype=ms.float32)
_executor.compile(net, x, y, b) compile(net, x, y, b)
def test_matmul_sub_broadcast(): def test_matmul_sub_broadcast():
@ -230,7 +235,7 @@ def test_matmul_sub_broadcast():
x = Tensor(np.ones([64, 32]), dtype=ms.float32) x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32) y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64]), dtype=ms.float32) b = Tensor(np.ones([64]), dtype=ms.float32)
_executor.compile(net, x, y, b) compile(net, x, y, b)
def test_matmul_sub_broadcast2(): def test_matmul_sub_broadcast2():
@ -254,7 +259,7 @@ def test_matmul_sub_broadcast2():
x = Tensor(np.ones([64, 32]), dtype=ms.float32) x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 1]), dtype=ms.float32) y = Tensor(np.ones([32, 1]), dtype=ms.float32)
b = Tensor(np.ones([1, 64]), dtype=ms.float32) b = Tensor(np.ones([1, 64]), dtype=ms.float32)
_executor.compile(net, x, y, b) compile(net, x, y, b)
def test_matmul_mul_broadcast(): def test_matmul_mul_broadcast():
@ -278,7 +283,7 @@ def test_matmul_mul_broadcast():
x = Tensor(np.ones([64, 32]), dtype=ms.float32) x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32) y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64]), dtype=ms.float32) b = Tensor(np.ones([64]), dtype=ms.float32)
_executor.compile(net, x, y, b) compile(net, x, y, b)
def test_matmul_mul_broadcast2(): def test_matmul_mul_broadcast2():
@ -302,7 +307,7 @@ def test_matmul_mul_broadcast2():
x = Tensor(np.ones([64, 32]), dtype=ms.float32) x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 1]), dtype=ms.float32) y = Tensor(np.ones([32, 1]), dtype=ms.float32)
b = Tensor(np.ones([1, 64]), dtype=ms.float32) b = Tensor(np.ones([1, 64]), dtype=ms.float32)
_executor.compile(net, x, y, b) compile(net, x, y, b)
def test_matmul_div_broadcast(): def test_matmul_div_broadcast():
@ -326,7 +331,7 @@ def test_matmul_div_broadcast():
x = Tensor(np.ones([64, 32]), dtype=ms.float32) x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32) y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64]), dtype=ms.float32) b = Tensor(np.ones([64]), dtype=ms.float32)
_executor.compile(net, x, y, b) compile(net, x, y, b)
def test_matmul_div_broadcast2(): def test_matmul_div_broadcast2():
@ -350,7 +355,7 @@ def test_matmul_div_broadcast2():
x = Tensor(np.ones([64, 32]), dtype=ms.float32) x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 1]), dtype=ms.float32) y = Tensor(np.ones([32, 1]), dtype=ms.float32)
b = Tensor(np.ones([1, 64]), dtype=ms.float32) b = Tensor(np.ones([1, 64]), dtype=ms.float32)
_executor.compile(net, x, y, b) compile(net, x, y, b)
def test_matmul_greater_broadcast(): def test_matmul_greater_broadcast():
class Net(nn.Cell): class Net(nn.Cell):
@ -373,7 +378,7 @@ def test_matmul_greater_broadcast():
x = Tensor(np.ones([64, 32]), dtype=ms.float32) x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32) y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64]), dtype=ms.float32) b = Tensor(np.ones([64]), dtype=ms.float32)
_executor.compile(net, x, y, b) compile(net, x, y, b)
def test_matmul_greater_broadcast2(): def test_matmul_greater_broadcast2():
@ -397,7 +402,7 @@ def test_matmul_greater_broadcast2():
x = Tensor(np.ones([64, 32]), dtype=ms.float32) x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 1]), dtype=ms.float32) y = Tensor(np.ones([32, 1]), dtype=ms.float32)
b = Tensor(np.ones([1, 64]), dtype=ms.float32) b = Tensor(np.ones([1, 64]), dtype=ms.float32)
_executor.compile(net, x, y, b) compile(net, x, y, b)
def test_matmul_floordiv(): def test_matmul_floordiv():
class Net(nn.Cell): class Net(nn.Cell):
@ -420,7 +425,7 @@ def test_matmul_floordiv():
x = Tensor(np.ones([64, 32]), dtype=ms.float32) x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32) y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64, 64]), dtype=ms.float32) b = Tensor(np.ones([64, 64]), dtype=ms.float32)
_executor.compile(net, x, y, b) compile(net, x, y, b)
def test_matmul_floordiv_broadcast(): def test_matmul_floordiv_broadcast():
@ -444,7 +449,7 @@ def test_matmul_floordiv_broadcast():
x = Tensor(np.ones([64, 32]), dtype=ms.float32) x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 64]), dtype=ms.float32) y = Tensor(np.ones([32, 64]), dtype=ms.float32)
b = Tensor(np.ones([64]), dtype=ms.float32) b = Tensor(np.ones([64]), dtype=ms.float32)
_executor.compile(net, x, y, b) compile(net, x, y, b)
def test_matmul_floordiv_broadcast2(): def test_matmul_floordiv_broadcast2():
@ -468,7 +473,7 @@ def test_matmul_floordiv_broadcast2():
x = Tensor(np.ones([64, 32]), dtype=ms.float32) x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 1]), dtype=ms.float32) y = Tensor(np.ones([32, 1]), dtype=ms.float32)
b = Tensor(np.ones([1, 64]), dtype=ms.float32) b = Tensor(np.ones([1, 64]), dtype=ms.float32)
_executor.compile(net, x, y, b) compile(net, x, y, b)
def test_assign_sub(): def test_assign_sub():
@ -495,4 +500,4 @@ def test_assign_sub():
x = Tensor(np.ones([128, 32]), dtype=ms.float32) x = Tensor(np.ones([128, 32]), dtype=ms.float32)
y = Tensor(np.ones([128, 32]), dtype=ms.float32) y = Tensor(np.ones([128, 32]), dtype=ms.float32)
z = Tensor(np.ones([128, 32]), dtype=ms.float32) z = Tensor(np.ones([128, 32]), dtype=ms.float32)
_executor.compile(net, x, y, z) compile(net, x, y, z)

@ -66,4 +66,5 @@ def test_auto_parallel_bn_with_prelu():
net = GradWrap(NetWithLoss(Net())) net = GradWrap(NetWithLoss(Net()))
context.set_auto_parallel_context(parallel_mode="auto_parallel") context.set_auto_parallel_context(parallel_mode="auto_parallel")
net.set_auto_parallel()
_executor.compile(net, x) _executor.compile(net, x)

@ -43,6 +43,12 @@ class GradWrap(nn.Cell):
def construct(self, x, y, b): def construct(self, x, y, b):
return C.grad_all(self.network)(x, y, b) return C.grad_all(self.network)(x, y, b)
def compile(net, x, y, b, phase):
net.set_auto_parallel()
_executor.compile(net, x, y, b, phase=phase)
def test_auto_parallel_arithmetic(): def test_auto_parallel_arithmetic():
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self): def __init__(self):
@ -63,7 +69,7 @@ def test_auto_parallel_arithmetic():
x = Tensor(np.ones([64, 32]), dtype=ms.float32) x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 128]), dtype=ms.float32) y = Tensor(np.ones([32, 128]), dtype=ms.float32)
b = Tensor(np.ones([64, 128]), dtype=ms.float32) b = Tensor(np.ones([64, 128]), dtype=ms.float32)
_executor.compile(net, x, y, b, phase='train') compile(net, x, y, b, phase='train')
strategies = _executor._get_strategy(net) strategies = _executor._get_strategy(net)
expected_strategies = {'Default/network-Net/FloorDiv-op0': [[2, 4], [2, 4]], expected_strategies = {'Default/network-Net/FloorDiv-op0': [[2, 4], [2, 4]],
'Default/network-Net/MatMul-op1': [[2, 1], [1, 4]]} 'Default/network-Net/MatMul-op1': [[2, 1], [1, 4]]}
@ -89,7 +95,7 @@ def test_auto_parallel_arithmetic_broadcast_both():
x = Tensor(np.ones([64, 32]), dtype=ms.float32) x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 1]), dtype=ms.float32) y = Tensor(np.ones([32, 1]), dtype=ms.float32)
b = Tensor(np.ones([1, 64]), dtype=ms.float32) b = Tensor(np.ones([1, 64]), dtype=ms.float32)
_executor.compile(net, x, y, b, phase='train') compile(net, x, y, b, phase='train')
strategies = _executor._get_strategy(net) strategies = _executor._get_strategy(net)
expected_strategies = {'Default/network-Net/FloorDiv-op0': [[8, 1], [1, 1]], expected_strategies = {'Default/network-Net/FloorDiv-op0': [[8, 1], [1, 1]],
'Default/network-Net/MatMul-op1': [[8, 1], [1, 1]]} 'Default/network-Net/MatMul-op1': [[8, 1], [1, 1]]}
@ -116,7 +122,7 @@ def test_auto_parallel_arithmetic_broadcast_right():
x = Tensor(np.ones([64, 32]), dtype=ms.float32) x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 32]), dtype=ms.float32) y = Tensor(np.ones([32, 32]), dtype=ms.float32)
b = Tensor(np.ones([32]), dtype=ms.float32) b = Tensor(np.ones([32]), dtype=ms.float32)
_executor.compile(net, x, y, b, phase='train') compile(net, x, y, b, phase='train')
strategies = _executor._get_strategy(net) strategies = _executor._get_strategy(net)
expected_strategies = {'Default/network-Net/FloorDiv-op0': [[4, 2], [2]], expected_strategies = {'Default/network-Net/FloorDiv-op0': [[4, 2], [2]],
'Default/network-Net/MatMul-op1': [[4, 1], [1, 2]]} 'Default/network-Net/MatMul-op1': [[4, 1], [1, 2]]}
@ -143,7 +149,7 @@ def test_auto_parallel_arithmetic_broadcast_left():
x = Tensor(np.ones([64, 32]), dtype=ms.float32) x = Tensor(np.ones([64, 32]), dtype=ms.float32)
y = Tensor(np.ones([32, 32]), dtype=ms.float32) y = Tensor(np.ones([32, 32]), dtype=ms.float32)
b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
_executor.compile(net, x, y, b, phase="train") compile(net, x, y, b, phase="train")
strategies = _executor._get_strategy(net) strategies = _executor._get_strategy(net)
expected_strategies = {'Default/network-Net/FloorDiv-op0': [[4, 2], [1, 4, 2]], expected_strategies = {'Default/network-Net/FloorDiv-op0': [[4, 2], [1, 4, 2]],
'Default/network-Net/MatMul-op1': [[4, 1], [1, 2]]} 'Default/network-Net/MatMul-op1': [[4, 1], [1, 2]]}

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save