pull/13761/head
lilei 4 years ago
parent 9359983123
commit b288efc87e

File diff suppressed because it is too large Load Diff

@ -0,0 +1,78 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import argparse
import numpy as np
from mindspore import dtype as mstype
from mindspore import Model, context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.dataset import create_dataset
from src.unet3d_model import UNet3d
from src.config import config as cfg
from src.utils import create_sliding_window, CalculateDice
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, device_id=device_id)
def get_args():
parser = argparse.ArgumentParser(description='Test the UNet3D on images and target masks')
parser.add_argument('--data_url', dest='data_url', type=str, default='', help='image data directory')
parser.add_argument('--seg_url', dest='seg_url', type=str, default='', help='seg data directory')
parser.add_argument('--ckpt_path', dest='ckpt_path', type=str, default='', help='checkpoint path')
return parser.parse_args()
def test_net(data_dir, seg_dir, ckpt_path, config=None):
eval_dataset = create_dataset(data_path=data_dir, seg_path=seg_dir, config=config, is_training=False)
eval_data_size = eval_dataset.get_dataset_size()
print("train dataset length is:", eval_data_size)
network = UNet3d(config=config)
network.set_train(False)
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(network, param_dict)
model = Model(network)
index = 0
total_dice = 0
for batch in eval_dataset.create_dict_iterator(num_epochs=1, output_numpy=True):
image = batch["image"]
seg = batch["seg"]
print("current image shape is {}".format(image.shape), flush=True)
sliding_window_list, slice_list = create_sliding_window(image, config.roi_size, config.overlap)
image_size = (config.batch_size, config.num_classes) + image.shape[2:]
output_image = np.zeros(image_size, np.float32)
count_map = np.zeros(image_size, np.float32)
importance_map = np.ones(config.roi_size, np.float32)
for window, slice_ in zip(sliding_window_list, slice_list):
window_image = Tensor(window, mstype.float32)
pred_probs = model.predict(window_image)
output_image[slice_] += pred_probs.asnumpy()
count_map[slice_] += importance_map
output_image = output_image / count_map
dice, _ = CalculateDice(output_image, seg)
print("The {} batch dice is {}".format(index, dice), flush=True)
total_dice += dice
index = index + 1
avg_dice = total_dice / eval_data_size
print("**********************End Eval***************************************")
print("eval average dice is {}".format(avg_dice))
if __name__ == '__main__':
args = get_args()
print("Testing setting:", args)
test_net(data_dir=args.data_url,
seg_dir=args.seg_url,
ckpt_path=args.ckpt_path,
config=cfg)

@ -0,0 +1,80 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -ne 3 ]
then
echo "Usage: sh run_distribute_train_ascend.sh [RANK_TABLE_FILE] [IMAGE_PATH] [SEG_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
echo $PATH1
if [ ! -f $PATH1 ]
then
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
exit 1
fi
PATH2=$(get_real_path $2)
echo $PATH2
if [ ! -d $PATH2 ]
then
echo "error: IMAGE_PATH=$PATH2 is not a file"
exit 1
fi
PATH3=$(get_real_path $3)
echo $PATH3
if [ ! -d $PATH3 ]
then
echo "error: SEG_PATH=$PATH3 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=$PATH1
for((i=0; i<${DEVICE_NUM}; i++))
do
export DEVICE_ID=$i
export RANK_ID=$i
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i
cp *.sh ./train_parallel$i
cp -r ../src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
python train.py \
--run_distribute=True \
--data_url=$PATH2 \
--seg_url=$PATH3 > log.txt 2>&1 &
cd ../
done

