!12885 Add transfer training to unet and update readme.

From: @c_34
Reviewed-by: 
Signed-off-by:
pull/12885/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 1fb56a2481

@ -0,0 +1,6 @@
ARG FROM_IMAGE_NAME
FROM ${FROM_IMAGE_NAME}
RUN apt install libgl1-mesa-glx -y
COPY requirements.txt .
RUN pip3.7 install -r requirements.txt

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -0,0 +1,29 @@
#!/bin/bash
docker_image=$1
data_dir=$2
model_dir=$3
docker run -it --ipc=host \
--device=/dev/davinci0 \
--device=/dev/davinci1 \
--device=/dev/davinci2 \
--device=/dev/davinci3 \
--device=/dev/davinci4 \
--device=/dev/davinci5 \
--device=/dev/davinci6 \
--device=/dev/davinci7 \
--device=/dev/davinci_manager \
--device=/dev/devmm_svm \
--device=/dev/hisi_hdc \
--privileged \
-v /usr/local/Ascend/driver:/usr/local/Ascend/driver \
-v /usr/local/Ascend/add-ons/:/usr/local/Ascend/add-ons \
-v ${data_dir}:${data_dir} \
-v ${model_dir}:${model_dir} \
-v /var/log/npu/conf/slog/slog.conf:/var/log/npu/conf/slog/slog.conf \
-v /var/log/npu/slog/:/var/log/npu/slog/ \
-v /var/log/npu/profiling/:/var/log/npu/profiling \
-v /var/log/npu/dump/:/var/log/npu/dump \
-v /var/log/npu/:/usr/slog ${docker_image} \
/bin/bash

@ -32,6 +32,8 @@ cfg_unet_medical = {
'resume': False,
'resume_ckpt': './',
'transfer_training': False,
'filter_weight': ['outc.weight', 'outc.bias']
}
cfg_unet_nested = {
@ -56,6 +58,8 @@ cfg_unet_nested = {
'resume': False,
'resume_ckpt': './',
'transfer_training': False,
'filter_weight': ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight']
}
cfg_unet_nested_cell = {
@ -81,6 +85,8 @@ cfg_unet_nested_cell = {
'resume': False,
'resume_ckpt': './',
'transfer_training': False,
'filter_weight': ['final1.weight', 'final2.weight', 'final3.weight', 'final4.weight']
}
cfg_unet_simple = {
@ -102,6 +108,8 @@ cfg_unet_simple = {
'resume': False,
'resume_ckpt': './',
'transfer_training': False,
'filter_weight': ["final.weight"]
}
cfg_unet = cfg_unet_medical

@ -68,6 +68,15 @@ class StepLossTimeMonitor(Callback):
print("epoch: {:3d}, avg loss:{:.4f}, total cost: {:.3f} s, per step fps:{:5.3f}".format(
cb_params.cur_epoch_num, np.mean(self.losses), epoch_cost, step_fps), flush=True)
def mask_to_image(mask):
return Image.fromarray((mask * 255).astype(np.uint8))
def filter_checkpoint_parameter_by_list(param_dict, filter_list):
"""remove useless parameters according to filter_list"""
for key in list(param_dict.keys()):
for name in filter_list:
if name in key:
print("Delete parameter from checkpoint: ", key)
del param_dict[key]
break

@ -30,7 +30,7 @@ from src.unet_medical import UNetMedical
from src.unet_nested import NestedUNet, UNet
from src.data_loader import create_dataset, create_cell_nuclei_dataset
from src.loss import CrossEntropyWithLogits, MultiCrossEntropyWithLogits
from src.utils import StepLossTimeMonitor
from src.utils import StepLossTimeMonitor, filter_checkpoint_parameter_by_list
from src.config import cfg_unet
device_id = int(os.getenv('DEVICE_ID'))
@ -45,7 +45,6 @@ def train_net(data_dir,
lr=0.0001,
run_distribute=False,
cfg=None):
rank = 0
group_size = 1
if run_distribute:
@ -69,6 +68,8 @@ def train_net(data_dir,
if cfg['resume']:
param_dict = load_checkpoint(cfg['resume_ckpt'])
if cfg['transfer_training']:
filter_checkpoint_parameter_by_list(param_dict, cfg['filter_weight'])
load_param_into_net(net, param_dict)
if 'use_ds' in cfg and cfg['use_ds']:

Loading…
Cancel
Save