diff --git a/mindspore/common/api.py b/mindspore/common/api.py index 0c14d00c80..5e8f830a3b 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -337,7 +337,6 @@ class _Executor: self.is_init = False self._executor = Executor_.get_instance() self.compile_cache = {} - self.phase_prefix = "" def init_dataset(self, queue_name, dataset_size, batch_size, dataset_types, dataset_shapes, input_indexs, phase='dataset'): @@ -383,7 +382,12 @@ class _Executor: """Build broadcast graph.""" from mindspore.nn.wrap.cell_wrapper import _BroadCastCell - _broadcast_net = _BroadCastCell(broadcast_params_dict.values()) + if not broadcast_params_dict: + broadcast_params_dict = {} + broadcast_params = [] + for param in broadcast_params_dict.values(): + broadcast_params.append(Tensor(param.asnumpy())) + _broadcast_net = _BroadCastCell(broadcast_params) _broadcast_net.phase = broadcast_phase broadcasted_params = _broadcast_net() for param_name, param in zip(broadcast_params_dict.keys(), broadcasted_params): @@ -440,11 +444,11 @@ class _Executor: if not hasattr(obj, "inputs_to_attr"): dic = dict(zip(args_names, args_list)) key = generate_key(phase, dic) - self.phase_prefix = str(key[1]) + obj.phase_prefix = str(key[1]) if 'export' in phase: - phase = phase + '.' + self.phase_prefix + '.' + str(obj.create_time) + phase = phase + '.' + obj.phase_prefix + '.' + str(obj.create_time) else: - phase = self.phase_prefix + phase + '.' + str(obj.create_time) + phase = obj.phase_prefix + phase + '.' + str(obj.create_time) if phase in self.compile_cache.keys(): logger.debug("%r graph has existed.", phase) @@ -518,9 +522,8 @@ class _Executor: for param_name, param in obj.parameters_broadcast_dict().items(): if param_name not in auto_split_param_names: broadcast_params_dict[param_name] = param - broadcast_phase = "_broadcast_subgraph" + "." + str(obj.create_time) + broadcast_phase = "_broadcast_subgraph" self._build_broadcast_graph(broadcast_params_dict, broadcast_phase) - self.compile_cache[phase] = broadcast_phase return phase, True @@ -529,15 +532,15 @@ class _Executor: return self._executor.updata_param_node_default_input(phase, new_param) def _get_shard_strategy(self, obj): - real_phase = self.phase_prefix + obj.phase + '.' + str(obj.create_time) + real_phase = obj.phase_prefix + obj.phase + '.' + str(obj.create_time) return self._executor.get_strategy(real_phase) def _get_num_parallel_ops(self, obj): - real_phase = self.phase_prefix + obj.phase + '.' + str(obj.create_time) + real_phase = obj.phase_prefix + obj.phase + '.' + str(obj.create_time) return self._executor.get_num_parallel_ops(real_phase) def _get_allreduce_fusion(self, obj): - real_phase = self.phase_prefix + obj.phase + '.' + str(obj.create_time) + real_phase = obj.phase_prefix + obj.phase + '.' + str(obj.create_time) return self._executor.get_allreduce_fusion(real_phase) def has_compiled(self, phase='predict'): @@ -581,7 +584,7 @@ class _Executor: if phase == 'save': return self._executor((), phase + '.' + str(obj.create_time)) - phase_real = self.phase_prefix + phase + '.' + str(obj.create_time) + phase_real = obj.phase_prefix + phase + '.' + str(obj.create_time) if self.has_compiled(phase_real): return self._exec_pip(obj, *args, phase=phase_real) raise KeyError('{} graph is not exist.'.format(phase_real)) @@ -589,10 +592,10 @@ class _Executor: def del_net_res(self, net_id): self._executor.del_net_res(net_id) - def _get_func_graph_proto(self, exec_id, ir_type="onnx_ir", use_prefix=False): + def _get_func_graph_proto(self, obj, exec_id, ir_type="onnx_ir", use_prefix=False): """Get graph proto from pipeline.""" if use_prefix: - exec_id = self.phase_prefix + exec_id + exec_id = obj.phase_prefix + exec_id if self._executor.has_compiled(exec_id) is False: return None return self._executor.get_func_graph_proto(exec_id, ir_type) diff --git a/mindspore/context.py b/mindspore/context.py index c79a33b05f..d28fa91983 100644 --- a/mindspore/context.py +++ b/mindspore/context.py @@ -570,8 +570,8 @@ def set_context(**kwargs): >>> context.set_context(reserve_class_name_in_scope=True) >>> context.set_context(variable_memory_max_size="6GB") >>> context.set_context(mode=context.GRAPH_MODE, - >>> device_target="Ascend",device_id=0, save_graphs=True, - >>> save_graphs_path="/mindspore") + ... device_target="Ascend",device_id=0, save_graphs=True, + ... save_graphs_path="/mindspore") >>> context.set_context(enable_profiling=True, profiling_options="training_trace") >>> context.set_context(max_device_memory="3.5GB") >>> context.set_context(print_file_path="print.pb") diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 9b359a5a30..0c06cd2886 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -87,6 +87,7 @@ class Cell(Cell_): self._phase = 'train' self._parameter_layout_dict = {} self._create_time = int(time.time() * 1e9) + self.phase_prefix = "" init_backend() # call gc to release GE session resources used by non-used cell objects @@ -237,7 +238,7 @@ class Cell(Cell_): def get_func_graph_proto(self): """Return graph binary proto.""" - return _executor._get_func_graph_proto(self.phase + "." + str(self.create_time), "anf_ir", True) + return _executor._get_func_graph_proto(self, self.phase + "." + str(self.create_time), "anf_ir", True) def __getattr__(self, name): if '_params' in self.__dict__: diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index e218fbabb9..24c8ad3818 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -556,7 +556,7 @@ def _export(net, file_name, file_format, *inputs): elif file_format == 'ONNX': # file_format is 'ONNX' phase_name = 'export.onnx' graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False) - onnx_stream = _executor._get_func_graph_proto(graph_id) + onnx_stream = _executor._get_func_graph_proto(net, graph_id) file_name += ".onnx" with open(file_name, 'wb') as f: os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) @@ -564,7 +564,7 @@ def _export(net, file_name, file_format, *inputs): elif file_format == 'MINDIR': # file_format is 'MINDIR' phase_name = 'export.mindir' graph_id, _ = _executor.compile(net, *inputs, phase=phase_name, do_convert=False) - onnx_stream = _executor._get_func_graph_proto(graph_id, 'mind_ir') + onnx_stream = _executor._get_func_graph_proto(net, graph_id, 'mind_ir') file_name += ".mindir" with open(file_name, 'wb') as f: os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR) diff --git a/tests/st/broadcast/env.sh b/tests/st/broadcast/env.sh new file mode 100644 index 0000000000..7091c2ccca --- /dev/null +++ b/tests/st/broadcast/env.sh @@ -0,0 +1,22 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +LOCAL_HIAI=/usr/local/HiAI +export TBE_IMPL_PATH=${LOCAL_HIAI}/runtime/ops/op_impl/built-in/ai_core/tbe/impl/ +export LD_LIBRARY_PATH=${LOCAL_HIAI}/runtime/lib64/:${LD_LIBRARY_PATH} +export PATH=${LOCAL_HIAI}/runtime/ccec_compiler/bin/:${PATH} +export PYTHONPATH=${LOCAL_HIAI}/runtime/ops/op_impl/built-in/ai_core/tbe/:${PYTHONPATH} +export DEVICE_MEMORY_CAPACITY=1073741824000 +export NOT_FULLY_USE_DEVICES=off diff --git a/tests/st/broadcast/lenet_broadcast_auto_parallel.py b/tests/st/broadcast/lenet_broadcast_auto_parallel.py new file mode 100644 index 0000000000..a4a5e4ddf9 --- /dev/null +++ b/tests/st/broadcast/lenet_broadcast_auto_parallel.py @@ -0,0 +1,61 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import os + +import numpy as np + +import mindspore.communication.management as distributedTool +import mindspore.nn as nn +from mindspore import context +from mindspore.nn.metrics import Accuracy +from mindspore.train import Model +from mindspore.train.callback import LossMonitor, TimeMonitor +from model_zoo.official.cv.lenet.src.dataset import create_dataset +from model_zoo.official.cv.lenet.src.lenet import LeNet5 + +np.set_printoptions(threshold=np.inf) +device_num = 2 +device_id = int(os.getenv('DEVICE_ID')) +rank_id = 0 + + +def setup_module(): + global device_num + global rank_id + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + context.set_context(device_id=device_id) + distributedTool.init() + rank_id = distributedTool.get_rank() + device_num = distributedTool.get_group_size() + context.set_auto_parallel_context(device_num=device_num, global_rank=device_id, parameter_broadcast=True) + + +def teardown_module(): + distributedTool.release() + + +def test_all_trains(): + ds_train = create_dataset(os.path.join('/home/workspace/mindspore_dataset/mnist', "train"), 32, 1) + + network = LeNet5(10) + net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") + net_opt = nn.Momentum(network.trainable_params(), 0.01, 0.9) + time_cb = TimeMonitor(data_size=ds_train.get_dataset_size()) + + model = Model(network, net_loss, net_opt, metrics={"Accuracy": Accuracy()}) + + print("============== Starting Training ==============") + model.train(1, ds_train, callbacks=[time_cb, LossMonitor()]) diff --git a/tests/st/broadcast/run_broadcast_auto_parallel.sh b/tests/st/broadcast/run_broadcast_auto_parallel.sh new file mode 100644 index 0000000000..e3d877c335 --- /dev/null +++ b/tests/st/broadcast/run_broadcast_auto_parallel.sh @@ -0,0 +1,53 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +set -e +BASE_PATH=$( + cd "$(dirname $0)" + pwd +) +CONFIG_PATH=/home/workspace/mindspore_config +export DEVICE_NUM=8 +export RANK_SIZE=$DEVICE_NUM +source ${BASE_PATH}/env.sh +unset SLOG_PRINT_TO_STDOUT +export MINDSPORE_HCCL_CONFIG_PATH=$CONFIG_PATH/hccl/rank_table_${DEVICE_NUM}p.json + +process_pid=() +for ((i = 0; i < $DEVICE_NUM; i++)); do + rm -rf ${BASE_PATH}/lenet_broadcast${i} + mkdir ${BASE_PATH}/lenet_broadcast${i} + cp -r ${BASE_PATH}/lenet_broadcast_auto_parallel.py ${BASE_PATH}/lenet_broadcast${i}/ + cd ${BASE_PATH}/lenet_broadcast${i} + export RANK_ID=${i} + export DEVICE_ID=${i} + echo "start training for device $i" + env >env$i.log + pytest -s -v lenet_broadcast_auto_parallel.py >test_lenet_auto_parallel_broadcast_8p_log$i.log 2>&1 & + process_pid[${i}]=$(echo $!) +done + +for ((i = 0; i < ${DEVICE_NUM}; i++)); do + wait ${process_pid[i]} + status=$(echo $?) + if [ "${status}" != "0" ]; then + echo "[ERROR] test_broadcast_auto_parallel failed. status: ${status}" + exit 1 + else + echo "[INFO] test_broadcast_auto_parallel success." + fi +done + +exit 0 diff --git a/tests/st/broadcast/test_broadcast_auto_parallel.py b/tests/st/broadcast/test_broadcast_auto_parallel.py new file mode 100644 index 0000000000..576cde6d4c --- /dev/null +++ b/tests/st/broadcast/test_broadcast_auto_parallel.py @@ -0,0 +1,27 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os + +import pytest + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_single +def test_broadcast_auto_parallel(): + sh_path = os.path.split(os.path.realpath(__file__))[0] + ret = os.system(f"sh {sh_path}/run_broadcast_auto_parallel.sh") + assert ret == 0