From 2a284a0575e948b4ed9b4890dac6469042a0f9ba Mon Sep 17 00:00:00 2001 From: hwjiaorui Date: Thu, 29 Oct 2020 22:08:46 +0800 Subject: [PATCH] register gatherv2 --- .../executor/tiling/op_tiling_calculater.cc | 1 + mindspore/core/abstract/infer_functions.h | 2 - mindspore/ops/_op_impl/tbe/__init__.py | 1 + .../ops/_op_impl/tbe/sparse_gather_v2_ds.py | 57 +++++++++++++++++++ 4 files changed, 59 insertions(+), 2 deletions(-) create mode 100644 mindspore/ops/_op_impl/tbe/sparse_gather_v2_ds.py diff --git a/mindspore/ccsrc/runtime/device/ascend/executor/tiling/op_tiling_calculater.cc b/mindspore/ccsrc/runtime/device/ascend/executor/tiling/op_tiling_calculater.cc index 4a3eb62677..d04d6fce58 100644 --- a/mindspore/ccsrc/runtime/device/ascend/executor/tiling/op_tiling_calculater.cc +++ b/mindspore/ccsrc/runtime/device/ascend/executor/tiling/op_tiling_calculater.cc @@ -143,6 +143,7 @@ std::string GetRealOpType(const std::string &op_type) { static const std::map kOpTypeMap = { {"SparseApplyFtrl", "SparseApplyFtrlD"}, {"SparseApplyProximalAdagrad", "SparseApplyProximalAdagradD"}, + {"SparseGatherV2", "GatherV2"}, }; auto iter = kOpTypeMap.find(op_type); if (iter == kOpTypeMap.end()) { diff --git a/mindspore/core/abstract/infer_functions.h b/mindspore/core/abstract/infer_functions.h index 13f6a370d8..df07fd56b8 100644 --- a/mindspore/core/abstract/infer_functions.h +++ b/mindspore/core/abstract/infer_functions.h @@ -88,8 +88,6 @@ AbstractBasePtr InferImplTensorAdd(const AnalysisEnginePtr &, const PrimitivePtr AbstractBasePtr InferImplSquare(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); -AbstractBasePtr InferImplSqrt(const AnalysisEnginePtr &, const PrimitivePtr &primitive, - const AbstractBasePtrList &args_spec_list); AbstractBasePtr InferImplSqrt(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index 3fa5d0b896..7885ee8bfb 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -296,6 +296,7 @@ from .fused_mul_add_n_l2loss import _fused_mul_add_n_l2loss_tbe from .fused_mul_apply_momentum_extern import _fused_mul_apply_momentum_extern_tbe from .lamb_next_right import _lamb_next_right_tbe from .sparse_gather_v2 import _sparse_gather_v2_tbe +from .sparse_gather_v2_ds import _sparse_gather_v2_ds_tbe from .data_format_dim_map import _data_format_dim_map_tbe from .histogram_fixed_width import _histogram_fixed_width_tbe from .tensor_scatter_update import _tensor_scatter_update_tbe diff --git a/mindspore/ops/_op_impl/tbe/sparse_gather_v2_ds.py b/mindspore/ops/_op_impl/tbe/sparse_gather_v2_ds.py new file mode 100644 index 0000000000..a9ace0f0a5 --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/sparse_gather_v2_ds.py @@ -0,0 +1,57 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SparseGatherV2 op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +sparse_gather_v2_ds_op_info = TBERegOp("SparseGatherV2") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("gather_v2.so") \ + .compute_cost(10) \ + .kernel_name("gather_v2") \ + .partial_flag(True) \ + .dynamic_shape(True) \ + .input(0, "x", False, "required", "all") \ + .input(1, "indices", False, "required", "all") \ + .input(2, "axis", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.I32_Default, DataType.I32_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I32_Default, DataType.I32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I8_Default, DataType.I32_Default, DataType.I32_Default, DataType.I8_Default) \ + .dtype_format(DataType.U8_Default, DataType.I32_Default, DataType.I32_Default, DataType.U8_Default) \ + .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.U32_Default, DataType.I32_Default, DataType.I32_Default, DataType.U32_Default) \ + .dtype_format(DataType.I16_Default, DataType.I32_Default, DataType.I32_Default, DataType.I16_Default) \ + .dtype_format(DataType.U16_Default, DataType.I32_Default, DataType.I32_Default, DataType.U16_Default) \ + .dtype_format(DataType.I64_Default, DataType.I32_Default, DataType.I32_Default, DataType.I64_Default) \ + .dtype_format(DataType.U64_Default, DataType.I32_Default, DataType.I32_Default, DataType.U64_Default) \ + .dtype_format(DataType.F16_Default, DataType.I64_Default, DataType.I32_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.I64_Default, DataType.I32_Default, DataType.F32_Default) \ + .dtype_format(DataType.I8_Default, DataType.I64_Default, DataType.I32_Default, DataType.I8_Default) \ + .dtype_format(DataType.U8_Default, DataType.I64_Default, DataType.I32_Default, DataType.U8_Default) \ + .dtype_format(DataType.I32_Default, DataType.I64_Default, DataType.I32_Default, DataType.I32_Default) \ + .dtype_format(DataType.U32_Default, DataType.I64_Default, DataType.I32_Default, DataType.U32_Default) \ + .dtype_format(DataType.I16_Default, DataType.I64_Default, DataType.I32_Default, DataType.I16_Default) \ + .dtype_format(DataType.U16_Default, DataType.I64_Default, DataType.I32_Default, DataType.U16_Default) \ + .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I32_Default, DataType.I64_Default) \ + .dtype_format(DataType.U64_Default, DataType.I64_Default, DataType.I32_Default, DataType.U64_Default) \ + .get_op_info() + + +@op_info_register(sparse_gather_v2_ds_op_info) +def _sparse_gather_v2_ds_tbe(): + """SparseGatherV2 TBE register""" + return