@ -0,0 +1,82 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]
then
echo "=============================================================================================================="
echo "Please run the script as: "
echo "bash scripts/run_standalone_eval.sh [DATASET] [CHECKPOINT]"
echo "for example: bash run_standalone_eval.sh /path/to/data/ /path/to/checkpoint/"
echo "=============================================================================================================="
fi
if [ $# != 3 ]
then
echo "Usage: sh run_eval_ascend.sh [IMAGE_PATH] [SEG_PATH] [CHECKPOINT_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
IMAGE_PATH=$(get_real_path $1)
SEG_PATH=$(get_real_path $2)
CHECKPOINT_FILE_PATH=$(get_real_path $3)
echo $IMAGE_PATH
echo $SEG_PATH
echo $CHECKPOINT_FILE_PATH
if [ ! -d $IMAGE_PATH ]
then
echo "error: IMAGE_PATH=$IMAGE_PATH is not a path"
exit 1
fi
if [ ! -d $SEG_PATH ]
then
echo "error: SEG_PATH=$SEG_PATH is not a path"
exit 1
fi
if [ ! -f $CHECKPOINT_FILE_PATH ]
then
echo "error: CHECKPOINT_FILE_PATH=$CHECKPOINT_FILE_PATH is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export RANK_SIZE=$DEVICE_NUM
export DEVICE_ID=0
export RANK_ID=0
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval
echo "start eval for checkpoint file: ${CHECKPOINT_FILE_PATH}"
python eval.py --data_url=$IMAGE_PATH --seg_url=$SEG_PATH --ckpt_path=$CHECKPOINT_FILE_PATH > eval.log 2>&1 &
echo "end eval for checkpoint file: ${CHECKPOINT_FILE_PATH}"
cd ..

@ -0,0 +1,62 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -ne 2 ]
then
echo "Usage: sh run_distribute_train_ascend.sh [IMAGE_PATH] [SEG_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
echo $PATH1
if [ ! -d $PATH1 ]
then
echo "error: IMAGE_PATH=$PATH1 is not a file"
exit 1
fi
PATH2=$(get_real_path $2)
echo $PATH2
if [ ! -d $PATH2 ]
then
echo "error: SEG_PATH=$PATH2 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
rm -rf ./train
mkdir ./train
cp ../*.py ./train
cp *.sh ./train
cp -r ../src ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
python train.py --data_url=$PATH1 --seg_url=$PATH2 > train.log 2>&1 &
cd ..

@ -0,0 +1,34 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from easydict import EasyDict
config = EasyDict({
'model': 'Unet3d',
'lr': 0.0005,
'epoch_size': 10,
'batch_size': 1,
'warmup_step': 120,
'warmup_ratio': 0.3,
'num_classes': 4,
'in_channels': 1,
'keep_checkpoint_max': 5,
'loss_scale': 256.0,
'roi_size': [224, 224, 96],
'overlap': 0.25,
'min_val': -500,
'max_val': 1000,
'upper_limit': 5,
'lower_limit': 3,
})

@ -0,0 +1,90 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.nn as nn
from mindspore import Parameter
from mindspore import dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops.operations import nn_ops as nps
from mindspore.common.initializer import initializer
def weight_variable(shape):
init_value = initializer('Normal', shape, mstype.float32)
return Parameter(init_value)
class Conv3D(nn.Cell):
def __init__(self,
in_channel,
out_channel,
kernel_size,
mode=1,
pad_mode="valid",
pad=0,
stride=1,
dilation=1,
group=1,
data_format="NCDHW",
bias_init="zeros",
has_bias=True):
super().__init__()
self.weight_shape = (out_channel, in_channel, kernel_size[0], kernel_size[1], kernel_size[2])
self.weight = weight_variable(self.weight_shape)
self.conv = nps.Conv3D(out_channel=out_channel, kernel_size=kernel_size, mode=mode, \
pad_mode=pad_mode, pad=pad, stride=stride, dilation=dilation, \
group=group, data_format=data_format)
self.bias_init = bias_init
self.has_bias = has_bias
self.bias_add = P.BiasAdd(data_format=data_format)
if self.has_bias:
self.bias = Parameter(initializer(self.bias_init, [out_channel]), name='bias')
def construct(self, x):
output = self.conv(x, self.weight)
if self.has_bias:
output = self.bias_add(output, self.bias)
return output
class Conv3DTranspose(nn.Cell):
def __init__(self,
in_channel,
out_channel,
kernel_size,
mode=1,
pad=0,
stride=1,
dilation=1,
group=1,
output_padding=0,
data_format="NCDHW",
bias_init="zeros",
has_bias=True):
super().__init__()
self.weight_shape = (in_channel, out_channel, kernel_size[0], kernel_size[1], kernel_size[2])
self.weight = weight_variable(self.weight_shape)
self.conv_transpose = nps.Conv3DTranspose(in_channel=in_channel, out_channel=out_channel,\
kernel_size=kernel_size, mode=mode, pad=pad, stride=stride, \
dilation=dilation, group=group, output_padding=output_padding, \
data_format=data_format)
self.bias_init = bias_init
self.has_bias = has_bias
self.bias_add = P.BiasAdd(data_format=data_format)
if self.has_bias:
self.bias = Parameter(initializer(self.bias_init, [out_channel]), name='bias')
def construct(self, x):
output = self.conv_transpose(x, self.weight)
if self.has_bias:
output = self.bias_add(output, self.bias)
return output

@ -0,0 +1,62 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import argparse
from pathlib import Path
import SimpleITK as sitk
from src.config import config
parser = argparse.ArgumentParser()
parser.add_argument("--input_path", type=str, help="Input image directory to be processed.")
parser.add_argument("--output_path", type=str, help="Output file path.")
args = parser.parse_args()
def get_list_of_files_in_dir(directory, file_types='*'):
"""
Get list of certain format files.
Args:
directory (str): The input directory for image.
file_types (str): The file_types to filter the files.
"""
return [f for f in Path(directory).glob(file_types) if f.is_file()]
def convert_nifti(input_dir, output_dir, roi_size, file_types):
"""
Convert dataset into mifti format.
Args:
input_dir (str): The input directory for image.
output_dir (str): The output directory to save nifti format data.
roi_size (str): The size to crop the image.
file_types: File types to convert into nifti.
"""
file_list = get_list_of_files_in_dir(input_dir, file_types)
for file_name in file_list:
file_name = str(file_name)
input_file_name, _ = os.path.splitext(os.path.basename(file_name))
img = sitk.ReadImage(file_name)
image_array = sitk.GetArrayFromImage(img)
D, H, W = image_array.shape
if H < roi_size[0] or W < roi_size[1] or D < roi_size[2]:
print("file {} size is smaller than roi size, ignore it.".format(input_file_name))
continue
output_path = os.path.join(output_dir, input_file_name + ".nii.gz")
sitk.WriteImage(img, output_path)
print("create output file {} success.".format(output_path))
if __name__ == '__main__':
convert_nifti(args.input_path, args.output_path, config.roi_size, "*.mhd")

@ -0,0 +1,74 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========================================================================
import os
import glob
import numpy as np
import mindspore.dataset as ds
from mindspore.dataset.transforms.py_transforms import Compose
from src.config import config as cfg
from src.transform import Dataset, ExpandChannel, LoadData, Orientation, ScaleIntensityRange, RandomCropSamples, OneHot
class ConvertLabel:
"""
Crop at the center of image with specified ROI size.
Args:
roi_size: the spatial size of the crop region e.g. [224,224,128]
If its components have non-positive values, the corresponding size of input image will be used.
"""
def operation(self, data):
"""
Apply the transform to `img`, assuming `img` is channel-first and
slicing doesn't apply to the channel dim.
"""
data[data > cfg['upper_limit']] = 0
data = data - (cfg['lower_limit'] - 1)
data = np.clip(data, 0, cfg['lower_limit'])
return data
def __call__(self, image, label):
label = self.operation(label)
return image, label
def create_dataset(data_path, seg_path, config, rank_size=1, rank_id=0, is_training=True):
seg_files = sorted(glob.glob(os.path.join(seg_path, "*.nii.gz")))
train_files = [os.path.join(data_path, os.path.basename(seg)) for seg in seg_files]
train_ds = Dataset(data=train_files, seg=seg_files)
train_loader = ds.GeneratorDataset(train_ds, column_names=["image", "seg"], num_parallel_workers=4, \
shuffle=is_training, num_shards=rank_size, shard_id=rank_id)
if is_training:
transform_image = Compose([LoadData(),
ExpandChannel(),
Orientation(),
ScaleIntensityRange(src_min=config.min_val, src_max=config.max_val, tgt_min=0.0, \
tgt_max=1.0, is_clip=True),
RandomCropSamples(roi_size=config.roi_size, num_samples=2),
ConvertLabel(),
OneHot(num_classes=config.num_classes)])
else:
transform_image = Compose([LoadData(),
ExpandChannel(),
Orientation(),
ScaleIntensityRange(src_min=config.min_val, src_max=config.max_val, tgt_min=0.0, \
tgt_max=1.0, is_clip=True),
ConvertLabel()])
train_loader = train_loader.map(operations=transform_image, input_columns=["image", "seg"], num_parallel_workers=12,
python_multiprocessing=True)
if not is_training:
train_loader = train_loader.batch(1)
return train_loader

@ -0,0 +1,37 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.nn as nn
from mindspore import dtype as mstype
from mindspore.ops import operations as P
from mindspore.nn.loss.loss import _Loss
from src.config import config
class SoftmaxCrossEntropyWithLogits(_Loss):
def __init__(self):
super(SoftmaxCrossEntropyWithLogits, self).__init__()
self.transpose = P.Transpose()
self.reshape = P.Reshape()
self.loss_fn = nn.SoftmaxCrossEntropyWithLogits(sparse=False)
self.cast = P.Cast()
self.reduce_mean = P.ReduceMean()
def construct(self, logits, label):
logits = self.transpose(logits, (0, 2, 3, 4, 1))
label = self.transpose(label, (0, 2, 3, 4, 1))
label = self.cast(label, mstype.float32)
loss = self.reduce_mean(self.loss_fn(self.reshape(logits, (-1, config['num_classes'])), \
self.reshape(label, (-1, config['num_classes']))))
return self.get_loss(loss)

@ -0,0 +1,39 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import math
def linear_warmup_learning_rate(current_step, warmup_steps, base_lr, init_lr):
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
learning_rate = float(init_lr) + lr_inc * current_step
return learning_rate
def a_cosine_learning_rate(current_step, base_lr, warmup_steps, decay_steps):
base = float(current_step - warmup_steps) / float(decay_steps)
learning_rate = (1 + math.cos(base * math.pi)) / 2 * base_lr
return learning_rate
def dynamic_lr(config, base_step):
"""dynamic learning rate generator"""
base_lr = config.lr
total_steps = int(base_step * config.epoch_size)
warmup_steps = config.warmup_step
lr = []
for i in range(total_steps):
if i < warmup_steps:
lr.append(linear_warmup_learning_rate(i, warmup_steps, base_lr, base_lr * config.warmup_ratio))
else:
lr.append(a_cosine_learning_rate(i, base_lr, warmup_steps, total_steps))
return lr

File diff suppressed because it is too large Load Diff

@ -0,0 +1,62 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.nn as nn
from mindspore import dtype as mstype
from mindspore.ops import operations as P
from src.unet3d_parts import Down, Up
class UNet3d(nn.Cell):
def __init__(self, config=None):
super(UNet3d, self).__init__()
self.n_channels = config.in_channels
self.n_classes = config.num_classes
# down
self.transpose = P.Transpose()
self.down1 = Down(in_channel=self.n_channels, out_channel=16, dtype=mstype.float16).to_float(mstype.float16)
self.down2 = Down(in_channel=16, out_channel=32, dtype=mstype.float16).to_float(mstype.float16)
self.down3 = Down(in_channel=32, out_channel=64, dtype=mstype.float16).to_float(mstype.float16)
self.down4 = Down(in_channel=64, out_channel=128, dtype=mstype.float16).to_float(mstype.float16)
self.down5 = Down(in_channel=128, out_channel=256, stride=1, kernel_size=(1, 1, 1), \
dtype=mstype.float16).to_float(mstype.float16)
# up
self.up1 = Up(in_channel=256, down_in_channel=128, out_channel=64, \
dtype=mstype.float16).to_float(mstype.float16)
self.up2 = Up(in_channel=64, down_in_channel=64, out_channel=32, \
dtype=mstype.float16).to_float(mstype.float16)
self.up3 = Up(in_channel=32, down_in_channel=32, out_channel=16, \
dtype=mstype.float16).to_float(mstype.float16)
self.up4 = Up(in_channel=16, down_in_channel=16, out_channel=self.n_classes, \
dtype=mstype.float16, is_output=True).to_float(mstype.float16)
self.cast = P.Cast()
def construct(self, input_data):
input_data = self.cast(input_data, mstype.float16)
x1 = self.down1(input_data)
x2 = self.down2(x1)
x3 = self.down3(x2)
x4 = self.down4(x3)
x5 = self.down5(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
x = self.cast(x, mstype.float32)
return x

@ -0,0 +1,112 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.nn as nn
from mindspore import dtype as mstype
from mindspore.ops import operations as P
from src.conv import Conv3D, Conv3DTranspose
class BatchNorm3d(nn.Cell):
def __init__(self, num_features):
super().__init__()
self.reshape = P.Reshape()
self.shape = P.Shape()
self.bn2d = nn.BatchNorm2d(num_features, data_format="NCHW")
def construct(self, x):
x_shape = self.shape(x)
x = self.reshape(x, (x_shape[0], x_shape[1], x_shape[2] * x_shape[3], x_shape[4]))
bn2d_out = self.bn2d(x)
bn3d_out = self.reshape(bn2d_out, x_shape)
return bn3d_out
class ResidualUnit(nn.Cell):
def __init__(self, in_channel, out_channel, stride=2, kernel_size=(3, 3, 3), down=True, is_output=False):
super().__init__()
self.stride = stride
self.down = down
self.in_channel = in_channel
self.out_channel = out_channel
self.down_conv_1 = Conv3D(in_channel, out_channel, kernel_size=(3, 3, 3), \
pad_mode="pad", stride=self.stride, pad=1)
self.is_output = is_output
if not is_output:
self.batchNormal1 = BatchNorm3d(num_features=self.out_channel)
self.relu1 = nn.PReLU()
if self.down:
self.down_conv_2 = Conv3D(out_channel, out_channel, kernel_size=(3, 3, 3), \
pad_mode="pad", stride=1, pad=1)
self.relu2 = nn.PReLU()
if kernel_size[0] == 1:
self.residual = Conv3D(in_channel, out_channel, kernel_size=(1, 1, 1), \
pad_mode="valid", stride=self.stride)
else:
self.residual = Conv3D(in_channel, out_channel, kernel_size=(3, 3, 3), \
pad_mode="pad", stride=self.stride, pad=1)
self.batchNormal2 = BatchNorm3d(num_features=self.out_channel)
def construct(self, x):
out = self.down_conv_1(x)
if self.is_output:
return out
out = self.batchNormal1(out)
out = self.relu1(out)
if self.down:
out = self.down_conv_2(out)
out = self.batchNormal2(out)
out = self.relu2(out)
res = self.residual(x)
else:
res = x
return out + res
class Down(nn.Cell):
def __init__(self, in_channel, out_channel, stride=2, kernel_size=(3, 3, 3), dtype=mstype.float16):
super().__init__()
self.stride = stride
self.in_channel = in_channel
self.out_channel = out_channel
self.down_conv = ResidualUnit(self.in_channel, self.out_channel, stride, kernel_size).to_float(dtype)
def construct(self, x):
x = self.down_conv(x)
return x
class Up(nn.Cell):
def __init__(self, in_channel, down_in_channel, out_channel, stride=2, is_output=False, dtype=mstype.float16):
super().__init__()
self.in_channel = in_channel
self.down_in_channel = down_in_channel
self.out_channel = out_channel
self.stride = stride
self.conv3d_transpose = Conv3DTranspose(in_channel=self.in_channel + self.down_in_channel, \
pad=1, out_channel=self.out_channel, kernel_size=(3, 3, 3), \
stride=self.stride, output_padding=(1, 1, 1))
self.concat = P.Concat(axis=1)
self.conv = ResidualUnit(self.out_channel, self.out_channel, stride=1, down=False, \
is_output=is_output).to_float(dtype)
self.batchNormal1 = BatchNorm3d(num_features=self.out_channel)
self.relu = nn.PReLU()
def construct(self, input_data, down_input):
x = self.concat((input_data, down_input))
x = self.conv3d_transpose(x)
x = self.batchNormal1(x)
x = self.relu(x)
x = self.conv(x)
return x

@ -0,0 +1,170 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import math
import numpy as np
from src.config import config
def correct_nifti_head(img):
"""
Check nifti object header's format, update the header if needed.
In the updated image pixdim matches the affine.
Args:
img: nifti image object
"""
dim = img.header["dim"][0]
if dim >= 5:
return img
pixdim = np.asarray(img.header.get_zooms())[:dim]
norm_affine = np.sqrt(np.sum(np.square(img.affine[:dim, :dim]), 0))
if np.allclose(pixdim, norm_affine):
return img
if hasattr(img, "get_sform"):
return rectify_header_sform_qform(img)
return img
def get_random_patch(dims, patch_size, rand_fn=None):
"""
Returns a tuple of slices to define a random patch in an array of shape `dims` with size `patch_size`.
Args:
dims: shape of source array
patch_size: shape of patch size to generate
rand_fn: generate random numbers
Returns:
(tuple of slice): a tuple of slice objects defining the patch
"""
rand_int = np.random.randint if rand_fn is None else rand_fn.randint
min_corner = tuple(rand_int(0, ms - ps + 1) if ms > ps else 0 for ms, ps in zip(dims, patch_size))
return tuple(slice(mc, mc + ps) for mc, ps in zip(min_corner, patch_size))
def first(iterable, default=None):
"""
Returns the first item in the given iterable or `default` if empty, meaningful mostly with 'for' expressions.
"""
for i in iterable:
return i
return default
def _get_scan_interval(image_size, roi_size, num_image_dims, overlap):
"""
Compute scan interval according to the image size, roi size and overlap.
Scan interval will be `int((1 - overlap) * roi_size)`, if interval is 0,
use 1 instead to make sure sliding window works.
"""
if len(image_size) != num_image_dims:
raise ValueError("image different from spatial dims.")
if len(roi_size) != num_image_dims:
raise ValueError("roi size different from spatial dims.")
scan_interval = []
for i in range(num_image_dims):
if roi_size[i] == image_size[i]:
scan_interval.append(int(roi_size[i]))
else:
interval = int(roi_size[i] * (1 - overlap))
scan_interval.append(interval if interval > 0 else 1)
return tuple(scan_interval)
def dense_patch_slices(image_size, patch_size, scan_interval):
"""
Enumerate all slices defining ND patches of size `patch_size` from an `image_size` input image.
Args:
image_size: dimensions of image to iterate over
patch_size: size of patches to generate slices
scan_interval: dense patch sampling interval
Returns:
a list of slice objects defining each patch
"""
num_spatial_dims = len(image_size)
patch_size = patch_size
scan_num = []
for i in range(num_spatial_dims):
if scan_interval[i] == 0:
scan_num.append(1)
else:
num = int(math.ceil(float(image_size[i]) / scan_interval[i]))
scan_dim = first(d for d in range(num) if d * scan_interval[i] + patch_size[i] >= image_size[i])
scan_num.append(scan_dim + 1 if scan_dim is not None else 1)
starts = []
for dim in range(num_spatial_dims):
dim_starts = []
for idx in range(scan_num[dim]):
start_idx = idx * scan_interval[dim]
start_idx -= max(start_idx + patch_size[dim] - image_size[dim], 0)
dim_starts.append(start_idx)
starts.append(dim_starts)
out = np.asarray([x.flatten() for x in np.meshgrid(*starts, indexing="ij")]).T
return [(slice(None),)*2 + tuple(slice(s, s + patch_size[d]) for d, s in enumerate(x)) for x in out]
def create_sliding_window(image, roi_size, overlap):
num_image_dims = len(image.shape) - 2
if overlap < 0 or overlap >= 1:
raise AssertionError("overlap must be >= 0 and < 1.")
image_size_temp = list(image.shape[2:])
image_size = tuple(max(image_size_temp[i], roi_size[i]) for i in range(num_image_dims))
scan_interval = _get_scan_interval(image_size, roi_size, num_image_dims, overlap)
slices = dense_patch_slices(image_size, roi_size, scan_interval)
windows_sliding = [image[slice] for slice in slices]
return windows_sliding, slices
def one_hot(labels):
N, _, D, H, W = labels.shape
labels = np.reshape(labels, (N, -1))
labels = labels.astype(np.int32)
N, K = labels.shape
one_hot_encoding = np.zeros((N, config['num_classes'], K), dtype=np.float32)
for i in range(N):
for j in range(K):
one_hot_encoding[i, labels[i][j], j] = 1
labels = np.reshape(one_hot_encoding, (N, config['num_classes'], D, H, W))
return labels
def CalculateDice(y_pred, label):
"""
Args:
y_pred: predictions. As for classification tasks,
`y_pred` should has the shape [BN] where N is larger than 1. As for segmentation tasks,
the shape should be [BNHW] or [BNHWD].
label: ground truth, the first dim is batch.
"""
y_pred_output = np.expand_dims(np.argmax(y_pred, axis=1), axis=1)
y_pred = one_hot(y_pred_output)
y = one_hot(label)
y_pred, y = ignore_background(y_pred, y)
inter = np.dot(y_pred.flatten(), y.flatten()).astype(np.float64)
union = np.dot(y_pred.flatten(), y_pred.flatten()).astype(np.float64) + np.dot(y.flatten(), \
y.flatten()).astype(np.float64)
single_dice_coeff = 2 * inter / (union + 1e-6)
return single_dice_coeff, y_pred_output
def ignore_background(y_pred, label):
"""
This function is used to remove background (the first channel) for `y_pred` and `y`.
Args:
y_pred: predictions. As for classification tasks,
`y_pred` should has the shape [BN] where N is larger than 1. As for segmentation tasks,
the shape should be [BNHW] or [BNHWD].
label: ground truth, the first dim is batch.
"""
label = label[:, 1:] if label.shape[1] > 1 else label
y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred
return y_pred, label

@ -0,0 +1,94 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# less required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import argparse
import ast
import mindspore
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import Tensor, Model, context
from mindspore.context import ParallelMode
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.callback import CheckpointConfig, ModelCheckpoint, LossMonitor, TimeMonitor
from src.dataset import create_dataset
from src.unet3d_model import UNet3d
from src.config import config as cfg
from src.lr_schedule import dynamic_lr
from src.loss import SoftmaxCrossEntropyWithLogits
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, \
device_id=device_id)
mindspore.set_seed(1)
def get_args():
parser = argparse.ArgumentParser(description='Train the UNet3D on images and target masks')
parser.add_argument('--data_url', dest='data_url', type=str, default='', help='image data directory')
parser.add_argument('--seg_url', dest='seg_url', type=str, default='', help='seg data directory')
parser.add_argument('--run_distribute', dest='run_distribute', type=ast.literal_eval, default=False, \
help='Run distribute, default: false')
return parser.parse_args()
def train_net(data_dir,
seg_dir,
run_distribute,
config=None):
if run_distribute:
init()
rank_id = get_rank()
rank_size = get_group_size()
parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode,
device_num=rank_size,
gradients_mean=True)
else:
rank_id = 0
rank_size = 1
train_dataset = create_dataset(data_path=data_dir, seg_path=seg_dir, config=config, \
rank_size=rank_size, rank_id=rank_id, is_training=True)
train_data_size = train_dataset.get_dataset_size()
print("train dataset length is:", train_data_size)
network = UNet3d(config=config)
loss = SoftmaxCrossEntropyWithLogits()
lr = Tensor(dynamic_lr(config, train_data_size), mstype.float32)
optimizer = nn.Adam(params=network.trainable_params(), learning_rate=lr)
scale_manager = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
network.set_train()
model = Model(network, loss_fn=loss, optimizer=optimizer, loss_scale_manager=scale_manager)
time_cb = TimeMonitor(data_size=train_data_size)
loss_cb = LossMonitor()
ckpt_config = CheckpointConfig(save_checkpoint_steps=train_data_size,
keep_checkpoint_max=config.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix='{}'.format(config.model),
directory='./ckpt_{}/'.format(device_id),
config=ckpt_config)
callbacks_list = [loss_cb, time_cb, ckpoint_cb]
print("============== Starting Training ==============")
model.train(config.epoch_size, train_dataset, callbacks=callbacks_list)
print("============== End Training ==============")
if __name__ == '__main__':
args = get_args()
print("Training setting:", args)
train_net(data_dir=args.data_url,
seg_dir=args.seg_url,
run_distribute=args.run_distribute,
config=cfg)
Loading…
Cancel
Save