|
|
|
@ -161,13 +161,9 @@ class WideDeepModel(nn.Cell):
|
|
|
|
|
self.layer_dims = self.deep_layer_dims_list + [1]
|
|
|
|
|
self.all_dim_list = [self.deep_input_dims] + self.layer_dims
|
|
|
|
|
|
|
|
|
|
init_acts = [('Wide_w', [self.vocab_size, 1], self.emb_init),
|
|
|
|
|
('V_l2', [self.vocab_size, self.emb_dim], self.emb_init),
|
|
|
|
|
('Wide_b', [1], self.emb_init)]
|
|
|
|
|
init_acts = [('Wide_b', [1], self.emb_init)]
|
|
|
|
|
var_map = init_var_dict(self.init_args, init_acts)
|
|
|
|
|
self.wide_w = var_map["Wide_w"]
|
|
|
|
|
self.wide_b = var_map["Wide_b"]
|
|
|
|
|
self.embedding_table = var_map["V_l2"]
|
|
|
|
|
if parameter_server:
|
|
|
|
|
self.wide_w.set_param_ps()
|
|
|
|
|
self.embedding_table.set_param_ps()
|
|
|
|
|