|
|
|
|
@ -24,9 +24,12 @@ from mindspore.ops import operations as P
|
|
|
|
|
from mindspore.nn import Dropout
|
|
|
|
|
from mindspore.nn.optim import Adam
|
|
|
|
|
from mindspore.nn.metrics import Metric
|
|
|
|
|
from mindspore import nn, ParameterTuple, Parameter
|
|
|
|
|
from mindspore.common.initializer import Uniform, initializer, Normal
|
|
|
|
|
from mindspore import nn, Tensor, ParameterTuple, Parameter
|
|
|
|
|
from mindspore.common.initializer import Uniform, initializer
|
|
|
|
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
|
|
|
|
from mindspore.parallel._utils import _get_device_num, _get_parallel_mode, _get_gradients_mean
|
|
|
|
|
from mindspore.context import ParallelMode
|
|
|
|
|
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
|
|
|
|
|
|
|
|
|
from .callback import EvalCallBack, LossCallBack
|
|
|
|
|
|
|
|
|
|
@ -60,7 +63,7 @@ class AUCMetric(Metric):
|
|
|
|
|
return auc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_method(method, shape, name, max_val=0.01):
|
|
|
|
|
def init_method(method, shape, name, max_val=1.0):
|
|
|
|
|
"""
|
|
|
|
|
The method of init parameters.
|
|
|
|
|
|
|
|
|
|
@ -73,18 +76,18 @@ def init_method(method, shape, name, max_val=0.01):
|
|
|
|
|
Returns:
|
|
|
|
|
Parameter.
|
|
|
|
|
"""
|
|
|
|
|
if method in ['random', 'uniform']:
|
|
|
|
|
if method in ['uniform']:
|
|
|
|
|
params = Parameter(initializer(Uniform(max_val), shape, ms_type), name=name)
|
|
|
|
|
elif method == "one":
|
|
|
|
|
params = Parameter(initializer("ones", shape, ms_type), name=name)
|
|
|
|
|
elif method == 'zero':
|
|
|
|
|
params = Parameter(initializer("zeros", shape, ms_type), name=name)
|
|
|
|
|
elif method == "normal":
|
|
|
|
|
params = Parameter(initializer(Normal(max_val), shape, ms_type), name=name)
|
|
|
|
|
params = Parameter(Tensor(np.random.normal(loc=0.0, scale=0.01, size=shape).astype(dtype=np_type)), name=name)
|
|
|
|
|
return params
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_var_dict(init_args, values):
|
|
|
|
|
def init_var_dict(init_args, var_list):
|
|
|
|
|
"""
|
|
|
|
|
Init parameter.
|
|
|
|
|
|
|
|
|
|
@ -96,17 +99,19 @@ def init_var_dict(init_args, values):
|
|
|
|
|
dict, a dict ot Parameter.
|
|
|
|
|
"""
|
|
|
|
|
var_map = {}
|
|
|
|
|
_, _max_val = init_args
|
|
|
|
|
for key, shape, init_flag in values:
|
|
|
|
|
_, max_val = init_args
|
|
|
|
|
for i, _ in enumerate(var_list):
|
|
|
|
|
key, shape, method = var_list[i]
|
|
|
|
|
if key not in var_map.keys():
|
|
|
|
|
if init_flag in ['random', 'uniform']:
|
|
|
|
|
var_map[key] = Parameter(initializer(Uniform(_max_val), shape, ms_type), name=key)
|
|
|
|
|
elif init_flag == "one":
|
|
|
|
|
if method in ['random', 'uniform']:
|
|
|
|
|
var_map[key] = Parameter(initializer(Uniform(max_val), shape, ms_type), name=key)
|
|
|
|
|
elif method == "one":
|
|
|
|
|
var_map[key] = Parameter(initializer("ones", shape, ms_type), name=key)
|
|
|
|
|
elif init_flag == "zero":
|
|
|
|
|
elif method == "zero":
|
|
|
|
|
var_map[key] = Parameter(initializer("zeros", shape, ms_type), name=key)
|
|
|
|
|
elif init_flag == 'normal':
|
|
|
|
|
var_map[key] = Parameter(initializer(Normal(_max_val), shape, ms_type), name=key)
|
|
|
|
|
elif method == 'normal':
|
|
|
|
|
var_map[key] = Parameter(Tensor(np.random.normal(loc=0.0, scale=0.01, size=shape).
|
|
|
|
|
astype(dtype=np_type)), name=key)
|
|
|
|
|
return var_map
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -122,7 +127,9 @@ class DenseLayer(nn.Cell):
|
|
|
|
|
keep_prob (float): Dropout Layer keep_prob_rate;
|
|
|
|
|
scale_coef (float): input scale coefficient;
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, input_dim, output_dim, weight_bias_init, act_str, keep_prob=0.9, scale_coef=1.0):
|
|
|
|
|
|
|
|
|
|
def __init__(self, input_dim, output_dim, weight_bias_init, act_str, scale_coef=1.0, convert_dtype=True,
|
|
|
|
|
use_act=True):
|
|
|
|
|
super(DenseLayer, self).__init__()
|
|
|
|
|
weight_init, bias_init = weight_bias_init
|
|
|
|
|
self.weight = init_method(weight_init, [input_dim, output_dim], name="weight")
|
|
|
|
|
@ -131,12 +138,15 @@ class DenseLayer(nn.Cell):
|
|
|
|
|
self.matmul = P.MatMul(transpose_b=False)
|
|
|
|
|
self.bias_add = P.BiasAdd()
|
|
|
|
|
self.cast = P.Cast()
|
|
|
|
|
self.dropout = Dropout(keep_prob=keep_prob)
|
|
|
|
|
self.dropout = Dropout(keep_prob=1.0)
|
|
|
|
|
self.mul = P.Mul()
|
|
|
|
|
self.realDiv = P.RealDiv()
|
|
|
|
|
self.scale_coef = scale_coef
|
|
|
|
|
self.convert_dtype = convert_dtype
|
|
|
|
|
self.use_act = use_act
|
|
|
|
|
|
|
|
|
|
def _init_activation(self, act_str):
|
|
|
|
|
"""Init activation function"""
|
|
|
|
|
act_str = act_str.lower()
|
|
|
|
|
if act_str == "relu":
|
|
|
|
|
act_func = P.ReLU()
|
|
|
|
|
@ -147,17 +157,23 @@ class DenseLayer(nn.Cell):
|
|
|
|
|
return act_func
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
x = self.act_func(x)
|
|
|
|
|
if self.training:
|
|
|
|
|
x = self.dropout(x)
|
|
|
|
|
x = self.mul(x, self.scale_coef)
|
|
|
|
|
x = self.cast(x, mstype.float16)
|
|
|
|
|
weight = self.cast(self.weight, mstype.float16)
|
|
|
|
|
wx = self.matmul(x, weight)
|
|
|
|
|
wx = self.cast(wx, mstype.float32)
|
|
|
|
|
wx = self.realDiv(wx, self.scale_coef)
|
|
|
|
|
output = self.bias_add(wx, self.bias)
|
|
|
|
|
return output
|
|
|
|
|
"""Construct function"""
|
|
|
|
|
x = self.dropout(x)
|
|
|
|
|
if self.convert_dtype:
|
|
|
|
|
x = self.cast(x, mstype.float16)
|
|
|
|
|
weight = self.cast(self.weight, mstype.float16)
|
|
|
|
|
bias = self.cast(self.bias, mstype.float16)
|
|
|
|
|
wx = self.matmul(x, weight)
|
|
|
|
|
wx = self.bias_add(wx, bias)
|
|
|
|
|
if self.use_act:
|
|
|
|
|
wx = self.act_func(wx)
|
|
|
|
|
wx = self.cast(wx, mstype.float32)
|
|
|
|
|
else:
|
|
|
|
|
wx = self.matmul(x, self.weight)
|
|
|
|
|
wx = self.bias_add(wx, self.bias)
|
|
|
|
|
if self.use_act:
|
|
|
|
|
wx = self.act_func(wx)
|
|
|
|
|
return wx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DeepFMModel(nn.Cell):
|
|
|
|
|
@ -176,6 +192,7 @@ class DeepFMModel(nn.Cell):
|
|
|
|
|
(list[str], weight_bias_init=['random', 'zero'])
|
|
|
|
|
keep_prob (float): if dropout_flag is True, keep_prob rate to keep connect; (float, keep_prob=0.8)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, config):
|
|
|
|
|
super(DeepFMModel, self).__init__()
|
|
|
|
|
|
|
|
|
|
@ -188,24 +205,24 @@ class DeepFMModel(nn.Cell):
|
|
|
|
|
self.weight_bias_init = config.weight_bias_init
|
|
|
|
|
self.keep_prob = config.keep_prob
|
|
|
|
|
init_acts = [('W_l2', [self.vocab_size, 1], 'normal'),
|
|
|
|
|
('V_l2', [self.vocab_size, self.emb_dim], 'normal'),
|
|
|
|
|
('b', [1], 'normal')]
|
|
|
|
|
('V_l2', [self.vocab_size, self.emb_dim], 'normal')]
|
|
|
|
|
var_map = init_var_dict(self.init_args, init_acts)
|
|
|
|
|
self.fm_w = var_map["W_l2"]
|
|
|
|
|
self.fm_b = var_map["b"]
|
|
|
|
|
self.embedding_table = var_map["V_l2"]
|
|
|
|
|
# Deep Layers
|
|
|
|
|
self.deep_input_dims = self.field_size * self.emb_dim + 1
|
|
|
|
|
" Deep Layers "
|
|
|
|
|
self.deep_input_dims = self.field_size * self.emb_dim
|
|
|
|
|
self.all_dim_list = [self.deep_input_dims] + self.deep_layer_dims_list + [1]
|
|
|
|
|
self.dense_layer_1 = DenseLayer(self.all_dim_list[0], self.all_dim_list[1],
|
|
|
|
|
self.weight_bias_init, self.deep_layer_act, self.keep_prob)
|
|
|
|
|
self.dense_layer_2 = DenseLayer(self.all_dim_list[1], self.all_dim_list[2],
|
|
|
|
|
self.weight_bias_init, self.deep_layer_act, self.keep_prob)
|
|
|
|
|
self.dense_layer_3 = DenseLayer(self.all_dim_list[2], self.all_dim_list[3],
|
|
|
|
|
self.weight_bias_init, self.deep_layer_act, self.keep_prob)
|
|
|
|
|
self.dense_layer_4 = DenseLayer(self.all_dim_list[3], self.all_dim_list[4],
|
|
|
|
|
self.weight_bias_init, self.deep_layer_act, self.keep_prob)
|
|
|
|
|
# FM, linear Layers
|
|
|
|
|
self.dense_layer_1 = DenseLayer(self.all_dim_list[0], self.all_dim_list[1], self.weight_bias_init,
|
|
|
|
|
self.deep_layer_act, self.keep_prob, convert_dtype=True)
|
|
|
|
|
self.dense_layer_2 = DenseLayer(self.all_dim_list[1], self.all_dim_list[2], self.weight_bias_init,
|
|
|
|
|
self.deep_layer_act, self.keep_prob, convert_dtype=True)
|
|
|
|
|
self.dense_layer_3 = DenseLayer(self.all_dim_list[2], self.all_dim_list[3], self.weight_bias_init,
|
|
|
|
|
self.deep_layer_act, self.keep_prob, convert_dtype=True)
|
|
|
|
|
self.dense_layer_4 = DenseLayer(self.all_dim_list[3], self.all_dim_list[4], self.weight_bias_init,
|
|
|
|
|
self.deep_layer_act, self.keep_prob, convert_dtype=True)
|
|
|
|
|
self.dense_layer_5 = DenseLayer(self.all_dim_list[4], self.all_dim_list[5], self.weight_bias_init,
|
|
|
|
|
self.deep_layer_act, self.keep_prob, convert_dtype=True, use_act=False)
|
|
|
|
|
" FM, linear Layers "
|
|
|
|
|
self.Gatherv2 = P.GatherV2()
|
|
|
|
|
self.Mul = P.Mul()
|
|
|
|
|
self.ReduceSum = P.ReduceSum(keep_dims=False)
|
|
|
|
|
@ -238,16 +255,14 @@ class DeepFMModel(nn.Cell):
|
|
|
|
|
fm_out = 0.5 * self.ReduceSum(v1 - v2, 1)
|
|
|
|
|
fm_out = self.Reshape(fm_out, (-1, 1))
|
|
|
|
|
# Deep layer
|
|
|
|
|
b = self.Reshape(self.fm_b, (1, 1))
|
|
|
|
|
b = self.Tile(b, (self.batch_size, 1))
|
|
|
|
|
deep_in = self.Reshape(vx, (-1, self.field_size * self.emb_dim))
|
|
|
|
|
deep_in = self.Concat((deep_in, b))
|
|
|
|
|
deep_in = self.dense_layer_1(deep_in)
|
|
|
|
|
deep_in = self.dense_layer_2(deep_in)
|
|
|
|
|
deep_in = self.dense_layer_3(deep_in)
|
|
|
|
|
deep_out = self.dense_layer_4(deep_in)
|
|
|
|
|
deep_in = self.dense_layer_4(deep_in)
|
|
|
|
|
deep_out = self.dense_layer_5(deep_in)
|
|
|
|
|
out = linear_out + fm_out + deep_out
|
|
|
|
|
return out, fm_id_weight, fm_id_embs
|
|
|
|
|
return out, self.fm_w, self.embedding_table
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class NetWithLossClass(nn.Cell):
|
|
|
|
|
@ -278,7 +293,7 @@ class TrainStepWrap(nn.Cell):
|
|
|
|
|
"""
|
|
|
|
|
TrainStepWrap definition
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, network, lr=5e-8, eps=1e-8, loss_scale=1000.0):
|
|
|
|
|
def __init__(self, network, lr, eps, loss_scale=1000.0):
|
|
|
|
|
super(TrainStepWrap, self).__init__(auto_prefix=False)
|
|
|
|
|
self.network = network
|
|
|
|
|
self.network.set_train()
|
|
|
|
|
@ -288,11 +303,24 @@ class TrainStepWrap(nn.Cell):
|
|
|
|
|
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
|
|
|
|
self.sens = loss_scale
|
|
|
|
|
|
|
|
|
|
self.reducer_flag = False
|
|
|
|
|
self.grad_reducer = None
|
|
|
|
|
parallel_mode = _get_parallel_mode()
|
|
|
|
|
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
|
|
|
|
|
self.reducer_flag = True
|
|
|
|
|
if self.reducer_flag:
|
|
|
|
|
mean = _get_gradients_mean()
|
|
|
|
|
degree = _get_device_num()
|
|
|
|
|
self.grad_reducer = DistributedGradReducer(self.optimizer.parameters, mean, degree)
|
|
|
|
|
|
|
|
|
|
def construct(self, batch_ids, batch_wts, label):
|
|
|
|
|
weights = self.weights
|
|
|
|
|
loss = self.network(batch_ids, batch_wts, label)
|
|
|
|
|
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) #
|
|
|
|
|
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) #
|
|
|
|
|
grads = self.grad(self.network, weights)(batch_ids, batch_wts, label, sens)
|
|
|
|
|
if self.reducer_flag:
|
|
|
|
|
# apply grad reducer on grads
|
|
|
|
|
grads = self.grad_reducer(grads)
|
|
|
|
|
return F.depend(loss, self.optimizer(grads))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|