!4581 modelzoo wide_and_deep_multitable

Merge pull request !4581 from yao_yf/modelzoo_wide_and_deep_mutitable
pull/4581/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 28755b2f1a

@ -69,8 +69,8 @@ bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
auto indices_addr = reinterpret_cast<int *>(inputs[1]->addr);
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
const size_t thread_num = 8;
std::thread threads[8];
const size_t thread_num = 16;
std::thread threads[16];
size_t task_proc_lens = (indices_lens_ + thread_num - 1) / thread_num;
size_t i;
size_t task_offset = 0;

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""train_imagenet."""
"""train_dataset."""
import os

@ -164,9 +164,6 @@ class WideDeepModel(nn.Cell):
init_acts = [('Wide_b', [1], self.emb_init)]
var_map = init_var_dict(self.init_args, init_acts)
self.wide_b = var_map["Wide_b"]
if parameter_server:
self.wide_w.set_param_ps()
self.embedding_table.set_param_ps()
self.dense_layer_1 = DenseLayer(self.all_dim_list[0],
self.all_dim_list[1],
self.weight_bias_init,
@ -217,6 +214,8 @@ class WideDeepModel(nn.Cell):
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim)
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1)
self.embedding_table = self.deep_embeddinglookup.embedding_table
self.wide_w.set_param_ps()
self.embedding_table.set_param_ps()
else:
self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target='DEVICE')
self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, target='DEVICE')

@ -0,0 +1,34 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# bash run_multinpu_train.sh
execute_path=$(pwd)
script_self=$(readlink -f "$0")
self_path=$(dirname "${script_self}")
export RANK_SIZE=$1
export EPOCH_SIZE=$2
export DATASET=$3
export RANK_TABLE_FILE=$4
for((i=0;i<$RANK_SIZE;i++));
do
rm -rf ${execute_path}/device_$i/
mkdir ${execute_path}/device_$i/
cd ${execute_path}/device_$i/ || exit
export RANK_ID=$i
export DEVICE_ID=$i
python -s ${self_path}/../train_and_eval_distribute.py --data_path=$DATASET --epochs=$EPOCH_SIZE >train_deep$i.log 2>&1 &
done

@ -0,0 +1,96 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
callbacks
"""
import time
from mindspore.train.callback import Callback
def add_write(file_path, out_str):
with open(file_path, 'a+', encoding="utf-8") as file_out:
file_out.write(out_str + "\n")
class LossCallBack(Callback):
"""
Monitor the loss in training.
If the loss is NAN or INF terminating training.
Note:
If per_print_times is 0 do not print loss.
Args:
per_print_times (int): Print loss every times. Default: 1.
"""
def __init__(self, config, per_print_times=1):
super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0.")
self._per_print_times = per_print_times
self.config = config
def step_end(self, run_context):
"""Monitor the loss in training."""
cb_params = run_context.original_args()
wide_loss, deep_loss = cb_params.net_outputs[0].asnumpy(), \
cb_params.net_outputs[1].asnumpy()
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
cur_num = cb_params.cur_step_num
print("===loss===", cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss, deep_loss, flush=True)
if self._per_print_times != 0 and cur_num % self._per_print_times == 0:
loss_file = open(self.config.loss_file_name, "a+")
loss_file.write(
"epoch: %s step: %s, wide_loss is %s, deep_loss is %s" %
(cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss,
deep_loss))
loss_file.write("\n")
loss_file.close()
print("epoch: %s step: %s, wide_loss is %s, deep_loss is %s" % (
cb_params.cur_epoch_num, cur_step_in_epoch, wide_loss,
deep_loss))
class EvalCallBack(Callback):
"""
Monitor the loss in training.
If the loss is NAN or INF terminating training.
Note:
If per_print_times is 0 do not print loss.
Args:
per_print_times (int): Print loss every times. Default: 1.
"""
def __init__(self, model, eval_dataset, auc_metric, config, print_per_step=1):
super(EvalCallBack, self).__init__()
if not isinstance(print_per_step, int) or print_per_step < 0:
raise ValueError("print_step must be int and >= 0.")
self.print_per_step = print_per_step
self.model = model
self.eval_dataset = eval_dataset
self.aucMetric = auc_metric
self.aucMetric.clear()
self.eval_file_name = config.eval_file_name
def epoch_end(self, run_context):
"""Monitor the auc in training."""
self.aucMetric.clear()
start_time = time.time()
out = self.model.eval(self.eval_dataset)
end_time = time.time()
eval_time = int(end_time - start_time)
time_str = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
out_str = "{}=====EvalCallBack model.eval(): {} ; eval_time:{}s".format(time_str, out.values(), eval_time)
print(out_str)
add_write(self.eval_file_name, out_str)

