|
|
|
@ -20,7 +20,7 @@ from mindspore.ops import functional as F
|
|
|
|
|
from mindspore.ops import composite as C
|
|
|
|
|
from mindspore.ops import operations as P
|
|
|
|
|
from mindspore.nn import Dropout
|
|
|
|
|
from mindspore.nn.optim import Adam, FTRL
|
|
|
|
|
from mindspore.nn.optim import Adam, FTRL, LazyAdam
|
|
|
|
|
# from mindspore.nn.metrics import Metric
|
|
|
|
|
from mindspore.common.initializer import Uniform, initializer
|
|
|
|
|
# from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
|
|
|
@ -82,7 +82,7 @@ class DenseLayer(nn.Cell):
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, input_dim, output_dim, weight_bias_init, act_str,
|
|
|
|
|
keep_prob=0.7, use_activation=True, convert_dtype=True, drop_out=False):
|
|
|
|
|
keep_prob=0.5, use_activation=True, convert_dtype=True, drop_out=False):
|
|
|
|
|
super(DenseLayer, self).__init__()
|
|
|
|
|
weight_init, bias_init = weight_bias_init
|
|
|
|
|
self.weight = init_method(
|
|
|
|
@ -137,8 +137,10 @@ class WideDeepModel(nn.Cell):
|
|
|
|
|
def __init__(self, config):
|
|
|
|
|
super(WideDeepModel, self).__init__()
|
|
|
|
|
self.batch_size = config.batch_size
|
|
|
|
|
host_device_mix = bool(config.host_device_mix)
|
|
|
|
|
parallel_mode = _get_parallel_mode()
|
|
|
|
|
if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL):
|
|
|
|
|
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
|
|
|
|
if is_auto_parallel:
|
|
|
|
|
self.batch_size = self.batch_size * get_group_size()
|
|
|
|
|
self.field_size = config.field_size
|
|
|
|
|
self.vocab_size = config.vocab_size
|
|
|
|
@ -187,16 +189,29 @@ class WideDeepModel(nn.Cell):
|
|
|
|
|
self.weight_bias_init,
|
|
|
|
|
self.deep_layer_act,
|
|
|
|
|
use_activation=False, convert_dtype=True, drop_out=config.dropout_flag)
|
|
|
|
|
|
|
|
|
|
self.embeddinglookup = nn.EmbeddingLookup(target='DEVICE')
|
|
|
|
|
self.mul = P.Mul()
|
|
|
|
|
self.wide_mul = P.Mul()
|
|
|
|
|
self.deep_mul = P.Mul()
|
|
|
|
|
self.reduce_sum = P.ReduceSum(keep_dims=False)
|
|
|
|
|
self.reshape = P.Reshape()
|
|
|
|
|
self.deep_reshape = P.Reshape()
|
|
|
|
|
self.square = P.Square()
|
|
|
|
|
self.shape = P.Shape()
|
|
|
|
|
self.tile = P.Tile()
|
|
|
|
|
self.concat = P.Concat(axis=1)
|
|
|
|
|
self.cast = P.Cast()
|
|
|
|
|
if is_auto_parallel and host_device_mix:
|
|
|
|
|
self.dense_layer_1.dropout.dropout_do_mask.set_strategy(((1, get_group_size()),))
|
|
|
|
|
self.dense_layer_1.matmul.set_strategy(((1, get_group_size()), (get_group_size(), 1)))
|
|
|
|
|
self.deep_embeddinglookup = nn.EmbeddingLookup()
|
|
|
|
|
self.deep_embeddinglookup.embeddinglookup.set_strategy(((1, get_group_size()), (1, 1)))
|
|
|
|
|
self.wide_embeddinglookup = nn.EmbeddingLookup()
|
|
|
|
|
self.wide_embeddinglookup.embeddinglookup.set_strategy(((get_group_size(), 1), (1, 1)))
|
|
|
|
|
self.deep_mul.set_strategy(((1, 1, get_group_size()), (1, 1, 1)))
|
|
|
|
|
self.deep_reshape.add_prim_attr("skip_redistribution", True)
|
|
|
|
|
self.reduce_sum.add_prim_attr("cross_batch", True)
|
|
|
|
|
else:
|
|
|
|
|
self.deep_embeddinglookup = nn.EmbeddingLookup(target='DEVICE')
|
|
|
|
|
self.wide_embeddinglookup = nn.EmbeddingLookup(target='DEVICE')
|
|
|
|
|
|
|
|
|
|
def construct(self, id_hldr, wt_hldr):
|
|
|
|
|
"""
|
|
|
|
@ -206,13 +221,13 @@ class WideDeepModel(nn.Cell):
|
|
|
|
|
"""
|
|
|
|
|
mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1))
|
|
|
|
|
# Wide layer
|
|
|
|
|
wide_id_weight = self.embeddinglookup(self.wide_w, id_hldr)
|
|
|
|
|
wx = self.mul(wide_id_weight, mask)
|
|
|
|
|
wide_id_weight = self.wide_embeddinglookup(self.wide_w, id_hldr)
|
|
|
|
|
wx = self.wide_mul(wide_id_weight, mask)
|
|
|
|
|
wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1))
|
|
|
|
|
# Deep layer
|
|
|
|
|
deep_id_embs = self.embeddinglookup(self.embedding_table, id_hldr)
|
|
|
|
|
vx = self.mul(deep_id_embs, mask)
|
|
|
|
|
deep_in = self.reshape(vx, (-1, self.field_size * self.emb_dim))
|
|
|
|
|
deep_id_embs = self.deep_embeddinglookup(self.embedding_table, id_hldr)
|
|
|
|
|
vx = self.deep_mul(deep_id_embs, mask)
|
|
|
|
|
deep_in = self.deep_reshape(vx, (-1, self.field_size * self.emb_dim))
|
|
|
|
|
deep_in = self.dense_layer_1(deep_in)
|
|
|
|
|
deep_in = self.dense_layer_2(deep_in)
|
|
|
|
|
deep_in = self.dense_layer_3(deep_in)
|
|
|
|
@ -233,19 +248,28 @@ class NetWithLossClass(nn.Cell):
|
|
|
|
|
|
|
|
|
|
def __init__(self, network, config):
|
|
|
|
|
super(NetWithLossClass, self).__init__(auto_prefix=False)
|
|
|
|
|
host_device_mix = bool(config.host_device_mix)
|
|
|
|
|
parallel_mode = _get_parallel_mode()
|
|
|
|
|
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
|
|
|
|
self.no_l2loss = host_device_mix and is_auto_parallel
|
|
|
|
|
self.network = network
|
|
|
|
|
self.l2_coef = config.l2_coef
|
|
|
|
|
self.loss = P.SigmoidCrossEntropyWithLogits()
|
|
|
|
|
self.square = P.Square()
|
|
|
|
|
self.reduceMean_false = P.ReduceMean(keep_dims=False)
|
|
|
|
|
if is_auto_parallel:
|
|
|
|
|
self.reduceMean_false.add_prim_attr("cross_batch", True)
|
|
|
|
|
self.reduceSum_false = P.ReduceSum(keep_dims=False)
|
|
|
|
|
|
|
|
|
|
def construct(self, batch_ids, batch_wts, label):
|
|
|
|
|
predict, embedding_table = self.network(batch_ids, batch_wts)
|
|
|
|
|
log_loss = self.loss(predict, label)
|
|
|
|
|
wide_loss = self.reduceMean_false(log_loss)
|
|
|
|
|
l2_loss_v = self.reduceSum_false(self.square(embedding_table)) / 2
|
|
|
|
|
deep_loss = self.reduceMean_false(log_loss) + self.l2_coef * l2_loss_v
|
|
|
|
|
if self.no_l2loss:
|
|
|
|
|
deep_loss = wide_loss
|
|
|
|
|
else:
|
|
|
|
|
l2_loss_v = self.reduceSum_false(self.square(embedding_table)) / 2
|
|
|
|
|
deep_loss = self.reduceMean_false(log_loss) + self.l2_coef * l2_loss_v
|
|
|
|
|
|
|
|
|
|
return wide_loss, deep_loss
|
|
|
|
|
|
|
|
|
@ -267,12 +291,15 @@ class TrainStepWrap(nn.Cell):
|
|
|
|
|
Append Adam and FTRL optimizers to the training network after that construct
|
|
|
|
|
function can be called to create the backward graph.
|
|
|
|
|
Args:
|
|
|
|
|
network (Cell): the training network. Note that loss function should have been added.
|
|
|
|
|
sens (Number): The adjust parameter. Default: 1000.0
|
|
|
|
|
network (Cell): The training network. Note that loss function should have been added.
|
|
|
|
|
sens (Number): The adjust parameter. Default: 1024.0
|
|
|
|
|
host_device_mix (Bool): Whether run in host and device mix mode. Default: False
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, network, sens=1024.0):
|
|
|
|
|
def __init__(self, network, sens=1024.0, host_device_mix=False):
|
|
|
|
|
super(TrainStepWrap, self).__init__()
|
|
|
|
|
parallel_mode = _get_parallel_mode()
|
|
|
|
|
is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL)
|
|
|
|
|
self.network = network
|
|
|
|
|
self.network.set_train()
|
|
|
|
|
self.trainable_params = network.trainable_params()
|
|
|
|
@ -285,10 +312,19 @@ class TrainStepWrap(nn.Cell):
|
|
|
|
|
weights_d.append(params)
|
|
|
|
|
self.weights_w = ParameterTuple(weights_w)
|
|
|
|
|
self.weights_d = ParameterTuple(weights_d)
|
|
|
|
|
self.optimizer_w = FTRL(learning_rate=1e-2, params=self.weights_w,
|
|
|
|
|
l1=1e-8, l2=1e-8, initial_accum=1.0)
|
|
|
|
|
self.optimizer_d = Adam(
|
|
|
|
|
self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens)
|
|
|
|
|
|
|
|
|
|
if host_device_mix and is_auto_parallel:
|
|
|
|
|
self.optimizer_d = LazyAdam(
|
|
|
|
|
self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens)
|
|
|
|
|
self.optimizer_w = FTRL(learning_rate=5e-2, params=self.weights_w,
|
|
|
|
|
l1=1e-8, l2=1e-8, initial_accum=1.0, loss_scale=sens)
|
|
|
|
|
self.optimizer_w.sparse_opt.add_prim_attr("primitive_target", "CPU")
|
|
|
|
|
self.optimizer_d.sparse_opt.add_prim_attr("primitive_target", "CPU")
|
|
|
|
|
else:
|
|
|
|
|
self.optimizer_d = Adam(
|
|
|
|
|
self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens)
|
|
|
|
|
self.optimizer_w = FTRL(learning_rate=5e-2, params=self.weights_w,
|
|
|
|
|
l1=1e-8, l2=1e-8, initial_accum=1.0, loss_scale=sens)
|
|
|
|
|
self.hyper_map = C.HyperMap()
|
|
|
|
|
self.grad_w = C.GradOperation('grad_w', get_by_list=True,
|
|
|
|
|
sens_param=True)
|