!13558 fix some bugs in resnet, ssd and naml

From: @zhao_ting_v
Reviewed-by: 
Signed-off-by:
pull/13558/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit b8b833b762

@ -284,6 +284,8 @@ Please follow the instructions in the link [hccn_tools](https://gitee.com/mindsp
Training result will be stored in the example path, whose folder name begins with "train" or "train_parallel". Under this, you can find checkpoint file together with result like the following in log.
If you want to change device_id for standalone training, you can set environment variable `export DEVICE_ID=x` or set `device_id=x` in context.
#### Running on GPU
```bash

@ -268,6 +268,8 @@ bash run_eval_gpu.sh [resnet50|resnet101] [cifar10|imagenet2012] [DATASET_PATH]
训练结果保存在示例路径中文件夹名称以“train”或“train_parallel”开头。您可在此路径下的日志中找到检查点文件以及结果如下所示。
运行单卡用例时如果想更换运行卡号,可以通过设置环境变量 `export DEVICE_ID=x` 或者在context中设置 `device_id=x`指定相应的卡号。
#### GPU处理器环境运行
```text

@ -73,7 +73,6 @@ fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1

@ -391,7 +391,7 @@ def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="ssd.
def create_ssd_dataset(mindrecord_file, batch_size=32, repeat_num=10, device_num=1, rank=0,
is_training=True, num_parallel_workers=4, use_multiprocessing=True):
is_training=True, num_parallel_workers=6, use_multiprocessing=True):
"""Create SSD dataset with MindDataset."""
ds = de.MindDataset(mindrecord_file, columns_list=["img_id", "image", "annotation"], num_shards=device_num,
shard_id=rank, num_parallel_workers=num_parallel_workers, shuffle=is_training)

@ -99,7 +99,7 @@ def get_args(phase):
args.n_sub_categories = cfg.n_sub_categories
args.n_words = cfg.n_words
if phase == "train":
args.epochs = cfg.epochs if args.epochs is None else args.epochs * math.ceil(args.device_num ** 0.5)
args.epochs = cfg.epochs * math.ceil(args.device_num ** 0.5) if args.epochs is None else args.epochs
args.lr = cfg.lr if args.lr is None else args.lr
args.print_times = cfg.print_times if args.print_times is None else args.print_times
args.embedding_file = cfg.embedding_file.format(args.dataset_path)

Loading…
Cancel
Save