change manul float16 to auto mixed_precision

pull/13825/head
chenhaozhe 4 years ago
parent 12a29ce040
commit 8a4d44f4d6

@ -22,7 +22,7 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.loss import CTCLoss from src.loss import CTCLoss
from src.dataset import create_dataset from src.dataset import create_dataset
from src.crnn import CRNN from src.crnn import crnn
from src.metric import CRNNAccuracy from src.metric import CRNNAccuracy
set_seed(1) set_seed(1)
@ -60,7 +60,7 @@ if __name__ == '__main__':
loss = CTCLoss(max_sequence_length=config.num_step, loss = CTCLoss(max_sequence_length=config.num_step,
max_label_length=max_text_length, max_label_length=max_text_length,
batch_size=config.batch_size) batch_size=config.batch_size)
net = CRNN(config) net = crnn(config)
# load checkpoint # load checkpoint
param_dict = load_checkpoint(args_opt.checkpoint_path) param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict) load_param_into_net(net, param_dict)

@ -20,7 +20,7 @@ import numpy as np
import mindspore as ms import mindspore as ms
from mindspore import Tensor, context, load_checkpoint, export from mindspore import Tensor, context, load_checkpoint, export
from src.crnn import CRNN from src.crnn import crnn
from src.config import config1 as config from src.config import config1 as config
parser = argparse.ArgumentParser(description="CRNN_export") parser = argparse.ArgumentParser(description="CRNN_export")
@ -37,7 +37,7 @@ if args.device_target == "Ascend":
if __name__ == "__main__": if __name__ == "__main__":
config.batch_size = 1 config.batch_size = 1
net = CRNN(config) net = crnn(config)
load_checkpoint(args.ckpt_file, net=net) load_checkpoint(args.ckpt_file, net=net)
net.set_train(False) net.set_train(False)

@ -96,28 +96,28 @@ class CRNN(nn.Cell):
self.rnn2_bw = P.DynamicRNN(forget_bias=0.0) self.rnn2_bw = P.DynamicRNN(forget_bias=0.0)
w1 = np.random.uniform(-k, k, (self.input_size + self.hidden_size, 4 * self.hidden_size)) w1 = np.random.uniform(-k, k, (self.input_size + self.hidden_size, 4 * self.hidden_size))
self.w1 = Parameter(w1.astype(np.float16), name="w1") self.w1 = Parameter(w1.astype(np.float32), name="w1")
w2 = np.random.uniform(-k, k, (2 * self.hidden_size + self.hidden_size, 4 * self.hidden_size)) w2 = np.random.uniform(-k, k, (2 * self.hidden_size + self.hidden_size, 4 * self.hidden_size))
self.w2 = Parameter(w2.astype(np.float16), name="w2") self.w2 = Parameter(w2.astype(np.float32), name="w2")
w1_bw = np.random.uniform(-k, k, (self.input_size + self.hidden_size, 4 * self.hidden_size)) w1_bw = np.random.uniform(-k, k, (self.input_size + self.hidden_size, 4 * self.hidden_size))
self.w1_bw = Parameter(w1_bw.astype(np.float16), name="w1_bw") self.w1_bw = Parameter(w1_bw.astype(np.float32), name="w1_bw")
w2_bw = np.random.uniform(-k, k, (2 * self.hidden_size + self.hidden_size, 4 * self.hidden_size)) w2_bw = np.random.uniform(-k, k, (2 * self.hidden_size + self.hidden_size, 4 * self.hidden_size))
self.w2_bw = Parameter(w2_bw.astype(np.float16), name="w2_bw") self.w2_bw = Parameter(w2_bw.astype(np.float32), name="w2_bw")
self.b1 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float16), name="b1") self.b1 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b1")
self.b2 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float16), name="b2") self.b2 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b2")
self.b1_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float16), name="b1_bw") self.b1_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b1_bw")
self.b2_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float16), name="b2_bw") self.b2_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float32), name="b2_bw")
self.h1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16)) self.h1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.h2 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16)) self.h2 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.h1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16)) self.h1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.h2_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16)) self.h2_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.c1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16)) self.c1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.c2 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16)) self.c2 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.c1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16)) self.c1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.c2_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16)) self.c2_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float32))
self.fc_weight = np.random.random((self.num_classes, self.hidden_size)).astype(np.float32) self.fc_weight = np.random.random((self.num_classes, self.hidden_size)).astype(np.float32)
self.fc_bias = np.random.random((self.num_classes)).astype(np.float32) self.fc_bias = np.random.random((self.num_classes)).astype(np.float32)
@ -142,7 +142,6 @@ class CRNN(nn.Cell):
def construct(self, x): def construct(self, x):
x = self.vgg(x) x = self.vgg(x)
x = self.cast(x, mstype.float16)
x = self.reshape(x, (self.batch_size, self.input_size, -1)) x = self.reshape(x, (self.batch_size, self.input_size, -1))
x = self.transpose(x, (2, 0, 1)) x = self.transpose(x, (2, 0, 1))
@ -169,3 +168,11 @@ class CRNN(nn.Cell):
output += (y2_after_fc,) output += (y2_after_fc,)
output = self.concat(output) output = self.concat(output)
return output return output
def crnn(config, full_precision=False):
"""Create a CRNN network with mixed_precision or full_precision"""
net = CRNN(config)
if not full_precision:
net = net.to_float(mstype.float16)
return net

@ -26,7 +26,7 @@ from mindspore.communication.management import init, get_group_size, get_rank
from src.loss import CTCLoss from src.loss import CTCLoss
from src.dataset import create_dataset from src.dataset import create_dataset
from src.crnn import CRNN from src.crnn import crnn
from src.crnn_for_train import TrainOneStepCellWithGradClip from src.crnn_for_train import TrainOneStepCellWithGradClip
set_seed(1) set_seed(1)
@ -83,7 +83,7 @@ if __name__ == '__main__':
loss = CTCLoss(max_sequence_length=config.num_step, loss = CTCLoss(max_sequence_length=config.num_step,
max_label_length=max_text_length, max_label_length=max_text_length,
batch_size=config.batch_size) batch_size=config.batch_size)
net = CRNN(config) net = crnn(config)
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum, nesterov=config.nesterov) opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum, nesterov=config.nesterov)
net = WithLossCell(net, loss) net = WithLossCell(net, loss)

Loading…
Cancel
Save