diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc index 11d2185c32..4064955526 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc @@ -69,7 +69,7 @@ static std::map tbe_func_adapter_map = { {"reduce_sum", "reduce_sum_d"}, {"one_hot", "one_hot_d"}, {"sum", "reduce_sum_d"}, - {"lamb_next_mv_with_decay_v1", "lamb_next_m_v_with_decay_v1"}, + {"lamb_next_mv_with_decay_v1", "lamb_next_m_v_with_decay"}, {"lamb_next_mv", "lamb_next_m_v"}, {"split", "split_d"}, {"resize_nearest_neighbor", "resize_nearest_neighbor_v2_d"}, diff --git a/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/ops/_op_impl/tbe/__init__.py index cdacb22a28..14fdaeac48 100644 --- a/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/ops/_op_impl/tbe/__init__.py @@ -129,7 +129,7 @@ from .confusion_transpose_d import _confusion_transpose_d_tbe from .confusion_softmax_grad import _confusion_softmax_grad_tbe from .lamb_update_with_lr_v2 import _lamb_update_with_lr_v2_tbe from .lamb_next_mv import _lamb_next_mv_tbe -from .lamb_next_mv_with_decay_v1 import _lamb_next_mv_with_decay_v1_tbe +from .lamb_next_mv_with_decay import _lamb_next_mv_with_decay_tbe from .lamb_update_with_lr import _lamb_update_with_lr_tbe from .rsqrt import _rsqrt_tbe from .sigmoid import _sigmoid_tbe diff --git a/mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay_v1.py b/mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay.py similarity index 59% rename from mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay_v1.py rename to mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay.py index aa135e5afe..380845d3e4 100644 --- a/mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay_v1.py +++ b/mindspore/ops/_op_impl/tbe/lamb_next_mv_with_decay.py @@ -13,33 +13,33 @@ # limitations under the License. # ============================================================================ -"""LambNextMVWithDecayV1 op""" +"""LambNextMVWithDecay op""" from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType -lamb_next_m_v_with_decay_v1_op_info = TBERegOp("LambNextMVWithDecayV1") \ +lamb_next_m_v_with_decay_op_info = TBERegOp("LambNextMVWithDecay") \ .fusion_type("OPAQUE") \ .async_flag(False) \ - .binfile_name("lamb_next_m_v_with_decay_v1.so") \ + .binfile_name("lamb_next_m_v_with_decay.so") \ .compute_cost(10) \ - .kernel_name("lamb_next_m_v_with_decay_v1") \ + .kernel_name("lamb_next_m_v_with_decay") \ .partial_flag(True) \ - .input(0, "input1", False, "required", "all") \ - .input(1, "input2", False, "required", "all") \ - .input(2, "input3", False, "required", "all") \ - .input(3, "input4", False, "required", "all") \ - .input(4, "input5", False, "required", "all") \ - .input(5, "input6", False, "required", "all") \ - .input(6, "input7", False, "required", "all") \ - .input(7, "input8", False, "required", "all") \ - .input(8, "input9", False, "required", "all") \ - .input(9, "inputx0", False, "required", "all") \ - .input(10, "inputx1", False, "required", "all") \ - .input(11, "inputx2", False, "required", "all") \ - .input(12, "inputx3", False, "required", "all") \ - .output(0, "output1", False, "required", "all") \ - .output(1, "output2", False, "required", "all") \ - .output(2, "output3", False, "required", "all") \ - .output(3, "output4", False, "required", "all") \ + .input(0, "input_mul3", False, "required", "all") \ + .input(1, "input_mul2", False, "required", "all") \ + .input(2, "input_realdiv1", False, "required", "all") \ + .input(3, "input_mul1", False, "required", "all") \ + .input(4, "input_mul0", False, "required", "all") \ + .input(5, "input_realdiv0", False, "required", "all") \ + .input(6, "input_mul4", False, "required", "all") \ + .input(7, "mul0_x", False, "required", "all") \ + .input(8, "mul1_sub", False, "required", "all") \ + .input(9, "mul2_x", False, "required", "all") \ + .input(10, "mul3_sub1", False, "required", "all") \ + .input(11, "mul4_x", False, "required", "all") \ + .input(12, "add2_y", False, "required", "all") \ + .output(0, "y1", True, "required", "all") \ + .output(1, "y2", True, "required", "all") \ + .output(2, "y3", True, "required", "all") \ + .output(3, "y4", True, "required", "all") \ .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, @@ -53,7 +53,7 @@ lamb_next_m_v_with_decay_v1_op_info = TBERegOp("LambNextMVWithDecayV1") \ .get_op_info() -@op_info_register(lamb_next_m_v_with_decay_v1_op_info) -def _lamb_next_mv_with_decay_v1_tbe(): - """LambNextMVWithDecayV1 TBE register""" +@op_info_register(lamb_next_m_v_with_decay_op_info) +def _lamb_next_mv_with_decay_tbe(): + """LambNextMVWithDecay TBE register""" return