From d9ecfb1858fe20736f5b460992bb4738e1bd8271 Mon Sep 17 00:00:00 2001 From: caojian05 Date: Thu, 2 Jul 2020 23:18:01 +0800 Subject: [PATCH] support multi server muli process --- model_zoo/googlenet/scripts/run_train.sh | 4 +++- model_zoo/googlenet/src/dataset.py | 19 +++++++++++++++++-- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/model_zoo/googlenet/scripts/run_train.sh b/model_zoo/googlenet/scripts/run_train.sh index c21c2f04b6..e8c045c8b1 100644 --- a/model_zoo/googlenet/scripts/run_train.sh +++ b/model_zoo/googlenet/scripts/run_train.sh @@ -33,10 +33,12 @@ MINDSPORE_HCCL_CONFIG_PATH=$(realpath $1) export MINDSPORE_HCCL_CONFIG_PATH echo "MINDSPORE_HCCL_CONFIG_PATH=${MINDSPORE_HCCL_CONFIG_PATH}" +export SERVER_ID=0 +rank_start=$((DEVICE_NUM * SERVER_ID)) for((i=0; i<${DEVICE_NUM}; i++)) do export DEVICE_ID=$i - export RANK_ID=$i + export RANK_ID=$((rank_start + i)) rm -rf ./train_parallel$i mkdir ./train_parallel$i cp -r ./src ./train_parallel$i diff --git a/model_zoo/googlenet/src/dataset.py b/model_zoo/googlenet/src/dataset.py index a1cbc2cdab..a3f74a0617 100644 --- a/model_zoo/googlenet/src/dataset.py +++ b/model_zoo/googlenet/src/dataset.py @@ -31,8 +31,7 @@ def create_dataset(data_home, repeat_num=1, training=True): if not training: data_dir = os.path.join(data_home, "cifar-10-verify-bin") - rank_size = int(os.environ.get("RANK_SIZE")) if os.environ.get("RANK_SIZE") else None - rank_id = int(os.environ.get("RANK_ID")) if os.environ.get("RANK_ID") else None + rank_size, rank_id = _get_rank_info() data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id) resize_height = cfg.image_height @@ -65,3 +64,19 @@ def create_dataset(data_home, repeat_num=1, training=True): data_set = data_set.batch(batch_size=cfg.batch_size, drop_remainder=True) return data_set + + +def _get_rank_info(): + """ + get rank size and rank id + """ + rank_size = int(os.environ.get("RANK_SIZE", 1)) + + if rank_size > 1: + from mindspore.communication.management import get_rank, get_group_size + rank_size = get_group_size() + rank_id = get_rank() + else: + rank_size = rank_id = None + + return rank_size, rank_id