|
|
|
@ -22,20 +22,11 @@ from mindspore.common.tensor import Tensor
|
|
|
|
|
from mindspore.train.model import Model
|
|
|
|
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
|
|
|
|
|
|
|
|
|
from mindspore import context
|
|
|
|
|
|
|
|
|
|
from src.dataset import load_dataset
|
|
|
|
|
from .transformer_for_infer import TransformerInferModel
|
|
|
|
|
from .transformer_for_train import TransformerTraining
|
|
|
|
|
from ..utils.load_weights import load_infer_weights
|
|
|
|
|
|
|
|
|
|
context.set_context(
|
|
|
|
|
mode=context.GRAPH_MODE,
|
|
|
|
|
save_graphs=False,
|
|
|
|
|
device_target="Ascend",
|
|
|
|
|
reserve_class_name_in_scope=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TransformerInferCell(nn.Cell):
|
|
|
|
|
"""
|
|
|
|
|
Encapsulation class of transformer network infer.
|
|
|
|
|