diff --git a/model_zoo/official/cv/densenet121/scripts/run_distribute_eval.sh b/model_zoo/official/cv/densenet121/scripts/run_distribute_eval.sh index f1ee68417b..21e2761cfb 100644 --- a/model_zoo/official/cv/densenet121/scripts/run_distribute_eval.sh +++ b/model_zoo/official/cv/densenet121/scripts/run_distribute_eval.sh @@ -37,7 +37,7 @@ do cp -r ./src ./eval_$i cd ./eval_$i || exit export RANK_ID=$i - echo "start training for rank $i, device $DEVICE_ID" + echo "start infering for rank $i, device $DEVICE_ID" env > env.log python eval.py \ --data_dir=$DATASET \ diff --git a/model_zoo/official/cv/densenet121/src/datasets/classification.py b/model_zoo/official/cv/densenet121/src/datasets/classification.py index f7754e38cb..0e9f2124e5 100644 --- a/model_zoo/official/cv/densenet121/src/datasets/classification.py +++ b/model_zoo/official/cv/densenet121/src/datasets/classification.py @@ -141,7 +141,6 @@ def classification_dataset(data_dir, image_size, per_batch_size, max_epoch, rank dataset = TxtDataset(root, data_dir) sampler = DistributedSampler(dataset, rank, group_size, shuffle=shuffle) de_dataset = de.GeneratorDataset(dataset, ["image", "label"], sampler=sampler) - de_dataset.set_dataset_size(len(sampler)) de_dataset = de_dataset.map(input_columns="image", num_parallel_workers=8, operations=transform_img) de_dataset = de_dataset.map(input_columns="label", num_parallel_workers=8, operations=transform_label) diff --git a/model_zoo/official/cv/unet/export.py b/model_zoo/official/cv/unet/export.py new file mode 100644 index 0000000000..12c2b9187c --- /dev/null +++ b/model_zoo/official/cv/unet/export.py @@ -0,0 +1,36 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# less required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import argparse +import numpy as np + +from mindspore import Tensor +from mindspore.train.serialization import export, load_checkpoint, load_param_into_net + +from src.unet.unet_model import UNet + +parser = argparse.ArgumentParser(description='Export ckpt to air') +parser.add_argument('--ckpt_file', type=str, default="ckpt_unet_medical_adam-1_600.ckpt", + help='The path of input ckpt file') +parser.add_argument('--air_file', type=str, default="unet_medical_adam-1_600.air", help='The path of output air file') +args = parser.parse_args() + +net = UNet(n_channels=1, n_classes=2) +# return a parameter dict for model +param_dict = load_checkpoint(args.ckpt_file) +# load the parameter into net +load_param_into_net(net, param_dict) +input_data = np.random.uniform(0.0, 1.0, size=[1, 1, 572, 572]).astype(np.float32) +export(net, Tensor(input_data), file_name=args.air_file, file_format='AIR') diff --git a/model_zoo/official/cv/yolov3_darknet53/README.md b/model_zoo/official/cv/yolov3_darknet53/README.md index 36c3742d1f..de93846d96 100644 --- a/model_zoo/official/cv/yolov3_darknet53/README.md +++ b/model_zoo/official/cv/yolov3_darknet53/README.md @@ -69,7 +69,7 @@ After installing MindSpore via the official website, you can start training and ``` # The darknet53_backbone.ckpt in the follow script is got from darknet53 training like paper. -# The parameter of pretrained_backbone is not necessary. +# pretrained_backbone can use src/convert_weight.py, convert darknet53.conv.74 to mindspore ckpt, darknet53.conv.74 can get from `https://pjreddie.com/media/files/darknet53.conv.74` . # The parameter of training_shape define image shape for network, default is "". # It means use 10 kinds of shape as input shape, or it can be set some kind of shape. # run training example(1p) by python command. diff --git a/model_zoo/official/cv/yolov3_darknet53/src/__init__.py b/model_zoo/official/cv/yolov3_darknet53/src/__init__.py index e69de29bb2..e30774307c 100644 --- a/model_zoo/official/cv/yolov3_darknet53/src/__init__.py +++ b/model_zoo/official/cv/yolov3_darknet53/src/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ diff --git a/model_zoo/official/cv/yolov3_darknet53/src/convert_weight.py b/model_zoo/official/cv/yolov3_darknet53/src/convert_weight.py new file mode 100644 index 0000000000..e5d10e313b --- /dev/null +++ b/model_zoo/official/cv/yolov3_darknet53/src/convert_weight.py @@ -0,0 +1,80 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Convert weight to mindspore ckpt.""" +import os +import argparse +import numpy as np +from mindspore.train.serialization import save_checkpoint +from mindspore import Tensor + +from src.yolo import YOLOV3DarkNet53 + +def load_weight(weights_file): + """Loads pre-trained weights.""" + if not os.path.isfile(weights_file): + raise ValueError(f'"{weights_file}" is not a valid weight file.') + with open(weights_file, 'rb') as fp: + np.fromfile(fp, dtype=np.int32, count=5) + return np.fromfile(fp, dtype=np.float32) + + +def build_network(): + """Build YOLOv3 network.""" + network = YOLOV3DarkNet53(is_training=True) + params = network.get_parameters() + params = [p for p in params if 'backbone' in p.name] + return params + + +def convert(weights_file, output_file): + """Conver weight to mindspore ckpt.""" + params = build_network() + weights = load_weight(weights_file) + index = 0 + param_list = [] + for i in range(0, len(params), 5): + weight = params[i] + mean = params[i+1] + var = params[i+2] + gamma = params[i+3] + beta = params[i+4] + beta_data = weights[index: index+beta.size()].reshape(beta.shape) + index += beta.size() + gamma_data = weights[index: index+gamma.size()].reshape(gamma.shape) + index += gamma.size() + mean_data = weights[index: index+mean.size()].reshape(mean.shape) + index += mean.size() + var_data = weights[index: index + var.size()].reshape(var.shape) + index += var.size() + weight_data = weights[index: index+weight.size()].reshape(weight.shape) + index += weight.size() + + param_list.append({'name': weight.name, 'type': weight.dtype, 'shape': weight.shape, + 'data': Tensor(weight_data)}) + param_list.append({'name': mean.name, 'type': mean.dtype, 'shape': mean.shape, 'data': Tensor(mean_data)}) + param_list.append({'name': var.name, 'type': var.dtype, 'shape': var.shape, 'data': Tensor(var_data)}) + param_list.append({'name': gamma.name, 'type': gamma.dtype, 'shape': gamma.shape, 'data': Tensor(gamma_data)}) + param_list.append({'name': beta.name, 'type': beta.dtype, 'shape': beta.shape, 'data': Tensor(beta_data)}) + + save_checkpoint(param_list, output_file) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="yolov3 weight convert.") + parser.add_argument("--input_file", type=str, default="./darknet53.conv.74", help="input file path.") + parser.add_argument("--output_file", type=str, default="./ackbone_darknet53.ckpt", help="output file path.") + args_opt = parser.parse_args() + + convert(args_opt.input_file, args_opt.output_file) diff --git a/model_zoo/official/cv/yolov3_darknet53/src/darknet.py b/model_zoo/official/cv/yolov3_darknet53/src/darknet.py index 4a2eb1de78..7e2e04b1fd 100644 --- a/model_zoo/official/cv/yolov3_darknet53/src/darknet.py +++ b/model_zoo/official/cv/yolov3_darknet53/src/darknet.py @@ -115,39 +115,38 @@ class DarkNet(nn.Cell): out_channels[0], kernel_size=3, stride=2) - self.conv2 = conv_block(in_channels[1], - out_channels[1], - kernel_size=3, - stride=2) - self.conv3 = conv_block(in_channels[2], - out_channels[2], - kernel_size=3, - stride=2) - self.conv4 = conv_block(in_channels[3], - out_channels[3], - kernel_size=3, - stride=2) - self.conv5 = conv_block(in_channels[4], - out_channels[4], - kernel_size=3, - stride=2) - self.layer1 = self._make_layer(block, layer_nums[0], in_channel=out_channels[0], out_channel=out_channels[0]) + self.conv2 = conv_block(in_channels[1], + out_channels[1], + kernel_size=3, + stride=2) self.layer2 = self._make_layer(block, layer_nums[1], in_channel=out_channels[1], out_channel=out_channels[1]) + self.conv3 = conv_block(in_channels[2], + out_channels[2], + kernel_size=3, + stride=2) self.layer3 = self._make_layer(block, layer_nums[2], in_channel=out_channels[2], out_channel=out_channels[2]) + self.conv4 = conv_block(in_channels[3], + out_channels[3], + kernel_size=3, + stride=2) self.layer4 = self._make_layer(block, layer_nums[3], in_channel=out_channels[3], out_channel=out_channels[3]) + self.conv5 = conv_block(in_channels[4], + out_channels[4], + kernel_size=3, + stride=2) self.layer5 = self._make_layer(block, layer_nums[4], in_channel=out_channels[4],