From 7d5dc4735c6ec6832aa04623e04515c0de72eb16 Mon Sep 17 00:00:00 2001 From: zhangrunjiao Date: Mon, 14 Dec 2020 15:12:03 +0800 Subject: [PATCH] add auto parallel recursive use case --- .../auto_parallel/parallel_strategy_search.py | 32 +++++++++++- .../run_parallel_recursive_strategy_search.sh | 52 +++++++++++++++++++ .../run_parallel_strategy_search.sh | 2 +- ...test_parallel_recursive_strategy_search.py | 29 +++++++++++ .../test_parallel_strategy_search.py | 2 +- 5 files changed, 113 insertions(+), 4 deletions(-) create mode 100644 tests/st/auto_parallel/run_parallel_recursive_strategy_search.sh create mode 100644 tests/st/auto_parallel/test_parallel_recursive_strategy_search.py diff --git a/tests/st/auto_parallel/parallel_strategy_search.py b/tests/st/auto_parallel/parallel_strategy_search.py index 0631df0cef..057c45b61f 100644 --- a/tests/st/auto_parallel/parallel_strategy_search.py +++ b/tests/st/auto_parallel/parallel_strategy_search.py @@ -295,12 +295,13 @@ class ParallelStrategySearchFactory: newest_ckpt_file = find_newest_ckpt_file(ckpt_path) return load_checkpoint(newest_ckpt_file) - def mindspore_auto_parallel_impl(self, dataset, epoch, device_num): + def mindspore_auto_parallel_impl(self, dataset, epoch, device_num, auto_parallel_search_mode="dynamic_programming"): parallel_mode_net = self.parallel_mode_net set_algo_parameters(fully_use_devices=False) context.reset_auto_parallel_context() context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, - device_num=device_num) + device_num=device_num, + auto_parallel_search_mode=auto_parallel_search_mode) self.parallel_ckpt = self._model_train_and_save_ckpt(net=parallel_mode_net, dataset=dataset, epoch=epoch) context.reset_auto_parallel_context() @@ -352,3 +353,30 @@ def test_auto_parallel_strategy_search_axis_1_basic(): fact.mindspore_auto_parallel_impl(dataset=parallel_dataset, epoch=2, device_num=8) fact.checkpoint_cmp(inputs_np=inputs_np) + + +def test_auto_parallel_recursive_strategy_search_axis_1_basic(): + inputs_np = np.random.randn(32, 3, 224, 224).astype(np.float32) + standalone_mode_net = ParallelStrategySearchNet(in_channel=3, + out_channel=8, axis=1, input_shape=(32, 4, 110, -1), + mul_size=(32, 1, 220, 220), test_size=(32, 4, 110, 880), + prelu_size=(1,), transpose_b=True, matmul_size=(1, 12), + num_class=12) + context.reset_auto_parallel_context() + context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL) + parallel_mode_net = ParallelStrategySearchNet(in_channel=3, + out_channel=8, axis=1, input_shape=(32, 4, 110, -1), + mul_size=(32, 1, 220, 220), test_size=(32, 4, 110, 880), + prelu_size=(1,), transpose_b=True, matmul_size=(1, 12), + num_class=12) + standalone_dataset = FakeData(size=128, batch_size=32, + image_size=(3, 224, 224), num_classes=12) + fact = ParallelStrategySearchFactory(standalone_mode_net=standalone_mode_net, + parallel_mode_net=parallel_mode_net) + fact.mindspore_standalone_impl(dataset=standalone_dataset, epoch=2) + parallel_dataset = FakeData(size=128, batch_size=4, + image_size=(3, 224, 224), use_parallel=True, + num_classes=12) + fact.mindspore_auto_parallel_impl(dataset=parallel_dataset, + epoch=2, device_num=8, auto_parallel_search_mode="recursive_programming") + fact.checkpoint_cmp(inputs_np=inputs_np) diff --git a/tests/st/auto_parallel/run_parallel_recursive_strategy_search.sh b/tests/st/auto_parallel/run_parallel_recursive_strategy_search.sh new file mode 100644 index 0000000000..c7d45612c9 --- /dev/null +++ b/tests/st/auto_parallel/run_parallel_recursive_strategy_search.sh @@ -0,0 +1,52 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +set -e +BASE_PATH=$(cd "$(dirname $0)"; pwd) +CONFIG_PATH=/home/workspace/mindspore_config +export DEVICE_NUM=8 +export RANK_SIZE=$DEVICE_NUM +source ${BASE_PATH}/env.sh +unset SLOG_PRINT_TO_STDOUT +export MINDSPORE_HCCL_CONFIG_PATH=$CONFIG_PATH/hccl/rank_table_${DEVICE_NUM}p.json +export LD_LIBRARY_PATH=/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe/op_tiling:${LD_LIBRARY_PATH} +export ASCEND_OPP_PATH=/usr/local/Ascend/opp/ + +process_pid=() +for((i=0; i<$DEVICE_NUM; i++)); do + rm -rf ${BASE_PATH}/parallel_recursive_strategy_search${i} + mkdir ${BASE_PATH}/parallel_recursive_strategy_search${i} + cp -r ${BASE_PATH}/parallel_strategy_search.py ${BASE_PATH}/parallel_recursive_strategy_search${i}/ + cd ${BASE_PATH}/parallel_recursive_strategy_search${i} + export RANK_ID=${i} + export DEVICE_ID=${i} + echo "start training for device $i" + env > env$i.log + pytest -s -v parallel_strategy_search.py::test_auto_parallel_recursive_strategy_search_axis_1_basic > parallel_recursive_strategy_search$i.log 2>&1 & + process_pid[${i}]=`echo $!` +done + +for((i=0; i<${DEVICE_NUM}; i++)); do + wait ${process_pid[i]} + status=`echo $?` + if [ "${status}" != "0" ]; then + echo "[ERROR] test_parallel_recursive_strategy_search failed. status: ${status}" + exit 1 + else + echo "[INFO] test_parallel_recursive_strategy_search success." + fi +done + +exit 0 diff --git a/tests/st/auto_parallel/run_parallel_strategy_search.sh b/tests/st/auto_parallel/run_parallel_strategy_search.sh index 38a4814ca0..2593e7e68c 100644 --- a/tests/st/auto_parallel/run_parallel_strategy_search.sh +++ b/tests/st/auto_parallel/run_parallel_strategy_search.sh @@ -34,7 +34,7 @@ for((i=0; i<$DEVICE_NUM; i++)); do export DEVICE_ID=${i} echo "start training for device $i" env > env$i.log - pytest -s -v parallel_strategy_search.py > parallel_strategy_search$i.log 2>&1 & + pytest -s -v parallel_strategy_search.py::test_auto_parallel_strategy_search_axis_1_basic > parallel_strategy_search$i.log 2>&1 & process_pid[${i}]=`echo $!` done diff --git a/tests/st/auto_parallel/test_parallel_recursive_strategy_search.py b/tests/st/auto_parallel/test_parallel_recursive_strategy_search.py new file mode 100644 index 0000000000..9f4fa70af4 --- /dev/null +++ b/tests/st/auto_parallel/test_parallel_recursive_strategy_search.py @@ -0,0 +1,29 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +import pytest + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_single +def test_sit_parallel_recursive_strategy_search(): + sh_path = os.path.split(os.path.realpath(__file__))[0] + ret = os.system(f"sh {sh_path}/run_parallel_recursive_strategy_search.sh") + os.system( + f"grep -E 'ERROR|error' " + f"{sh_path}/parallel_recursive_strategy_search*/parallel_recursive_strategy_search*log -C 3") + assert ret == 0 diff --git a/tests/st/auto_parallel/test_parallel_strategy_search.py b/tests/st/auto_parallel/test_parallel_strategy_search.py index d46a39bd79..d543830ffc 100644 --- a/tests/st/auto_parallel/test_parallel_strategy_search.py +++ b/tests/st/auto_parallel/test_parallel_strategy_search.py @@ -20,7 +20,7 @@ import pytest @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @pytest.mark.env_single -def test_parallel_strategy_search(): +def test_sit_parallel_strategy_search(): sh_path = os.path.split(os.path.realpath(__file__))[0] ret = os.system(f"sh {sh_path}/run_parallel_strategy_search.sh") os.system(f"grep -E 'ERROR|error' {sh_path}/parallel_strategy_search*/parallel_strategy_search*log -C 3")