parent
9359983123
commit
b288efc87e
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,78 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore import Model, context, Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.dataset import create_dataset
|
||||
from src.unet3d_model import UNet3d
|
||||
from src.config import config as cfg
|
||||
from src.utils import create_sliding_window, CalculateDice
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='Test the UNet3D on images and target masks')
|
||||
parser.add_argument('--data_url', dest='data_url', type=str, default='', help='image data directory')
|
||||
parser.add_argument('--seg_url', dest='seg_url', type=str, default='', help='seg data directory')
|
||||
parser.add_argument('--ckpt_path', dest='ckpt_path', type=str, default='', help='checkpoint path')
|
||||
return parser.parse_args()
|
||||
|
||||
def test_net(data_dir, seg_dir, ckpt_path, config=None):
|
||||
eval_dataset = create_dataset(data_path=data_dir, seg_path=seg_dir, config=config, is_training=False)
|
||||
eval_data_size = eval_dataset.get_dataset_size()
|
||||
print("train dataset length is:", eval_data_size)
|
||||
|
||||
network = UNet3d(config=config)
|
||||
network.set_train(False)
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
model = Model(network)
|
||||
index = 0
|
||||
total_dice = 0
|
||||
for batch in eval_dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
|
||||
image = batch["image"]
|
||||
seg = batch["seg"]
|
||||
print("current image shape is {}".format(image.shape), flush=True)
|
||||
sliding_window_list, slice_list = create_sliding_window(image, config.roi_size, config.overlap)
|
||||
image_size = (config.batch_size, config.num_classes) + image.shape[2:]
|
||||
output_image = np.zeros(image_size, np.float32)
|
||||
count_map = np.zeros(image_size, np.float32)
|
||||
importance_map = np.ones(config.roi_size, np.float32)
|
||||
for window, slice_ in zip(sliding_window_list, slice_list):
|
||||
window_image = Tensor(window, mstype.float32)
|
||||
pred_probs = model.predict(window_image)
|
||||
output_image[slice_] += pred_probs.asnumpy()
|
||||
count_map[slice_] += importance_map
|
||||
output_image = output_image / count_map
|
||||
dice, _ = CalculateDice(output_image, seg)
|
||||
print("The {} batch dice is {}".format(index, dice), flush=True)
|
||||
total_dice += dice
|
||||
index = index + 1
|
||||
avg_dice = total_dice / eval_data_size
|
||||
print("**********************End Eval***************************************")
|
||||
print("eval average dice is {}".format(avg_dice))
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_args()
|
||||
print("Testing setting:", args)
|
||||
test_net(data_dir=args.data_url,
|
||||
seg_dir=args.seg_url,
|
||||
ckpt_path=args.ckpt_path,
|
||||
config=cfg)
|
@ -0,0 +1,80 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# -ne 3 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [IMAGE_PATH] [SEG_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
echo $PATH1
|
||||
|
||||
if [ ! -f $PATH1 ]
|
||||
then
|
||||
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PATH2=$(get_real_path $2)
|
||||
echo $PATH2
|
||||
if [ ! -d $PATH2 ]
|
||||
then
|
||||
echo "error: IMAGE_PATH=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PATH3=$(get_real_path $3)
|
||||
echo $PATH3
|
||||
if [ ! -d $PATH3 ]
|
||||
then
|
||||
echo "error: SEG_PATH=$PATH3 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp *.sh ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env > env.log
|
||||
|
||||
python train.py \
|
||||
--run_distribute=True \
|
||||
--data_url=$PATH2 \
|
||||
--seg_url=$PATH3 > log.txt 2>&1 &
|
||||
|
||||
cd ../
|
||||
done
|
@ -0,0 +1,82 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "=============================================================================================================="
|
||||
echo "Please run the script as: "
|
||||
echo "bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT]"
|
||||
echo "for example: bash run_standalone_eval.sh /path/to/data/ /path/to/checkpoint/"
|
||||
echo "=============================================================================================================="
|
||||
fi
|
||||
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_eval_ascend.sh [IMAGE_PATH] [SEG_PATH] [CHECKPOINT_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
IMAGE_PATH=$(get_real_path $1)
|
||||
SEG_PATH=$(get_real_path $2)
|
||||
CHECKPOINT_FILE_PATH=$(get_real_path $3)
|
||||
echo $IMAGE_PATH
|
||||
echo $SEG_PATH
|
||||
echo $CHECKPOINT_FILE_PATH
|
||||
|
||||
if [ ! -d $IMAGE_PATH ]
|
||||
then
|
||||
echo "error: IMAGE_PATH=$IMAGE_PATH is not a path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $SEG_PATH ]
|
||||
then
|
||||
echo "error: SEG_PATH=$SEG_PATH is not a path"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $CHECKPOINT_FILE_PATH ]
|
||||
then
|
||||
echo "error: CHECKPOINT_FILE_PATH=$CHECKPOINT_FILE_PATH is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
|
||||
if [ -d "eval" ];
|
||||
then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp *.sh ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval
|
||||
echo "start eval for checkpoint file: ${CHECKPOINT_FILE_PATH}"
|
||||
python eval.py --data_url=$IMAGE_PATH --seg_url=$SEG_PATH --ckpt_path=$CHECKPOINT_FILE_PATH > eval.log 2>&1 &
|
||||
echo "end eval for checkpoint file: ${CHECKPOINT_FILE_PATH}"
|
||||
cd ..
|
@ -0,0 +1,62 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# -ne 2 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train_ascend.sh [IMAGE_PATH] [SEG_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
echo $PATH1
|
||||
if [ ! -d $PATH1 ]
|
||||
then
|
||||
echo "error: IMAGE_PATH=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PATH2=$(get_real_path $2)
|
||||
echo $PATH2
|
||||
if [ ! -d $PATH2 ]
|
||||
then
|
||||
echo "error: SEG_PATH=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
rm -rf ./train
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp *.sh ./train
|
||||
cp -r ../src ./train
|
||||
cd ./train || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --data_url=$PATH1 --seg_url=$PATH2 > train.log 2>&1 &
|
||||
cd ..
|
@ -0,0 +1,34 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
from easydict import EasyDict
|
||||
config = EasyDict({
|
||||
'model': 'Unet3d',
|
||||
'lr': 0.0005,
|
||||
'epoch_size': 10,
|
||||
'batch_size': 1,
|
||||
'warmup_step': 120,
|
||||
'warmup_ratio': 0.3,
|
||||
'num_classes': 4,
|
||||
'in_channels': 1,
|
||||
'keep_checkpoint_max': 5,
|
||||
'loss_scale': 256.0,
|
||||
'roi_size': [224, 224, 96],
|
||||
'overlap': 0.25,
|
||||
'min_val': -500,
|
||||
'max_val': 1000,
|
||||
'upper_limit': 5,
|
||||
'lower_limit': 3,
|
||||
})
|
@ -0,0 +1,90 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Parameter
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import nn_ops as nps
|
||||
from mindspore.common.initializer import initializer
|
||||
|
||||
def weight_variable(shape):
|
||||
init_value = initializer('Normal', shape, mstype.float32)
|
||||
return Parameter(init_value)
|
||||
|
||||
class Conv3D(nn.Cell):
|
||||
def __init__(self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
mode=1,
|
||||
pad_mode="valid",
|
||||
pad=0,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
group=1,
|
||||
data_format="NCDHW",
|
||||
bias_init="zeros",
|
||||
has_bias=True):
|
||||
super().__init__()
|
||||
self.weight_shape = (out_channel, in_channel, kernel_size[0], kernel_size[1], kernel_size[2])
|
||||
self.weight = weight_variable(self.weight_shape)
|
||||
self.conv = nps.Conv3D(out_channel=out_channel, kernel_size=kernel_size, mode=mode, \
|
||||
pad_mode=pad_mode, pad=pad, stride=stride, dilation=dilation, \
|
||||
group=group, data_format=data_format)
|
||||
self.bias_init = bias_init
|
||||
self.has_bias = has_bias
|
||||
self.bias_add = P.BiasAdd(data_format=data_format)
|
||||
if self.has_bias:
|
||||
self.bias = Parameter(initializer(self.bias_init, [out_channel]), name='bias')
|
||||
|
||||
def construct(self, x):
|
||||
output = self.conv(x, self.weight)
|
||||
if self.has_bias:
|
||||
output = self.bias_add(output, self.bias)
|
||||
return output
|
||||
|
||||
class Conv3DTranspose(nn.Cell):
|
||||
def __init__(self,
|
||||
in_channel,
|
||||
out_channel,
|
||||
kernel_size,
|
||||
mode=1,
|
||||
pad=0,
|
||||
stride=1,
|
||||
dilation=1,
|
||||
group=1,
|
||||
output_padding=0,
|
||||
data_format="NCDHW",
|
||||
bias_init="zeros",
|
||||
has_bias=True):
|
||||
super().__init__()
|
||||
self.weight_shape = (in_channel, out_channel, kernel_size[0], kernel_size[1], kernel_size[2])
|
||||
self.weight = weight_variable(self.weight_shape)
|
||||
self.conv_transpose = nps.Conv3DTranspose(in_channel=in_channel, out_channel=out_channel,\
|
||||
kernel_size=kernel_size, mode=mode, pad=pad, stride=stride, \
|
||||
dilation=dilation, group=group, output_padding=output_padding, \
|
||||
data_format=data_format)
|
||||
self.bias_init = bias_init
|
||||
self.has_bias = has_bias
|
||||
self.bias_add = P.BiasAdd(data_format=data_format)
|
||||
if self.has_bias:
|
||||
self.bias = Parameter(initializer(self.bias_init, [out_channel]), name='bias')
|
||||
|
||||
def construct(self, x):
|
||||
output = self.conv_transpose(x, self.weight)
|
||||
if self.has_bias:
|
||||
output = self.bias_add(output, self.bias)
|
||||
return output
|
@ -0,0 +1,62 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import os
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import SimpleITK as sitk
|
||||
from src.config import config
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--input_path", type=str, help="Input image directory to be processed.")
|
||||
parser.add_argument("--output_path", type=str, help="Output file path.")
|
||||
args = parser.parse_args()
|
||||
|
||||
def get_list_of_files_in_dir(directory, file_types='*'):
|
||||
"""
|
||||
Get list of certain format files.
|
||||
|
||||
Args:
|
||||
directory (str): The input directory for image.
|
||||
file_types (str): The file_types to filter the files.
|
||||
"""
|
||||
return [f for f in Path(directory).glob(file_types) if f.is_file()]
|
||||
|
||||
def convert_nifti(input_dir, output_dir, roi_size, file_types):
|
||||
"""
|
||||
Convert dataset into mifti format.
|
||||
|
||||
Args:
|
||||
input_dir (str): The input directory for image.
|
||||
output_dir (str): The output directory to save nifti format data.
|
||||
roi_size (str): The size to crop the image.
|
||||
file_types: File types to convert into nifti.
|
||||
"""
|
||||
file_list = get_list_of_files_in_dir(input_dir, file_types)
|
||||
for file_name in file_list:
|
||||
file_name = str(file_name)
|
||||
input_file_name, _ = os.path.splitext(os.path.basename(file_name))
|
||||
img = sitk.ReadImage(file_name)
|
||||
image_array = sitk.GetArrayFromImage(img)
|
||||
D, H, W = image_array.shape
|
||||
if H < roi_size[0] or W < roi_size[1] or D < roi_size[2]:
|
||||
print("file {} size is smaller than roi size, ignore it.".format(input_file_name))
|
||||
continue
|
||||
output_path = os.path.join(output_dir, input_file_name + ".nii.gz")
|
||||
sitk.WriteImage(img, output_path)
|
||||
print("create output file {} success.".format(output_path))
|
||||
|
||||
if __name__ == '__main__':
|
||||
convert_nifti(args.input_path, args.output_path, config.roi_size, "*.mhd")
|
@ -0,0 +1,74 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# =========================================================================
|
||||
|
||||
import os
|
||||
import glob
|
||||
import numpy as np
|
||||
import mindspore.dataset as ds
|
||||
from mindspore.dataset.transforms.py_transforms import Compose
|
||||
from src.config import config as cfg
|
||||
from src.transform import Dataset, ExpandChannel, LoadData, Orientation, ScaleIntensityRange, RandomCropSamples, OneHot
|
||||
|
||||
class ConvertLabel:
|
||||
"""
|
||||
Crop at the center of image with specified ROI size.
|
||||
|
||||
Args:
|
||||
roi_size: the spatial size of the crop region e.g. [224,224,128]
|
||||
If its components have non-positive values, the corresponding size of input image will be used.
|
||||
"""
|
||||
def operation(self, data):
|
||||
"""
|
||||
Apply the transform to `img`, assuming `img` is channel-first and
|
||||
slicing doesn't apply to the channel dim.
|
||||
"""
|
||||
data[data > cfg['upper_limit']] = 0
|
||||
data = data - (cfg['lower_limit'] - 1)
|
||||
data = np.clip(data, 0, cfg['lower_limit'])
|
||||
return data
|
||||
|
||||
def __call__(self, image, label):
|
||||
label = self.operation(label)
|
||||
return image, label
|
||||
|
||||
def create_dataset(data_path, seg_path, config, rank_size=1, rank_id=0, is_training=True):
|
||||
seg_files = sorted(glob.glob(os.path.join(seg_path, "*.nii.gz")))
|
||||
train_files = [os.path.join(data_path, os.path.basename(seg)) for seg in seg_files]
|
||||
train_ds = Dataset(data=train_files, seg=seg_files)
|
||||
train_loader = ds.GeneratorDataset(train_ds, column_names=["image", "seg"], num_parallel_workers=4, \
|
||||
shuffle=is_training, num_shards=rank_size, shard_id=rank_id)
|
||||
|
||||
if is_training:
|
||||
transform_image = Compose([LoadData(),
|
||||
ExpandChannel(),
|
||||
Orientation(),
|
||||
ScaleIntensityRange(src_min=config.min_val, src_max=config.max_val, tgt_min=0.0, \
|
||||
tgt_max=1.0, is_clip=True),
|
||||
RandomCropSamples(roi_size=config.roi_size, num_samples=2),
|
||||
ConvertLabel(),
|
||||
OneHot(num_classes=config.num_classes)])
|
||||
else:
|
||||
transform_image = Compose([LoadData(),
|
||||
ExpandChannel(),
|
||||
Orientation(),
|
||||
ScaleIntensityRange(src_min=config.min_val, src_max=config.max_val, tgt_min=0.0, \
|
||||
tgt_max=1.0, is_clip=True),
|
||||
ConvertLabel()])
|
||||
|
||||
train_loader = train_loader.map(operations=transform_image, input_columns=["image", "seg"], num_parallel_workers=12,
|
||||
python_multiprocessing=True)
|
||||
if not is_training:
|
||||
train_loader = train_loader.batch(1)
|
||||
return train_loader
|
@ -0,0 +1,37 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from src.config import config
|
||||
|
||||
class SoftmaxCrossEntropyWithLogits(_Loss):
|
||||
def __init__(self):
|
||||
super(SoftmaxCrossEntropyWithLogits, self).__init__()
|
||||
self.transpose = P.Transpose()
|
||||
self.reshape = P.Reshape()
|
||||
self.loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
|
||||
self.cast = P.Cast()
|
||||
self.reduce_mean = P.ReduceMean()
|
||||
|
||||
def construct(self, logits, label):
|
||||
logits = self.transpose(logits, (0, 2, 3, 4, 1))
|
||||
label = self.transpose(label, (0, 2, 3, 4, 1))
|
||||
label = self.cast(label, mstype.float32)
|
||||
loss = self.reduce_mean(self.loss_fn(self.reshape(logits, (-1, config['num_classes'])), \
|
||||
self.reshape(label, (-1, config['num_classes']))))
|
||||
return self.get_loss(loss)
|
@ -0,0 +1,39 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import math
|
||||
|
||||
def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
|
||||
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
|
||||
learning_rate = float(init_lr) + lr_inc * current_step
|
||||
return learning_rate
|
||||
|
||||
def a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps):
|
||||
base = float(current_step - warmup_steps) / float(decay_steps)
|
||||
learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr
|
||||
return learning_rate
|
||||
|
||||
def dynamic_lr(config, base_step):
|
||||
"""dynamic learning rate generator"""
|
||||
base_lr = config.lr
|
||||
total_steps = int(base_step * config.epoch_size)
|
||||
warmup_steps = config.warmup_step
|
||||
lr = []
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr.append(linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * config.warmup_ratio))
|
||||
else:
|
||||
lr.append(a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps))
|
||||
return lr
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,62 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from src.unet3d_parts import Down, Up
|
||||
|
||||
class UNet3d(nn.Cell):
|
||||
def __init__(self, config=None):
|
||||
super(UNet3d, self).__init__()
|
||||
self.n_channels = config.in_channels
|
||||
self.n_classes = config.num_classes
|
||||
|
||||
# down
|
||||
self.transpose = P.Transpose()
|
||||
self.down1 = Down(in_channel=self.n_channels, out_channel=16, dtype=mstype.float16).to_float(mstype.float16)
|
||||
self.down2 = Down(in_channel=16, out_channel=32, dtype=mstype.float16).to_float(mstype.float16)
|
||||
self.down3 = Down(in_channel=32, out_channel=64, dtype=mstype.float16).to_float(mstype.float16)
|
||||
self.down4 = Down(in_channel=64, out_channel=128, dtype=mstype.float16).to_float(mstype.float16)
|
||||
self.down5 = Down(in_channel=128, out_channel=256, stride=1, kernel_size=(1, 1, 1), \
|
||||
dtype=mstype.float16).to_float(mstype.float16)
|
||||
|
||||
# up
|
||||
self.up1 = Up(in_channel=256, down_in_channel=128, out_channel=64, \
|
||||
dtype=mstype.float16).to_float(mstype.float16)
|
||||
self.up2 = Up(in_channel=64, down_in_channel=64, out_channel=32, \
|
||||
dtype=mstype.float16).to_float(mstype.float16)
|
||||
self.up3 = Up(in_channel=32, down_in_channel=32, out_channel=16, \
|
||||
dtype=mstype.float16).to_float(mstype.float16)
|
||||
self.up4 = Up(in_channel=16, down_in_channel=16, out_channel=self.n_classes, \
|
||||
dtype=mstype.float16, is_output=True).to_float(mstype.float16)
|
||||
|
||||
self.cast = P.Cast()
|
||||
|
||||
|
||||
def construct(self, input_data):
|
||||
input_data = self.cast(input_data, mstype.float16)
|
||||
x1 = self.down1(input_data)
|
||||
x2 = self.down2(x1)
|
||||
x3 = self.down3(x2)
|
||||
x4 = self.down4(x3)
|
||||
x5 = self.down5(x4)
|
||||
|
||||
x = self.up1(x5, x4)
|
||||
x = self.up2(x, x3)
|
||||
x = self.up3(x, x2)
|
||||
x = self.up4(x, x1)
|
||||
x = self.cast(x, mstype.float32)
|
||||
return x
|
@ -0,0 +1,112 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from src.conv import Conv3D, Conv3DTranspose
|
||||
|
||||
class BatchNorm3d(nn.Cell):
|
||||
def __init__(self, num_features):
|
||||
super().__init__()
|
||||
self.reshape = P.Reshape()
|
||||
self.shape = P.Shape()
|
||||
self.bn2d = nn.BatchNorm2d(num_features, data_format="NCHW")
|
||||
|
||||
def construct(self, x):
|
||||
x_shape = self.shape(x)
|
||||
x = self.reshape(x, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3], x_shape[4]))
|
||||
bn2d_out = self.bn2d(x)
|
||||
bn3d_out = self.reshape(bn2d_out, x_shape)
|
||||
return bn3d_out
|
||||
|
||||
class ResidualUnit(nn.Cell):
|
||||
def __init__(self, in_channel, out_channel, stride=2, kernel_size=(3, 3, 3), down=True, is_output=False):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
self.down = down
|
||||
self.in_channel = in_channel
|
||||
self.out_channel = out_channel
|
||||
self.down_conv_1 = Conv3D(in_channel, out_channel, kernel_size=(3, 3, 3), \
|
||||
pad_mode="pad", stride=self.stride, pad=1)
|
||||
self.is_output = is_output
|
||||
if not is_output:
|
||||
self.batchNormal1 = BatchNorm3d(num_features=self.out_channel)
|
||||
self.relu1 = nn.PReLU()
|
||||
if self.down:
|
||||
self.down_conv_2 = Conv3D(out_channel, out_channel, kernel_size=(3, 3, 3), \
|
||||
pad_mode="pad", stride=1, pad=1)
|
||||
self.relu2 = nn.PReLU()
|
||||
if kernel_size[0] == 1:
|
||||
self.residual = Conv3D(in_channel, out_channel, kernel_size=(1, 1, 1), \
|
||||
pad_mode="valid", stride=self.stride)
|
||||
else:
|
||||
self.residual = Conv3D(in_channel, out_channel, kernel_size=(3, 3, 3), \
|
||||
pad_mode="pad", stride=self.stride, pad=1)
|
||||
self.batchNormal2 = BatchNorm3d(num_features=self.out_channel)
|
||||
|
||||
|
||||
def construct(self, x):
|
||||
out = self.down_conv_1(x)
|
||||
if self.is_output:
|
||||
return out
|
||||
out = self.batchNormal1(out)
|
||||
out = self.relu1(out)
|
||||
if self.down:
|
||||
out = self.down_conv_2(out)
|
||||
out = self.batchNormal2(out)
|
||||
out = self.relu2(out)
|
||||
res = self.residual(x)
|
||||
else:
|
||||
res = x
|
||||
return out + res
|
||||
|
||||
class Down(nn.Cell):
|
||||
def __init__(self, in_channel, out_channel, stride=2, kernel_size=(3, 3, 3), dtype=mstype.float16):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
self.in_channel = in_channel
|
||||
self.out_channel = out_channel
|
||||
self.down_conv = ResidualUnit(self.in_channel, self.out_channel, stride, kernel_size).to_float(dtype)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.down_conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Up(nn.Cell):
|
||||
def __init__(self, in_channel, down_in_channel, out_channel, stride=2, is_output=False, dtype=mstype.float16):
|
||||
super().__init__()
|
||||
self.in_channel = in_channel
|
||||
self.down_in_channel = down_in_channel
|
||||
self.out_channel = out_channel
|
||||
self.stride = stride
|
||||
self.conv3d_transpose = Conv3DTranspose(in_channel=self.in_channel + self.down_in_channel, \
|
||||
pad=1, out_channel=self.out_channel, kernel_size=(3, 3, 3), \
|
||||
stride=self.stride, output_padding=(1, 1, 1))
|
||||
|
||||
self.concat = P.Concat(axis=1)
|
||||
self.conv = ResidualUnit(self.out_channel, self.out_channel, stride=1, down=False, \
|
||||
is_output=is_output).to_float(dtype)
|
||||
self.batchNormal1 = BatchNorm3d(num_features=self.out_channel)
|
||||
self.relu = nn.PReLU()
|
||||
|
||||
def construct(self, input_data, down_input):
|
||||
x = self.concat((input_data, down_input))
|
||||
x = self.conv3d_transpose(x)
|
||||
x = self.batchNormal1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv(x)
|
||||
return x
|
@ -0,0 +1,170 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
from src.config import config
|
||||
|
||||
def correct_nifti_head(img):
|
||||
"""
|
||||
Check nifti object header's format, update the header if needed.
|
||||
In the updated image pixdim matches the affine.
|
||||
|
||||
Args:
|
||||
img: nifti image object
|
||||
"""
|
||||
dim = img.header["dim"][0]
|
||||
if dim >= 5:
|
||||
return img
|
||||
pixdim = np.asarray(img.header.get_zooms())[:dim]
|
||||
norm_affine = np.sqrt(np.sum(np.square(img.affine[:dim, :dim]), 0))
|
||||
if np.allclose(pixdim, norm_affine):
|
||||
return img
|
||||
if hasattr(img, "get_sform"):
|
||||
return rectify_header_sform_qform(img)
|
||||
return img
|
||||
|
||||
def get_random_patch(dims, patch_size, rand_fn=None):
|
||||
"""
|
||||
Returns a tuple of slices to define a random patch in an array of shape `dims` with size `patch_size`.
|
||||
|
||||
Args:
|
||||
dims: shape of source array
|
||||
patch_size: shape of patch size to generate
|
||||
rand_fn: generate random numbers
|
||||
|
||||
Returns:
|
||||
(tuple of slice): a tuple of slice objects defining the patch
|
||||
"""
|
||||
rand_int = np.random.randint if rand_fn is None else rand_fn.randint
|
||||
min_corner = tuple(rand_int(0, ms - ps + 1) if ms > ps else 0 for ms, ps in zip(dims, patch_size))
|
||||
return tuple(slice(mc, mc + ps) for mc, ps in zip(min_corner, patch_size))
|
||||
|
||||
|
||||
def first(iterable, default=None):
|
||||
"""
|
||||
Returns the first item in the given iterable or `default` if empty, meaningful mostly with 'for' expressions.
|
||||
"""
|
||||
for i in iterable:
|
||||
return i
|
||||
return default
|
||||
|
||||
def _get_scan_interval(image_size, roi_size, num_image_dims, overlap):
|
||||
"""
|
||||
Compute scan interval according to the image size, roi size and overlap.
|
||||
Scan interval will be `int((1 - overlap) * roi_size)`, if interval is 0,
|
||||
use 1 instead to make sure sliding window works.
|
||||
"""
|
||||
if len(image_size) != num_image_dims:
|
||||
raise ValueError("image different from spatial dims.")
|
||||
if len(roi_size) != num_image_dims:
|
||||
raise ValueError("roi size different from spatial dims.")
|
||||
|
||||
scan_interval = []
|
||||
for i in range(num_image_dims):
|
||||
if roi_size[i] == image_size[i]:
|
||||
scan_interval.append(int(roi_size[i]))
|
||||
else:
|
||||
interval = int(roi_size[i] * (1 - overlap))
|
||||
scan_interval.append(interval if interval > 0 else 1)
|
||||
return tuple(scan_interval)
|
||||
|
||||
def dense_patch_slices(image_size, patch_size, scan_interval):
|
||||
"""
|
||||
Enumerate all slices defining ND patches of size `patch_size` from an `image_size` input image.
|
||||
|
||||
Args:
|
||||
image_size: dimensions of image to iterate over
|
||||
patch_size: size of patches to generate slices
|
||||
scan_interval: dense patch sampling interval
|
||||
|
||||
Returns:
|
||||
a list of slice objects defining each patch
|
||||
"""
|
||||
num_spatial_dims = len(image_size)
|
||||
patch_size = patch_size
|
||||
scan_num = []
|
||||
for i in range(num_spatial_dims):
|
||||
if scan_interval[i] == 0:
|
||||
scan_num.append(1)
|
||||
else:
|
||||
num = int(math.ceil(float(image_size[i]) / scan_interval[i]))
|
||||
scan_dim = first(d for d in range(num) if d * scan_interval[i] + patch_size[i] >= image_size[i])
|
||||
scan_num.append(scan_dim + 1 if scan_dim is not None else 1)
|
||||
starts = []
|
||||
for dim in range(num_spatial_dims):
|
||||
dim_starts = []
|
||||
for idx in range(scan_num[dim]):
|
||||
start_idx = idx * scan_interval[dim]
|
||||
start_idx -= max(start_idx + patch_size[dim] - image_size[dim], 0)
|
||||
dim_starts.append(start_idx)
|
||||
starts.append(dim_starts)
|
||||
out = np.asarray([x.flatten() for x in np.meshgrid(*starts, indexing="ij")]).T
|
||||
return [(slice(None),)*2 + tuple(slice(s, s + patch_size[d]) for d, s in enumerate(x)) for x in out]
|
||||
|
||||
def create_sliding_window(image, roi_size, overlap):
|
||||
num_image_dims = len(image.shape) - 2
|
||||
if overlap < 0 or overlap >= 1:
|
||||
raise AssertionError("overlap must be >= 0 and < 1.")
|
||||
image_size_temp = list(image.shape[2:])
|
||||
image_size = tuple(max(image_size_temp[i], roi_size[i]) for i in range(num_image_dims))
|
||||
|
||||
scan_interval = _get_scan_interval(image_size, roi_size, num_image_dims, overlap)
|
||||
slices = dense_patch_slices(image_size, roi_size, scan_interval)
|
||||
windows_sliding = [image[slice] for slice in slices]
|
||||
return windows_sliding, slices
|
||||
|
||||
def one_hot(labels):
|
||||
N, _, D, H, W = labels.shape
|
||||
labels = np.reshape(labels, (N, -1))
|
||||
labels = labels.astype(np.int32)
|
||||
N, K = labels.shape
|
||||
one_hot_encoding = np.zeros((N, config['num_classes'], K), dtype=np.float32)
|
||||
for i in range(N):
|
||||
for j in range(K):
|
||||
one_hot_encoding[i, labels[i][j], j] = 1
|
||||
labels = np.reshape(one_hot_encoding, (N, config['num_classes'], D, H, W))
|
||||
return labels
|
||||
|
||||
def CalculateDice(y_pred, label):
|
||||
"""
|
||||
Args:
|
||||
y_pred: predictions. As for classification tasks,
|
||||
`y_pred` should has the shape [BN] where N is larger than 1. As for segmentation tasks,
|
||||
the shape should be [BNHW] or [BNHWD].
|
||||
label: ground truth, the first dim is batch.
|
||||
"""
|
||||
y_pred_output = np.expand_dims(np.argmax(y_pred, axis=1), axis=1)
|
||||
y_pred = one_hot(y_pred_output)
|
||||
y = one_hot(label)
|
||||
y_pred, y = ignore_background(y_pred, y)
|
||||
inter = np.dot(y_pred.flatten(), y.flatten()).astype(np.float64)
|
||||
union = np.dot(y_pred.flatten(), y_pred.flatten()).astype(np.float64) + np.dot(y.flatten(), \
|
||||
y.flatten()).astype(np.float64)
|
||||
single_dice_coeff = 2 * inter / (union + 1e-6)
|
||||
return single_dice_coeff, y_pred_output
|
||||
|
||||
def ignore_background(y_pred, label):
|
||||
"""
|
||||
This function is used to remove background (the first channel) for `y_pred` and `y`.
|
||||
Args:
|
||||
y_pred: predictions. As for classification tasks,
|
||||
`y_pred` should has the shape [BN] where N is larger than 1. As for segmentation tasks,
|
||||
the shape should be [BNHW] or [BNHWD].
|
||||
label: ground truth, the first dim is batch.
|
||||
"""
|
||||
label = label[:, 1:] if label.shape[1] > 1 else label
|
||||
y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred
|
||||
return y_pred, label
|
@ -0,0 +1,94 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# less required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import ast
|
||||
import mindspore
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import Tensor, Model, context
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor
|
||||
from src.dataset import create_dataset
|
||||
from src.unet3d_model import UNet3d
|
||||
from src.config import config as cfg
|
||||
from src.lr_schedule import dynamic_lr
|
||||
from src.loss import SoftmaxCrossEntropyWithLogits
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, \
|
||||
device_id=device_id)
|
||||
mindspore.set_seed(1)
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='Train the UNet3D on images and target masks')
|
||||
parser.add_argument('--data_url', dest='data_url', type=str, default='', help='image data directory')
|
||||
parser.add_argument('--seg_url', dest='seg_url', type=str, default='', help='seg data directory')
|
||||
parser.add_argument('--run_distribute', dest='run_distribute', type=ast.literal_eval, default=False, \
|
||||
help='Run distribute, default: false')
|
||||
return parser.parse_args()
|
||||
|
||||
def train_net(data_dir,
|
||||
seg_dir,
|
||||
run_distribute,
|
||||
config=None):
|
||||
if run_distribute:
|
||||
init()
|
||||
rank_id = get_rank()
|
||||
rank_size = get_group_size()
|
||||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode,
|
||||
device_num=rank_size,
|
||||
gradients_mean=True)
|
||||
else:
|
||||
rank_id = 0
|
||||
rank_size = 1
|
||||
train_dataset = create_dataset(data_path=data_dir, seg_path=seg_dir, config=config, \
|
||||
rank_size=rank_size, rank_id=rank_id, is_training=True)
|
||||
train_data_size = train_dataset.get_dataset_size()
|
||||
print("train dataset length is:", train_data_size)
|
||||
|
||||
network = UNet3d(config=config)
|
||||
|
||||
loss = SoftmaxCrossEntropyWithLogits()
|
||||
lr = Tensor(dynamic_lr(config, train_data_size), mstype.float32)
|
||||
optimizer = nn.Adam(params=network.trainable_params(), learning_rate=lr)
|
||||
scale_manager = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
network.set_train()
|
||||
|
||||
model = Model(network, loss_fn=loss, optimizer=optimizer, loss_scale_manager=scale_manager)
|
||||
|
||||
time_cb = TimeMonitor(data_size=train_data_size)
|
||||
loss_cb = LossMonitor()
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix='{}'.format(config.model),
|
||||
directory='./ckpt_{}/'.format(device_id),
|
||||
config=ckpt_config)
|
||||
callbacks_list = [loss_cb, time_cb, ckpoint_cb]
|
||||
print("============== Starting Training ==============")
|
||||
model.train(config.epoch_size, train_dataset, callbacks=callbacks_list)
|
||||
print("============== End Training ==============")
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = get_args()
|
||||
print("Training setting:", args)
|
||||
train_net(data_dir=args.data_url,
|
||||
seg_dir=args.seg_url,
|
||||
run_distribute=args.run_distribute,
|
||||
config=cfg)
|
Loading…
Reference in new issue