@ -24,12 +24,12 @@ from mindspore import dataset as de
from mindspore . train . model import Model , ParallelMode
from mindspore . nn . wrap import WithLossCell
from mindspore . train . callback import TimeMonitor , LossMonitor , CheckpointConfig , ModelCheckpoint
from mindspore . communication . management import init
from mindspore . communication . management import init , get_group_size , get_rank
from src . loss import CTCLoss
from src . loss import CTCLoss , CTCLossV2
from src . config import config as cf
from src . dataset import create_dataset
from src . warpctc import StackedRNN
from src . warpctc import StackedRNN , StackedRNNForGPU
from src . warpctc_for_train import TrainOneStepCellWithGradClip
from src . lr_schedule import get_lr
@ -38,38 +38,60 @@ np.random.seed(1)
de . config . set_seed ( 1 )
parser = argparse . ArgumentParser ( description = " Warpctc training " )
parser . add_argument ( " --run_distribute " , type = bool , default = False , help = " Run distribute, default is false. " )
parser . add_argument ( ' --device_num ' , type = int , default = 1 , help = ' Device num, default is 1. ' )
parser . add_argument ( " --run_distribute " , action = ' store_true ' , help = " Run distribute, default is false. " )
parser . add_argument ( ' --dataset_path ' , type = str , default = None , help = ' Dataset path, default is None ' )
parser . add_argument ( ' --platform ' , type = str , default = ' Ascend ' , choices = [ ' Ascend ' , ' GPU ' ] ,
help = ' Running platform, choose from Ascend, GPU, and default is Ascend. ' )
parser . set_defaults ( run_distribute = False )
args_opt = parser . parse_args ( )
device_id = int ( os . getenv ( ' DEVICE_ID ' ) )
context . set_context ( mode = context . GRAPH_MODE ,
device_target = " Ascend " ,
save_graphs = False ,
device_id = device_id )
context. set_context ( mode = context . GRAPH_MODE , device_target = args_opt . platform , save_graphs = False )
if args_opt . platform == ' Ascend ' :
device_id = int ( os . getenv ( ' DEVICE_ID ' ) )
context . set_context ( device_id = device_id )
if __name__ == ' __main__ ' :
lr_scale = 1
if args_opt . run_distribute :
if args_opt . platform == ' Ascend ' :
init ( )
lr_scale = 1
device_num = int ( os . environ . get ( " RANK_SIZE " ) )
rank = int ( os . environ . get ( " RANK_ID " ) )
else :
init ( ' nccl ' )
lr_scale = 0.5
device_num = get_group_size ( )
rank = get_rank ( )
context . reset_auto_parallel_context ( )
context . set_auto_parallel_context ( device_num = args_opt . device_num ,
context . set_auto_parallel_context ( device_num = device_num,
parallel_mode = ParallelMode . DATA_PARALLEL ,
mirror_mean = True )
init ( )
else :
device_num = 1
rank = 0
max_captcha_digits = cf . max_captcha_digits
input_size = m . ceil ( cf . captcha_height / 64 ) * 64 * 3
# create dataset
dataset = create_dataset ( dataset_path = args_opt . dataset_path , repeat_num = 1 , batch_size = cf . batch_size )
dataset = create_dataset ( dataset_path = args_opt . dataset_path , batch_size = cf . batch_size ,
num_shards = device_num , shard_id = rank , device_target = args_opt . platform )
step_size = dataset . get_dataset_size ( )
# define lr
lr_init = cf . learning_rate if not args_opt . run_distribute else cf . learning_rate * args_opt. device_num
lr_init = cf . learning_rate if not args_opt . run_distribute else cf . learning_rate * device_num * lr_scale
lr = get_lr ( cf . epoch_size , step_size , lr_init )
# define loss
loss = CTCLoss ( max_sequence_length = cf . captcha_width , max_label_length = max_captcha_digits , batch_size = cf . batch_size )
# define net
net = StackedRNN ( input_size = input_size , batch_size = cf . batch_size , hidden_size = cf . hidden_size )
# define opt
opt = nn . SGD ( params = net . trainable_params ( ) , learning_rate = lr , momentum = cf . momentum )
if args_opt . platform == ' Ascend ' :
loss = CTCLoss ( max_sequence_length = cf . captcha_width ,
max_label_length = max_captcha_digits ,
batch_size = cf . batch_size )
net = StackedRNN ( input_size = input_size , batch_size = cf . batch_size , hidden_size = cf . hidden_size )
opt = nn . SGD ( params = net . trainable_params ( ) , learning_rate = lr , momentum = cf . momentum )
else :
loss = CTCLossV2 ( max_sequence_length = cf . captcha_width , batch_size = cf . batch_size )
net = StackedRNNForGPU ( input_size = input_size , batch_size = cf . batch_size , hidden_size = cf . hidden_size )
opt = nn . Momentum ( params = net . trainable_params ( ) , learning_rate = lr , momentum = cf . momentum )
net = WithLossCell ( net , loss )
net = TrainOneStepCellWithGradClip ( net , opt ) . set_train ( )
# define model
@ -79,6 +101,6 @@ if __name__ == '__main__':
if cf . save_checkpoint :
config_ck = CheckpointConfig ( save_checkpoint_steps = cf . save_checkpoint_steps ,
keep_checkpoint_max = cf . keep_checkpoint_max )
ckpt_cb = ModelCheckpoint ( prefix = " wa pt ctc" , directory = cf . save_checkpoint_path , config = config_ck )
ckpt_cb = ModelCheckpoint ( prefix = " wa r pctc" , directory = cf . save_checkpoint_path , config = config_ck )
callbacks . append ( ckpt_cb )
model . train ( cf . epoch_size , dataset , callbacks = callbacks )