diff --git a/model_zoo/official/cv/yolov3_darknet53/README.md b/model_zoo/official/cv/yolov3_darknet53/README.md index 7a0c2fab31..5d345eb670 100644 --- a/model_zoo/official/cv/yolov3_darknet53/README.md +++ b/model_zoo/official/cv/yolov3_darknet53/README.md @@ -53,8 +53,8 @@ Dataset used: [COCO2014](https://cocodataset.org/#download) # [Environment Requirements](#contents) -- Hardware(Ascend) - - Prepare hardware environment with Ascend processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. +- Hardware(Ascend/GPU) + - Prepare hardware environment with Ascend or GPU processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. - Framework - [MindSpore](http://10.90.67.50/mindspore/archive/20200506/OpenSource/me_vm_x86/) - For more information, please check the resources below: @@ -65,7 +65,7 @@ Dataset used: [COCO2014](https://cocodataset.org/#download) # [Quick Start](#contents) -After installing MindSpore via the official website, you can start training and evaluation in Ascend as follows: +After installing MindSpore via the official website, you can start training and evaluation in as follows. If running on GPU, please add `--device_target=GPU` in the python command or use the "_gpu" shell script ("xxx_gpu.sh"). ``` # The darknet53_backbone.ckpt in the follow script is got from darknet53 training like paper. @@ -87,9 +87,12 @@ python train.py \ # standalone training example(1p) by shell script sh run_standalone_train.sh dataset/coco2014 darknet53_backbone.ckpt -# distributed training example(8p) by shell script +# For Ascend device, distributed training example(8p) by shell script sh run_distribute_train.sh dataset/coco2014 darknet53_backbone.ckpt rank_table_8p.json +# For GPU device, distributed training example(8p) by shell script +sh run_distribute_train_gpu.sh dataset/coco2014 darknet53_backbone.ckpt + # run evaluation by python command python eval.py \ --data_dir=./dataset/coco2014 \ @@ -113,6 +116,9 @@ sh run_eval.sh dataset/coco2014/ checkpoint/0-319_102400.ckpt ├─run_standalone_train.sh # launch standalone training(1p) in ascend ├─run_distribute_train.sh # launch distributed training(8p) in ascend └─run_eval.sh # launch evaluating in ascend + ├─run_standalone_train_gpu.sh # launch standalone training(1p) in gpu + ├─run_distribute_train_gpu.sh # launch distributed training(8p) in gpu + └─run_eval_gpu.sh # launch evaluating in gpu ├─src ├─__init__.py # python init file ├─config.py # parameter configuration @@ -138,6 +144,7 @@ Major parameters in train.py as follow. optional arguments: -h, --help show this help message and exit + --device_target device where the code will be implemented: "Ascend" | "GPU", default is "Ascend" --data_dir DATA_DIR Train dataset directory. --per_batch_size PER_BATCH_SIZE Batch size for Training. Default: 32. @@ -212,7 +219,7 @@ python train.py \ --lr_scheduler=cosine_annealing > log.txt 2>&1 & ``` -The python command above will run in the background, you can view the results through the file `log.txt`. +The python command above will run in the background, you can view the results through the file `log.txt`. If running on GPU, please add `--device_target=GPU` in the python command. After training, you'll get some checkpoint files under the outputs folder by default. The loss value will be achieved as follows: @@ -228,9 +235,14 @@ The model checkpoint will be saved in outputs directory. ### Distributed Training +For Ascend device, distributed training example(8p) by shell script ``` sh run_distribute_train.sh dataset/coco2014 darknet53_backbone.ckpt rank_table_8p.json ``` +For GPU device, distributed training example(8p) by shell script +``` +sh run_distribute_train_gpu.sh dataset/coco2014 darknet53_backbone.ckpt +``` The above shell script will run distribute training in the background. You can view the results through the file `train_parallel[X]/log.txt`. The loss value will be achieved as follows: @@ -254,7 +266,7 @@ epoch[319], iter[102300], loss:35.430038, 423.49 imgs/sec, lr:2.409552052995423e ### Evaluation -Before running the command below. +Before running the command below. If running on GPU, please add `--device_target=GPU` in the python command or use the "_gpu" shell script ("xxx_gpu.sh"). ``` python eval.py \ diff --git a/model_zoo/official/cv/yolov3_darknet53/eval.py b/model_zoo/official/cv/yolov3_darknet53/eval.py index f04ed2447c..cb8a89ec65 100644 --- a/model_zoo/official/cv/yolov3_darknet53/eval.py +++ b/model_zoo/official/cv/yolov3_darknet53/eval.py @@ -35,9 +35,6 @@ from src.logger import get_logger from src.yolo_dataset import create_yolo_dataset from src.config import ConfigYOLOV3DarkNet53 -devid = int(os.getenv('DEVICE_ID')) -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=True, device_id=devid) - class Redirct: def __init__(self): @@ -208,6 +205,10 @@ def parse_args(): """Parse arguments.""" parser = argparse.ArgumentParser('mindspore coco testing') + # device related + parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], + help='device where the code will be implemented. (Default: Ascend)') + # dataset related parser.add_argument('--data_dir', type=str, default='', help='train data dir') parser.add_argument('--per_batch_size', default=1, type=int, help='batch size for per gpu') @@ -243,10 +244,13 @@ def test(): start_time = time.time() args = parse_args() + devid = int(os.getenv('DEVICE_ID')) if os.getenv('DEVICE_ID') else 0 + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=True, device_id=devid) + # logger args.outputs_dir = os.path.join(args.log_path, datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S')) - rank_id = int(os.environ.get('RANK_ID')) + rank_id = int(os.environ.get('RANK_ID')) if os.environ.get('RANK_ID') else 0 args.logger = get_logger(args.outputs_dir, rank_id) context.reset_auto_parallel_context() diff --git a/model_zoo/official/cv/yolov3_darknet53/scripts/run_distribute_train_gpu.sh b/model_zoo/official/cv/yolov3_darknet53/scripts/run_distribute_train_gpu.sh new file mode 100644 index 0000000000..b54070844c --- /dev/null +++ b/model_zoo/official/cv/yolov3_darknet53/scripts/run_distribute_train_gpu.sh @@ -0,0 +1,67 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 2 ] +then + echo "Usage: sh run_distribute_train_gpu.sh [DATASET_PATH] [PRETRAINED_BACKBONE]" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +DATASET_PATH=$(get_real_path $1) +PRETRAINED_BACKBONE=$(get_real_path $2) +echo $DATASET_PATH +echo $PRETRAINED_BACKBONE + +if [ ! -d $DATASET_PATH ] +then + echo "error: DATASET_PATH=$DATASET_PATH is not a directory" +exit 1 +fi + +if [ ! -f $PRETRAINED_BACKBONE ] +then + echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file" +exit 1 +fi + +export DEVICE_NUM=8 + +rm -rf ./train_parallel +mkdir ./train_parallel +cp ../*.py ./train_parallel +cp -r ../src ./train_parallel +cd ./train_parallel || exit +env > env.log +mpirun --allow-run-as-root -n ${DEVICE_NUM} python train.py \ + --data_dir=$DATASET_PATH \ + --pretrained_backbone=$PRETRAINED_BACKBONE \ + --device_target=GPU \ + --is_distributed=1 \ + --lr=0.1 \ + --T_max=320 \ + --max_epoch=320 \ + --warmup_epochs=4 \ + --training_shape=416 \ + --lr_scheduler=cosine_annealing > log.txt 2>&1 & +cd .. diff --git a/model_zoo/official/cv/yolov3_darknet53/scripts/run_eval_gpu.sh b/model_zoo/official/cv/yolov3_darknet53/scripts/run_eval_gpu.sh new file mode 100644 index 0000000000..ccf94a9302 --- /dev/null +++ b/model_zoo/official/cv/yolov3_darknet53/scripts/run_eval_gpu.sh @@ -0,0 +1,67 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 2 ] +then + echo "Usage: sh run_eval_gpu.sh [DATASET_PATH] [CHECKPOINT_PATH]" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} +DATASET_PATH=$(get_real_path $1) +CHECKPOINT_PATH=$(get_real_path $2) +echo $DATASET_PATH +echo $CHECKPOINT_PATH + +if [ ! -d $DATASET_PATH ] +then + echo "error: DATASET_PATH=$PATH1 is not a directory" +exit 1 +fi + +if [ ! -f $CHECKPOINT_PATH ] +then + echo "error: CHECKPOINT_PATH=$PATH2 is not a file" +exit 1 +fi + +export DEVICE_NUM=1 +export DEVICE_ID=0 +export RANK_SIZE=$DEVICE_NUM +export RANK_ID=0 + +if [ -d "eval" ]; +then + rm -rf ./eval +fi +mkdir ./eval +cp ../*.py ./eval +cp -r ../src ./eval +cd ./eval || exit +env > env.log +echo "start infering for device $DEVICE_ID" +python eval.py \ + --device_target="GPU" \ + --data_dir=$DATASET_PATH \ + --pretrained=$CHECKPOINT_PATH \ + --testing_shape=416 > log.txt 2>&1 & +cd .. diff --git a/model_zoo/official/cv/yolov3_darknet53/scripts/run_standalone_train_gpu.sh b/model_zoo/official/cv/yolov3_darknet53/scripts/run_standalone_train_gpu.sh new file mode 100644 index 0000000000..186d1446ae --- /dev/null +++ b/model_zoo/official/cv/yolov3_darknet53/scripts/run_standalone_train_gpu.sh @@ -0,0 +1,75 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 2 ] +then + echo "Usage: sh run_standalone_train_gpu.sh [DATASET_PATH] [PRETRAINED_BACKBONE]" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +DATASET_PATH=$(get_real_path $1) +echo $DATASET_PATH +PRETRAINED_BACKBONE=$(get_real_path $2) +echo $PRETRAINED_BACKBONE + +if [ ! -d $DATASET_PATH ] +then + echo "error: DATASET_PATH=$DATASET_PATH is not a directory" +exit 1 +fi + +if [ ! -f $PRETRAINED_BACKBONE ] +then + echo "error: PRETRAINED_PATH=$PRETRAINED_BACKBONE is not a file" +exit 1 +fi + +export DEVICE_NUM=1 +export DEVICE_ID=0 +export RANK_ID=0 +export RANK_SIZE=1 + +if [ -d "train" ]; +then + rm -rf ./train +fi +mkdir ./train +cp ../*.py ./train +cp -r ../src ./train +cd ./train || exit +echo "start training for device $DEVICE_ID" +env > env.log + +python train.py \ + --device_targe="GPU" \ + --data_dir=$DATASET_PATH \ + --pretrained_backbone=$PRETRAINED_BACKBONE \ + --is_distributed=0 \ + --lr=0.1 \ + --T_max=320 \ + --max_epoch=320 \ + --warmup_epochs=4 \ + --training_shape=416 \ + --lr_scheduler=cosine_annealing > log.txt 2>&1 & +cd .. \ No newline at end of file diff --git a/model_zoo/official/cv/yolov3_darknet53/src/transforms.py b/model_zoo/official/cv/yolov3_darknet53/src/transforms.py index 4756a141f0..da0b313cf7 100644 --- a/model_zoo/official/cv/yolov3_darknet53/src/transforms.py +++ b/model_zoo/official/cv/yolov3_darknet53/src/transforms.py @@ -465,6 +465,11 @@ class MultiScaleTrans: self.seed_list = self.generate_seed_list(seed_num=self.seed_num) self.resize_count_num = int(np.ceil(self.dataset_size / self.resize_rate)) self.device_num = device_num + self.anchor_scales = config.anchor_scales + self.num_classes = config.num_classes + self.max_box = config.max_box + self.label_smooth = config.label_smooth + self.label_smooth_factor = config.label_smooth_factor def generate_seed_list(self, init_seed=1234, seed_num=int(1e6), seed_range=(1, 1000)): seed_list = [] @@ -474,13 +479,20 @@ class MultiScaleTrans: seed_list.append(seed) return seed_list - def __call__(self, imgs, annos, batchInfo): + def __call__(self, imgs, annos, x1, x2, x3, x4, x5, x6, batchInfo): epoch_num = batchInfo.get_epoch_num() size_idx = int(batchInfo.get_batch_num() / self.resize_rate) seed_key = self.seed_list[(epoch_num * self.resize_count_num + size_idx) % self.seed_num] ret_imgs = [] ret_annos = [] + bbox1 = [] + bbox2 = [] + bbox3 = [] + gt1 = [] + gt2 = [] + gt3 = [] + if self.size_dict.get(seed_key, None) is None: random.seed(seed_key) new_size = random.choice(self.config.multi_scale) @@ -491,8 +503,19 @@ class MultiScaleTrans: for img, anno in zip(imgs, annos): img, anno = preprocess_fn(img, anno, self.config, input_size, self.device_num) ret_imgs.append(img.transpose(2, 0, 1).copy()) - ret_annos.append(anno) - return np.array(ret_imgs), np.array(ret_annos) + bbox_true_1, bbox_true_2, bbox_true_3, gt_box1, gt_box2, gt_box3 = \ + _preprocess_true_boxes(true_boxes=anno, anchors=self.anchor_scales, in_shape=img.shape[0:2], + num_classes=self.num_classes, max_boxes=self.max_box, + label_smooth=self.label_smooth, label_smooth_factor=self.label_smooth_factor) + bbox1.append(bbox_true_1) + bbox2.append(bbox_true_2) + bbox3.append(bbox_true_3) + gt1.append(gt_box1) + gt2.append(gt_box2) + gt3.append(gt_box3) + ret_annos.append(0) + return np.array(ret_imgs), np.array(ret_annos), np.array(bbox1), np.array(bbox2), np.array(bbox3), \ + np.array(gt1), np.array(gt2), np.array(gt3) def thread_batch_preprocess_true_box(annos, config, input_shape, result_index, batch_bbox_true_1, batch_bbox_true_2, diff --git a/model_zoo/official/cv/yolov3_darknet53/src/util.py b/model_zoo/official/cv/yolov3_darknet53/src/util.py index f97bdd2548..315662bf13 100644 --- a/model_zoo/official/cv/yolov3_darknet53/src/util.py +++ b/model_zoo/official/cv/yolov3_darknet53/src/util.py @@ -15,6 +15,9 @@ """Util class or function.""" from mindspore.train.serialization import load_checkpoint import mindspore.nn as nn +import mindspore.common.dtype as mstype + +from .yolo import YoloLossBlock class AverageMeter: @@ -175,3 +178,10 @@ class ShapeRecord: for key in self.shape_record: rate = self.shape_record[key] / float(self.shape_record['total']) logger.info('shape {}: {:.2f}%'.format(key, rate*100)) + + +def keep_loss_fp32(network): + """Keep loss of network with float32""" + for _, cell in network.cells_and_names(): + if isinstance(cell, (YoloLossBlock,)): + cell.to_float(mstype.float32) diff --git a/model_zoo/official/cv/yolov3_darknet53/src/yolo_dataset.py b/model_zoo/official/cv/yolov3_darknet53/src/yolo_dataset.py index 45657db823..8045c468f2 100644 --- a/model_zoo/official/cv/yolov3_darknet53/src/yolo_dataset.py +++ b/model_zoo/official/cv/yolov3_darknet53/src/yolo_dataset.py @@ -15,6 +15,7 @@ """YOLOV3 dataset.""" import os +import multiprocessing from PIL import Image from pycocotools.coco import COCO import mindspore.dataset as de @@ -126,7 +127,7 @@ class COCOYoloDataset: tmp.append(int(label)) # tmp [x_min y_min x_max y_max, label] out_target.append(tmp) - return img, out_target + return img, out_target, [], [], [], [], [], [] def __len__(self): return len(self.img_ids) @@ -155,20 +156,22 @@ def create_yolo_dataset(image_dir, anno_path, batch_size, max_epoch, device_num, hwc_to_chw = CV.HWC2CHW() config.dataset_size = len(yolo_dataset) - num_parallel_workers1 = int(64 / device_num) - num_parallel_workers2 = int(16 / device_num) + cores = multiprocessing.cpu_count() + num_parallel_workers = int(cores / device_num) if is_training: multi_scale_trans = MultiScaleTrans(config, device_num) + dataset_column_names = ["image", "annotation", "bbox1", "bbox2", "bbox3", + "gt_box1", "gt_box2", "gt_box3"] if device_num != 8: - ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "annotation"], - num_parallel_workers=num_parallel_workers1, + ds = de.GeneratorDataset(yolo_dataset, column_names=dataset_column_names, + num_parallel_workers=min(32, num_parallel_workers), sampler=distributed_sampler) - ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=['image', 'annotation'], - num_parallel_workers=num_parallel_workers2, drop_remainder=True) + ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=dataset_column_names, + num_parallel_workers=min(32, num_parallel_workers), drop_remainder=True) else: - ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "annotation"], sampler=distributed_sampler) - ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=['image', 'annotation'], - num_parallel_workers=8, drop_remainder=True) + ds = de.GeneratorDataset(yolo_dataset, column_names=dataset_column_names, sampler=distributed_sampler) + ds = ds.batch(batch_size, per_batch_map=multi_scale_trans, input_columns=dataset_column_names, + num_parallel_workers=min(8, num_parallel_workers), drop_remainder=True) else: ds = de.GeneratorDataset(yolo_dataset, column_names=["image", "img_id"], sampler=distributed_sampler) diff --git a/model_zoo/official/cv/yolov3_darknet53/train.py b/model_zoo/official/cv/yolov3_darknet53/train.py index 92ac9f5353..63d53f6d2a 100644 --- a/model_zoo/official/cv/yolov3_darknet53/train.py +++ b/model_zoo/official/cv/yolov3_darknet53/train.py @@ -28,6 +28,8 @@ from mindspore.train.callback import ModelCheckpoint, RunContext from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig import mindspore as ms from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore import amp +from mindspore.train.loss_scale_manager import FixedLossScaleManager from src.yolo import YOLOV3DarkNet53, YoloWithLossCell, TrainingWrapper from src.logger import get_logger @@ -37,13 +39,7 @@ from src.lr_scheduler import warmup_step_lr, warmup_cosine_annealing_lr, \ from src.yolo_dataset import create_yolo_dataset from src.initializer import default_recurisive_init from src.config import ConfigYOLOV3DarkNet53 -from src.transforms import batch_preprocess_true_box, batch_preprocess_true_box_single -from src.util import ShapeRecord - - -devid = int(os.getenv('DEVICE_ID')) -context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, - device_target="Ascend", save_graphs=True, device_id=devid) +from src.util import keep_loss_fp32 class BuildTrainNetwork(nn.Cell): @@ -62,6 +58,10 @@ def parse_args(): """Parse train arguments.""" parser = argparse.ArgumentParser('mindspore coco training') + # device related + parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'], + help='device where the code will be implemented. (Default: Ascend)') + # dataset related parser.add_argument('--data_dir', type=str, help='Train dataset directory.') parser.add_argument('--per_batch_size', default=32, type=int, help='Batch size for Training. Default: 32.') @@ -136,9 +136,16 @@ def train(): """Train function.""" args = parse_args() + devid = int(os.getenv('DEVICE_ID')) if os.getenv('DEVICE_ID') else 0 + context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, + device_target=args.device_target, save_graphs=True, device_id=devid) + # init distributed if args.is_distributed: - init() + if args.device_target == "Ascend": + init() + else: + init("nccl") args.rank = get_rank() args.group_size = get_group_size() @@ -259,9 +266,19 @@ def train(): momentum=args.momentum, weight_decay=args.weight_decay, loss_scale=args.loss_scale) - - network = TrainingWrapper(network, opt) - network.set_train() + enable_amp = False + is_gpu = context.get_context("device_target") == "GPU" + if is_gpu: + enable_amp = True + if enable_amp: + loss_scale_value = 1.0 + loss_scale = FixedLossScaleManager(loss_scale_value, drop_overflow_update=False) + network = amp.build_train_network(network, optimizer=opt, loss_scale_manager=loss_scale, + level="O2", keep_batchnorm_fp32=True) + keep_loss_fp32(network) + else: + network = TrainingWrapper(network, opt) + network.set_train() if args.rank_save_ckpt_flag: # checkpoint save @@ -282,28 +299,19 @@ def train(): t_end = time.time() data_loader = ds.create_dict_iterator() - shape_record = ShapeRecord() for i, data in enumerate(data_loader): images = data["image"] input_shape = images.shape[2:4] args.logger.info('iter[{}], shape{}'.format(i, input_shape[0])) - shape_record.set(input_shape) images = Tensor(images) - annos = data["annotation"] - if args.group_size == 1: - batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, batch_gt_box2 = \ - batch_preprocess_true_box(annos, config, input_shape) - else: - batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1, batch_gt_box2 = \ - batch_preprocess_true_box_single(annos, config, input_shape) - - batch_y_true_0 = Tensor(batch_y_true_0) - batch_y_true_1 = Tensor(batch_y_true_1) - batch_y_true_2 = Tensor(batch_y_true_2) - batch_gt_box0 = Tensor(batch_gt_box0) - batch_gt_box1 = Tensor(batch_gt_box1) - batch_gt_box2 = Tensor(batch_gt_box2) + + batch_y_true_0 = Tensor(data['bbox1']) + batch_y_true_1 = Tensor(data['bbox2']) + batch_y_true_2 = Tensor(data['bbox3']) + batch_gt_box0 = Tensor(data['gt_box1']) + batch_gt_box1 = Tensor(data['gt_box2']) + batch_gt_box2 = Tensor(data['gt_box3']) input_shape = Tensor(tuple(input_shape[::-1]), ms.float32) loss = network(images, batch_y_true_0, batch_y_true_1, batch_y_true_2, batch_gt_box0, batch_gt_box1,