@ -0,0 +1,95 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" config. """
import argparse
def argparse_init():
"""
argparse_init
"""
parser = argparse.ArgumentParser(description='WideDeep')
parser.add_argument("--data_path", type=str, default="./test_raw_data/") # The location of the input data.
parser.add_argument("--epochs", type=int, default=200) # The number of epochs used to train.
parser.add_argument("--batch_size", type=int, default=131072) # Batch size for training and evaluation
parser.add_argument("--eval_batch_size", type=int, default=131072) # The batch size used for evaluation.
parser.add_argument("--deep_layers_dim", type=int, nargs='+', default=[1024, 512, 256, 128]) # The sizes of hidden layers for MLP
parser.add_argument("--deep_layers_act", type=str, default='relu') # The act of hidden layers for MLP
parser.add_argument("--keep_prob", type=float, default=1.0) # The Embedding size of MF model.
parser.add_argument("--adam_lr", type=float, default=0.003) # The Adam lr
parser.add_argument("--ftrl_lr", type=float, default=0.1) # The ftrl lr.
parser.add_argument("--l2_coef", type=float, default=0.0) # The l2 coefficient.
parser.add_argument("--is_tf_dataset", type=bool, default=True) # The l2 coefficient.
parser.add_argument("--output_path", type=str, default="./output/") # The location of the output file.
parser.add_argument("--ckpt_path", type=str, default="./checkpoints/") # The location of the checkpoints file.
parser.add_argument("--eval_file_name", type=str, default="eval.log") # Eval output file.
parser.add_argument("--loss_file_name", type=str, default="loss.log") # Loss output file.
return parser
class WideDeepConfig():
"""
WideDeepConfig
"""
def __init__(self):
self.data_path = ''
self.epochs = 200
self.batch_size = 131072
self.eval_batch_size = 131072
self.deep_layers_act = 'relu'
self.weight_bias_init = ['normal', 'normal']
self.emb_init = 'normal'
self.init_args = [-0.01, 0.01]
self.dropout_flag = False
self.keep_prob = 1.0
self.l2_coef = 0.0
self.adam_lr = 0.003
self.ftrl_lr = 0.1
self.is_tf_dataset = True
self.input_emb_dim = 0
self.output_path = "./output/"
self.eval_file_name = "eval.log"
self.loss_file_name = "loss.log"
self.ckpt_path = "./checkpoints/"
def argparse_init(self):
"""
argparse_init
"""
parser = argparse_init()
args, _ = parser.parse_known_args()
self.data_path = args.data_path
self.epochs = args.epochs
self.batch_size = args.batch_size
self.eval_batch_size = args.eval_batch_size
self.deep_layers_act = args.deep_layers_act
self.keep_prob = args.keep_prob
self.weight_bias_init = ['normal', 'normal']
self.emb_init = 'normal'
self.init_args = [-0.01, 0.01]
self.dropout_flag = False
self.l2_coef = args.l2_coef
self.ftrl_lr = args.ftrl_lr
self.adam_lr = args.adam_lr
self.is_tf_dataset = args.is_tf_dataset
self.output_path = args.output_path
self.eval_file_name = args.eval_file_name
self.loss_file_name = args.loss_file_name
self.ckpt_path = args.ckpt_path

