Pre Merge pull request !14270 from MapleGrove/simclr
commit
e8dd7eae7f
@ -0,0 +1,62 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
##############export checkpoint file into air, onnx, mindir models#################
|
||||
python export.py
|
||||
"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.simclr_model import SimCLR
|
||||
from src.resnet import resnet50 as resnet
|
||||
|
||||
parser = argparse.ArgumentParser(description='SimCLR')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||
parser.add_argument("--batch_size", type=int, default=128, help="batch size")
|
||||
parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['cifar10'],
|
||||
help='Dataset, Currently only cifar10 is supported.')
|
||||
parser.add_argument('--device_target', type=str, default="Ascend",
|
||||
choices=['Ascend'],
|
||||
help='Device target, Currently only Ascend is supported.')
|
||||
parser.add_argument("--ckpt_file", type=str, required=True, help="Checkpoint file path.")
|
||||
parser.add_argument("--file_name", type=str, default="simclr", help="output file name.")
|
||||
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="AIR", help="file format")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
|
||||
if args_opt.device_target == "Ascend":
|
||||
context.set_context(device_id=args_opt.device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
if args_opt.dataset_name == 'cifar10':
|
||||
width_multiplier = 1
|
||||
cifar_stem = True
|
||||
projection_dimension = 128
|
||||
image_height = 32
|
||||
image_width = 32
|
||||
else:
|
||||
raise ValueError("dataset is not support.")
|
||||
|
||||
base_net = resnet(1, width_multiplier=width_multiplier, cifar_stem=cifar_stem)
|
||||
net = SimCLR(base_net, projection_dimension, base_net.end_point.in_channels)
|
||||
|
||||
param_dict = load_checkpoint(args_opt.ckpt_file)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
input_arr = Tensor(np.zeros([args_opt.batch_size, 3, image_height, image_width]), ms.float32)
|
||||
export(net, input_arr, file_name=args_opt.file_name, file_format=args_opt.file_format)
|
||||
@ -0,0 +1,210 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
######################## eval SimCLR example ########################
|
||||
eval SimCLR according to model file:
|
||||
python eval.py --encoder_checkpoint_path Your.ckpt --train_dataset_path /YourDataPath1
|
||||
--eval_dataset_path /YourDataPath2
|
||||
"""
|
||||
import ast
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import nn
|
||||
from mindspore import ops
|
||||
from mindspore import context
|
||||
from mindspore.common.initializer import TruncatedNormal
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import init, get_rank
|
||||
from src.dataset import create_dataset
|
||||
from src.simclr_model import SimCLR
|
||||
from src.resnet import resnet50 as resnet
|
||||
from src.reporter import Reporter
|
||||
from src.optimizer import get_eval_optimizer as get_optimizer
|
||||
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='Linear Evaluation Protocol')
|
||||
parser.add_argument('--device_target', type=str, default='Ascend',
|
||||
help="Device target, Currently only Ascend is supported.")
|
||||
parser.add_argument('--run_distribute', type=ast.literal_eval, default=False, help='Running distributed evaluation.')
|
||||
parser.add_argument('--run_cloudbrain', type=ast.literal_eval, default=True,
|
||||
help='Whether it is running on CloudBrain platform.')
|
||||
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="device id, default is 0.")
|
||||
parser.add_argument('--dataset_name', type=str, default="cifar10", help='Dataset, Currently only cifar10 is supported.')
|
||||
parser.add_argument('--train_url', default=None, help='Cloudbrain Location of training outputs.\
|
||||
This parameter needs to be set when running on the cloud brain platform.')
|
||||
parser.add_argument('--data_url', default=None, help='Cloudbrain Location of data.\
|
||||
This parameter needs to be set when running on the cloud brain platform.')
|
||||
parser.add_argument('--train_dataset_path', type=str, default="./cifar/train",\
|
||||
help='Dataset path for training classifier.\
|
||||
This parameter needs to be set when running on the host.')
|
||||
parser.add_argument('--eval_dataset_path', type=str, default="./cifar/eval",\
|
||||
help='Dataset path for evaluating classifier.\
|
||||
This parameter needs to be set when running on the host.')
|
||||
parser.add_argument('--train_output_path', type=str, default="./outputs", help='Location of ckpt and log.\
|
||||
This parameter needs to be set when running on the host.')
|
||||
parser.add_argument("--class_num", type=int, default=10, help="dataset classification number")
|
||||
parser.add_argument('--batch_size', type=int, default=128, help='batch_size for training classifier, default is 128.')
|
||||
parser.add_argument('--epoch_size', type=int, default=100, help='epoch size for training classifier, default is 200.')
|
||||
parser.add_argument('--projection_dimension', type=int, default=128,
|
||||
help='Projection output dimensionality, default is 128.')
|
||||
parser.add_argument('--width_multiplier', type=int, default=1, help='width_multiplier=4,resnet50x4')
|
||||
parser.add_argument('--pre_classifier_checkpoint_path', type=str, default=None, help='Classifier Checkpoint file path.')
|
||||
parser.add_argument('--encoder_checkpoint_path', type=str, default="simclr_156.ckpt",
|
||||
help='Encoder Checkpoint file path.')
|
||||
parser.add_argument("--save_checkpoint_epochs", type=int, default=10, help="Save checkpoint epochs, default is 1.")
|
||||
parser.add_argument("--print_iter", type=int, default=100, help="log print iter, default is 100.")
|
||||
parser.add_argument('--save_graphs', type=ast.literal_eval, default=False,
|
||||
help='whether save graphs, default is False.')
|
||||
parser.add_argument('--use_norm', type=ast.literal_eval, default=False, help='Dataset normalize.')
|
||||
|
||||
args = parser.parse_args()
|
||||
set_seed(1)
|
||||
local_data_url = './cache/data'
|
||||
local_train_url = './cache/train'
|
||||
_local_train_url = local_train_url
|
||||
|
||||
if args.device_target != "Ascend":
|
||||
raise ValueError("Unsupported device target.")
|
||||
if args.run_distribute:
|
||||
args.device_id = int(os.getenv("DEVICE_ID"))
|
||||
if args.device_num > int(os.getenv("RANK_SIZE")) or args.device_num == 1:
|
||||
args.device_num = int(os.getenv("RANK_SIZE"))
|
||||
context.set_context(device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=args.save_graphs)
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True, device_num=args.device_num)
|
||||
init()
|
||||
args.rank = get_rank()
|
||||
local_data_url = os.path.join(local_data_url, str(args.device_id))
|
||||
local_train_url = os.path.join(local_train_url, str(args.device_id))
|
||||
args.train_output_path = os.path.join(args.train_output_path, str(args.device_id))
|
||||
else:
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target,
|
||||
save_graphs=args.save_graphs, device_id=args.device_id)
|
||||
args.rank = 0
|
||||
args.device_num = 1
|
||||
|
||||
if args.run_cloudbrain:
|
||||
import moxing as mox
|
||||
args.train_dataset_path = os.path.join(local_data_url, "train")
|
||||
args.eval_dataset_path = os.path.join(local_data_url, "val")
|
||||
args.train_output_path = local_train_url
|
||||
mox.file.copy_parallel(src_url=args.data_url, dst_url=local_data_url)
|
||||
|
||||
class LogisticRegression(nn.Cell):
|
||||
"""
|
||||
Logistic regression
|
||||
"""
|
||||
def __init__(self, n_features, n_classes):
|
||||
super(LogisticRegression, self).__init__()
|
||||
self.model = nn.Dense(n_features, n_classes, TruncatedNormal(0.02), TruncatedNormal(0.02))
|
||||
|
||||
def construct(self, x):
|
||||
x = self.model(x)
|
||||
return x
|
||||
|
||||
class Linear_Eval():
|
||||
"""
|
||||
Linear classifier
|
||||
"""
|
||||
def __init__(self, net, loss):
|
||||
super(Linear_Eval, self).__init__()
|
||||
self.net = net
|
||||
self.softmax = nn.Softmax()
|
||||
self.loss = loss
|
||||
def __call__(self, x, y):
|
||||
x = self.net(x)
|
||||
loss = self.loss(x, y)
|
||||
x = self.softmax(x)
|
||||
predicts = ops.Argmax(output_type=mstype.int32)(x)
|
||||
acc = np.sum(predicts.asnumpy() == y.asnumpy())/len(y.asnumpy())
|
||||
return loss.asnumpy(), acc
|
||||
|
||||
class Linear_Train(nn.Cell):
|
||||
"""
|
||||
Train linear classifier
|
||||
"""
|
||||
def __init__(self, net, loss, opt):
|
||||
super(Linear_Train, self).__init__()
|
||||
self.netwithloss = nn.WithLossCell(net, loss)
|
||||
self.train_net = nn.TrainOneStepCell(self.netwithloss, opt)
|
||||
self.train_net.set_train()
|
||||
def construct(self, x, y):
|
||||
return self.train_net(x, y)
|
||||
|
||||
if __name__ == "__main__":
|
||||
base_net = resnet(1, args.width_multiplier, cifar_stem=args.dataset_name == "cifar10")
|
||||
simclr_model = SimCLR(base_net, args.projection_dimension, base_net.end_point.in_channels)
|
||||
if args.run_cloudbrain:
|
||||
mox.file.copy_parallel(src_url=args.encoder_checkpoint_path, dst_url=local_data_url+'/encoder.ckpt')
|
||||
simclr_param = load_checkpoint(local_data_url+'/encoder.ckpt')
|
||||
else:
|
||||
simclr_param = load_checkpoint(args.encoder_checkpoint_path)
|
||||
load_param_into_net(simclr_model.encoder, simclr_param)
|
||||
classifier = LogisticRegression(simclr_model.n_features, args.class_num)
|
||||
dataset = create_dataset(args, dataset_mode="train_classifier")
|
||||
optimizer = get_optimizer(classifier, dataset.get_dataset_size(), args)
|
||||
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
net_Train = Linear_Train(net=classifier, loss=criterion, opt=optimizer)
|
||||
reporter = Reporter(args, linear_eval=True)
|
||||
reporter.dataset_size = dataset.get_dataset_size()
|
||||
reporter.linear_eval = True
|
||||
if args.pre_classifier_checkpoint_path:
|
||||
if args.run_cloudbrain:
|
||||
mox.file.copy_parallel(src_url=args.pre_classifier_checkpoint_path,
|
||||
dst_url=local_data_url+'/pre_classifier.ckpt')
|
||||
classifier_param = load_checkpoint(local_data_url+'/pre_classifier.ckpt')
|
||||
else:
|
||||
classifier_param = load_checkpoint(args.pre_classifier_checkpoint_path)
|
||||
load_param_into_net(classifier, classifier_param)
|
||||
else:
|
||||
dataset_train = []
|
||||
for _, data in enumerate(dataset, start=1):
|
||||
_, images, labels = data
|
||||
features = simclr_model.inference(images)
|
||||
dataset_train.append([features, labels])
|
||||
reporter.info('==========start training linear classifier===============')
|
||||
# Train.
|
||||
for _ in range(args.epoch_size):
|
||||
reporter.epoch_start()
|
||||
for idx, data in enumerate(dataset_train, start=1):
|
||||
features, labels = data
|
||||
out = net_Train(features, labels)
|
||||
reporter.step_end(out)
|
||||
reporter.epoch_end(classifier)
|
||||
reporter.info('==========end training linear classifier===============')
|
||||
|
||||
dataset = create_dataset(args, dataset_mode="eval_classifier")
|
||||
reporter.dataset_size = dataset.get_dataset_size()
|
||||
net_Eval = Linear_Eval(net=classifier, loss=criterion)
|
||||
# Eval.
|
||||
reporter.info('==========start evaluating linear classifier===============')
|
||||
reporter.start_predict()
|
||||
for idx, data in enumerate(dataset, start=1):
|
||||
_, images, labels = data
|
||||
features = simclr_model.inference(images)
|
||||
batch_loss, batch_acc = net_Eval(features, labels)
|
||||
reporter.predict_step_end(batch_loss, batch_acc)
|
||||
reporter.end_predict()
|
||||
reporter.info('==========end evaluating linear classifier===============')
|
||||
if args.run_cloudbrain:
|
||||
mox.file.copy_parallel(src_url=_local_train_url, dst_url=args.train_url)
|
||||
@ -0,0 +1,64 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
# an simple tutorial as follows, more parameters can be setting
|
||||
if [ $# != 4 ]
|
||||
then
|
||||
echo "Usage: sh run_distribution_ascend.sh [DEVICENUM] [RANK_TABLE_FILE] [cifar10] [TRAIN_DATASET_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
#
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
#
|
||||
if [ ! -f $2 ]
|
||||
then
|
||||
echo "error: RANK_TABLE_FILE=$2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=$1
|
||||
export RANK_SIZE=$1
|
||||
RANK_TABLE_FILE=$(get_real_path $2)
|
||||
export RANK_TABLE_FILE
|
||||
export DATASET_NAME=$3
|
||||
export TRAIN_DATASET_PATH=$(get_real_path $4)
|
||||
echo "RANK_TABLE_FILE=${RANK_TABLE_FILE}"
|
||||
|
||||
export SERVER_ID=0
|
||||
rank_start=$((DEVICE_NUM * SERVER_ID))
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$((rank_start + i))
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cp ../train.py ./train_parallel$i
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
cd ./train_parallel$i ||exit
|
||||
env > env.log
|
||||
python train.py --device_id=$i --dataset_name=$DATASET_NAME --train_dataset_path=$TRAIN_DATASET_PATH \
|
||||
--run_cloudbrain=False --run_distribute=True > log 2>&1 &
|
||||
cd ..
|
||||
done
|
||||
@ -0,0 +1,37 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
# an simple tutorial as follows, more parameters can be setting
|
||||
if [ $# != 5 ]
|
||||
then
|
||||
echo "Usage: sh run_standalone_eval_ascend.sh [cifar10] [DEVICE_ID] [SIMCLR_MODEL_PATH] [TRAIN_DATASET_PATH] [EVAL_DATASET_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
script_self=$(readlink -f "$0")
|
||||
self_path=$(dirname "${script_self}")
|
||||
export DATASET_NAME=$1
|
||||
export DEVICE_ID=$2
|
||||
export SIMCLR_MODEL_PATH=$3
|
||||
export TRAIN_DATASET_PATH=$4
|
||||
export EVAL_DATASET_PATH=$5
|
||||
|
||||
|
||||
python ${self_path}/../linear_eval.py --dataset_name=$DATASET_NAME \
|
||||
--encoder_checkpoint_path=$SIMCLR_MODEL_PATH \
|
||||
--train_dataset_path=$TRAIN_DATASET_PATH \
|
||||
--eval_dataset_path=$EVAL_DATASET_PATH \
|
||||
--device_id=$DEVICE_ID --device_target="Ascend" \
|
||||
--run_distribute=False --run_cloudbrain=False > eval_log 2>&1 &
|
||||
@ -0,0 +1,31 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
# an simple tutorial as follows, more parameters can be setting
|
||||
if [ $# != 3 ]
|
||||
then
|
||||
echo "Usage: sh run_standalone_train_ascend.sh [cifar10] [TRAIN_DATASET_PATH] [DEVICE_ID]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
script_self=$(readlink -f "$0")
|
||||
self_path=$(dirname "${script_self}")
|
||||
export DATASET_NAME=$1
|
||||
export TRAIN_DATASET_PATH=$2
|
||||
export DEVICE_ID=$3
|
||||
|
||||
python ${self_path}/../train.py --dataset_name=$DATASET_NAME --train_dataset_path=$TRAIN_DATASET_PATH \
|
||||
--device_id=$DEVICE_ID --device_target="Ascend" \
|
||||
--run_cloudbrain=False --run_distribute=False > log 2>&1 &
|
||||
@ -0,0 +1,94 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
create train or eval dataset.
|
||||
"""
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
import mindspore.dataset.transforms.c_transforms as C2
|
||||
import mindspore.dataset.vision.py_transforms as py_vision
|
||||
from mindspore.dataset.vision import Inter
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
ds.config.set_seed(0)
|
||||
|
||||
def gaussian_blur(im):
|
||||
sigma = 0
|
||||
_, w = im.shape[:2]
|
||||
kernel_size = int(w // 10)
|
||||
if kernel_size % 2 == 0:
|
||||
kernel_size -= 1
|
||||
return np.array(cv2.GaussianBlur(im, (kernel_size, kernel_size), sigma))
|
||||
|
||||
def copy_column(x, y):
|
||||
return x, x, y
|
||||
|
||||
def create_dataset(args, dataset_mode, repeat_num=1):
|
||||
"""
|
||||
create a train or evaluate cifar10 dataset for SimCLR
|
||||
"""
|
||||
if args.dataset_name != "cifar10":
|
||||
raise ValueError("Unsupported dataset.")
|
||||
if dataset_mode in ("train_endcoder", "train_classifier"):
|
||||
dataset_path = args.train_dataset_path
|
||||
else:
|
||||
dataset_path = args.eval_dataset_path
|
||||
if args.run_distribute and args.device_target == "Ascend":
|
||||
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True,
|
||||
num_shards=args.device_num, shard_id=args.device_id)
|
||||
else:
|
||||
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=8, shuffle=True)
|
||||
# define map operations
|
||||
trans = []
|
||||
if dataset_mode == "train_endcoder":
|
||||
if args.use_crop:
|
||||
trans += [C.Resize(256, interpolation=Inter.BICUBIC)]
|
||||
trans += [C.RandomResizedCrop(size=(32, 32), scale=(0.31, 1),
|
||||
interpolation=Inter.BICUBIC, max_attempts=100)]
|
||||
if args.use_flip:
|
||||
trans += [C.RandomHorizontalFlip(prob=0.5)]
|
||||
if args.use_color_jitter:
|
||||
scale = 0.6
|
||||
color_jitter = C.RandomColorAdjust(0.8 * scale, 0.8 * scale, 0.8 * scale, 0.2 * scale)
|
||||
trans += [C2.RandomApply([color_jitter], prob=0.8)]
|
||||
if args.use_color_gray:
|
||||
trans += [py_vision.ToPIL(),
|
||||
py_vision.RandomGrayscale(prob=0.2),
|
||||
np.array] # need to convert PIL image to a NumPy array to pass it to C++ operation
|
||||
if args.use_blur:
|
||||
trans += [C2.RandomApply([gaussian_blur], prob=0.8)]
|
||||
if args.use_norm:
|
||||
trans += [C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])]
|
||||
trans += [C2.TypeCast(mstype.float32), C.HWC2CHW()]
|
||||
else:
|
||||
trans += [C.Resize(32)]
|
||||
trans += [C2.TypeCast(mstype.float32)]
|
||||
if args.use_norm:
|
||||
trans += [C.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])]
|
||||
trans += [C.HWC2CHW()]
|
||||
type_cast_op = C2.TypeCast(mstype.int32)
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=8)
|
||||
data_set = data_set.map(operations=copy_column, input_columns=["image", "label"],
|
||||
output_columns=["image1", "image2", "label"],
|
||||
column_order=["image1", "image2", "label"], num_parallel_workers=8)
|
||||
data_set = data_set.map(operations=trans, input_columns=["image1"], num_parallel_workers=8)
|
||||
data_set = data_set.map(operations=trans, input_columns=["image2"], num_parallel_workers=8)
|
||||
# apply batch operations
|
||||
data_set = data_set.batch(args.batch_size, drop_remainder=True)
|
||||
# apply dataset repeat operation
|
||||
data_set = data_set.repeat(repeat_num)
|
||||
return data_set
|
||||
@ -0,0 +1,198 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""learning rate generator"""
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _generate_steps_lr(lr_init, lr_max, total_steps, warmup_steps):
|
||||
"""
|
||||
Applies three steps decay to generate learning rate array.
|
||||
"""
|
||||
decay_epoch_index = [0.3 * total_steps, 0.6 * total_steps, 0.8 * total_steps]
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
|
||||
else:
|
||||
if i < decay_epoch_index[0]:
|
||||
lr = lr_max
|
||||
elif i < decay_epoch_index[1]:
|
||||
lr = lr_max * 0.1
|
||||
elif i < decay_epoch_index[2]:
|
||||
lr = lr_max * 0.01
|
||||
else:
|
||||
lr = lr_max * 0.001
|
||||
lr_each_step.append(lr)
|
||||
return lr_each_step
|
||||
|
||||
|
||||
def _generate_poly_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps):
|
||||
"""
|
||||
Applies polynomial decay to generate learning rate array.
|
||||
|
||||
Args:
|
||||
lr_init(float): init learning rate.
|
||||
lr_end(float): end learning rate
|
||||
lr_max(float): max learning rate.
|
||||
total_steps(int): all steps in training.
|
||||
warmup_steps(int): all steps in warmup epochs.
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array.
|
||||
"""
|
||||
lr_each_step = []
|
||||
if warmup_steps != 0:
|
||||
inc_each_step = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
else:
|
||||
inc_each_step = 0
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = float(lr_init) + inc_each_step * float(i)
|
||||
else:
|
||||
base = (1.0 - (float(i) - float(warmup_steps)) / (float(total_steps) - float(warmup_steps)))
|
||||
lr = float(lr_max) * base * base
|
||||
if lr < 0.0:
|
||||
lr = 0.0
|
||||
lr_each_step.append(lr)
|
||||
return lr_each_step
|
||||
|
||||
|
||||
def _generate_cosine_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps):
|
||||
"""
|
||||
Applies cosine decay to generate learning rate array.
|
||||
|
||||
Args:
|
||||
lr_init(float): init learning rate.
|
||||
lr_end(float): end learning rate
|
||||
lr_max(float): max learning rate.
|
||||
total_steps(int): all steps in training.
|
||||
warmup_steps(int): all steps in warmup epochs.
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array.
|
||||
"""
|
||||
decay_steps = total_steps - warmup_steps
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr_inc = (float(lr_max) - float(lr_init)) / float(warmup_steps)
|
||||
lr = float(lr_init) + lr_inc * (i + 1)
|
||||
else:
|
||||
linear_decay = (total_steps - i) / decay_steps
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps))
|
||||
decayed = linear_decay * cosine_decay + 0.00001
|
||||
lr = lr_max * decayed
|
||||
lr_each_step.append(lr)
|
||||
return lr_each_step
|
||||
|
||||
|
||||
def _generate_liner_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps):
|
||||
"""
|
||||
Applies liner decay to generate learning rate array.
|
||||
|
||||
Args:
|
||||
lr_init(float): init learning rate.
|
||||
lr_end(float): end learning rate
|
||||
lr_max(float): max learning rate.
|
||||
total_steps(int): all steps in training.
|
||||
warmup_steps(int): all steps in warmup epochs.
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array.
|
||||
"""
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = lr_init + (lr_max - lr_init) * i / warmup_steps
|
||||
else:
|
||||
lr = lr_max - (lr_max - lr_end) * (i - warmup_steps) / (total_steps - warmup_steps)
|
||||
lr_each_step.append(lr)
|
||||
return lr_each_step
|
||||
|
||||
|
||||
|
||||
def get_lr(lr_init, lr_end, lr_max, warmup_epochs, total_epochs, steps_per_epoch, lr_decay_mode):
|
||||
"""
|
||||
generate learning rate array
|
||||
|
||||
Args:
|
||||
lr_init(float): init learning rate
|
||||
lr_end(float): end learning rate
|
||||
lr_max(float): max learning rate
|
||||
warmup_epochs(int): number of warmup epochs
|
||||
total_epochs(int): total epoch of training
|
||||
steps_per_epoch(int): steps of one epoch
|
||||
lr_decay_mode(string): learning rate decay mode, including steps, poly, cosine or liner(default)
|
||||
|
||||
Returns:
|
||||
np.array, learning rate array
|
||||
"""
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
warmup_steps = steps_per_epoch * warmup_epochs
|
||||
|
||||
if lr_decay_mode == 'steps':
|
||||
lr_each_step = _generate_steps_lr(lr_init, lr_max, total_steps, warmup_steps)
|
||||
elif lr_decay_mode == 'poly':
|
||||
lr_each_step = _generate_poly_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps)
|
||||
elif lr_decay_mode == 'cosine':
|
||||
lr_each_step = _generate_cosine_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps)
|
||||
else:
|
||||
lr_each_step = _generate_liner_lr(lr_init, lr_end, lr_max, total_steps, warmup_steps)
|
||||
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
return lr_each_step
|
||||
|
||||
|
||||
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
|
||||
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
|
||||
lr = float(init_lr) + lr_inc * current_step
|
||||
return lr
|
||||
|
||||
|
||||
def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch=120, global_step=0):
|
||||
"""
|
||||
generate learning rate array with cosine
|
||||
|
||||
Args:
|
||||
lr(float): base learning rate
|
||||
steps_per_epoch(int): steps size of one epoch
|
||||
warmup_epochs(int): number of warmup epochs
|
||||
max_epoch(int): total epochs of training
|
||||
global_step(int): the current start index of lr array
|
||||
Returns:
|
||||
np.array, learning rate array
|
||||
"""
|
||||
base_lr = lr
|
||||
warmup_init_lr = 0
|
||||
total_steps = int(max_epoch * steps_per_epoch)
|
||||
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||
decay_steps = total_steps - warmup_steps
|
||||
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
|
||||
else:
|
||||
linear_decay = (total_steps - i) / decay_steps
|
||||
cosine_decay = 0.5 * (1 + math.cos(math.pi * 2 * 0.47 * i / decay_steps))
|
||||
decayed = linear_decay * cosine_decay + 0.00001
|
||||
lr = base_lr * decayed
|
||||
lr_each_step.append(lr)
|
||||
|
||||
lr_each_step = np.array(lr_each_step).astype(np.float32)
|
||||
learning_rate = lr_each_step[global_step:]
|
||||
return learning_rate
|
||||
@ -0,0 +1,91 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""SimCLR Loss class."""
|
||||
|
||||
from mindspore import Tensor
|
||||
from mindspore import ops as P
|
||||
from mindspore.common import dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
class CrossEntropyLoss(nn.Cell):
|
||||
"""
|
||||
Cross Entropy Loss.
|
||||
"""
|
||||
def __init__(self, reduction="mean"):
|
||||
super(CrossEntropyLoss, self).__init__()
|
||||
self.cross_entropy = P.SoftmaxCrossEntropyWithLogits()
|
||||
if reduction == "sum":
|
||||
self.reduction = P.ReduceSum()
|
||||
if reduction == "mean":
|
||||
self.reduction = P.ReduceMean()
|
||||
self.one_hot = P.OneHot()
|
||||
self.one = Tensor(1.0, mstype.float32)
|
||||
self.zero = Tensor(0.0, mstype.float32)
|
||||
|
||||
def construct(self, logits, label):
|
||||
loss = self.cross_entropy(logits, label)[0]
|
||||
loss = self.reduction(loss, (-1,))
|
||||
return loss
|
||||
|
||||
|
||||
class NT_Xent_Loss(nn.Cell):
|
||||
"""
|
||||
Loss for SimCLR.
|
||||
"""
|
||||
def __init__(self, batch_size, temperature=1, world_size=1):
|
||||
super(NT_Xent_Loss, self).__init__()
|
||||
# Parameters.
|
||||
self.LARGE_NUM = 1e9
|
||||
self.batch_size = batch_size
|
||||
self.temperature = temperature
|
||||
self.world_size = world_size
|
||||
self.N = 2 * self.batch_size * self.world_size
|
||||
# Tail_Loss.
|
||||
self.criterion = CrossEntropyLoss(reduction="mean")
|
||||
self.norm = P.L2Normalize(axis=1)
|
||||
self.one_hot = P.OneHot()
|
||||
self.range = nn.Range(0, self.batch_size)
|
||||
self.one = Tensor(1.0, mstype.float32)
|
||||
self.zero = Tensor(0.0, mstype.float32)
|
||||
self.transpose = P.Transpose()
|
||||
self.matmul = nn.MatMul()
|
||||
# Operations.
|
||||
self.ones = P.Ones()
|
||||
self.zeros = P.Zeros()
|
||||
self.cat1 = P.Concat(axis=1)
|
||||
|
||||
def construct(self, z_i, z_j):
|
||||
"""
|
||||
Forward.
|
||||
"""
|
||||
hidden1 = self.norm(z_i)
|
||||
hidden2 = self.norm(z_j)
|
||||
hidden1_large = hidden1
|
||||
hidden2_large = hidden2
|
||||
ones_mask = self.range()
|
||||
zeros_mask = self.zeros((self.batch_size, self.batch_size), mstype.float32)
|
||||
masks = self.one_hot(ones_mask, self.batch_size, self.one, self.zero)
|
||||
labels = self.cat1((masks, zeros_mask))
|
||||
logits_aa = self.matmul(hidden1, self.transpose(hidden1_large, (1, 0))) / self.temperature
|
||||
logits_aa = logits_aa - masks * self.LARGE_NUM
|
||||
logits_bb = self.matmul(hidden2, self.transpose(hidden2_large, (1, 0))) / self.temperature
|
||||
logits_bb = logits_bb - masks * self.LARGE_NUM
|
||||
logits_ab = self.matmul(hidden1, self.transpose(hidden2_large, (1, 0))) / self.temperature
|
||||
logits_ba = self.matmul(hidden2, self.transpose(hidden1_large, (1, 0))) / self.temperature
|
||||
loss_a = self.criterion(self.cat1((logits_ab, logits_aa)), labels)
|
||||
loss_b = self.criterion(self.cat1((logits_ba, logits_bb)), labels)
|
||||
loss = loss_a + loss_b
|
||||
return loss
|
||||
@ -0,0 +1,52 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""optimizer generator"""
|
||||
from mindspore import nn, Tensor
|
||||
from .lr_generator import get_lr
|
||||
|
||||
def get_train_optimizer(net, steps_per_epoch, args):
|
||||
"""
|
||||
generate optimizer for updating the weights.
|
||||
"""
|
||||
if args.optimizer == "Adam":
|
||||
lr = get_lr(lr_init=1e-4, lr_end=1e-6, lr_max=9e-4,
|
||||
warmup_epochs=args.warmup_epochs, total_epochs=args.epoch_size,
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
lr_decay_mode="linear")
|
||||
lr = Tensor(lr)
|
||||
decayed_params = []
|
||||
no_decayed_params = []
|
||||
for param in net.trainable_params():
|
||||
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
|
||||
decayed_params.append(param)
|
||||
else:
|
||||
no_decayed_params.append(param)
|
||||
group_params = [{'params': decayed_params, 'weight_decay': args.weight_decay},
|
||||
{'params': no_decayed_params},
|
||||
{'order_params': net.trainable_params()}]
|
||||
optimizer = nn.Adam(params=group_params, learning_rate=lr)
|
||||
else:
|
||||
raise ValueError("Unsupported optimizer.")
|
||||
|
||||
return optimizer
|
||||
|
||||
def get_eval_optimizer(net, steps_per_epoch, args):
|
||||
lr = get_lr(lr_init=1e-3, lr_end=6e-6, lr_max=1e-2,
|
||||
warmup_epochs=5, total_epochs=args.epoch_size,
|
||||
steps_per_epoch=steps_per_epoch,
|
||||
lr_decay_mode="linear")
|
||||
lr = Tensor(lr)
|
||||
optimizer = nn.Adam(params=net.trainable_params(), learning_rate=lr)
|
||||
return optimizer
|
||||
@ -0,0 +1,135 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Reporter class."""
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from mindspore.train.serialization import save_checkpoint
|
||||
|
||||
class Reporter(logging.Logger):
|
||||
"""
|
||||
This class includes several functions that can save images/checkpoints and print/save logging information.
|
||||
"""
|
||||
def __init__(self, args, linear_eval):
|
||||
super(Reporter, self).__init__("clean")
|
||||
self.log_dir = os.path.join(args.train_output_path, 'log')
|
||||
if not os.path.exists(self.log_dir):
|
||||
os.makedirs(self.log_dir, exist_ok=True)
|
||||
if linear_eval:
|
||||
self.ckpts_dir = os.path.join(args.train_output_path, "checkpoint")
|
||||
if not os.path.exists(self.ckpts_dir):
|
||||
os.makedirs(self.ckpts_dir, exist_ok=True)
|
||||
self.rank = args.rank
|
||||
self.save_checkpoint_epochs = args.save_checkpoint_epochs
|
||||
formatter = logging.Formatter('%(message)s')
|
||||
# console handler
|
||||
console = logging.StreamHandler()
|
||||
console.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(message)s')
|
||||
console.setFormatter(formatter)
|
||||
self.addHandler(console)
|
||||
# file handler
|
||||
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(self.rank)
|
||||
self.log_fn = os.path.join(self.log_dir, log_name)
|
||||
fh = logging.FileHandler(self.log_fn)
|
||||
fh.setLevel(logging.INFO)
|
||||
fh.setFormatter(formatter)
|
||||
self.addHandler(fh)
|
||||
if args:
|
||||
self.save_args(args)
|
||||
self.step = 0
|
||||
self.epoch = 0
|
||||
self.dataset_size = 0
|
||||
self.print_iter = args.print_iter
|
||||
self.contrastive_loss = []
|
||||
self.linear_eval = False
|
||||
self.Loss = 0
|
||||
self.Acc = 0
|
||||
|
||||
|
||||
def info(self, msg, *args, **kwargs):
|
||||
if self.isEnabledFor(logging.INFO):
|
||||
self._log(logging.INFO, msg, args, **kwargs)
|
||||
|
||||
def save_args(self, args):
|
||||
self.info('Args:')
|
||||
args_dict = vars(args)
|
||||
for key in args_dict.keys():
|
||||
self.info('--> %s: %s', key, args_dict[key])
|
||||
self.info('')
|
||||
|
||||
def important_info(self, msg, *args, **kwargs):
|
||||
if self.logger.isEnabledFor(logging.INFO) and self.rank == 0:
|
||||
line_width = 2
|
||||
important_msg = '\n'
|
||||
important_msg += ('*'*70 + '\n')*line_width
|
||||
important_msg += ('*'*line_width + '\n')*2
|
||||
important_msg += '*'*line_width + ' '*8 + msg + '\n'
|
||||
important_msg += ('*'*line_width + '\n')*2
|
||||
important_msg += ('*'*70 + '\n')*line_width
|
||||
self.info(important_msg, *args, **kwargs)
|
||||
|
||||
def epoch_start(self):
|
||||
self.step_start_time = time.time()
|
||||
self.epoch_start_time = time.time()
|
||||
self.step = 0
|
||||
self.epoch += 1
|
||||
self.contrastive_loss = []
|
||||
|
||||
def step_end(self, loss):
|
||||
"""print log when step end."""
|
||||
self.step += 1
|
||||
self.contrastive_loss.append(loss.asnumpy())
|
||||
if self.step % self.print_iter == 0:
|
||||
step_cost = (time.time() - self.step_start_time) * 1000 / self.print_iter
|
||||
self.info("Epoch[{}] [{}/{}] step cost: {:.2f} ms, loss: {}".format(
|
||||
self.epoch, self.step, self.dataset_size, step_cost, loss))
|
||||
self.step_start_time = time.time()
|
||||
|
||||
def epoch_end(self, net):
|
||||
"""print log and save cgeckpoints when epoch end."""
|
||||
epoch_cost = (time.time() - self.epoch_start_time) * 1000
|
||||
pre_step_time = epoch_cost / self.dataset_size
|
||||
mean_loss = sum(self.contrastive_loss) / self.dataset_size
|
||||
|
||||
self.info("Epoch [{}] total cost: {:.2f} ms, pre step: {:.2f} ms, mean_loss: {:.2f}"\
|
||||
.format(self.epoch, epoch_cost, pre_step_time, mean_loss))
|
||||
if self.epoch % self.save_checkpoint_epochs == 0:
|
||||
if self.linear_eval:
|
||||
save_checkpoint(net, os.path.join(self.ckpts_dir, f"linearClassifier_{self.epoch}.ckpt"))
|
||||
else:
|
||||
save_checkpoint(net, os.path.join(self.ckpts_dir, f"simclr_{self.epoch}.ckpt"))
|
||||
|
||||
def start_predict(self):
|
||||
self.predict_start_time = time.time()
|
||||
self.step = 0
|
||||
self.info('==========start predict===============')
|
||||
|
||||
def end_predict(self):
|
||||
avg_loss = self.Loss / self.step
|
||||
avg_acc = self.Acc / self.step
|
||||
self.info('Average loss {:.5f}, Average accuracy {:.5f}'.format(avg_loss, avg_acc))
|
||||
self.info('==========end predict===============\n')
|
||||
|
||||
def predict_step_end(self, loss, acc):
|
||||
self.step += 1
|
||||
self.Loss = self.Loss + loss
|
||||
self.Acc = self.Acc + acc
|
||||
if self.step % self.print_iter == 0:
|
||||
current_loss = self.Loss / self.step
|
||||
current_acc = self.Acc / self.step
|
||||
self.info('[{}/{}] Current total loss {:.5f}, Current total accuracy {:.5f}'\
|
||||
.format(self.step, self.dataset_size, current_loss, current_acc))
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,53 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""SimCLR Model class."""
|
||||
from mindspore import nn
|
||||
from .resnet import _fc
|
||||
|
||||
class Identity(nn.Cell):
|
||||
def construct(self, x):
|
||||
return x
|
||||
|
||||
class SimCLR(nn.Cell):
|
||||
"""
|
||||
SimCLR Model.
|
||||
"""
|
||||
def __init__(self, encoder, project_dim, n_features):
|
||||
super(SimCLR, self).__init__()
|
||||
self.encoder = encoder
|
||||
self.n_features = n_features
|
||||
self.encoder.end_point = Identity()
|
||||
self.dense1 = _fc(self.n_features, self.n_features)
|
||||
self.relu = nn.ReLU()
|
||||
self.end_point = _fc(self.n_features, project_dim)
|
||||
|
||||
# Projector MLP.
|
||||
def projector(self, x):
|
||||
out = self.dense1(x)
|
||||
out = self.relu(out)
|
||||
out = self.end_point(out)
|
||||
return out
|
||||
|
||||
def construct(self, x_i, x_j):
|
||||
h_i = self.encoder(x_i)
|
||||
z_i = self.projector(h_i)
|
||||
|
||||
h_j = self.encoder(x_j)
|
||||
z_j = self.projector(h_j)
|
||||
return h_i, h_j, z_i, z_j
|
||||
|
||||
def inference(self, x):
|
||||
h = self.encoder(x)
|
||||
return h
|
||||
@ -0,0 +1,158 @@
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
######################## train SimCLR example ########################
|
||||
train simclr and get network model files(.ckpt) :
|
||||
python train.py --train_dataset_path /YourDataPath
|
||||
"""
|
||||
import ast
|
||||
import argparse
|
||||
import os
|
||||
from src.nt_xent import NT_Xent_Loss
|
||||
from src.optimizer import get_train_optimizer as get_optimizer
|
||||
from src.dataset import create_dataset
|
||||
from src.simclr_model import SimCLR
|
||||
from src.resnet import resnet50 as resnet
|
||||
from mindspore import nn
|
||||
from mindspore import context
|
||||
from mindspore.train import Model
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.common import initializer as weight_init
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import init, get_rank
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
parser = argparse.ArgumentParser(description='MindSpore SimCLR')
|
||||
parser.add_argument('--device_target', type=str, default='Ascend',
|
||||
help='Device target, Currently only Ascend is supported.')
|
||||
parser.add_argument('--run_cloudbrain', type=ast.literal_eval, default=True,
|
||||
help='Whether it is running on CloudBrain platform.')
|
||||
parser.add_argument('--run_distribute', type=ast.literal_eval, default=True, help='Run distributed training.')
|
||||
parser.add_argument('--device_num', type=int, default=1, help='Device num.')
|
||||
parser.add_argument("--device_id", type=int, default=0, help="device id, default is 0.")
|
||||
parser.add_argument('--dataset_name', type=str, default="cifar10", help='Dataset, Currently only cifar10 is supported.')
|
||||
parser.add_argument('--train_url', default=None, help='Cloudbrain Location of training outputs.\
|
||||
This parameter needs to be set when running on the cloud brain platform.')
|
||||
parser.add_argument('--data_url', default=None, help='Cloudbrain Location of data.\
|
||||
This parameter needs to be set when running on the cloud brain platform.')
|
||||
parser.add_argument('--train_dataset_path', type=str, default="./cifar/train",
|
||||
help='Dataset path for training classifier. '
|
||||
'This parameter needs to be set when running on the host.')
|
||||
parser.add_argument('--train_output_path', type=str, default="./outputs", help='Location of ckpt and log.\
|
||||
This parameter needs to be set when running on the host.')
|
||||
parser.add_argument('--batch_size', type=int, default=128, help='batch_size, default is 128.')
|
||||
parser.add_argument('--epoch_size', type=int, default=100, help='epoch size for training, default is 200.')
|
||||
parser.add_argument('--projection_dimension', type=int, default=128,
|
||||
help='Projection output dimensionality, default is 128.')
|
||||
parser.add_argument('--width_multiplier', type=int, default=1, help='width_multiplier for ResNet50')
|
||||
parser.add_argument("--temperature", type=float, default=0.5, help="temperature for loss")
|
||||
parser.add_argument('--pre_trained_path', type=str, default=None, help='Pretrained checkpoint path')
|
||||
parser.add_argument("--pretrain_epoch_size", type=int, default=0,
|
||||
help="real_epoch_size = epoch_size - pretrain_epoch_size.")
|
||||
parser.add_argument("--save_checkpoint_epochs", type=int, default=1, help="Save checkpoint epochs, default is 1.")
|
||||
parser.add_argument('--save_graphs', type=ast.literal_eval, default=False,
|
||||
help='whether save graphs, default is False.')
|
||||
parser.add_argument('--optimizer', type=str, default="Adam", help='Optimizer, Currently only Adam is supported.')
|
||||
parser.add_argument("--weight_decay", type=float, default=3e-4, help="weight decay")
|
||||
parser.add_argument("--warmup_epochs", type=int, default=15, help="warmup epochs.")
|
||||
parser.add_argument('--use_crop', type=ast.literal_eval, default=True, help='RandomResizedCrop')
|
||||
parser.add_argument('--use_flip', type=ast.literal_eval, default=True, help='RandomHorizontalFlip')
|
||||
parser.add_argument('--use_color_jitter', type=ast.literal_eval, default=True, help='RandomColorAdjust')
|
||||
parser.add_argument('--use_color_gray', type=ast.literal_eval, default=True, help='RandomGrayscale')
|
||||
parser.add_argument('--use_blur', type=ast.literal_eval, default=False, help='GaussianBlur')
|
||||
parser.add_argument('--use_norm', type=ast.literal_eval, default=False, help='Normalize')
|
||||
|
||||
args = parser.parse_args()
|
||||
local_data_url = './cache/data'
|
||||
local_train_url = './cache/train'
|
||||
_local_train_url = local_train_url
|
||||
|
||||
if args.device_target != "Ascend":
|
||||
raise ValueError("Unsupported device target.")
|
||||
if args.run_distribute:
|
||||
args.device_id = int(os.getenv("DEVICE_ID"))
|
||||
if args.device_num > int(os.getenv("RANK_SIZE")) or args.device_num == 1:
|
||||
args.device_num = int(os.getenv("RANK_SIZE"))
|
||||
context.set_context(device_id=args.device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=args.save_graphs)
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True, device_num=args.device_num)
|
||||
init()
|
||||
args.rank = get_rank()
|
||||
local_data_url = os.path.join(local_data_url, str(args.device_id))
|
||||
local_train_url = os.path.join(local_train_url, str(args.device_id))
|
||||
args.train_output_path = os.path.join(args.train_output_path, str(args.device_id))
|
||||
else:
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target,
|
||||
save_graphs=args.save_graphs, device_id=args.device_id)
|
||||
args.rank = 0
|
||||
args.device_num = 1
|
||||
|
||||
if args.run_cloudbrain:
|
||||
import moxing as mox
|
||||
args.train_dataset_path = os.path.join(local_data_url, "train")
|
||||
args.train_output_path = local_train_url
|
||||
mox.file.copy_parallel(src_url=args.data_url, dst_url=local_data_url)
|
||||
|
||||
set_seed(1)
|
||||
|
||||
class NetWithLossCell(nn.Cell):
|
||||
def __init__(self, backbone, loss_fn):
|
||||
super(NetWithLossCell, self).__init__(auto_prefix=False)
|
||||
self._backbone = backbone
|
||||
self._loss_fn = loss_fn
|
||||
|
||||
def construct(self, data_x, data_y, label):
|
||||
_, _, x_pred, y_pred = self._backbone(data_x, data_y)
|
||||
return self._loss_fn(x_pred, y_pred)
|
||||
|
||||
if __name__ == "__main__":
|
||||
dataset = create_dataset(args, dataset_mode="train_endcoder")
|
||||
# Net.
|
||||
base_net = resnet(1, args.width_multiplier, cifar_stem=args.dataset_name == "cifar10")
|
||||
net = SimCLR(base_net, args.projection_dimension, base_net.end_point.in_channels)
|
||||
# init weight
|
||||
if args.pre_trained_path:
|
||||
if args.run_cloudbrain:
|
||||
mox.file.copy_parallel(src_url=args.pre_trained_path, dst_url=local_data_url+'/pre_train.ckpt')
|
||||
param_dict = load_checkpoint(local_data_url+'/pre_train.ckpt')
|
||||
else:
|
||||
param_dict = load_checkpoint(args.pre_trained_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
else:
|
||||
for _, cell in net.cells_and_names():
|
||||
if isinstance(cell, nn.Conv2d):
|
||||
cell.weight.set_data(weight_init.initializer(weight_init.XavierUniform(),
|
||||
cell.weight.shape,
|
||||
cell.weight.dtype))
|
||||
if isinstance(cell, nn.Dense):
|
||||
cell.weight.set_data(weight_init.initializer(weight_init.TruncatedNormal(),
|
||||
cell.weight.shape,
|
||||
cell.weight.dtype))
|
||||
optimizer = get_optimizer(net, dataset.get_dataset_size(), args)
|
||||
loss = NT_Xent_Loss(args.batch_size, args.temperature)
|
||||
net_loss = NetWithLossCell(net, loss)
|
||||
train_net = nn.TrainOneStepCell(net_loss, optimizer)
|
||||
model = Model(train_net)
|
||||
time_cb = TimeMonitor(data_size=dataset.get_dataset_size())
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=args.save_checkpoint_epochs)
|
||||
ckpts_dir = os.path.join(args.train_output_path, "checkpoint")
|
||||
ckpoint_cb = ModelCheckpoint(prefix="checkpoint_simclr", directory=ckpts_dir, config=config_ck)
|
||||
print("============== Starting Training ==============")
|
||||
model.train(args.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, LossMonitor()])
|
||||
if args.run_cloudbrain and args.device_id == 0:
|
||||
mox.file.copy_parallel(src_url=_local_train_url, dst_url=args.train_url)
|
||||
Loading…
Reference in new issue