pull/8420/head
bai-yangfan 4 years ago
parent 3177a4245e
commit a9cba56191

@ -20,5 +20,11 @@ Helper functions in train piplines.
from .model import Model from .model import Model
from .dataset_helper import DatasetHelper, connect_network_with_dataset from .dataset_helper import DatasetHelper, connect_network_with_dataset
from . import amp from . import amp
from .amp import build_train_network
from .loss_scale_manager import LossScaleManager, FixedLossScaleManager, DynamicLossScaleManager
from .serialization import save_checkpoint, load_checkpoint, load_param_into_net, export, parse_print,\
build_searched_strategy, merge_sliced_parameter
__all__ = ["Model", "DatasetHelper", "amp", "connect_network_with_dataset"] __all__ = ["Model", "DatasetHelper", "amp", "connect_network_with_dataset", "build_train_network", "LossScaleManager",
"FixedLossScaleManager", "DynamicLossScaleManager", "save_checkpoint", "load_checkpoint",
"load_param_into_net", "export", "parse_print", "build_searched_strategy", "merge_sliced_parameter"]

@ -26,8 +26,6 @@ from .loss_scale_manager import DynamicLossScaleManager, LossScaleManager
from ..context import ParallelMode from ..context import ParallelMode
from .. import context from .. import context
__all__ = ["build_train_network"]
class OutputTo16(nn.Cell): class OutputTo16(nn.Cell):
"Wrap cell for amp. Cast network output back to float16" "Wrap cell for amp. Cast network output back to float16"

@ -17,8 +17,6 @@
from .._checkparam import Validator as validator from .._checkparam import Validator as validator
from .. import nn from .. import nn
__all__ = ["LossScaleManager", "FixedLossScaleManager", "DynamicLossScaleManager"]
class LossScaleManager: class LossScaleManager:
"""Loss scale manager abstract class.""" """Loss scale manager abstract class."""

@ -33,8 +33,6 @@ from mindspore._checkparam import check_input_data, Validator
from mindspore.compression.export import quant_export from mindspore.compression.export import quant_export
import mindspore.context as context import mindspore.context as context
__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "parse_print",
"build_searched_strategy", "merge_sliced_parameter"]
tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16, tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16,
"Int32": mstype.int32, "Uint32": mstype.uint32, "Int64": mstype.int64, "Uint64": mstype.uint64, "Int32": mstype.int32, "Uint32": mstype.uint32, "Int64": mstype.int64, "Uint64": mstype.uint64,

@ -20,9 +20,7 @@ import argparse
import numpy as np import numpy as np
import mindspore as ms import mindspore as ms
from mindspore import Tensor from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.config import alexnet_cifar10_cfg, alexnet_imagenet_cfg from src.config import alexnet_cifar10_cfg, alexnet_imagenet_cfg
from src.alexnet import AlexNet from src.alexnet import AlexNet

@ -16,9 +16,8 @@
import argparse import argparse
import numpy as np import numpy as np
from mindspore import Tensor, context from mindspore import Tensor, context, load_checkpoint, export
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.train.serialization import load_checkpoint, export
from src.config import Config_CNNCTC from src.config import Config_CNNCTC
from src.cnn_ctc import CNNCTC_Model from src.cnn_ctc import CNNCTC_Model

@ -16,10 +16,7 @@
import argparse import argparse
import numpy as np import numpy as np
from mindspore import Tensor from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.serialization import export
from src.nets import net_factory from src.nets import net_factory

@ -17,8 +17,7 @@ import argparse
import numpy as np import numpy as np
import mindspore as ms import mindspore as ms
from mindspore import Tensor from mindspore import Tensor, load_checkpoint, load_param_into_net, export
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.FasterRcnn.faster_rcnn_r50 import Faster_Rcnn_Resnet50 from src.FasterRcnn.faster_rcnn_r50 import Faster_Rcnn_Resnet50
from src.config import config from src.config import config

@ -20,8 +20,7 @@ import argparse
import numpy as np import numpy as np
import mindspore as ms import mindspore as ms
from mindspore import Tensor from mindspore import Tensor, load_checkpoint, load_param_into_net, export
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.config import cifar_cfg, imagenet_cfg from src.config import cifar_cfg, imagenet_cfg
from src.googlenet import GoogleNet from src.googlenet import GoogleNet

@ -19,8 +19,7 @@ import argparse
import numpy as np import numpy as np
import mindspore as ms import mindspore as ms
from mindspore import Tensor from mindspore import Tensor, load_checkpoint, load_param_into_net, export
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.config import config_gpu as cfg from src.config import config_gpu as cfg
from src.inception_v3 import InceptionV3 from src.inception_v3 import InceptionV3

@ -20,9 +20,7 @@ import argparse
import numpy as np import numpy as np
import mindspore import mindspore
from mindspore import Tensor from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.config import mnist_cfg as cfg from src.config import mnist_cfg as cfg
from src.lenet import LeNet5 from src.lenet import LeNet5