@ -0,0 +1,153 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Area under cure metric
"""
import time
import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score, average_precision_score
from mindspore.nn.metrics import Metric
def groupby_df_v1(test_df, gb_key):
"""
groupby_df_v1
"""
data_groups = test_df.groupby(gb_key)
return data_groups
def _compute_metric_v1(batch_groups, topk):
"""
_compute_metric_v1
"""
results = []
for df in batch_groups:
df = df.sort_values(by="preds", ascending=False)
if df.shape[0] > topk:
df = df.head(topk)
preds = df["preds"].values
labels = df["labels"].values
if np.sum(labels) > 0:
results.append(average_precision_score(labels, preds))
else:
results.append(0.0)
return results
def mean_AP_topk(batch_labels, batch_preds, topk=12):
"""
mean_AP_topk
"""
def ap_score(label, y_preds, topk):
ind_list = np.argsort(y_preds)[::-1]
ind_list = ind_list[:topk]
if label not in set(ind_list):
return 0.0
rank = list(ind_list).index(label)
return 1.0 / (rank + 1)
mAP_list = []
for label, preds in zip(batch_labels, batch_preds):
mAP = ap_score(label, preds, topk)
mAP_list.append(mAP)
return mAP_list
def new_compute_mAP(test_df, gb_key="display_ids", top_k=12):
"""
new_compute_mAP
"""
total_start = time.time()
display_ids = test_df["display_ids"]
labels = test_df["labels"]
predictions = test_df["preds"]
test_df.sort_values(by=[gb_key], inplace=True, ascending=True)
display_ids = test_df["display_ids"]
labels = test_df["labels"]
predictions = test_df["preds"]
_, display_ids_idx = np.unique(display_ids, return_index=True)
preds = np.split(predictions.tolist(), display_ids_idx.tolist()[1:])
labels = np.split(labels.tolist(), display_ids_idx.tolist()[1:])
def pad_fn(ele_l):
res_list = ele_l + [0.0 for i in range(30 - len(ele_l))]
return res_list
preds = list(map(lambda x: pad_fn(x.tolist()), preds))
labels = [np.argmax(l) for l in labels]
result_list = []
batch_size = 100000
for idx in range(0, len(labels), batch_size):
batch_labels = labels[idx:idx + batch_size]
batch_preds = preds[idx:idx + batch_size]
meanAP = mean_AP_topk(batch_labels, batch_preds, topk=top_k)
result_list.extend(meanAP)
mean_AP = np.mean(result_list)
print("compute time: {}".format(time.time() - total_start))
print("mean_AP: {}".format(mean_AP))
return mean_AP
class AUCMetric(Metric):
"""
AUCMetric
"""
def __init__(self):
super(AUCMetric, self).__init__()
self.index = 1
def clear(self):
"""Clear the internal evaluation result."""
self.true_labels = []
self.pred_probs = []
self.display_id = []
def update(self, *inputs):
"""
update
"""
all_predict = inputs[1].asnumpy() # predict
all_label = inputs[2].asnumpy() # label
all_display_id = inputs[3].asnumpy() # label
self.true_labels.extend(all_label.flatten().tolist())
self.pred_probs.extend(all_predict.flatten().tolist())
self.display_id.extend(all_display_id.flatten().tolist())
def eval(self):
"""
eval
"""
if len(self.true_labels) != len(self.pred_probs):
raise RuntimeError(
'true_labels.size() is not equal to pred_probs.size()')
result_df = pd.DataFrame({
"display_ids": self.display_id,
"preds": self.pred_probs,
"labels": self.true_labels,
})
auc = roc_auc_score(self.true_labels, self.pred_probs)
MAP = new_compute_mAP(result_df, gb_key="display_ids", top_k=12)
print("=====" * 20 + " auc_metric end ")
print("=====" * 20 + " auc: {}, map: {}".format(auc, MAP))
return auc

@ -0,0 +1,107 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" training_and_evaluating """
import os
import sys
from mindspore import Model, context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.callback import TimeMonitor
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
from src.callbacks import LossCallBack, EvalCallBack
from src.datasets import create_dataset, compute_emb_dim
from src.metrics import AUCMetric
from src.config import WideDeepConfig
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
def get_WideDeep_net(config):
"""
Get network of wide&deep model.
"""
WideDeep_net = WideDeepModel(config)
loss_net = NetWithLossClass(WideDeep_net, config)
train_net = TrainStepWrap(loss_net, config)
eval_net = PredictWithSigmoid(WideDeep_net)
return train_net, eval_net
class ModelBuilder():
"""
ModelBuilder.
"""
def __init__(self):
pass
def get_hook(self):
pass
def get_train_hook(self):
hooks = []
callback = LossCallBack()
hooks.append(callback)
if int(os.getenv('DEVICE_ID')) == 0:
pass
return hooks
def get_net(self, config):
return get_WideDeep_net(config)
def train_and_eval(config):
"""
train_and_eval.
"""
data_path = config.data_path
epochs = config.epochs
print("epochs is {}".format(epochs))
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
batch_size=config.batch_size, is_tf_dataset=config.is_tf_dataset)
ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
batch_size=config.batch_size, is_tf_dataset=config.is_tf_dataset)
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
net_builder = ModelBuilder()
train_net, eval_net = net_builder.get_net(config)
train_net.set_train()
auc_metric = AUCMetric()
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)
callback = LossCallBack(config)
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(),
keep_checkpoint_max=10)
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
directory=config.ckpt_path, config=ckptconfig)
model.train(epochs, ds_train, callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback,
callback, ckpoint_cb])
if __name__ == "__main__":
wide_and_deep_config = WideDeepConfig()
wide_and_deep_config.argparse_init()
compute_emb_dim(wide_and_deep_config)
context.set_context(mode=context.GRAPH_MODE, device_target="Davinci",
save_graphs=True)
train_and_eval(wide_and_deep_config)

