From 78386683bf9361c4cce8e52564982a853c784096 Mon Sep 17 00:00:00 2001 From: yao_yf Date: Thu, 27 Aug 2020 09:46:55 +0800 Subject: [PATCH] wide_and_deep field slice --- .../script/run_auto_parallel_train_cluster.sh | 2 + .../wide_and_deep/script/start_cluster.sh | 2 +- .../recommend/wide_and_deep/src/config.py | 8 +- .../recommend/wide_and_deep/src/datasets.py | 137 +++++++++++++++--- .../wide_and_deep/src/wide_and_deep.py | 23 ++- .../train_and_eval_auto_parallel.py | 45 ++++-- .../train_and_eval_distribute.py | 8 +- .../train_and_eval_parameter_server.py | 13 +- 8 files changed, 184 insertions(+), 54 deletions(-) diff --git a/model_zoo/official/recommend/wide_and_deep/script/run_auto_parallel_train_cluster.sh b/model_zoo/official/recommend/wide_and_deep/script/run_auto_parallel_train_cluster.sh index 495c4fc40d..a6ff83a6af 100644 --- a/model_zoo/official/recommend/wide_and_deep/script/run_auto_parallel_train_cluster.sh +++ b/model_zoo/official/recommend/wide_and_deep/script/run_auto_parallel_train_cluster.sh @@ -41,6 +41,8 @@ do cd ${execute_path}/device_$RANK_ID || exit if [ $MODE == "host_device_mix" ]; then python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --host_device_mix=1 >train_deep$i.log 2>&1 & + elif [ $MODE == "field_slice_host_device_mix" ]; then + python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --host_device_mix=1 --full_batch=1 --field_slice=1 >train_deep$i.log 2>&1 & else python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --host_device_mix=0 >train_deep$i.log 2>&1 & fi diff --git a/model_zoo/official/recommend/wide_and_deep/script/start_cluster.sh b/model_zoo/official/recommend/wide_and_deep/script/start_cluster.sh index 06c9e4dfb5..8499642ccc 100644 --- a/model_zoo/official/recommend/wide_and_deep/script/start_cluster.sh +++ b/model_zoo/official/recommend/wide_and_deep/script/start_cluster.sh @@ -38,7 +38,7 @@ do user=$(get_node_user ${cluster_config_path} ${node}) passwd=$(get_node_passwd ${cluster_config_path} ${node}) echo "------------------${user}@${node}---------------------" - if [ $MODE == "host_device_mix" ]; then + if [ $MODE == "host_device_mix" ] || [ $MODE == "field_slice_host_device_mix" ]; then ssh_pass ${node} ${user} ${passwd} "mkdir -p ${execute_path}; cd ${execute_path}; bash ${SCRIPTPATH}/run_auto_parallel_train_cluster.sh ${RANK_SIZE} ${RANK_START} ${EPOCH_SIZE} ${VOCAB_SIZE} ${EMB_DIM} ${DATASET} ${ENV_SH} ${MODE} ${RANK_TABLE_FILE}" else echo "[ERROR] mode is wrong" diff --git a/model_zoo/official/recommend/wide_and_deep/src/config.py b/model_zoo/official/recommend/wide_and_deep/src/config.py index 5464110cb1..c022ce3382 100644 --- a/model_zoo/official/recommend/wide_and_deep/src/config.py +++ b/model_zoo/official/recommend/wide_and_deep/src/config.py @@ -25,7 +25,7 @@ def argparse_init(): parser.add_argument("--data_path", type=str, default="./test_raw_data/", help="This should be set to the same directory given to the data_download's data_dir argument") parser.add_argument("--epochs", type=int, default=15, help="Total train epochs") - parser.add_argument("--full_batch", type=bool, default=False, help="Enable loading the full batch ") + parser.add_argument("--full_batch", type=int, default=0, help="Enable loading the full batch ") parser.add_argument("--batch_size", type=int, default=16000, help="Training batch size.") parser.add_argument("--eval_batch_size", type=int, default=16000, help="Eval batch size.") parser.add_argument("--field_size", type=int, default=39, help="The number of features.") @@ -46,6 +46,7 @@ def argparse_init(): parser.add_argument("--host_device_mix", type=int, default=0, help="Enable host device mode or not") parser.add_argument("--dataset_type", type=str, default="tfrecord", help="tfrecord/mindrecord/hd5") parser.add_argument("--parameter_server", type=int, default=0, help="Open parameter server of not") + parser.add_argument("--field_slice", type=int, default=0, help="Enable split field mode or not") return parser @@ -81,6 +82,8 @@ class WideDeepConfig(): self.host_device_mix = 0 self.dataset_type = "tfrecord" self.parameter_server = 0 + self.field_slice = False + self.manual_shape = None def argparse_init(self): """ @@ -91,7 +94,7 @@ class WideDeepConfig(): self.device_target = args.device_target self.data_path = args.data_path self.epochs = args.epochs - self.full_batch = args.full_batch + self.full_batch = bool(args.full_batch) self.batch_size = args.batch_size self.eval_batch_size = args.eval_batch_size self.field_size = args.field_size @@ -114,3 +117,4 @@ class WideDeepConfig(): self.host_device_mix = args.host_device_mix self.dataset_type = args.dataset_type self.parameter_server = args.parameter_server + self.field_slice = bool(args.field_slice) diff --git a/model_zoo/official/recommend/wide_and_deep/src/datasets.py b/model_zoo/official/recommend/wide_and_deep/src/datasets.py index 8ed4ba2375..1a28ac5327 100644 --- a/model_zoo/official/recommend/wide_and_deep/src/datasets.py +++ b/model_zoo/official/recommend/wide_and_deep/src/datasets.py @@ -23,6 +23,7 @@ import pandas as pd import mindspore.dataset.engine as de import mindspore.common.dtype as mstype + class DataType(Enum): """ Enumerate supported dataset format. @@ -83,9 +84,9 @@ class H5Dataset(): yield os.path.join(self._hdf_data_dir, self._file_prefix + '_input_part_' + str( p) + '.h5'), \ - os.path.join(self._hdf_data_dir, - self._file_prefix + '_output_part_' + str( - p) + '.h5'), i + 1 == len(parts) + os.path.join(self._hdf_data_dir, + self._file_prefix + '_output_part_' + str( + p) + '.h5'), i + 1 == len(parts) def _generator(self, X, y, batch_size, shuffle=True): """ @@ -169,8 +170,41 @@ def _get_h5_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000): return ds +def _padding_func(batch_size, manual_shape, target_column, field_size=39): + """ + get padding_func + """ + if manual_shape: + generate_concat_offset = [item[0]+item[1] for item in manual_shape] + part_size = int(target_column / len(generate_concat_offset)) + filled_value = [] + for i in range(field_size, target_column): + filled_value.append(generate_concat_offset[i//part_size]-1) + print("Filed Value:", filled_value) + + def padding_func(x, y, z): + x = np.array(x).flatten().reshape(batch_size, field_size) + y = np.array(y).flatten().reshape(batch_size, field_size) + z = np.array(z).flatten().reshape(batch_size, 1) + + x_id = np.ones((batch_size, target_column - field_size), + dtype=np.int32) * filled_value + x_id = np.concatenate([x, x_id.astype(dtype=np.int32)], axis=1) + mask = np.concatenate( + [y, np.zeros((batch_size, target_column-39), dtype=np.float32)], axis=1) + return (x_id, mask, z) + else: + def padding_func(x, y, z): + x = np.array(x).flatten().reshape(batch_size, field_size) + y = np.array(y).flatten().reshape(batch_size, field_size) + z = np.array(z).flatten().reshape(batch_size, 1) + return (x, y, z) + return padding_func + + def _get_tf_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000, - line_per_sample=1000, rank_size=None, rank_id=None): + line_per_sample=1000, rank_size=None, rank_id=None, + manual_shape=None, target_column=40): """ get_tf_dataset """ @@ -189,21 +223,22 @@ def _get_tf_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000, ds = de.TFRecordDataset(dataset_files=dataset_files, shuffle=shuffle, schema=schema, num_parallel_workers=8, num_shards=rank_size, shard_id=rank_id, shard_equal_rows=True) else: - ds = de.TFRecordDataset(dataset_files=dataset_files, shuffle=shuffle, schema=schema, num_parallel_workers=8) + ds = de.TFRecordDataset(dataset_files=dataset_files, + shuffle=shuffle, schema=schema, num_parallel_workers=8) ds = ds.batch(int(batch_size / line_per_sample), drop_remainder=True) - ds = ds.map(operations=(lambda x, y, z: ( - np.array(x).flatten().reshape(batch_size, 39), - np.array(y).flatten().reshape(batch_size, 39), - np.array(z).flatten().reshape(batch_size, 1))), + + ds = ds.map(operations=_padding_func(batch_size, manual_shape, target_column), input_columns=['feat_ids', 'feat_vals', 'label'], columns_order=['feat_ids', 'feat_vals', 'label'], num_parallel_workers=8) - #if train_mode: + # if train_mode: ds = ds.repeat(epochs) return ds + def _get_mindrecord_dataset(directory, train_mode=True, epochs=1, batch_size=1000, - line_per_sample=1000, rank_size=None, rank_id=None): + line_per_sample=1000, rank_size=None, rank_id=None, + manual_shape=None, target_column=40): """ Get dataset with mindrecord format. @@ -233,9 +268,7 @@ def _get_mindrecord_dataset(directory, train_mode=True, epochs=1, batch_size=100 columns_list=['feat_ids', 'feat_vals', 'label'], shuffle=shuffle, num_parallel_workers=8) ds = ds.batch(int(batch_size / line_per_sample), drop_remainder=True) - ds = ds.map(operations=(lambda x, y, z: (np.array(x).flatten().reshape(batch_size, 39), - np.array(y).flatten().reshape(batch_size, 39), - np.array(z).flatten().reshape(batch_size, 1))), + ds = ds.map(_padding_func(batch_size, manual_shape, target_column), input_columns=['feat_ids', 'feat_vals', 'label'], columns_order=['feat_ids', 'feat_vals', 'label'], num_parallel_workers=8) @@ -243,18 +276,84 @@ def _get_mindrecord_dataset(directory, train_mode=True, epochs=1, batch_size=100 return ds +def _get_vocab_size(target_column_number, worker_size, total_vocab_size, multiply=False, per_vocab_size=None): + """ + get_vocab_size + """ + # Only 39 + inidival_vocabs = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 691, 540, 20855, 23639, 182, 15, + 10091, 347, 4, 16366, 4494, 21293, 3103, 27, 6944, 22366, 11, 3267, 1610, + 5, 21762, 14, 15, 15030, 61, 12220] + + new_vocabs = inidival_vocabs + [1] * \ + (target_column_number - len(inidival_vocabs)) + part_size = int(target_column_number / worker_size) + + # According to the workers, we merge some fields into the same part + new_vocab_size = [] + for i in range(0, target_column_number, part_size): + new_vocab_size.append(sum(new_vocabs[i: i + part_size])) + + index_offsets = [0] + + # The gold feature numbers ared used to caculate the offset + features = [item for item in new_vocab_size] + + # According to the per_vocab_size, maxize the vocab size + if per_vocab_size is not None: + new_vocab_size = [per_vocab_size] * worker_size + else: + # Expands the vocabulary of each field by the multiplier + if multiply is True: + cur_sum = sum(new_vocab_size) + k = total_vocab_size/cur_sum + new_vocab_size = [ + math.ceil(int(item*k)/worker_size)*worker_size for item in new_vocab_size] + new_vocab_size = [(item // 8 + 1)*8 for item in new_vocab_size] + + else: + if total_vocab_size > sum(new_vocab_size): + new_vocab_size[-1] = total_vocab_size - \ + sum(new_vocab_size[:-1]) + new_vocab_size = [item for item in new_vocab_size] + else: + raise ValueError( + "Please providede the correct vocab size, now is {}".format(total_vocab_size)) + + for i in range(worker_size-1): + off = index_offsets[i] + features[i] + index_offsets.append(off) + + print("the offset: ", index_offsets) + manual_shape = tuple( + ((new_vocab_size[i], index_offsets[i]) for i in range(worker_size))) + vocab_total = sum(new_vocab_size) + return manual_shape, vocab_total + + +def compute_manual_shape(config, worker_size): + target_column = (config.field_size // worker_size + 1) * worker_size + config.field_size = target_column + manual_shape, vocab_total = _get_vocab_size(target_column, worker_size, total_vocab_size=config.vocab_size, + per_vocab_size=None, multiply=False) + config.manual_shape = manual_shape + config.vocab_size = int(vocab_total) + + def create_dataset(data_dir, train_mode=True, epochs=1, batch_size=1000, - data_type=DataType.TFRECORD, line_per_sample=1000, rank_size=None, rank_id=None): + data_type=DataType.TFRECORD, line_per_sample=1000, + rank_size=None, rank_id=None, manual_shape=None, target_column=40): """ create_dataset """ if data_type == DataType.TFRECORD: return _get_tf_dataset(data_dir, train_mode, epochs, batch_size, - line_per_sample, rank_size=rank_size, rank_id=rank_id) + line_per_sample, rank_size=rank_size, rank_id=rank_id, + manual_shape=manual_shape, target_column=target_column) if data_type == DataType.MINDRECORD: - return _get_mindrecord_dataset(data_dir, train_mode, epochs, - batch_size, line_per_sample, - rank_size, rank_id) + return _get_mindrecord_dataset(data_dir, train_mode, epochs, batch_size, + line_per_sample, rank_size=rank_size, rank_id=rank_id, + manual_shape=manual_shape, target_column=target_column) if rank_size > 1: raise RuntimeError("please use tfrecord dataset.") diff --git a/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py b/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py index e2350723e7..e873c15ad0 100644 --- a/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py +++ b/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py @@ -143,6 +143,7 @@ class WideDeepModel(nn.Cell): is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) if is_auto_parallel: self.batch_size = self.batch_size * get_group_size() + is_field_slice = config.field_slice self.field_size = config.field_size self.vocab_size = config.vocab_size self.emb_dim = config.emb_dim @@ -196,11 +197,10 @@ class WideDeepModel(nn.Cell): self.tile = P.Tile() self.concat = P.Concat(axis=1) self.cast = P.Cast() - if is_auto_parallel and host_device_mix: + if is_auto_parallel and host_device_mix and not is_field_slice: self.dense_layer_1.dropout.dropout_do_mask.set_strategy(((1, get_group_size()),)) self.dense_layer_1.dropout.dropout.set_strategy(((1, get_group_size()),)) self.dense_layer_1.matmul.set_strategy(((1, get_group_size()), (get_group_size(), 1))) - self.dense_layer_1.matmul.add_prim_attr("field_size", config.field_size) self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, slice_mode=nn.EmbeddingLookUpSplitMode.TABLE_COLUMN_SLICE) self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, @@ -209,9 +209,20 @@ class WideDeepModel(nn.Cell): self.deep_reshape.add_prim_attr("skip_redistribution", True) self.reduce_sum.add_prim_attr("cross_batch", True) self.embedding_table = self.deep_embeddinglookup.embedding_table - elif host_device_mix: - self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim) - self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1) + elif is_auto_parallel and host_device_mix and is_field_slice and config.full_batch and config.manual_shape: + manual_shapes = tuple((s[0] for s in config.manual_shape)) + self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, + slice_mode=nn.EmbeddingLookUpSplitMode.FIELD_SLICE, + manual_shapes=manual_shapes) + self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, + slice_mode=nn.EmbeddingLookUpSplitMode.FIELD_SLICE, + manual_shapes=manual_shapes) + self.deep_mul.set_strategy(((1, get_group_size(), 1), (1, get_group_size(), 1))) + self.wide_mul.set_strategy(((1, get_group_size(), 1), (1, get_group_size(), 1))) + self.reduce_sum.set_strategy(((1, get_group_size(), 1),)) + self.dense_layer_1.dropout.dropout_do_mask.set_strategy(((1, get_group_size()),)) + self.dense_layer_1.dropout.dropout.set_strategy(((1, get_group_size()),)) + self.dense_layer_1.matmul.set_strategy(((1, get_group_size()), (get_group_size(), 1))) self.embedding_table = self.deep_embeddinglookup.embedding_table elif parameter_server: self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim) @@ -263,7 +274,7 @@ class NetWithLossClass(nn.Cell): parameter_server = bool(config.parameter_server) parallel_mode = context.get_auto_parallel_context("parallel_mode") is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) - self.no_l2loss = (is_auto_parallel if host_device_mix else parameter_server) + self.no_l2loss = (is_auto_parallel if (host_device_mix or config.field_slice) else parameter_server) self.network = network self.l2_coef = config.l2_coef self.loss = P.SigmoidCrossEntropyWithLogits() diff --git a/model_zoo/official/recommend/wide_and_deep/train_and_eval_auto_parallel.py b/model_zoo/official/recommend/wide_and_deep/train_and_eval_auto_parallel.py index b40139a1eb..5440873d2c 100644 --- a/model_zoo/official/recommend/wide_and_deep/train_and_eval_auto_parallel.py +++ b/model_zoo/official/recommend/wide_and_deep/train_and_eval_auto_parallel.py @@ -27,12 +27,13 @@ from mindspore.nn.wrap.cell_wrapper import VirtualDatasetCellTriple from src.wide_and_deep import PredictWithSigmoid, TrainStepWrap, NetWithLossClass, WideDeepModel from src.callbacks import LossCallBack, EvalCallBack -from src.datasets import create_dataset, DataType +from src.datasets import create_dataset, DataType, compute_manual_shape from src.metrics import AUCMetric from src.config import WideDeepConfig sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + def get_WideDeep_net(config): """ Get network of wide&deep model. @@ -40,7 +41,8 @@ def get_WideDeep_net(config): WideDeep_net = WideDeepModel(config) loss_net = NetWithLossClass(WideDeep_net, config) loss_net = VirtualDatasetCellTriple(loss_net) - train_net = TrainStepWrap(loss_net, host_device_mix=bool(config.host_device_mix)) + train_net = TrainStepWrap( + loss_net, host_device_mix=bool(config.host_device_mix)) eval_net = PredictWithSigmoid(WideDeep_net) eval_net = VirtualDatasetCellTriple(eval_net) return train_net, eval_net @@ -50,6 +52,7 @@ class ModelBuilder(): """ ModelBuilder """ + def __init__(self): pass @@ -86,10 +89,19 @@ def train_and_eval(config): if config.full_batch: context.set_auto_parallel_context(full_batch=True) de.config.set_seed(1) - ds_train = create_dataset(data_path, train_mode=True, epochs=1, - batch_size=batch_size*get_group_size(), data_type=dataset_type) - ds_eval = create_dataset(data_path, train_mode=False, epochs=1, - batch_size=batch_size*get_group_size(), data_type=dataset_type) + if config.field_slice: + compute_manual_shape(config, get_group_size()) + ds_train = create_dataset(data_path, train_mode=True, epochs=1, + batch_size=batch_size*get_group_size(), data_type=dataset_type, + manual_shape=config.manual_shape, target_column=config.field_size) + ds_eval = create_dataset(data_path, train_mode=False, epochs=1, + batch_size=batch_size*get_group_size(), data_type=dataset_type, + manual_shape=config.manual_shape, target_column=config.field_size) + else: + ds_train = create_dataset(data_path, train_mode=True, epochs=1, + batch_size=batch_size*get_group_size(), data_type=dataset_type) + ds_eval = create_dataset(data_path, train_mode=False, epochs=1, + batch_size=batch_size*get_group_size(), data_type=dataset_type) else: ds_train = create_dataset(data_path, train_mode=True, epochs=1, batch_size=batch_size, rank_id=get_rank(), @@ -106,9 +118,11 @@ def train_and_eval(config): train_net.set_train() auc_metric = AUCMetric() - model = Model(train_net, eval_network=eval_net, metrics={"auc": auc_metric}) + model = Model(train_net, eval_network=eval_net, + metrics={"auc": auc_metric}) - eval_callback = EvalCallBack(model, ds_eval, auc_metric, config, host_device_mix=host_device_mix) + eval_callback = EvalCallBack( + model, ds_eval, auc_metric, config, host_device_mix=host_device_mix) callback = LossCallBack(config=config, per_print_times=20) ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*epochs, @@ -116,16 +130,19 @@ def train_and_eval(config): ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', directory=config.ckpt_path, config=ckptconfig) context.set_auto_parallel_context(strategy_ckpt_save_file=config.stra_ckpt) - callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback] + callback_list = [TimeMonitor( + ds_train.get_dataset_size()), eval_callback, callback] if not host_device_mix: callback_list.append(ckpoint_cb) - model.train(epochs, ds_train, callbacks=callback_list, dataset_sink_mode=(not host_device_mix)) + model.train(epochs, ds_train, callbacks=callback_list, + dataset_sink_mode=(not host_device_mix)) if __name__ == "__main__": wide_deep_config = WideDeepConfig() wide_deep_config.argparse_init() - context.set_context(mode=context.GRAPH_MODE, device_target=wide_deep_config.device_target, save_graphs=True) + context.set_context(mode=context.GRAPH_MODE, + device_target=wide_deep_config.device_target, save_graphs=True) context.set_context(variable_memory_max_size="24GB") context.set_context(enable_sparse=True) set_multi_subgraphs() @@ -134,7 +151,9 @@ if __name__ == "__main__": elif wide_deep_config.device_target == "GPU": init("nccl") if wide_deep_config.host_device_mix == 1: - context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, mirror_mean=True) + context.set_auto_parallel_context( + parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, mirror_mean=True) else: - context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL, mirror_mean=True) + context.set_auto_parallel_context( + parallel_mode=ParallelMode.AUTO_PARALLEL, mirror_mean=True) train_and_eval(wide_deep_config) diff --git a/model_zoo/official/recommend/wide_and_deep/train_and_eval_distribute.py b/model_zoo/official/recommend/wide_and_deep/train_and_eval_distribute.py index 9e70cd1d68..7e99aa72bb 100644 --- a/model_zoo/official/recommend/wide_and_deep/train_and_eval_distribute.py +++ b/model_zoo/official/recommend/wide_and_deep/train_and_eval_distribute.py @@ -101,12 +101,8 @@ def train_and_eval(config): callback = LossCallBack(config=config) ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) - if config.device_target == "Ascend": - ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', - directory=config.ckpt_path, config=ckptconfig) - elif config.device_target == "GPU": - ckpoint_cb = ModelCheckpoint(prefix='widedeep_train_' + str(get_rank()), - directory=config.ckpt_path, config=ckptconfig) + ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', + directory=config.ckpt_path, config=ckptconfig) out = model.eval(ds_eval) print("=====" * 5 + "model.eval() initialized: {}".format(out)) callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback] diff --git a/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server.py b/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server.py index bab19acdc4..5f93e202da 100644 --- a/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server.py +++ b/model_zoo/official/recommend/wide_and_deep/train_and_eval_parameter_server.py @@ -103,14 +103,13 @@ def train_and_eval(config): callback = LossCallBack(config=config) ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size(), keep_checkpoint_max=5) - if config.device_target == "Ascend": - ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', - directory=config.ckpt_path, config=ckptconfig) - elif config.device_target == "GPU": - ckpoint_cb = ModelCheckpoint(prefix='widedeep_train_' + str(get_rank()), - directory=config.ckpt_path, config=ckptconfig) + ckpoint_cb = ModelCheckpoint(prefix='widedeep_train', + directory=config.ckpt_path, config=ckptconfig) + callback_list = [TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback] + if get_rank() == 0: + callback_list.append(ckpoint_cb) model.train(epochs, ds_train, - callbacks=[TimeMonitor(ds_train.get_dataset_size()), eval_callback, callback, ckpoint_cb], + callbacks=callback_list, dataset_sink_mode=(not parameter_server))