diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index da5ac85859..7f15ad62c5 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -104,6 +104,7 @@ from .tanh_grad import _tanh_grad_tbe from .softmax import _softmax_tbe from .softplus import _softplus_tbe from .softplus_grad import _softplus_grad_tbe +from .softmax_grad_ext import _softmax_grad_ext_tbe from .square import _square_tbe from .sqrt import _sqrt_tbe from .transpose_d import _transpose_d_tbe diff --git a/mindspore/ops/_op_impl/tbe/softmax_grad_ext.py b/mindspore/ops/_op_impl/tbe/softmax_grad_ext.py new file mode 100644 index 0000000000..51060d717b --- /dev/null +++ b/mindspore/ops/_op_impl/tbe/softmax_grad_ext.py @@ -0,0 +1,47 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SoftmaxGradExt op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +softmax_grad_ext_op_info = TBERegOp("SoftmaxGradExt") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("softmax_grad_ext.so") \ + .compute_cost(10) \ + .kernel_name("softmax_grad_ext") \ + .partial_flag(True) \ + .dynamic_format(True) \ + .attr("axes", "required", "listInt", "all") \ + .attr("keep_dims", "required", "bool", "all") \ + .input(0, "grad", False, "required", "all") \ + .input(1, "x1", False, "required", "all") \ + .input(2, "x2", False, "required", "all") \ + .output(0, "y", True, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, + DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD, + DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, + DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, + DataType.F32_5HD, DataType.F32_5HD) \ + .get_op_info() + + +@op_info_register(softmax_grad_ext_op_info) +def _softmax_grad_ext_tbe(): + """SoftmaxGradExt TBE register""" + return