From 31819bb4a77bd5f78ddd23aeb31574f864893fcb Mon Sep 17 00:00:00 2001 From: yao_yf Date: Tue, 10 Nov 2020 15:20:09 +0800 Subject: [PATCH] support forward unique --- .../host/dynamic_shape_kernel.cc | 2 +- .../pass/convert_const_input_to_attr.cc | 2 +- .../ccsrc/backend/session/ascend_session.cc | 1 - .../ccsrc/runtime/device/kernel_runtime.cc | 14 ++++++ mindspore/common/parameter.py | 10 +++++ mindspore/core/abstract/prim_arrays.cc | 11 +++++ mindspore/core/abstract/prim_others.cc | 5 +-- mindspore/nn/layer/embedding.py | 45 ++++++++++++++++--- mindspore/nn/optim/optimizer.py | 5 +++ mindspore/ops/_op_impl/tbe/__init__.py | 1 + mindspore/ops/_op_impl/tbe/relu_ds.py | 40 +++++++++++++++++ mindspore/ops/operations/comm_ops.py | 6 ++- mindspore/train/serialization.py | 2 +- .../script/run_auto_parallel_train_cluster.sh | 2 +- .../wide_and_deep/script/start_cluster.sh | 2 +- .../recommend/wide_and_deep/src/callbacks.py | 6 +-- .../recommend/wide_and_deep/src/config.py | 3 ++ .../wide_and_deep/src/wide_and_deep.py | 33 +++++++++----- .../train_and_eval_auto_parallel.py | 8 ++-- .../python_file_for_ci/config.py | 8 ++++ .../ut/python/parallel/test_dynamic_shape.py | 4 +- 21 files changed, 173 insertions(+), 37 deletions(-) create mode 100644 mindspore/ops/_op_impl/tbe/relu_ds.py diff --git a/mindspore/ccsrc/backend/kernel_compiler/host/dynamic_shape_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/host/dynamic_shape_kernel.cc index 580ac9a1e5..4b58de5fbc 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/host/dynamic_shape_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/host/dynamic_shape_kernel.cc @@ -27,7 +27,7 @@ void DynamicShapeKernel::Execute() { } auto prev_output_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, 0); - auto output_shape = std::vector(SizeToLong(prev_output_shape.size())); + std::vector output_shape = {SizeToLong(prev_output_shape.size())}; auto output_type = TypeId::kNumberTypeInt64; diff --git a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc index bd284c1c96..640d8d9801 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/convert_const_input_to_attr.cc @@ -62,7 +62,7 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An continue; } } - if (AnfAlgo::IsNodeDynamicShape(cnode) && + if (AnfAlgo::IsDynamicShape(cnode) && DynamicShapeConstInputToAttr.find(AnfAlgo::GetCNodeName(cnode)) == DynamicShapeConstInputToAttr.end()) { MS_LOG(INFO) << "current node is dynamic shape " << cnode->fullname_with_scope(); continue; diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index f7d3185da0..8a9fd75926 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -541,7 +541,6 @@ void AscendSession::BuildDynamicKernel(const std::shared_ptr &kerne void AscendSession::MemoryAlloc(KernelGraph *kernel_graph) const { MS_LOG(INFO) << "Start!"; MS_EXCEPTION_IF_NULL(kernel_graph); - opt::RemoveNopNode(kernel_graph); auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); MS_EXCEPTION_IF_NULL(runtime_instance); runtime_instance->AssignMemory(kernel_graph); diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index 5732df7f91..2f09dab5bc 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -471,7 +471,21 @@ bool KernelRuntime::KernelMemNotReuse(const AnfNodePtr &node) { return false; } DeviceAddressPtr KernelRuntime::PreAssignCNodeMemory(const AnfNodePtr &anf_node, size_t index) { MS_EXCEPTION_IF_NULL(anf_node); + if (!anf_node->isa()) { + MS_LOG(EXCEPTION) << "anf_node should be a cnode"; + } + auto cnode = anf_node->cast(); + if (opt::IsNopNode(cnode)) { + size_t kNopNodeInputSize = 2; + size_t kNopNodeRealInputIndex = 1; + if (cnode->size() != kNopNodeInputSize) { + MS_LOG(EXCEPTION) << cnode->fullname_with_scope() << " has invalid input size: " << cnode->size(); + } + auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, index); + return PreAssignCNodeMemory(cnode->input(kNopNodeRealInputIndex), input_node_with_index.second); + } auto kernel_mod = AnfAlgo::GetKernelMod(anf_node); + MS_EXCEPTION_IF_NULL(kernel_mod); auto output_sizes = kernel_mod->GetOutputSizeList(); if (output_sizes.size() <= index) { MS_LOG(EXCEPTION) << "Previous node output size < node index"; diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index 76caacba7b..d40a4b4abe 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -126,6 +126,7 @@ class Parameter(MetaTensor_): self.is_param_ps = False self._cast_type = None self.init_in_server = False + self._unique = False self.is_in_parallel = _is_in_parallel_mode() @staticmethod @@ -238,6 +239,15 @@ class Parameter(MetaTensor_): def sliced(self, sliced_): self._sliced = sliced_ + @property + def unique(self): + """whether the parameter is already unique or not.""" + return self._unique + + @unique.setter + def unique(self, unique_): + self._unique = unique_ + @property def is_init(self): """ diff --git a/mindspore/core/abstract/prim_arrays.cc b/mindspore/core/abstract/prim_arrays.cc index 8ddf36ac75..2ccd3345e1 100644 --- a/mindspore/core/abstract/prim_arrays.cc +++ b/mindspore/core/abstract/prim_arrays.cc @@ -433,6 +433,17 @@ AbstractBasePtr InferImplTranspose(const AnalysisEnginePtr &, const PrimitivePtr result_shp.push_back(input_shp[idx]); indices.insert(idx); } + ShapeVector max_shp; + ShapeVector min_shp; + if (input->shape()->max_shape().size() == input_shp.size() && + input->shape()->min_shape().size() == input_shp.size()) { + for (size_t i = 0; i < perm_vec.size(); i++) { + size_t idx = static_cast(perm_vec[i]); + max_shp.push_back(input->shape()->max_shape()[idx]); + min_shp.push_back(input->shape()->min_shape()[idx]); + } + return std::make_shared(input->element(), std::make_shared(result_shp, min_shp, max_shp)); + } return std::make_shared(input->element(), std::make_shared(result_shp)); } diff --git a/mindspore/core/abstract/prim_others.cc b/mindspore/core/abstract/prim_others.cc index 08e0449ef8..b4ec64ad7b 100644 --- a/mindspore/core/abstract/prim_others.cc +++ b/mindspore/core/abstract/prim_others.cc @@ -405,10 +405,9 @@ AbstractBasePtr InferImplAllGather(const AnalysisEnginePtr &, const PrimitivePtr if (tmp_shape.empty()) { MS_LOG(EXCEPTION) << "shape size is 0"; } - if (tmp_shape[0] % rank_size != 0) { - MS_LOG(EXCEPTION) << "first dimension of x should be divided by rank_size"; + if (tmp_shape[0] > 0) { + tmp_shape[0] = tmp_shape[0] * rank_size; } - tmp_shape[0] = tmp_shape[0] / rank_size; return std::make_shared(x->element(), std::make_shared(tmp_shape)); } diff --git a/mindspore/nn/layer/embedding.py b/mindspore/nn/layer/embedding.py index 795cbd0372..e5b6e34fc7 100755 --- a/mindspore/nn/layer/embedding.py +++ b/mindspore/nn/layer/embedding.py @@ -149,6 +149,7 @@ class EmbeddingLookup(Cell): max_norm (Union[float, None]): A maximum clipping value. The data type must be float16, float32 or None. Default: None sparse (bool): Using sparse mode. When 'target' is set to 'CPU', 'sparse' has to be true. Default: True. + Inputs: - **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`. Specifies the indices of elements of the original Tensor. Values can be out of range of embedding_table, @@ -161,6 +162,8 @@ class EmbeddingLookup(Cell): Examples: >>> input_indices = Tensor(np.array([[1, 0], [3, 2]]), mindspore.int32) >>> out = nn.EmbeddingLookup(4,2)(input_indices) + >>> output.shape + (2, 2, 2) """ BATCH_SLICE = "batch_slice" FIELD_SLICE = "field_slice" @@ -188,6 +191,12 @@ class EmbeddingLookup(Cell): name='embedding_table') parallel_mode = _get_parallel_mode() is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) + self.forward_unique = False + self.gather_revert = P.GatherV2() + self.unique = P.Unique().shard(((1,),)) + self.reshape = P.Reshape() + self.shape = P.Shape() + indices_shape_size = 2 if slice_mode == "field_slice" and is_auto_parallel: if not manual_shapes: raise ValueError("in slice field mode, the manual_shapes should not be none") @@ -200,18 +209,32 @@ class EmbeddingLookup(Cell): self.gatherv2.shard(((get_group_size(), 1), (1, get_group_size()))) self.embeddinglookup.shard(((get_group_size(), 1), (1, get_group_size()))) elif slice_mode == "table_row_slice" and is_auto_parallel: - self.gatherv2.shard(((get_group_size(), 1), (1, 1))) - self.embeddinglookup.shard(((get_group_size(), 1), (1, 1))) + if target == 'DEVICE': + indices_shape_size = 1 + self.gather_revert.shard(((1, 1), (1,))) + self.forward_unique = True + indices_strategy = (1,)*indices_shape_size + self.gatherv2.shard(((get_group_size(), 1), indices_strategy)) + self.embeddinglookup.shard(((get_group_size(), 1), indices_strategy)) elif slice_mode == "table_column_slice" and is_auto_parallel: - self.gatherv2.shard(((1, get_group_size()), (1, 1))) - self.embeddinglookup.shard(((1, get_group_size()), (1, 1))) + if target == 'DEVICE': + indices_shape_size = 1 + self.gather_revert.shard(((1, get_group_size()), (1,))) + self.forward_unique = True + indices_strategy = (1,)*indices_shape_size + self.gatherv2.shard(((1, get_group_size()), indices_strategy)) + self.embeddinglookup.shard(((1, get_group_size()), indices_strategy)) elif slice_mode == "batch_slice" and is_auto_parallel: - self.gatherv2.shard(((1, 1), (get_group_size(), 1))) - self.embeddinglookup.shard(((1, 1), (get_group_size(), 1))) + indices_strategy = [get_group_size()] + indices_strategy.extend([1]*(indices_shape_size - 1)) + indices_strategy = tuple(indices_strategy) + self.gatherv2.shard(((1, 1), indices_strategy)) + self.embeddinglookup.shard(((1, 1), indices_strategy)) else: if is_auto_parallel: raise ValueError("slice_mode should support mode in nn.EmbeddingLookup, but get " + str(slice_mode)) + self.embedding_table.unique = self.forward_unique self.max_norm = max_norm if self.max_norm is not None: self.max_norm = validator.check_positive_float(self.max_norm, 'max_norm', self.cls_name) @@ -221,7 +244,15 @@ class EmbeddingLookup(Cell): if self.target == "CPU": out = self.embeddinglookup(self.embedding_table, indices, 0) else: - out = self.gatherv2(self.embedding_table, indices, 0) + if self.forward_unique: + shp = self.shape(indices) + (self.embedding_size,) + indices_flatten = self.reshape(indices, (-1,)) + unique_id, unique_idx = self.unique(indices_flatten) + weight_unique = self.gatherv2(unique_id) + weight_flatten = self.gather_revert(weight_unique, unique_idx, 0) + out = self.reshape(weight_flatten, shp) + else: + out = self.gatherv2(self.embedding_table, indices, 0) if self.max_norm is not None: axis = _make_axis_range(F.rank(indices), F.rank(out)) clip_by_norm = ClipByNorm(axis) diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index 54eceabcaa..c57f647d1b 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -144,6 +144,11 @@ class Optimizer(Cell): decay_filter = lambda x: 'beta' not in x.name and 'gamma' not in x.name self.decay_flags = tuple(decay_filter(x) for x in self.parameters) self.exec_weight_decay = self.weight_decay > 0 + # when a parameter has been unique, there is no need do another unique in optimizer. + for param in self.parameters: + if param.unique: + self._unique = False + break ps_filter = lambda x: x.is_param_ps self.ps_parameters = tuple(ps_filter(x) for x in self.parameters) self.reciprocal_scale = 1.0 / loss_scale diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 108b3ae257..f1415f21cd 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -67,6 +67,7 @@ from .mul_ds import _mul_ds_tbe from .real_div import _real_div_tbe from .real_div_ds import _real_div_ds_tbe from .relu import _relu_tbe +from .relu_ds import _relu_ds_tbe from .relu_grad import _relu_grad_tbe from .relu6 import _relu6_tbe from .relu6_grad import _relu6_grad_tbe diff --git a/mindspore/ops/_op_impl/tbe/relu_ds.py b/mindspore/ops/_op_impl/tbe/relu_ds.py new file mode 100644 index 0000000000..da229c2b5a --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/relu_ds.py @@ -0,0 +1,40 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""ReLU op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +relu_op_info = TBERegOp("ReLU") \ + .fusion_type("ELEMWISE") \ + .async_flag(False) \ + .binfile_name("relu.so") \ + .compute_cost(10) \ + .kernel_name("relu") \ + .partial_flag(True) \ + .dynamic_shape(True) \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .op_pattern("formatAgnostic") \ + .dtype_format(DataType.I8_None, DataType.I8_None) \ + .dtype_format(DataType.I32_None, DataType.I32_None) \ + .dtype_format(DataType.F16_None, DataType.F16_None) \ + .dtype_format(DataType.F32_None, DataType.F32_None) \ + .get_op_info() + + +@op_info_register(relu_op_info) +def _relu_ds_tbe(): + """Relu TBE register""" + return diff --git a/mindspore/ops/operations/comm_ops.py b/mindspore/ops/operations/comm_ops.py index 9ce3a69f0e..edbd72775f 100644 --- a/mindspore/ops/operations/comm_ops.py +++ b/mindspore/ops/operations/comm_ops.py @@ -163,7 +163,8 @@ class AllGather(PrimitiveWithInfer): def infer_shape(self, x_shape): validator.check_positive_int(len(x_shape), "x shape", self.name) - x_shape[0] = x_shape[0] * self.rank_size + if x_shape[0] > 0: + x_shape[0] = x_shape[0] * self.rank_size return x_shape def infer_dtype(self, x_dtype): @@ -213,7 +214,8 @@ class _HostAllGather(PrimitiveWithInfer): def infer_shape(self, x_shape): validator.check_positive_int(len(x_shape), "x shape", self.name) - x_shape[0] = x_shape[0] * self.group_size + if x_shape[0] > 0: + x_shape[0] = x_shape[0] * self.group_size return x_shape def infer_dtype(self, x_dtype): diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 388d71e44f..e1a9ec1a6e 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -720,7 +720,7 @@ def _merge_param_with_strategy(sliced_data, parameter_name, strategy, is_even): if field_size > 0: from mindspore.parallel._tensor import _reshape_param_data_with_weight - merged_tensor = _reshape_param_data_with_weight(all_gather_tensor, dev_mat, [field_size]) + merged_tensor = _reshape_param_data_with_weight(all_gather_tensor, dev_mat, field_size) else: from mindspore.parallel._tensor import _reshape_param_data diff --git a/model_zoo/official/recommend/wide_and_deep/script/run_auto_parallel_train_cluster.sh b/model_zoo/official/recommend/wide_and_deep/script/run_auto_parallel_train_cluster.sh index 5789da7fd7..d7fdab21ff 100644 --- a/model_zoo/official/recommend/wide_and_deep/script/run_auto_parallel_train_cluster.sh +++ b/model_zoo/official/recommend/wide_and_deep/script/run_auto_parallel_train_cluster.sh @@ -43,7 +43,7 @@ do python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --host_device_mix=1 >train_deep$i.log 2>&1 & elif [ $MODE == "field_slice_host_device_mix" ]; then python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --host_device_mix=1 --full_batch=1 --field_slice=1 >train_deep$i.log 2>&1 & - elif [ $MODE == "backward_unique" ]; then + elif [ $MODE == "forward_unique" ]; then python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --sparse=1 >train_deep$i.log 2>&1 & else python -s ${self_path}/../train_and_eval_auto_parallel.py --data_path=$DATASET --epochs=$EPOCH_SIZE --vocab_size=$VOCAB_SIZE --emb_dim=$EMB_DIM --dropout_flag=1 --host_device_mix=0 >train_deep$i.log 2>&1 & diff --git a/model_zoo/official/recommend/wide_and_deep/script/start_cluster.sh b/model_zoo/official/recommend/wide_and_deep/script/start_cluster.sh index 697f81711f..7963aa89bf 100644 --- a/model_zoo/official/recommend/wide_and_deep/script/start_cluster.sh +++ b/model_zoo/official/recommend/wide_and_deep/script/start_cluster.sh @@ -38,7 +38,7 @@ do user=$(get_node_user ${cluster_config_path} ${node}) passwd=$(get_node_passwd ${cluster_config_path} ${node}) echo "------------------${user}@${node}---------------------" - if [ $MODE == "host_device_mix" ] || [ $MODE == "field_slice_host_device_mix" ] || [ $MODE == "backward_unique" ]; then + if [ $MODE == "host_device_mix" ] || [ $MODE == "field_slice_host_device_mix" ] || [ $MODE == "forward_unique" ]; then ssh_pass ${node} ${user} ${passwd} "mkdir -p ${execute_path}; cd ${execute_path}; bash ${SCRIPTPATH}/run_auto_parallel_train_cluster.sh ${RANK_SIZE} ${RANK_START} ${EPOCH_SIZE} ${VOCAB_SIZE} ${EMB_DIM} ${DATASET} ${ENV_SH} ${MODE} ${RANK_TABLE_FILE}" else echo "[ERROR] mode is wrong" diff --git a/model_zoo/official/recommend/wide_and_deep/src/callbacks.py b/model_zoo/official/recommend/wide_and_deep/src/callbacks.py index 2e358a6cda..bebefc9ca0 100644 --- a/model_zoo/official/recommend/wide_and_deep/src/callbacks.py +++ b/model_zoo/official/recommend/wide_and_deep/src/callbacks.py @@ -88,7 +88,7 @@ class EvalCallBack(Callback): Args: print_per_step (int): Print loss every times. Default: 1. """ - def __init__(self, model, eval_dataset, auc_metric, config, print_per_step=1, host_device_mix=False): + def __init__(self, model, eval_dataset, auc_metric, config, print_per_step=1): super(EvalCallBack, self).__init__() if not isinstance(print_per_step, int) or print_per_step < 0: raise ValueError("print_per_step must be int and >= 0.") @@ -99,7 +99,7 @@ class EvalCallBack(Callback): self.aucMetric.clear() self.eval_file_name = config.eval_file_name self.eval_values = [] - self.host_device_mix = host_device_mix + self.sparse = config.sparse self.config = config def epoch_end(self, run_context): @@ -116,7 +116,7 @@ class EvalCallBack(Callback): ParallelMode.DATA_PARALLEL): rank_id = get_rank() start_time = time.time() - out = self.model.eval(self.eval_dataset, dataset_sink_mode=(not self.host_device_mix)) + out = self.model.eval(self.eval_dataset, dataset_sink_mode=(not self.sparse)) end_time = time.time() eval_time = int(end_time - start_time) diff --git a/model_zoo/official/recommend/wide_and_deep/src/config.py b/model_zoo/official/recommend/wide_and_deep/src/config.py index 02e080d734..7834baf91e 100644 --- a/model_zoo/official/recommend/wide_and_deep/src/config.py +++ b/model_zoo/official/recommend/wide_and_deep/src/config.py @@ -48,6 +48,7 @@ def argparse_init(): parser.add_argument("--parameter_server", type=int, default=0, help="Open parameter server of not") parser.add_argument("--field_slice", type=int, default=0, help="Enable split field mode or not") parser.add_argument("--sparse", type=int, default=0, help="Enable sparse or not") + parser.add_argument("--deep_table_slice_mode", type=str, default="column_slice", help="column_slice/row_slice") return parser @@ -86,6 +87,7 @@ class WideDeepConfig(): self.field_slice = False self.manual_shape = None self.sparse = False + self.deep_table_slice_mode = "column_slice" def argparse_init(self): """ @@ -121,5 +123,6 @@ class WideDeepConfig(): self.parameter_server = args.parameter_server self.field_slice = bool(args.field_slice) self.sparse = bool(args.sparse) + self.deep_table_slice_mode = args.deep_table_slice_mode if self.host_device_mix == 1: self.sparse = True diff --git a/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py b/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py index e243961bd2..d179224ead 100644 --- a/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py +++ b/model_zoo/official/recommend/wide_and_deep/src/wide_and_deep.py @@ -198,19 +198,29 @@ class WideDeepModel(nn.Cell): self.tile = P.Tile() self.concat = P.Concat(axis=1) self.cast = P.Cast() + self.unique = P.Unique().shard(((1,),)) + self.wide_gatherv2 = P.GatherV2() + self.deep_gatherv2 = P.GatherV2() if is_auto_parallel and sparse and not is_field_slice: - self.dense_layer_1.dropout.dropout_do_mask.shard(((1, get_group_size()),)) - self.dense_layer_1.dropout.dropout.shard(((1, get_group_size()),)) - self.dense_layer_1.matmul.shard(((1, get_group_size()), (get_group_size(), 1))) target = 'DEVICE' if host_device_mix: target = 'CPU' - self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target=target, - slice_mode=nn.EmbeddingLookup.TABLE_COLUMN_SLICE) self.wide_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, 1, target=target, slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE) - self.deep_mul.shard(((1, 1, get_group_size()), (1, 1, 1))) - self.deep_reshape.add_prim_attr("skip_redistribution", True) + if target == 'DEVICE': + self.wide_mul.shard(((1, 1, 1), (1, 1, 1))) + if config.deep_table_slice_mode == "column_slice": + self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target=target, + slice_mode=nn.EmbeddingLookup.TABLE_COLUMN_SLICE) + self.dense_layer_1.dropout.dropout_do_mask.shard(((1, get_group_size()),)) + self.dense_layer_1.dropout.dropout.shard(((1, get_group_size()),)) + self.dense_layer_1.matmul.shard(((1, get_group_size()), (get_group_size(), 1))) + self.dense_layer_1.matmul.add_prim_attr("field_size", self.field_size) + self.deep_mul.shard(((1, 1, get_group_size()), (1, 1, 1))) + self.deep_reshape.add_prim_attr("skip_redistribution", True) + else: + self.deep_embeddinglookup = nn.EmbeddingLookup(self.vocab_size, self.emb_dim, target=target, + slice_mode=nn.EmbeddingLookup.TABLE_ROW_SLICE) self.reduce_sum.add_prim_attr("cross_batch", True) self.embedding_table = self.deep_embeddinglookup.embedding_table elif is_auto_parallel and host_device_mix and is_field_slice and config.full_batch and config.manual_shape: @@ -247,13 +257,15 @@ class WideDeepModel(nn.Cell): id_hldr: batch ids; wt_hldr: batch weights; """ - mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1)) # Wide layer wide_id_weight = self.wide_embeddinglookup(id_hldr) + # Deep layer + deep_id_embs = self.deep_embeddinglookup(id_hldr) + mask = self.reshape(wt_hldr, (self.batch_size, self.field_size, 1)) + # Wide layer wx = self.wide_mul(wide_id_weight, mask) wide_out = self.reshape(self.reduce_sum(wx, 1) + self.wide_b, (-1, 1)) # Deep layer - deep_id_embs = self.deep_embeddinglookup(id_hldr) vx = self.deep_mul(deep_id_embs, mask) deep_in = self.deep_reshape(vx, (-1, self.field_size * self.emb_dim)) deep_in = self.dense_layer_1(deep_in) @@ -333,7 +345,8 @@ class TrainStepWrap(nn.Cell): parameter_server (Bool): Whether run in parameter server mode. Default: False """ - def __init__(self, network, sens=1024.0, host_device_mix=False, parameter_server=False, sparse=False): + def __init__(self, network, sens=1024.0, host_device_mix=False, parameter_server=False, + sparse=False): super(TrainStepWrap, self).__init__() parallel_mode = context.get_auto_parallel_context("parallel_mode") is_auto_parallel = parallel_mode in (ParallelMode.SEMI_AUTO_PARALLEL, ParallelMode.AUTO_PARALLEL) diff --git a/model_zoo/official/recommend/wide_and_deep/train_and_eval_auto_parallel.py b/model_zoo/official/recommend/wide_and_deep/train_and_eval_auto_parallel.py index f855e76517..6aca38a273 100644 --- a/model_zoo/official/recommend/wide_and_deep/train_and_eval_auto_parallel.py +++ b/model_zoo/official/recommend/wide_and_deep/train_and_eval_auto_parallel.py @@ -40,8 +40,8 @@ def get_WideDeep_net(config): WideDeep_net = WideDeepModel(config) loss_net = NetWithLossClass(WideDeep_net, config) loss_net = VirtualDatasetCellTriple(loss_net) - train_net = TrainStepWrap( - loss_net, host_device_mix=bool(config.host_device_mix), sparse=config.sparse) + train_net = TrainStepWrap(loss_net, host_device_mix=bool(config.host_device_mix), + sparse=config.sparse) eval_net = PredictWithSigmoid(WideDeep_net) eval_net = VirtualDatasetCellTriple(eval_net) return train_net, eval_net @@ -122,7 +122,7 @@ def train_and_eval(config): metrics={"auc": auc_metric}) eval_callback = EvalCallBack( - model, ds_eval, auc_metric, config, host_device_mix=host_device_mix) + model, ds_eval, auc_metric, config) callback = LossCallBack(config=config, per_print_times=20) ckptconfig = CheckpointConfig(save_checkpoint_steps=ds_train.get_dataset_size()*epochs, @@ -146,7 +146,7 @@ if __name__ == "__main__": context.set_context(variable_memory_max_size="24GB") context.set_context(enable_sparse=True) init() - if wide_deep_config.host_device_mix == 1: + if wide_deep_config.sparse: context.set_auto_parallel_context( parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=True) else: diff --git a/tests/st/model_zoo_tests/wide_and_deep/python_file_for_ci/config.py b/tests/st/model_zoo_tests/wide_and_deep/python_file_for_ci/config.py index 4a2ed56e53..ecf7c2d621 100644 --- a/tests/st/model_zoo_tests/wide_and_deep/python_file_for_ci/config.py +++ b/tests/st/model_zoo_tests/wide_and_deep/python_file_for_ci/config.py @@ -37,6 +37,8 @@ def argparse_init(): parser.add_argument("--ckpt_path", type=str, default="./checkpoints/") parser.add_argument("--eval_file_name", type=str, default="eval.log") parser.add_argument("--loss_file_name", type=str, default="loss.log") + parser.add_argument("--sparse", type=int, default=0, help="Enable sparse or not") + parser.add_argument("--deep_table_slice_mode", type=str, default="column_slice", help="column_slice/row_slice") return parser @@ -66,6 +68,8 @@ class WideDeepConfig(): self.loss_file_name = "loss.log" self.ckpt_path = "./checkpoints/" self.stra_ckpt = "./strategy_train.ckpt" + self.sparse = False + self.deep_table_slice_mode = "column_slice" def argparse_init(self): """ @@ -94,3 +98,7 @@ class WideDeepConfig(): self.loss_file_name = args.loss_file_name self.ckpt_path = args.ckpt_path self.stra_ckpt = args.stra_ckpt + self.sparse = bool(args.sparse) + self.deep_table_slice_mode = args.deep_table_slice_mode + if self.host_device_mix == 1: + self.sparse = True diff --git a/tests/ut/python/parallel/test_dynamic_shape.py b/tests/ut/python/parallel/test_dynamic_shape.py index 86a52873b7..052ad6a1ad 100644 --- a/tests/ut/python/parallel/test_dynamic_shape.py +++ b/tests/ut/python/parallel/test_dynamic_shape.py @@ -93,7 +93,7 @@ def test_unique_row_split(): self.embedding_lookp = P.GatherV2().shard(((8, 1), (1,))) self.embedding_table = Parameter(initializer('normal', [2000, 128]), name='embedding_table') - self.gatherv2 = P.GatherV2().shard(((1, 1), (8,))) + self.gatherv2 = P.GatherV2().shard(((1, 1), (1,))) self.reshape = P.Reshape() self.matmul = P.MatMul() self.mul_weight = Parameter(Tensor(np.full([32, 64, 1], 0.5, dtype=np.float32)), name="mul_weight") @@ -108,7 +108,7 @@ def test_unique_row_split(): return vx size = 8 - context.set_auto_parallel_context(device_num=size, global_rank=0, parallel_mode="stand_alone") + context.set_auto_parallel_context(device_num=size, global_rank=0, parallel_mode="semi_auto_parallel") x = Tensor(np.ones([32, 64]), dtype=ms.int32) net = Net() optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)