@ -0,0 +1,113 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" training_multinpu"""
import os
import sys
from mindspore import Model, context
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.callback import TimeMonitor
from mindspore.train import ParallelMode
from mindspore.communication.management import get_rank, get_group_size, init
from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel
from src.callbacks import LossCallBack, EvalCallBack
from src.datasets import create_dataset, compute_emb_dim
from src.metrics import AUCMetric
from src.config import WideDeepConfig
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
def get_WideDeep_net(config):
"""
get_WideDeep_net
"""
WideDeep_net = WideDeepModel(config)
loss_net = NetWithLossClass(WideDeep_net, config)
train_net = TrainStepWrap(loss_net, config)
eval_net = PredictWithSigmoid(WideDeep_net)
return train_net, eval_net
class ModelBuilder():
"""
ModelBuilder
"""
def __init__(self):
pass
def get_hook(self):
pass
def get_train_hook(self):
hooks = []
callback = LossCallBack()
hooks.append(callback)
if int(os.getenv('DEVICE_ID')) == 0:
pass
return hooks
def get_net(self, config):
return get_WideDeep_net(config)
def train_and_eval(config):
"""
train_and_eval
"""
data_path = config.data_path
epochs = config.epochs
print("epochs is {}".format(epochs))
ds_train = create_dataset(data_path, train_mode=True, epochs=1,
batch_size=config.batch_size, is_tf_dataset=config.is_tf_dataset,
rank_id=get_rank(), rank_size=get_group_size())
ds_eval = create_dataset(data_path, train_mode=False, epochs=1,
batch_size=config.batch_size, is_tf_dataset=config.is_tf_dataset,
rank_id=get_rank(), rank_size=get_group_size())
print("ds_train.size: {}".format(ds_train.get_dataset_size()))
print("ds_eval.size: {}".format(ds_eval.get_dataset_size()))
net_builder = ModelBuilder()
train_net, eval_net = net_builder.get_net(config)
train_net.set_train()
auc_metric = AUCMetric()
model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric})
eval_callback = EvalCallBack(model, ds_eval, auc_metric, config)
callback = LossCallBack(config)
ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(),
keep_checkpoint_max=10)
ckpoint_cb = ModelCheckpoint(prefix='widedeep_train',
directory=config.ckpt_path, config=ckptconfig)
model.train(epochs, ds_train, callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback,
callback, ckpoint_cb])
if __name__ == "__main__":
wide_and_deep_config = WideDeepConfig()
wide_and_deep_config.argparse_init()
compute_emb_dim(wide_and_deep_config)
context.set_context(mode=context.GRAPH_MODE, device_target="Davinci",
save_graphs=True)
init()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
device_num=get_group_size())
train_and_eval(wide_and_deep_config)
Loading…
Cancel
Save