@ -20,10 +20,8 @@ import argparse
import numpy as np import numpy as np
import mindspore import mindspore
from mindspore import Tensor from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
from mindspore import context
from mindspore.compression.quant import QuantizationAwareTraining from mindspore.compression.quant import QuantizationAwareTraining
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.config import mnist_cfg as cfg from src.config import mnist_cfg as cfg
from src.lenet_fusion import LeNet5 as LeNet5Fusion from src.lenet_fusion import LeNet5 as LeNet5Fusion

@ -16,8 +16,7 @@
import argparse import argparse
import numpy as np import numpy as np
from mindspore import Tensor, context from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50 from src.maskrcnn.mask_rcnn_r50 import Mask_Rcnn_Resnet50
from src.config import config from src.config import config

@ -16,8 +16,7 @@
mobilenetv2 export mindir. mobilenetv2 export mindir.
""" """
import numpy as np import numpy as np
from mindspore import Tensor from mindspore import Tensor, export
from mindspore.train.serialization import export
from src.config import set_config from src.config import set_config
from src.args import export_parse_args from src.args import export_parse_args
from src.models import define_net, load_ckpt from src.models import define_net, load_ckpt

@ -18,9 +18,7 @@ import argparse
import numpy as np import numpy as np
import mindspore import mindspore
from mindspore import Tensor from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore.compression.quant import QuantizationAwareTraining from mindspore.compression.quant import QuantizationAwareTraining
from src.mobilenetV2 import mobilenetV2 from src.mobilenetV2 import mobilenetV2

@ -17,8 +17,7 @@ mobilenetv3 export mindir.
""" """
import argparse import argparse
import numpy as np import numpy as np
from mindspore import context, Tensor from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.config import config_gpu from src.config import config_gpu
from src.mobilenetV3 import mobilenet_v3_large from src.mobilenetV3 import mobilenet_v3_large

@ -19,8 +19,7 @@ import argparse
import numpy as np import numpy as np
import mindspore as ms import mindspore as ms
from mindspore import Tensor from mindspore import Tensor, load_checkpoint, load_param_into_net, export
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.config import nasnet_a_mobile_config_gpu as cfg from src.config import nasnet_a_mobile_config_gpu as cfg
from src.nasnet_a_mobile import NASNetAMobile from src.nasnet_a_mobile import NASNetAMobile

@ -19,8 +19,7 @@ import argparse
import numpy as np import numpy as np
import mindspore as ms import mindspore as ms
from mindspore import Tensor from mindspore import Tensor, load_checkpoint, load_param_into_net, export
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.config import config from src.config import config
from src.ETSNET.etsnet import ETSNet from src.ETSNET.etsnet import ETSNet

@ -19,8 +19,7 @@ python export.py
import argparse import argparse
import numpy as np import numpy as np
from mindspore import Tensor from mindspore import Tensor, load_checkpoint, load_param_into_net, export
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='resnet export') parser = argparse.ArgumentParser(description='resnet export')

@ -16,9 +16,7 @@
import argparse import argparse
import numpy as np import numpy as np
from mindspore import context from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
from mindspore import Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.resnet_thor import resnet50 as resnet from src.resnet_thor import resnet50 as resnet
from src.config import config from src.config import config

@ -17,8 +17,7 @@ resnext export mindir.
""" """
import argparse import argparse
import numpy as np import numpy as np
from mindspore import context, Tensor from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.config import config from src.config import config
from src.image_classification import get_network from src.image_classification import get_network

@ -17,8 +17,7 @@ ssd export mindir.
""" """
import argparse import argparse
import numpy as np import numpy as np
from mindspore import context, Tensor from mindspore import context, Tensor, load_checkpoint, load_param_into_net, export
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.ssd import SSD300, ssd_mobilenet_v2 from src.ssd import SSD300, ssd_mobilenet_v2
from src.config import config from src.config import config

@ -16,8 +16,7 @@
import argparse import argparse
import numpy as np import numpy as np
from mindspore import Tensor from mindspore import Tensor, export, load_checkpoint, load_param_into_net
from mindspore.train.serialization import export, load_checkpoint, load_param_into_net
from src.unet.unet_model import UNet from src.unet.unet_model import UNet

@ -16,8 +16,7 @@
import argparse import argparse
import numpy as np import numpy as np
from mindspore import Tensor, context from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.warpctc import StackedRNN from src.warpctc import StackedRNN
from src.config import config from src.config import config

@ -24,7 +24,7 @@ from mindspore.common import dtype as mstype
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.nn.metrics import Accuracy from mindspore.nn.metrics import Accuracy
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export from mindspore import load_checkpoint, load_param_into_net, export
from mindspore.train import Model from mindspore.train import Model
from mindspore.compression.quant import QuantizationAwareTraining from mindspore.compression.quant import QuantizationAwareTraining
from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net

Loading…
Cancel
Save