[API 2.0] Add transformer apis (#26418)

* Add MultiHeadAttention api.
test=develop

* Add MultiHeadAttention cache type and gen_cache.
test=develop

* Add TransformerEncoderLayer and TransformerEncoder.
test=develop

* Add Transformer decoder apis.
test=develop

* Add Transformer api.
test=develop

* add unittests for transformer api

* add unittests for transformer api

* Fix some bugs in Transformer apis.
test=develop

* add unittests for encoder, decoder and transformer

* clean conflicts infor in code

* clean Chinese comments

* Add TransformerDecoderCell and TransformerBeamSearchDecoder.
test=develop

* Remove TransformerDecoderCell and TransformerBeamSearchDecoder temporarily.
test=develop

* Add import for Transformer apis.
test=develop

* Update usage of weight_attr and Tensor in Transformer api docs.
test=develop

* Update Transformer apis by renaming MultiheadAttention and cal_kv according to comments.
test=develop

* Fix MultiHeadAttention in test_transformer_api.py.
test=develop

Co-authored-by: LiuChiaChi <709153940@qq.com>
test_feature_precision_test_c
Guo Sheng 5 years ago committed by GitHub
parent 8645591d66
commit 317f7ce2ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

File diff suppressed because it is too large Load Diff

@ -130,6 +130,12 @@ from .layer.norm import InstanceNorm #DEFINE_ALIAS
# from .layer.rnn import RNNCell #DEFINE_ALIAS
# from .layer.rnn import GRUCell #DEFINE_ALIAS
# from .layer.rnn import LSTMCell #DEFINE_ALIAS
from .layer.transformer import MultiHeadAttention
from .layer.transformer import TransformerEncoderLayer
from .layer.transformer import TransformerEncoder
from .layer.transformer import TransformerDecoderLayer
from .layer.transformer import TransformerDecoder
from .layer.transformer import Transformer
from .layer.distance import PairwiseDistance #DEFINE_ALIAS
from .layer import loss #DEFINE_ALIAS

@ -21,6 +21,7 @@ from . import extension
from . import activation
from . import norm
from . import distance
from . import transformer
from .activation import *
from .loss import *
@ -28,6 +29,7 @@ from .conv import *
from .extension import *
from .activation import *
from .norm import *
from .transformer import *
# from .activation import PReLU #DEFINE_ALIAS
from .activation import ReLU #DEFINE_ALIAS
from .activation import LeakyReLU #DEFINE_ALIAS

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save