This is an example of training Bert with MLPerf v0.7 dataset by second-order optimizer THOR. THOR is a novel approximate seond-order optimization method in MindSpore. With fewer iterations, THOR can finish Bert-Large training in 14 minutes to a masked lm accuracy of 71.3% using 8 Ascend 910, which is much faster than SGD with Momentum.
The architecture of Bert contains 3 embedding layers which are used to look up token embeddings, position embeddings and segmentation embeddings; Then BERT basically consists of a stack of Transformer encoder blocks; finally bert is trained for two tasks: Masked Language Model and Next Sentence Prediction.
- Note:Data will be processed using scripts in [pretraining data creation](
The classical first-order optimization algorithm, such as SGD, has a small amount of computation, but the convergence speed is slow and requires lots of iterations. The second-order optimization algorithm uses the second-order derivative of the target function to accelerate convergence, can converge faster to the optimal value of the model and requires less iterations. But the application of the second-order optimization algorithm in deep neural network training is not common because of the high computation cost. The main computational cost of the second-order optimization algorithm lies in the inverse operation of the second-order information matrix (Hessian matrix, FIM, etc.), and the time complexity is about O (n^3). On the basis of the existing natural gradient algorithm, we developed the available second-order optimizer THOR in MindSpore by adopting approximation and shearing of FIM information matrix to reduce the computational complexity of the inverse matrix. With eight Ascend 910 chips, THOR can complete Bert-Large training in 14min.
> For distributed training, a hccl configuration file with JSON format needs to be created in advance. About the configuration file, you can refer to the [HCCL_TOOL](
Training result will be stored in the current path, whose folder name begins with the file name that the user defines. Under this, you can find checkpoint file together with result like the followings in log.
epoch: 1, step: 1, outputs are [5.0842705], total_time_span is 795.4807660579681, step_time_span is 795.4807660579681
epoch: 1, step: 100, outputs are [4.4550357], total_time_span is 579.6836116313934, step_time_span is 5.855390016478721
epoch: 1, step: 101, outputs are [4.804837], total_time_span is 0.6697461605072021, step_time_span is 0.6697461605072021
epoch: 1, step: 200, outputs are [4.453913], total_time_span is 26.3735454082489, step_time_span is 0.2663994485681707
epoch: 1, step: 201, outputs are [4.6619444], total_time_span is 0.6340286731719971, step_time_span is 0.6340286731719971
epoch: 1, step: 300, outputs are [4.251204], total_time_span is 26.366267919540405, step_time_span is 0.2663259385812162
epoch: 1, step: 301, outputs are [4.1396527], total_time_span is 0.6269843578338623, step_time_span is 0.6269843578338623
epoch: 1, step: 400, outputs are [4.3717675], total_time_span is 26.37460947036743, step_time_span is 0.2664101966703781
epoch: 1, step: 401, outputs are [4.9887424], total_time_span is 0.6313872337341309, step_time_span is 0.6313872337341309
epoch: 1, step: 500, outputs are [4.7275505], total_time_span is 26.377585411071777, step_time_span is 0.2664402566774927
epoch: 3, step: 2001, outputs are [1.5040319], total_time_span is 0.6242287158966064, step_time_span is 0.6242287158966064
epoch: 3, step: 2100, outputs are [1.232682], total_time_span is 26.37802791595459, step_time_span is 0.26644472642378375
epoch: 3, step: 2101, outputs are [1.1442064], total_time_span is 0.6277685165405273, step_time_span is 0.6277685165405273
epoch: 3, step: 2200, outputs are [1.8860981], total_time_span is 26.378745555877686, step_time_span is 0.2664519753118958
epoch: 3, step: 2201, outputs are [1.4248213], total_time_span is 0.6273438930511475, step_time_span is 0.6273438930511475
epoch: 3, step: 2300, outputs are [1.2741681], total_time_span is 26.374130964279175, step_time_span is 0.2664053632755472
epoch: 3, step: 2301, outputs are [1.2470423], total_time_span is 0.6276984214782715, step_time_span is 0.6276984214782715
epoch: 3, step: 2400, outputs are [1.2646998], total_time_span is 26.37843370437622, step_time_span is 0.2664488252967295
epoch: 3, step: 2401, outputs are [1.2794371], total_time_span is 0.6266779899597168, step_time_span is 0.6266779899597168
epoch: 3, step: 2500, outputs are [1.265375], total_time_span is 26.374578714370728, step_time_span is 0.2664098860037447
### Evaluation Process
Before running the command below, please check the checkpoint path used for evaluation. Please set the checkpoint path to be the absolute full path, e.g., "username/bert_thor/LOG0/checkpoint_bert-3_1000.ckpt".