adapt LambNextMVWithDecay

pull/976/head
zhaozhenlong 5 years ago
parent 8e06c2fe93
commit ef5f7306d3

@ -69,7 +69,7 @@ static std::map<string, string> tbe_func_adapter_map = {
{"reduce_sum", "reduce_sum_d"},
{"one_hot", "one_hot_d"},
{"sum", "reduce_sum_d"},
{"lamb_next_mv_with_decay_v1", "lamb_next_m_v_with_decay_v1"},
{"lamb_next_mv_with_decay_v1", "lamb_next_m_v_with_decay"},
{"lamb_next_mv", "lamb_next_m_v"},
{"split", "split_d"},
{"resize_nearest_neighbor", "resize_nearest_neighbor_v2_d"},

@ -129,7 +129,7 @@ from .confusion_transpose_d import _confusion_transpose_d_tbe
from .confusion_softmax_grad import _confusion_softmax_grad_tbe
from .lamb_update_with_lr_v2 import _lamb_update_with_lr_v2_tbe
from .lamb_next_mv import _lamb_next_mv_tbe
from .lamb_next_mv_with_decay_v1 import _lamb_next_mv_with_decay_v1_tbe
from .lamb_next_mv_with_decay import _lamb_next_mv_with_decay_tbe
from .lamb_update_with_lr import _lamb_update_with_lr_tbe
from .rsqrt import _rsqrt_tbe
from .sigmoid import _sigmoid_tbe

@ -13,33 +13,33 @@
# limitations under the License.
# ============================================================================
"""LambNextMVWithDecayV1 op"""
"""LambNextMVWithDecay op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
lamb_next_m_v_with_decay_v1_op_info = TBERegOp("LambNextMVWithDecayV1") \
lamb_next_m_v_with_decay_op_info = TBERegOp("LambNextMVWithDecay") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("lamb_next_m_v_with_decay_v1.so") \
.binfile_name("lamb_next_m_v_with_decay.so") \
.compute_cost(10) \
.kernel_name("lamb_next_m_v_with_decay_v1") \
.kernel_name("lamb_next_m_v_with_decay") \
.partial_flag(True) \
.input(0, "input1", False, "required", "all") \
.input(1, "input2", False, "required", "all") \
.input(2, "input3", False, "required", "all") \
.input(3, "input4", False, "required", "all") \
.input(4, "input5", False, "required", "all") \
.input(5, "input6", False, "required", "all") \
.input(6, "input7", False, "required", "all") \
.input(7, "input8", False, "required", "all") \
.input(8, "input9", False, "required", "all") \
.input(9, "inputx0", False, "required", "all") \
.input(10, "inputx1", False, "required", "all") \
.input(11, "inputx2", False, "required", "all") \
.input(12, "inputx3", False, "required", "all") \
.output(0, "output1", False, "required", "all") \
.output(1, "output2", False, "required", "all") \
.output(2, "output3", False, "required", "all") \
.output(3, "output4", False, "required", "all") \
.input(0, "input_mul3", False, "required", "all") \
.input(1, "input_mul2", False, "required", "all") \
.input(2, "input_realdiv1", False, "required", "all") \
.input(3, "input_mul1", False, "required", "all") \
.input(4, "input_mul0", False, "required", "all") \
.input(5, "input_realdiv0", False, "required", "all") \
.input(6, "input_mul4", False, "required", "all") \
.input(7, "mul0_x", False, "required", "all") \
.input(8, "mul1_sub", False, "required", "all") \
.input(9, "mul2_x", False, "required", "all") \
.input(10, "mul3_sub1", False, "required", "all") \
.input(11, "mul4_x", False, "required", "all") \
.input(12, "add2_y", False, "required", "all") \
.output(0, "y1", True, "required", "all") \
.output(1, "y2", True, "required", "all") \
.output(2, "y3", True, "required", "all") \
.output(3, "y4", True, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
@ -53,7 +53,7 @@ lamb_next_m_v_with_decay_v1_op_info = TBERegOp("LambNextMVWithDecayV1") \
.get_op_info()
@op_info_register(lamb_next_m_v_with_decay_v1_op_info)
def _lamb_next_mv_with_decay_v1_tbe():
"""LambNextMVWithDecayV1 TBE register"""
@op_info_register(lamb_next_m_v_with_decay_op_info)
def _lamb_next_mv_with_decay_tbe():
"""LambNextMVWithDecay TBE register"""
return
Loading…
Cancel
Save