|
|
|
@ -14,7 +14,7 @@
|
|
|
|
|
# ============================================================================
|
|
|
|
|
"""wide and deep model"""
|
|
|
|
|
import numpy as np
|
|
|
|
|
from mindspore import nn
|
|
|
|
|
from mindspore import nn, context
|
|
|
|
|
from mindspore import Parameter, ParameterTuple
|
|
|
|
|
import mindspore.common.dtype as mstype
|
|
|
|
|
from mindspore.ops import functional as F
|
|
|
|
@ -22,10 +22,7 @@ from mindspore.ops import composite as C
|
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
|
from mindspore.nn import Dropout
|
|
|
|
|
from mindspore.nn.optim import Adam, FTRL, LazyAdam
|
|
|
|
|
# from mindspore.nn.metrics import Metric
|
|
|
|
|
from mindspore.common.initializer import Uniform, initializer
|
|
|
|
|
# from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
|
|
|
|
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_mirror_mean
|
|
|
|
|
from mindspore.train.parallel_utils import ParallelMode
|
|
|
|
|
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
|
|
|
|
from mindspore.communication.management import get_group_size
|
|
|
|
@ -142,7 +139,7 @@ class WideDeepModel(nn.Cell):
|
|
|
|
|
self.batch_size = config.batch_size
|
|
|
|
|
host_device_mix = bool(config.host_device_mix)
|
|
|
|
|
parameter_server = bool(config.parameter_server)
|
|
|
|
|
parallel_mode = _get_parallel_mode()
|
|
|
|
|
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
|
|
|
|
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
|
|
|
|
if is_auto_parallel:
|
|
|
|
|
self.batch_size = self.batch_size * get_group_size()
|
|
|
|
@ -259,7 +256,7 @@ class NetWithLossClass(nn.Cell):
|
|
|
|
|
super(NetWithLossClass, self).__init__(auto_prefix=False)
|
|
|
|
|
host_device_mix = bool(config.host_device_mix)
|
|
|
|
|
parameter_server = bool(config.parameter_server)
|
|
|
|
|
parallel_mode = _get_parallel_mode()
|
|
|
|
|
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
|
|
|
|
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
|
|
|
|
self.no_l2loss = (is_auto_parallel if host_device_mix else parameter_server)
|
|
|
|
|
self.network = network
|
|
|
|
@ -312,7 +309,7 @@ class TrainStepWrap(nn.Cell):
|
|
|
|
|
|
|
|
|
|
def __init__(self, network, sens=1024.0, host_device_mix=False, parameter_server=False):
|
|
|
|
|
super(TrainStepWrap, self).__init__()
|
|
|
|
|
parallel_mode = _get_parallel_mode()
|
|
|
|
|
parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
|
|
|
|
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
|
|
|
|
self.network = network
|
|
|
|
|
self.network.set_train()
|
|
|
|
@ -351,12 +348,11 @@ class TrainStepWrap(nn.Cell):
|
|
|
|
|
self.reducer_flag = False
|
|
|
|
|
self.grad_reducer_w = None
|
|
|
|
|
self.grad_reducer_d = None
|
|
|
|
|
parallel_mode = _get_parallel_mode()
|
|
|
|
|
self.reducer_flag = parallel_mode in (ParallelMode.DATA_PARALLEL,
|
|
|
|
|
ParallelMode.HYBRID_PARALLEL)
|
|
|
|
|
if self.reducer_flag:
|
|
|
|
|
mean = _get_mirror_mean()
|
|
|
|
|
degree = _get_device_num()
|
|
|
|
|
mean = context.get_auto_parallel_context("mirror_mean")
|
|
|
|
|
degree = context.get_auto_parallel_context("device_num")
|
|
|
|
|
self.grad_reducer_w = DistributedGradReducer(self.optimizer_w.parameters, mean, degree)
|
|
|
|
|
self.grad_reducer_d = DistributedGradReducer(self.optimizer_d.parameters, mean, degree)
|
|
|
|
|
|
|
|
|
|