|  |  |  | @ -20,7 +20,7 @@ from mindspore.ops import functional as F | 
			
		
	
		
			
				
					|  |  |  |  | from mindspore.ops import composite as C | 
			
		
	
		
			
				
					|  |  |  |  | from mindspore.ops import operations as P | 
			
		
	
		
			
				
					|  |  |  |  | from mindspore.nn import Dropout | 
			
		
	
		
			
				
					|  |  |  |  | from mindspore.nn.optim import Adam, FTRL | 
			
		
	
		
			
				
					|  |  |  |  | from mindspore.nn.optim import Adam, FTRL, LazyAdam | 
			
		
	
		
			
				
					|  |  |  |  | # from mindspore.nn.metrics import Metric | 
			
		
	
		
			
				
					|  |  |  |  | from mindspore.common.initializer import Uniform, initializer | 
			
		
	
		
			
				
					|  |  |  |  | # from mindspore.train.callback import ModelCheckpoint, CheckpointConfig | 
			
		
	
	
		
			
				
					|  |  |  | @ -82,7 +82,7 @@ class DenseLayer(nn.Cell): | 
			
		
	
		
			
				
					|  |  |  |  |     """ | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     def __init__(self, input_dim, output_dim, weight_bias_init, act_str, | 
			
		
	
		
			
				
					|  |  |  |  |                  keep_prob=0.7, use_activation=True, convert_dtype=True, drop_out=False): | 
			
		
	
		
			
				
					|  |  |  |  |                  keep_prob=0.5, use_activation=True, convert_dtype=True, drop_out=False): | 
			
		
	
		
			
				
					|  |  |  |  |         super(DenseLayer, self).__init__() | 
			
		
	
		
			
				
					|  |  |  |  |         weight_init, bias_init = weight_bias_init | 
			
		
	
		
			
				
					|  |  |  |  |         self.weight = init_method( | 
			
		
	
	
		
			
				
					|  |  |  | @ -137,8 +137,10 @@ class WideDeepModel(nn.Cell): | 
			
		
	
		
			
				
					|  |  |  |  |     def __init__(self, config): | 
			
		
	
		
			
				
					|  |  |  |  |         super(WideDeepModel, self).__init__() | 
			
		
	
		
			
				
					|  |  |  |  |         self.batch_size = config.batch_size | 
			
		
	
		
			
				
					|  |  |  |  |         host_device_mix = bool(config.host_device_mix) | 
			
		
	
		
			
				
					|  |  |  |  |         parallel_mode = _get_parallel_mode() | 
			
		
	
		
			
				
					|  |  |  |  |         if parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL): | 
			
		
	
		
			
				
					|  |  |  |  |         is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | 
			
		
	
		
			
				
					|  |  |  |  |         if is_auto_parallel: | 
			
		
	
		
			
				
					|  |  |  |  |             self.batch_size = self.batch_size * get_group_size() | 
			
		
	
		
			
				
					|  |  |  |  |         self.field_size = config.field_size | 
			
		
	
		
			
				
					|  |  |  |  |         self.vocab_size = config.vocab_size | 
			
		
	
	
		
			
				
					|  |  |  | @ -187,16 +189,29 @@ class WideDeepModel(nn.Cell): | 
			
		
	
		
			
				
					|  |  |  |  |                                         self.weight_bias_init, | 
			
		
	
		
			
				
					|  |  |  |  |                                         self.deep_layer_act, | 
			
		
	
		
			
				
					|  |  |  |  |                                         use_activation=False, convert_dtype=True, drop_out=config.dropout_flag) | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |         self.embeddinglookup = nn.EmbeddingLookup(target='DEVICE') | 
			
		
	
		
			
				
					|  |  |  |  |         self.mul = P.Mul() | 
			
		
	
		
			
				
					|  |  |  |  |         self.wide_mul = P.Mul() | 
			
		
	
		
			
				
					|  |  |  |  |         self.deep_mul = P.Mul() | 
			
		
	
		
			
				
					|  |  |  |  |         self.reduce_sum = P.ReduceSum(keep_dims=False) | 
			
		
	
		
			
				
					|  |  |  |  |         self.reshape = P.Reshape() | 
			
		
	
		
			
				
					|  |  |  |  |         self.deep_reshape = P.Reshape() | 
			
		
	
		
			
				
					|  |  |  |  |         self.square = P.Square() | 
			
		
	
		
			
				
					|  |  |  |  |         self.shape = P.Shape() | 
			
		
	
		
			
				
					|  |  |  |  |         self.tile = P.Tile() | 
			
		
	
		
			
				
					|  |  |  |  |         self.concat = P.Concat(axis=1) | 
			
		
	
		
			
				
					|  |  |  |  |         self.cast = P.Cast() | 
			
		
	
		
			
				
					|  |  |  |  |         if is_auto_parallel and host_device_mix: | 
			
		
	
		
			
				
					|  |  |  |  |             self.dense_layer_1.dropout.dropout_do_mask.set_strategy(((1, get_group_size()),)) | 
			
		
	
		
			
				
					|  |  |  |  |             self.dense_layer_1.matmul.set_strategy(((1, get_group_size()), (get_group_size(), 1))) | 
			
		
	
		
			
				
					|  |  |  |  |             self.deep_embeddinglookup = nn.EmbeddingLookup() | 
			
		
	
		
			
				
					|  |  |  |  |             self.deep_embeddinglookup.embeddinglookup.set_strategy(((1, get_group_size()), (1, 1))) | 
			
		
	
		
			
				
					|  |  |  |  |             self.wide_embeddinglookup = nn.EmbeddingLookup() | 
			
		
	
		
			
				
					|  |  |  |  |             self.wide_embeddinglookup.embeddinglookup.set_strategy(((get_group_size(), 1), (1, 1))) | 
			
		
	
		
			
				
					|  |  |  |  |             self.deep_mul.set_strategy(((1, 1, get_group_size()), (1, 1, 1))) | 
			
		
	
		
			
				
					|  |  |  |  |             self.deep_reshape.add_prim_attr("skip_redistribution", True) | 
			
		
	
		
			
				
					|  |  |  |  |             self.reduce_sum.add_prim_attr("cross_batch", True) | 
			
		
	
		
			
				
					|  |  |  |  |         else: | 
			
		
	
		
			
				
					|  |  |  |  |             self.deep_embeddinglookup = nn.EmbeddingLookup(target='DEVICE') | 
			
		
	
		
			
				
					|  |  |  |  |             self.wide_embeddinglookup = nn.EmbeddingLookup(target='DEVICE') | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     def construct(self, id_hldr, wt_hldr): | 
			
		
	
		
			
				
					|  |  |  |  |         """ | 
			
		
	
	
		
			
				
					|  |  |  | @ -206,13 +221,13 @@ class WideDeepModel(nn.Cell): | 
			
		
	
		
			
				
					|  |  |  |  |         """ | 
			
		
	
		
			
				
					|  |  |  |  |         mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1)) | 
			
		
	
		
			
				
					|  |  |  |  |         # Wide layer | 
			
		
	
		
			
				
					|  |  |  |  |         wide_id_weight = self.embeddinglookup(self.wide_w, id_hldr) | 
			
		
	
		
			
				
					|  |  |  |  |         wx = self.mul(wide_id_weight, mask) | 
			
		
	
		
			
				
					|  |  |  |  |         wide_id_weight = self.wide_embeddinglookup(self.wide_w, id_hldr) | 
			
		
	
		
			
				
					|  |  |  |  |         wx = self.wide_mul(wide_id_weight, mask) | 
			
		
	
		
			
				
					|  |  |  |  |         wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1)) | 
			
		
	
		
			
				
					|  |  |  |  |         # Deep layer | 
			
		
	
		
			
				
					|  |  |  |  |         deep_id_embs = self.embeddinglookup(self.embedding_table, id_hldr) | 
			
		
	
		
			
				
					|  |  |  |  |         vx = self.mul(deep_id_embs, mask) | 
			
		
	
		
			
				
					|  |  |  |  |         deep_in = self.reshape(vx, (-1, self.field_size * self.emb_dim)) | 
			
		
	
		
			
				
					|  |  |  |  |         deep_id_embs = self.deep_embeddinglookup(self.embedding_table, id_hldr) | 
			
		
	
		
			
				
					|  |  |  |  |         vx = self.deep_mul(deep_id_embs, mask) | 
			
		
	
		
			
				
					|  |  |  |  |         deep_in = self.deep_reshape(vx, (-1, self.field_size * self.emb_dim)) | 
			
		
	
		
			
				
					|  |  |  |  |         deep_in = self.dense_layer_1(deep_in) | 
			
		
	
		
			
				
					|  |  |  |  |         deep_in = self.dense_layer_2(deep_in) | 
			
		
	
		
			
				
					|  |  |  |  |         deep_in = self.dense_layer_3(deep_in) | 
			
		
	
	
		
			
				
					|  |  |  | @ -233,19 +248,28 @@ class NetWithLossClass(nn.Cell): | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     def __init__(self, network, config): | 
			
		
	
		
			
				
					|  |  |  |  |         super(NetWithLossClass, self).__init__(auto_prefix=False) | 
			
		
	
		
			
				
					|  |  |  |  |         host_device_mix = bool(config.host_device_mix) | 
			
		
	
		
			
				
					|  |  |  |  |         parallel_mode = _get_parallel_mode() | 
			
		
	
		
			
				
					|  |  |  |  |         is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | 
			
		
	
		
			
				
					|  |  |  |  |         self.no_l2loss = host_device_mix and is_auto_parallel | 
			
		
	
		
			
				
					|  |  |  |  |         self.network = network | 
			
		
	
		
			
				
					|  |  |  |  |         self.l2_coef = config.l2_coef | 
			
		
	
		
			
				
					|  |  |  |  |         self.loss = P.SigmoidCrossEntropyWithLogits() | 
			
		
	
		
			
				
					|  |  |  |  |         self.square = P.Square() | 
			
		
	
		
			
				
					|  |  |  |  |         self.reduceMean_false = P.ReduceMean(keep_dims=False) | 
			
		
	
		
			
				
					|  |  |  |  |         if is_auto_parallel: | 
			
		
	
		
			
				
					|  |  |  |  |             self.reduceMean_false.add_prim_attr("cross_batch", True) | 
			
		
	
		
			
				
					|  |  |  |  |         self.reduceSum_false = P.ReduceSum(keep_dims=False) | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     def construct(self, batch_ids, batch_wts, label): | 
			
		
	
		
			
				
					|  |  |  |  |         predict, embedding_table = self.network(batch_ids, batch_wts) | 
			
		
	
		
			
				
					|  |  |  |  |         log_loss = self.loss(predict, label) | 
			
		
	
		
			
				
					|  |  |  |  |         wide_loss = self.reduceMean_false(log_loss) | 
			
		
	
		
			
				
					|  |  |  |  |         l2_loss_v = self.reduceSum_false(self.square(embedding_table)) / 2 | 
			
		
	
		
			
				
					|  |  |  |  |         deep_loss = self.reduceMean_false(log_loss) + self.l2_coef * l2_loss_v | 
			
		
	
		
			
				
					|  |  |  |  |         if self.no_l2loss: | 
			
		
	
		
			
				
					|  |  |  |  |             deep_loss = wide_loss | 
			
		
	
		
			
				
					|  |  |  |  |         else: | 
			
		
	
		
			
				
					|  |  |  |  |             l2_loss_v = self.reduceSum_false(self.square(embedding_table)) / 2 | 
			
		
	
		
			
				
					|  |  |  |  |             deep_loss = self.reduceMean_false(log_loss) + self.l2_coef * l2_loss_v | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |         return wide_loss, deep_loss | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
	
		
			
				
					|  |  |  | @ -267,12 +291,15 @@ class TrainStepWrap(nn.Cell): | 
			
		
	
		
			
				
					|  |  |  |  |     Append Adam and FTRL optimizers to the training network after that construct | 
			
		
	
		
			
				
					|  |  |  |  |     function can be called to create the backward graph. | 
			
		
	
		
			
				
					|  |  |  |  |     Args: | 
			
		
	
		
			
				
					|  |  |  |  |         network (Cell): the training network. Note that loss function should have been added. | 
			
		
	
		
			
				
					|  |  |  |  |         sens (Number): The adjust parameter. Default: 1000.0 | 
			
		
	
		
			
				
					|  |  |  |  |         network (Cell): The training network. Note that loss function should have been added. | 
			
		
	
		
			
				
					|  |  |  |  |         sens (Number): The adjust parameter. Default: 1024.0 | 
			
		
	
		
			
				
					|  |  |  |  |         host_device_mix (Bool): Whether run in host and device mix mode. Default: False | 
			
		
	
		
			
				
					|  |  |  |  |     """ | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     def __init__(self, network, sens=1024.0): | 
			
		
	
		
			
				
					|  |  |  |  |     def __init__(self, network, sens=1024.0, host_device_mix=False): | 
			
		
	
		
			
				
					|  |  |  |  |         super(TrainStepWrap, self).__init__() | 
			
		
	
		
			
				
					|  |  |  |  |         parallel_mode = _get_parallel_mode() | 
			
		
	
		
			
				
					|  |  |  |  |         is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) | 
			
		
	
		
			
				
					|  |  |  |  |         self.network = network | 
			
		
	
		
			
				
					|  |  |  |  |         self.network.set_train() | 
			
		
	
		
			
				
					|  |  |  |  |         self.trainable_params = network.trainable_params() | 
			
		
	
	
		
			
				
					|  |  |  | @ -285,10 +312,19 @@ class TrainStepWrap(nn.Cell): | 
			
		
	
		
			
				
					|  |  |  |  |                 weights_d.append(params) | 
			
		
	
		
			
				
					|  |  |  |  |         self.weights_w = ParameterTuple(weights_w) | 
			
		
	
		
			
				
					|  |  |  |  |         self.weights_d = ParameterTuple(weights_d) | 
			
		
	
		
			
				
					|  |  |  |  |         self.optimizer_w = FTRL(learning_rate=1e-2, params=self.weights_w, | 
			
		
	
		
			
				
					|  |  |  |  |                                 l1=1e-8, l2=1e-8, initial_accum=1.0) | 
			
		
	
		
			
				
					|  |  |  |  |         self.optimizer_d = Adam( | 
			
		
	
		
			
				
					|  |  |  |  |             self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens) | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |         if host_device_mix and is_auto_parallel: | 
			
		
	
		
			
				
					|  |  |  |  |             self.optimizer_d = LazyAdam( | 
			
		
	
		
			
				
					|  |  |  |  |                 self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens) | 
			
		
	
		
			
				
					|  |  |  |  |             self.optimizer_w = FTRL(learning_rate=5e-2, params=self.weights_w, | 
			
		
	
		
			
				
					|  |  |  |  |                                     l1=1e-8, l2=1e-8, initial_accum=1.0, loss_scale=sens) | 
			
		
	
		
			
				
					|  |  |  |  |             self.optimizer_w.sparse_opt.add_prim_attr("primitive_target", "CPU") | 
			
		
	
		
			
				
					|  |  |  |  |             self.optimizer_d.sparse_opt.add_prim_attr("primitive_target", "CPU") | 
			
		
	
		
			
				
					|  |  |  |  |         else: | 
			
		
	
		
			
				
					|  |  |  |  |             self.optimizer_d = Adam( | 
			
		
	
		
			
				
					|  |  |  |  |                 self.weights_d, learning_rate=3.5e-4, eps=1e-8, loss_scale=sens) | 
			
		
	
		
			
				
					|  |  |  |  |             self.optimizer_w = FTRL(learning_rate=5e-2, params=self.weights_w, | 
			
		
	
		
			
				
					|  |  |  |  |                                     l1=1e-8, l2=1e-8, initial_accum=1.0, loss_scale=sens) | 
			
		
	
		
			
				
					|  |  |  |  |         self.hyper_map = C.HyperMap() | 
			
		
	
		
			
				
					|  |  |  |  |         self.grad_w = C.GradOperation('grad_w', get_by_list=True, | 
			
		
	
		
			
				
					|  |  |  |  |                                       sens_param=True) |