|
|
|
@ -17,7 +17,8 @@ from .bert_for_pre_training import BertNetworkWithLoss, BertPreTraining, \
|
|
|
|
|
BertPretrainingLoss, GetMaskedLMOutput, GetNextSentenceOutput, \
|
|
|
|
|
BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell, \
|
|
|
|
|
BertTrainAccumulationAllReduceEachWithLossScaleCell, \
|
|
|
|
|
BertTrainAccumulationAllReducePostWithLossScaleCell
|
|
|
|
|
BertTrainAccumulationAllReducePostWithLossScaleCell, \
|
|
|
|
|
BertTrainOneStepWithLossScaleCellForAdam
|
|
|
|
|
from .bert_model import BertAttention, BertConfig, BertEncoderCell, BertModel, \
|
|
|
|
|
BertOutput, BertSelfAttention, BertTransformer, EmbeddingLookup, \
|
|
|
|
|
EmbeddingPostprocessor, RelaPosEmbeddingsGenerator, RelaPosMatrixGenerator, \
|
|
|
|
@ -31,5 +32,6 @@ __all__ = [
|
|
|
|
|
"BertAttention", "BertConfig", "BertEncoderCell", "BertModel", "BertOutput",
|
|
|
|
|
"BertSelfAttention", "BertTransformer", "EmbeddingLookup",
|
|
|
|
|
"EmbeddingPostprocessor", "RelaPosEmbeddingsGenerator", "AdamWeightDecayForBert",
|
|
|
|
|
"RelaPosMatrixGenerator", "SaturateCast", "CreateAttentionMaskFromInputMask"
|
|
|
|
|
"RelaPosMatrixGenerator", "SaturateCast", "CreateAttentionMaskFromInputMask",
|
|
|
|
|
"BertTrainOneStepWithLossScaleCellForAdam"
|
|
|
|
|
]
|
|
|
|
|