parent
4e3abb2434
commit
03125ee773
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,113 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Evaluation for retinanet"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import time
|
||||
import numpy as np
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.retinanet import retinanet50, resnet50, retinanetInferWithDecoder
|
||||
from src.dataset import create_retinanet_dataset, data_to_mindrecord_byte_image, voc_data_to_mindrecord
|
||||
from src.config import config
|
||||
from src.coco_eval import metrics
|
||||
from src.box_utils import default_boxes
|
||||
|
||||
def retinanet_eval(dataset_path, ckpt_path):
|
||||
"""retinanet evaluation."""
|
||||
batch_size = 1
|
||||
ds = create_retinanet_dataset(dataset_path, batch_size=batch_size, repeat_num=1, is_training=False)
|
||||
backbone = resnet50(config.num_classes)
|
||||
net = retinanet50(backbone, config)
|
||||
net = retinanetInferWithDecoder(net, Tensor(default_boxes), config)
|
||||
print("Load Checkpoint!")
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
net.init_parameters_data()
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
net.set_train(False)
|
||||
i = batch_size
|
||||
total = ds.get_dataset_size() * batch_size
|
||||
start = time.time()
|
||||
pred_data = []
|
||||
print("\n========================================\n")
|
||||
print("total images num: ", total)
|
||||
print("Processing, please wait a moment.")
|
||||
for data in ds.create_dict_iterator(output_numpy=True):
|
||||
img_id = data['img_id']
|
||||
img_np = data['image']
|
||||
image_shape = data['image_shape']
|
||||
|
||||
output = net(Tensor(img_np))
|
||||
for batch_idx in range(img_np.shape[0]):
|
||||
pred_data.append({"boxes": output[0].asnumpy()[batch_idx],
|
||||
"box_scores": output[1].asnumpy()[batch_idx],
|
||||
"img_id": int(np.squeeze(img_id[batch_idx])),
|
||||
"image_shape": image_shape[batch_idx]})
|
||||
percent = round(i / total * 100., 2)
|
||||
|
||||
print(f' {str(percent)} [{i}/{total}]', end='\r')
|
||||
i += batch_size
|
||||
cost_time = int((time.time() - start) * 1000)
|
||||
print(f' 100% [{total}/{total}] cost {cost_time} ms')
|
||||
mAP = metrics(pred_data)
|
||||
print("\n========================================\n")
|
||||
print(f"mAP: {mAP}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='retinanet evaluation')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.")
|
||||
parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend"),
|
||||
help="run platform, only support Ascend.")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.run_platform, device_id=args_opt.device_id)
|
||||
|
||||
prefix = "retinanet_eval.mindrecord"
|
||||
mindrecord_dir = config.mindrecord_dir
|
||||
mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
|
||||
if args_opt.dataset == "voc":
|
||||
config.coco_root = config.voc_root
|
||||
if not os.path.exists(mindrecord_file):
|
||||
if not os.path.isdir(mindrecord_dir):
|
||||
os.makedirs(mindrecord_dir)
|
||||
if args_opt.dataset == "coco":
|
||||
if os.path.isdir(config.coco_root):
|
||||
print("Create Mindrecord.")
|
||||
data_to_mindrecord_byte_image("coco", False, prefix)
|
||||
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
|
||||
else:
|
||||
print("coco_root not exits.")
|
||||
elif args_opt.dataset == "voc":
|
||||
if os.path.isdir(config.voc_dir) and os.path.isdir(config.voc_root):
|
||||
print("Create Mindrecord.")
|
||||
voc_data_to_mindrecord(mindrecord_dir, False, prefix)
|
||||
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
|
||||
else:
|
||||
print("voc_root or voc_dir not exits.")
|
||||
else:
|
||||
if os.path.isdir(config.image_dir) and os.path.exists(config.anno_path):
|
||||
print("Create Mindrecord.")
|
||||
data_to_mindrecord_byte_image("other", False, prefix)
|
||||
print("Create Mindrecord Done, at {}".format(mindrecord_dir))
|
||||
else:
|
||||
print("IMAGE_DIR or ANNO_PATH not exits.")
|
||||
|
||||
print("Start Eval!")
|
||||
retinanet_eval(mindrecord_file, config.checkpoint_path)
|
@ -0,0 +1,83 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "sh run_distribute_train.sh DEVICE_NUM EPOCH_SIZE LR DATASET RANK_TABLE_FILE PRE_TRAINED PRE_TRAINED_EPOCH_SIZE"
|
||||
echo "for example: sh run_distribute_train.sh 8 500 0.1 coco /data/hccl.json /opt/retinanet-500_458.ckpt(optional) 200(optional)"
|
||||
echo "It is better to use absolute path."
|
||||
echo "================================================================================================================="
|
||||
|
||||
if [ $# != 5 ] && [ $# != 7 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train.sh [DEVICE_NUM] [EPOCH_SIZE] [LR] [DATASET] \
|
||||
[RANK_TABLE_FILE] [PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Before start distribute train, first create mindrecord files.
|
||||
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
cd $BASE_PATH/../ || exit
|
||||
python train.py --only_create_dataset=True
|
||||
|
||||
echo "After running the script, the network runs in the background. The log will be generated in LOGx/log.txt"
|
||||
|
||||
export RANK_SIZE=$1
|
||||
EPOCH_SIZE=$2
|
||||
LR=$3
|
||||
DATASET=$4
|
||||
PRE_TRAINED=$6
|
||||
PRE_TRAINED_EPOCH_SIZE=$7
|
||||
export RANK_TABLE_FILE=$5
|
||||
|
||||
for((i=0;i<RANK_SIZE;i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
rm -rf LOG$i
|
||||
mkdir ./LOG$i
|
||||
cp ./*.py ./LOG$i
|
||||
cp -r ./src ./LOG$i
|
||||
cp -r ./scripts ./LOG$i
|
||||
cd ./LOG$i || exit
|
||||
export RANK_ID=$i
|
||||
echo "start training for rank $i, device $DEVICE_ID"
|
||||
env > env.log
|
||||
if [ $# == 5 ]
|
||||
then
|
||||
python train.py \
|
||||
--distribute=True \
|
||||
--lr=$LR \
|
||||
--dataset=$DATASET \
|
||||
--device_num=$RANK_SIZE \
|
||||
--device_id=$DEVICE_ID \
|
||||
--epoch_size=$EPOCH_SIZE > log.txt 2>&1 &
|
||||
fi
|
||||
|
||||
if [ $# == 7 ]
|
||||
then
|
||||
python train.py \
|
||||
--distribute=True \
|
||||
--lr=$LR \
|
||||
--dataset=$DATASET \
|
||||
--device_num=$RANK_SIZE \
|
||||
--device_id=$DEVICE_ID \
|
||||
--pre_trained=$PRE_TRAINED \
|
||||
--pre_trained_epoch_size=$PRE_TRAINED_EPOCH_SIZE \
|
||||
--epoch_size=$EPOCH_SIZE > log.txt 2>&1 &
|
||||
fi
|
||||
|
||||
cd ../
|
||||
done
|
@ -0,0 +1,49 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "Usage: sh run_eval.sh [DATASET] [DEVICE_ID]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
DATASET=$1
|
||||
echo $DATASET
|
||||
|
||||
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=$2
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export RANK_ID=0
|
||||
|
||||
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
cd $BASE_PATH/../ || exit
|
||||
|
||||
if [ -d "eval$2" ];
|
||||
then
|
||||
rm -rf ./eval$2
|
||||
fi
|
||||
|
||||
mkdir ./eval$2
|
||||
cp ./*.py ./eval$2
|
||||
cp -r ./src ./eval$2
|
||||
cd ./eval$2 || exit
|
||||
env > env.log
|
||||
echo "start inferring for device $DEVICE_ID"
|
||||
python eval.py \
|
||||
--dataset=$DATASET \
|
||||
--device_id=$2 > log.txt 2>&1 &
|
||||
cd ..
|
@ -0,0 +1,76 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "sh run_single_train.sh DEVICE_ID EPOCH_SIZE LR DATASET PRE_TRAINED PRE_TRAINED_EPOCH_SIZE"
|
||||
echo "for example: sh run_single_train.sh 0 500 0.1 coco /opt/retinanet-500_458.ckpt(optional) 200(optional)"
|
||||
echo "It is better to use absolute path."
|
||||
echo "================================================================================================================="
|
||||
|
||||
if [ $# != 4 ] && [ $# != 6 ]
|
||||
then
|
||||
echo "Usage: sh run_single_train.sh [DEVICE_ID] [EPOCH_SIZE] [LR] [DATASET] \
|
||||
[PRE_TRAINED](optional) [PRE_TRAINED_EPOCH_SIZE](optional)"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Before start single train, first create mindrecord files.
|
||||
BASE_PATH=$(cd "`dirname $0`" || exit; pwd)
|
||||
cd $BASE_PATH/../ || exit
|
||||
python train.py --only_create_dataset=True
|
||||
|
||||
echo "After running the script, the network runs in the background. The log will be generated in LOGx/log.txt"
|
||||
|
||||
export DEVICE_ID=$1
|
||||
EPOCH_SIZE=$2
|
||||
LR=$3
|
||||
DATASET=$4
|
||||
PRE_TRAINED=$5
|
||||
PRE_TRAINED_EPOCH_SIZE=$6
|
||||
|
||||
rm -rf LOG$1
|
||||
mkdir ./LOG$1
|
||||
cp ./*.py ./LOG$1
|
||||
cp -r ./src ./LOG$1
|
||||
cd ./LOG$1 || exit
|
||||
echo "start training for device $1"
|
||||
env > env.log
|
||||
if [ $# == 4 ]
|
||||
then
|
||||
python train.py \
|
||||
--distribute=False \
|
||||
--lr=$LR \
|
||||
--dataset=$DATASET \
|
||||
--device_num=1 \
|
||||
--device_id=$DEVICE_ID \
|
||||
--epoch_size=$EPOCH_SIZE > log.txt 2>&1 &
|
||||
fi
|
||||
|
||||
if [ $# == 6 ]
|
||||
then
|
||||
python train,py \
|
||||
--distribute=False \
|
||||
--lr=$LR \
|
||||
--dataset=$DATASET \
|
||||
--device_num=1 \
|
||||
--device_id=$DEVICE_ID \
|
||||
--pre_trained=$PRE_TRAINED \
|
||||
--pre_trained_epoch_size=$PRE_TRAINED_EPOCH_SIZE \
|
||||
--epoch_size=$EPOCH_SIZE > log.txt 2>&1 &
|
||||
fi
|
||||
|
||||
cd ../
|
||||
|
@ -0,0 +1,165 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Bbox utils"""
|
||||
|
||||
import math
|
||||
import itertools as it
|
||||
import numpy as np
|
||||
from .config import config
|
||||
|
||||
|
||||
class GeneratDefaultBoxes():
|
||||
"""
|
||||
Generate Default boxes for retinanet, follows the order of (W, H, archor_sizes).
|
||||
`self.default_boxes` has a shape of [archor_sizes, H, W, 4], the last dimension is [y, x, h, w].
|
||||
`self.default_boxes_ltrb` has a shape as `self.default_boxes`, the last dimension is [y1, x1, y2, x2].
|
||||
"""
|
||||
def __init__(self):
|
||||
fk = config.img_shape[0] / np.array(config.steps)
|
||||
scales = np.array([2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)])
|
||||
anchor_size = np.array(config.anchor_size)
|
||||
self.default_boxes = []
|
||||
for idex, feature_size in enumerate(config.feature_size):
|
||||
base_size = anchor_size[idex] / config.img_shape[0]
|
||||
size1 = base_size*scales[0]
|
||||
size2 = base_size*scales[1]
|
||||
size3 = base_size*scales[2]
|
||||
all_sizes = []
|
||||
for aspect_ratio in config.aspect_ratios[idex]:
|
||||
w1, h1 = size1 * math.sqrt(aspect_ratio), size1 / math.sqrt(aspect_ratio)
|
||||
all_sizes.append((h1, w1))
|
||||
w2, h2 = size2 * math.sqrt(aspect_ratio), size2 / math.sqrt(aspect_ratio)
|
||||
all_sizes.append((h2, w2))
|
||||
w3, h3 = size3 * math.sqrt(aspect_ratio), size3 / math.sqrt(aspect_ratio)
|
||||
all_sizes.append((h3, w3))
|
||||
|
||||
assert len(all_sizes) == config.num_default[idex]
|
||||
|
||||
for i, j in it.product(range(feature_size), repeat=2):
|
||||
for h, w in all_sizes:
|
||||
cx, cy = (j + 0.5) / fk[idex], (i + 0.5) / fk[idex]
|
||||
self.default_boxes.append([cy, cx, h, w])
|
||||
|
||||
def to_ltrb(cy, cx, h, w):
|
||||
return cy - h / 2, cx - w / 2, cy + h / 2, cx + w / 2
|
||||
|
||||
# For IoU calculation
|
||||
self.default_boxes_ltrb = np.array(tuple(to_ltrb(*i) for i in self.default_boxes), dtype='float32')
|
||||
self.default_boxes = np.array(self.default_boxes, dtype='float32')
|
||||
|
||||
|
||||
default_boxes_ltrb = GeneratDefaultBoxes().default_boxes_ltrb
|
||||
default_boxes = GeneratDefaultBoxes().default_boxes
|
||||
y1, x1, y2, x2 = np.split(default_boxes_ltrb[:, :4], 4, axis=-1)
|
||||
vol_anchors = (x2 - x1) * (y2 - y1)
|
||||
matching_threshold = config.match_thershold
|
||||
|
||||
|
||||
def retinanet_bboxes_encode(boxes):
|
||||
"""
|
||||
Labels anchors with ground truth inputs.
|
||||
|
||||
Args:
|
||||
boxex: ground truth with shape [N, 5], for each row, it stores [y, x, h, w, cls].
|
||||
|
||||
Returns:
|
||||
gt_loc: location ground truth with shape [num_anchors, 4].
|
||||
gt_label: class ground truth with shape [num_anchors, 1].
|
||||
num_matched_boxes: number of positives in an image.
|
||||
"""
|
||||
|
||||
def jaccard_with_anchors(bbox):
|
||||
"""Compute jaccard score a box and the anchors."""
|
||||
# Intersection bbox and volume.
|
||||
ymin = np.maximum(y1, bbox[0])
|
||||
xmin = np.maximum(x1, bbox[1])
|
||||
ymax = np.minimum(y2, bbox[2])
|
||||
xmax = np.minimum(x2, bbox[3])
|
||||
w = np.maximum(xmax - xmin, 0.)
|
||||
h = np.maximum(ymax - ymin, 0.)
|
||||
|
||||
# Volumes.
|
||||
inter_vol = h * w
|
||||
union_vol = vol_anchors + (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) - inter_vol
|
||||
jaccard = inter_vol / union_vol
|
||||
return np.squeeze(jaccard)
|
||||
|
||||
pre_scores = np.zeros((config.num_retinanet_boxes), dtype=np.float32)
|
||||
t_boxes = np.zeros((config.num_retinanet_boxes, 4), dtype=np.float32)
|
||||
t_label = np.zeros((config.num_retinanet_boxes), dtype=np.int64)
|
||||
for bbox in boxes:
|
||||
label = int(bbox[4])
|
||||
scores = jaccard_with_anchors(bbox)
|
||||
idx = np.argmax(scores)
|
||||
scores[idx] = 2.0
|
||||
mask = (scores > matching_threshold)
|
||||
mask = mask & (scores > pre_scores)
|
||||
pre_scores = np.maximum(pre_scores, scores * mask)
|
||||
t_label = mask * label + (1 - mask) * t_label
|
||||
for i in range(4):
|
||||
t_boxes[:, i] = mask * bbox[i] + (1 - mask) * t_boxes[:, i]
|
||||
|
||||
index = np.nonzero(t_label)
|
||||
|
||||
# Transform to ltrb.
|
||||
bboxes = np.zeros((config.num_retinanet_boxes, 4), dtype=np.float32)
|
||||
bboxes[:, [0, 1]] = (t_boxes[:, [0, 1]] + t_boxes[:, [2, 3]]) / 2
|
||||
bboxes[:, [2, 3]] = t_boxes[:, [2, 3]] - t_boxes[:, [0, 1]]
|
||||
|
||||
# Encode features.
|
||||
bboxes_t = bboxes[index]
|
||||
default_boxes_t = default_boxes[index]
|
||||
bboxes_t[:, :2] = (bboxes_t[:, :2] - default_boxes_t[:, :2]) / (default_boxes_t[:, 2:] * config.prior_scaling[0])
|
||||
tmp = np.maximum(bboxes_t[:, 2:4] / default_boxes_t[:, 2:4], 0.000001)
|
||||
bboxes_t[:, 2:4] = np.log(tmp) / config.prior_scaling[1]
|
||||
bboxes[index] = bboxes_t
|
||||
|
||||
num_match = np.array([len(np.nonzero(t_label)[0])], dtype=np.int32)
|
||||
return bboxes, t_label.astype(np.int32), num_match
|
||||
|
||||
|
||||
def retinanet_bboxes_decode(boxes):
|
||||
"""Decode predict boxes to [y, x, h, w]"""
|
||||
boxes_t = boxes.copy()
|
||||
default_boxes_t = default_boxes.copy()
|
||||
boxes_t[:, :2] = boxes_t[:, :2] * config.prior_scaling[0] * default_boxes_t[:, 2:] + default_boxes_t[:, :2]
|
||||
boxes_t[:, 2:4] = np.exp(boxes_t[:, 2:4] * config.prior_scaling[1]) * default_boxes_t[:, 2:4]
|
||||
|
||||
bboxes = np.zeros((len(boxes_t), 4), dtype=np.float32)
|
||||
|
||||
bboxes[:, [0, 1]] = boxes_t[:, [0, 1]] - boxes_t[:, [2, 3]] / 2
|
||||
bboxes[:, [2, 3]] = boxes_t[:, [0, 1]] + boxes_t[:, [2, 3]] / 2
|
||||
|
||||
return np.clip(bboxes, 0, 1)
|
||||
|
||||
|
||||
def intersect(box_a, box_b):
|
||||
"""Compute the intersect of two sets of boxes."""
|
||||
max_yx = np.minimum(box_a[:, 2:4], box_b[2:4])
|
||||
min_yx = np.maximum(box_a[:, :2], box_b[:2])
|
||||
inter = np.clip((max_yx - min_yx), a_min=0, a_max=np.inf)
|
||||
return inter[:, 0] * inter[:, 1]
|
||||
|
||||
|
||||
def jaccard_numpy(box_a, box_b):
|
||||
"""Compute the jaccard overlap of two sets of boxes."""
|
||||
inter = intersect(box_a, box_b)
|
||||
area_a = ((box_a[:, 2] - box_a[:, 0]) *
|
||||
(box_a[:, 3] - box_a[:, 1]))
|
||||
area_b = ((box_b[2] - box_b[0]) *
|
||||
(box_b[3] - box_b[1]))
|
||||
union = area_a + area_b - inter
|
||||
return inter / union
|
@ -0,0 +1,125 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Coco metrics utils"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
from .config import config
|
||||
|
||||
|
||||
def apply_nms(all_boxes, all_scores, thres, max_boxes):
|
||||
"""Apply NMS to bboxes."""
|
||||
y1 = all_boxes[:, 0]
|
||||
x1 = all_boxes[:, 1]
|
||||
y2 = all_boxes[:, 2]
|
||||
x2 = all_boxes[:, 3]
|
||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
|
||||
order = all_scores.argsort()[::-1]
|
||||
keep = []
|
||||
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
keep.append(i)
|
||||
|
||||
if len(keep) >= max_boxes:
|
||||
break
|
||||
|
||||
xx1 = np.maximum(x1[i], x1[order[1:]])
|
||||
yy1 = np.maximum(y1[i], y1[order[1:]])
|
||||
xx2 = np.minimum(x2[i], x2[order[1:]])
|
||||
yy2 = np.minimum(y2[i], y2[order[1:]])
|
||||
|
||||
w = np.maximum(0.0, xx2 - xx1 + 1)
|
||||
h = np.maximum(0.0, yy2 - yy1 + 1)
|
||||
inter = w * h
|
||||
|
||||
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
||||
|
||||
inds = np.where(ovr <= thres)[0]
|
||||
|
||||
order = order[inds + 1]
|
||||
return keep
|
||||
|
||||
|
||||
def metrics(pred_data):
|
||||
"""Calculate mAP of predicted bboxes."""
|
||||
from pycocotools.coco import COCO
|
||||
from pycocotools.cocoeval import COCOeval
|
||||
num_classes = config.num_classes
|
||||
|
||||
coco_root = config.coco_root
|
||||
data_type = config.val_data_type
|
||||
|
||||
#Classes need to train or test.
|
||||
val_cls = config.coco_classes
|
||||
val_cls_dict = {}
|
||||
for i, cls in enumerate(val_cls):
|
||||
val_cls_dict[i] = cls
|
||||
|
||||
anno_json = os.path.join(coco_root, config.instances_set.format(data_type))
|
||||
coco_gt = COCO(anno_json)
|
||||
classs_dict = {}
|
||||
cat_ids = coco_gt.loadCats(coco_gt.getCatIds())
|
||||
for cat in cat_ids:
|
||||
classs_dict[cat["name"]] = cat["id"]
|
||||
|
||||
predictions = []
|
||||
img_ids = []
|
||||
|
||||
for sample in pred_data:
|
||||
pred_boxes = sample['boxes']
|
||||
box_scores = sample['box_scores']
|
||||
img_id = sample['img_id']
|
||||
h, w = sample['image_shape']
|
||||
|
||||
final_boxes = []
|
||||
final_label = []
|
||||
final_score = []
|
||||
img_ids.append(img_id)
|
||||
|
||||
for c in range(1, num_classes):
|
||||
class_box_scores = box_scores[:, c]
|
||||
score_mask = class_box_scores > config.min_score
|
||||
class_box_scores = class_box_scores[score_mask]
|
||||
class_boxes = pred_boxes[score_mask] * [h, w, h, w]
|
||||
|
||||
if score_mask.any():
|
||||
nms_index = apply_nms(class_boxes, class_box_scores, config.nms_thershold, config.max_boxes)
|
||||
class_boxes = class_boxes[nms_index]
|
||||
class_box_scores = class_box_scores[nms_index]
|
||||
|
||||
final_boxes += class_boxes.tolist()
|
||||
final_score += class_box_scores.tolist()
|
||||
final_label += [classs_dict[val_cls_dict[c]]] * len(class_box_scores)
|
||||
|
||||
for loc, label, score in zip(final_boxes, final_label, final_score):
|
||||
res = {}
|
||||
res['image_id'] = img_id
|
||||
res['bbox'] = [loc[1], loc[0], loc[3] - loc[1], loc[2] - loc[0]]
|
||||
res['score'] = score
|
||||
res['category_id'] = label
|
||||
predictions.append(res)
|
||||
with open('predictions.json', 'w') as f:
|
||||
json.dump(predictions, f)
|
||||
|
||||
coco_dt = coco_gt.loadRes('predictions.json')
|
||||
E = COCOeval(coco_gt, coco_dt, iouType='bbox')
|
||||
E.params.imgIds = img_ids
|
||||
E.evaluate()
|
||||
E.accumulate()
|
||||
E.summarize()
|
||||
return E.stats[0]
|
@ -0,0 +1,86 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#" ============================================================================
|
||||
|
||||
"""Config parameters for retinanet models."""
|
||||
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
config = ed({
|
||||
"img_shape": [600, 600],
|
||||
"num_retinanet_boxes": 67995,
|
||||
"match_thershold": 0.5,
|
||||
"nms_thershold": 0.6,
|
||||
"min_score": 0.1,
|
||||
"max_boxes": 100,
|
||||
|
||||
# learing rate settings
|
||||
"global_step": 0,
|
||||
"lr_init": 1e-6,
|
||||
"lr_end_rate": 5e-3,
|
||||
"warmup_epochs1": 2,
|
||||
"warmup_epochs2": 5,
|
||||
"warmup_epochs3": 23,
|
||||
"warmup_epochs4": 60,
|
||||
"warmup_epochs5": 160,
|
||||
"momentum": 0.9,
|
||||
"weight_decay": 1.5e-4,
|
||||
|
||||
# network
|
||||
"num_default": [9, 9, 9, 9, 9],
|
||||
"extras_out_channels": [256, 256, 256, 256, 256],
|
||||
"feature_size": [75, 38, 19, 10, 5],
|
||||
"aspect_ratios": [(0.5, 1.0, 2.0), (0.5, 1.0, 2.0), (0.5, 1.0, 2.0), (0.5, 1.0, 2.0), (0.5, 1.0, 2.0)],
|
||||
"steps": (8, 16, 32, 64, 128),
|
||||
"anchor_size": (32, 64, 128, 256, 512),
|
||||
"prior_scaling": (0.1, 0.2),
|
||||
"gamma": 2.0,
|
||||
"alpha": 0.75,
|
||||
|
||||
# `mindrecord_dir` and `coco_root` are better to use absolute path.
|
||||
"mindrecord_dir": "/data/hitwh/retinanet/MindRecord_COCO",
|
||||
"coco_root": "/data/dataset/coco2017",
|
||||
"train_data_type": "train2017",
|
||||
"val_data_type": "val2017",
|
||||
"instances_set": "/data/dataset/coco2017/annotations/instances_{}.json",
|
||||
"coco_classes": ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
|
||||
'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
|
||||
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
|
||||
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra',
|
||||
'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',
|
||||
'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
|
||||
'kite', 'baseball bat', 'baseball glove', 'skateboard',
|
||||
'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',
|
||||
'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
|
||||
'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
|
||||
'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',
|
||||
'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',
|
||||
'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink',
|
||||
'refrigerator', 'book', 'clock', 'vase', 'scissors',
|
||||
'teddy bear', 'hair drier', 'toothbrush'),
|
||||
"num_classes": 81,
|
||||
# The annotation.json position of voc validation dataset.
|
||||
"voc_root": "",
|
||||
# voc original dataset.
|
||||
"voc_dir": "",
|
||||
# if coco or voc used, `image_dir` and `anno_path` are useless.
|
||||
"image_dir": "",
|
||||
"anno_path": "",
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_epochs": 1,
|
||||
"keep_checkpoint_max": 1,
|
||||
"save_checkpoint_path": "./model",
|
||||
"finish_epoch": 0,
|
||||
"checkpoint_path": "/home/hitwh1/1.0/ckpt_0/retinanet-500_458_59.ckpt"
|
||||
})
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,35 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Parameters utils"""
|
||||
|
||||
from mindspore.common.initializer import initializer, TruncatedNormal
|
||||
|
||||
def init_net_param(network, initialize_mode='TruncatedNormal'):
|
||||
"""Init the parameters in net."""
|
||||
params = network.trainable_params()
|
||||
for p in params:
|
||||
if 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
|
||||
if initialize_mode == 'TruncatedNormal':
|
||||
p.set_data(initializer(TruncatedNormal(), p.data.shape, p.data.dtype))
|
||||
else:
|
||||
p.set_data(initialize_mode, p.data.shape, p.data.dtype)
|
||||
|
||||
|
||||
|
||||
def filter_checkpoint_parameter(param_dict):
|
||||
"""remove useless parameters"""
|
||||
for key in list(param_dict.keys()):
|
||||
if 'multi_loc_layers' in key or 'multi_cls_layers' in key:
|
||||
del param_dict[key]
|
@ -0,0 +1,74 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Learning rate schedule"""
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_lr(global_step, lr_init, lr_end, lr_max, warmup_epochs1, warmup_epochs2,
|
||||
warmup_epochs3, warmup_epochs4, warmup_epochs5, total_epochs, steps_per_epoch):
|
||||
"""
|
||||
generate learning rate array
|
||||
|
||||
Args:
|
||||
global_step(int): total steps of the training
|
||||
lr_init(float): init learning rate
|
||||
lr_end(float): end learning rate
|
||||
lr_max(float): max learning rate
|
||||
warmup_epochs(float): number of warmup epochs
|
||||
total_epochs(int): total epoch of training
|
||||
steps_per_epoch(int): steps of one epoch
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array
|
||||
"""
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
warmup_steps1 = steps_per_epoch * warmup_epochs1
|
||||
warmup_steps2 = warmup_steps1 + steps_per_epoch * warmup_epochs2
|
||||
warmup_steps3 = warmup_steps2 + steps_per_epoch * warmup_epochs3
|
||||
warmup_steps4 = warmup_steps3 + steps_per_epoch * warmup_epochs4
|
||||
warmup_steps5 = warmup_steps4 + steps_per_epoch * warmup_epochs5
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps1:
|
||||
lr = lr_init*(warmup_steps1-i) / (warmup_steps1) + \
|
||||
(lr_max*1e-4) * i / (warmup_steps1*3)
|
||||
elif warmup_steps1 <= i < warmup_steps2:
|
||||
lr = 1e-5*(warmup_steps2-i) / (warmup_steps2 - warmup_steps1) + \
|
||||
(lr_max*1e-3) * (i-warmup_steps1) / (warmup_steps2 - warmup_steps1)
|
||||
elif warmup_steps2 <= i < warmup_steps3:
|
||||
lr = 1e-4*(warmup_steps3-i) / (warmup_steps3 - warmup_steps2) + \
|
||||
(lr_max*1e-2) * (i-warmup_steps2) / (warmup_steps3 - warmup_steps2)
|
||||
elif warmup_steps3 <= i < warmup_steps4:
|
||||
lr = 1e-3*(warmup_steps4-i) / (warmup_steps4 - warmup_steps3) + \
|
||||
(lr_max*1e-1) * (i-warmup_steps3) / (warmup_steps4 - warmup_steps3)
|
||||
elif warmup_steps4 <= i < warmup_steps5:
|
||||
lr = 1e-2*(warmup_steps5-i) / (warmup_steps5 - warmup_steps4) + \
|
||||
lr_max * (i-warmup_steps4) / (warmup_steps5 - warmup_steps4)
|
||||
else:
|
||||
lr = lr_end + \
|
||||
(lr_max - lr_end) * \
|
||||
(1. + math.cos(math.pi * (i-warmup_steps5) / (total_steps - warmup_steps5))) / 2.
|
||||
if lr < 0.0:
|
||||
lr = 0.0
|
||||
lr_each_step.append(lr)
|
||||
|
||||
current_step = global_step
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
learning_rate = lr_each_step[current_step:]
|
||||
|
||||
return learning_rate
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,153 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Train retinanet and get checkpoint files."""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import ast
|
||||
import mindspore
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.communication.management import init, get_rank
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor, Callback
|
||||
from mindspore.train import Model
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import set_seed
|
||||
from src.retinanet import retinanetWithLossCell, TrainingWrapper, retinanet50, resnet50
|
||||
from src.config import config
|
||||
from src.dataset import create_retinanet_dataset, create_mindrecord
|
||||
from src.lr_schedule import get_lr
|
||||
from src.init_params import init_net_param, filter_checkpoint_parameter
|
||||
|
||||
|
||||
set_seed(1)
|
||||
class Monitor(Callback):
|
||||
"""
|
||||
Monitor loss and time.
|
||||
|
||||
Args:
|
||||
lr_init (numpy array): train lr
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Examples:
|
||||
>>> Monitor(100,lr_init=Tensor([0.05]*100).asnumpy())
|
||||
"""
|
||||
|
||||
def __init__(self, lr_init=None):
|
||||
super(Monitor, self).__init__()
|
||||
self.lr_init = lr_init
|
||||
self.lr_init_len = len(lr_init)
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
print("lr:[{:8.6f}]".format(self.lr_init[cb_params.cur_step_num-1]), flush=True)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="retinanet training")
|
||||
parser.add_argument("--only_create_dataset", type=ast.literal_eval, default=False,
|
||||
help="If set it true, only create Mindrecord, default is False.")
|
||||
parser.add_argument("--distribute", type=ast.literal_eval, default=False,
|
||||
help="Run distribute, default is False.")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--device_num", type=int, default=1, help="Use device nums, default is 1.")
|
||||
parser.add_argument("--lr", type=float, default=0.1, help="Learning rate, default is 0.1.")
|
||||
parser.add_argument("--mode", type=str, default="sink", help="Run sink mode or not, default is sink.")
|
||||
parser.add_argument("--dataset", type=str, default="coco", help="Dataset, default is coco.")
|
||||
parser.add_argument("--epoch_size", type=int, default=500, help="Epoch size, default is 500.")
|
||||
parser.add_argument("--batch_size", type=int, default=32, help="Batch size, default is 32.")
|
||||
parser.add_argument("--pre_trained", type=str, default=None, help="Pretrained Checkpoint file path.")
|
||||
parser.add_argument("--pre_trained_epoch_size", type=int, default=0, help="Pretrained epoch size.")
|
||||
parser.add_argument("--save_checkpoint_epochs", type=int, default=1, help="Save checkpoint epochs, default is 1.")
|
||||
parser.add_argument("--loss_scale", type=int, default=1024, help="Loss scale, default is 1024.")
|
||||
parser.add_argument("--filter_weight", type=ast.literal_eval, default=False,
|
||||
help="Filter weight parameters, default is False.")
|
||||
parser.add_argument("--run_platform", type=str, default="Ascend", choices=("Ascend"),
|
||||
help="run platform, only support Ascend.")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
if args_opt.run_platform == "Ascend":
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
if args_opt.distribute:
|
||||
if os.getenv("DEVICE_ID", "not_set").isdigit():
|
||||
context.set_context(device_id=int(os.getenv("DEVICE_ID")))
|
||||
init()
|
||||
device_num = args_opt.device_num
|
||||
rank = get_rank()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
|
||||
device_num=device_num)
|
||||
else:
|
||||
rank = 0
|
||||
device_num = 1
|
||||
context.set_context(device_id=args_opt.device_id)
|
||||
|
||||
else:
|
||||
raise ValueError("Unsupported platform.")
|
||||
|
||||
mindrecord_file = create_mindrecord(args_opt.dataset, "retinanet.mindrecord", True)
|
||||
|
||||
if not args_opt.only_create_dataset:
|
||||
loss_scale = float(args_opt.loss_scale)
|
||||
|
||||
# When create MindDataset, using the fitst mindrecord file, such as retinanet.mindrecord0.
|
||||
dataset = create_retinanet_dataset(mindrecord_file, repeat_num=1,
|
||||
batch_size=args_opt.batch_size, device_num=device_num, rank=rank)
|
||||
|
||||
dataset_size = dataset.get_dataset_size()
|
||||
print("Create dataset done!")
|
||||
|
||||
|
||||
backbone = resnet50(config.num_classes)
|
||||
retinanet = retinanet50(backbone, config)
|
||||
net = retinanetWithLossCell(retinanet, config)
|
||||
net.to_float(mindspore.float16)
|
||||
init_net_param(net)
|
||||
|
||||
if args_opt.pre_trained:
|
||||
if args_opt.pre_trained_epoch_size <= 0:
|
||||
raise KeyError("pre_trained_epoch_size must be greater than 0.")
|
||||
param_dict = load_checkpoint(args_opt.pre_trained)
|
||||
if args_opt.filter_weight:
|
||||
filter_checkpoint_parameter(param_dict)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
lr = Tensor(get_lr(global_step=config.global_step,
|
||||
lr_init=config.lr_init, lr_end=config.lr_end_rate * args_opt.lr, lr_max=args_opt.lr,
|
||||
warmup_epochs1=config.warmup_epochs1, warmup_epochs2=config.warmup_epochs2,
|
||||
warmup_epochs3=config.warmup_epochs3, warmup_epochs4=config.warmup_epochs4,
|
||||
warmup_epochs5=config.warmup_epochs5, total_epochs=args_opt.epoch_size,
|
||||
steps_per_epoch=dataset_size))
|
||||
opt = nn.Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr,
|
||||
config.momentum, config.weight_decay, loss_scale)
|
||||
net = TrainingWrapper(net, opt, loss_scale)
|
||||
model = Model(net)
|
||||
print("Start train retinanet, the first epoch will be slower because of the graph compilation.")
|
||||
cb = [TimeMonitor(), LossMonitor()]
|
||||
cb += [Monitor(lr_init=lr.asnumpy())]
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=dataset_size * args_opt.save_checkpoint_epochs,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix="retinanet", directory=config.save_checkpoint_path, config=config_ck)
|
||||
if args_opt.distribute:
|
||||
if rank == 0:
|
||||
cb += [ckpt_cb]
|
||||
model.train(args_opt.epoch_size, dataset, callbacks=cb, dataset_sink_mode=True)
|
||||
else:
|
||||
cb += [ckpt_cb]
|
||||
model.train(args_opt.epoch_size, dataset, callbacks=cb, dataset_sink_mode=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in new issue