|
|
|
@ -13,33 +13,33 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
# ============================================================================
|
|
|
|
|
|
|
|
|
|
"""LambNextMVWithDecayV1 op"""
|
|
|
|
|
"""LambNextMVWithDecay op"""
|
|
|
|
|
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
|
|
|
|
|
|
|
|
|
lamb_next_m_v_with_decay_v1_op_info = TBERegOp("LambNextMVWithDecayV1") \
|
|
|
|
|
lamb_next_m_v_with_decay_op_info = TBERegOp("LambNextMVWithDecay") \
|
|
|
|
|
.fusion_type("OPAQUE") \
|
|
|
|
|
.async_flag(False) \
|
|
|
|
|
.binfile_name("lamb_next_m_v_with_decay_v1.so") \
|
|
|
|
|
.binfile_name("lamb_next_m_v_with_decay.so") \
|
|
|
|
|
.compute_cost(10) \
|
|
|
|
|
.kernel_name("lamb_next_m_v_with_decay_v1") \
|
|
|
|
|
.kernel_name("lamb_next_m_v_with_decay") \
|
|
|
|
|
.partial_flag(True) \
|
|
|
|
|
.input(0, "input1", False, "required", "all") \
|
|
|
|
|
.input(1, "input2", False, "required", "all") \
|
|
|
|
|
.input(2, "input3", False, "required", "all") \
|
|
|
|
|
.input(3, "input4", False, "required", "all") \
|
|
|
|
|
.input(4, "input5", False, "required", "all") \
|
|
|
|
|
.input(5, "input6", False, "required", "all") \
|
|
|
|
|
.input(6, "input7", False, "required", "all") \
|
|
|
|
|
.input(7, "input8", False, "required", "all") \
|
|
|
|
|
.input(8, "input9", False, "required", "all") \
|
|
|
|
|
.input(9, "inputx0", False, "required", "all") \
|
|
|
|
|
.input(10, "inputx1", False, "required", "all") \
|
|
|
|
|
.input(11, "inputx2", False, "required", "all") \
|
|
|
|
|
.input(12, "inputx3", False, "required", "all") \
|
|
|
|
|
.output(0, "output1", False, "required", "all") \
|
|
|
|
|
.output(1, "output2", False, "required", "all") \
|
|
|
|
|
.output(2, "output3", False, "required", "all") \
|
|
|
|
|
.output(3, "output4", False, "required", "all") \
|
|
|
|
|
.input(0, "input_mul3", False, "required", "all") \
|
|
|
|
|
.input(1, "input_mul2", False, "required", "all") \
|
|
|
|
|
.input(2, "input_realdiv1", False, "required", "all") \
|
|
|
|
|
.input(3, "input_mul1", False, "required", "all") \
|
|
|
|
|
.input(4, "input_mul0", False, "required", "all") \
|
|
|
|
|
.input(5, "input_realdiv0", False, "required", "all") \
|
|
|
|
|
.input(6, "input_mul4", False, "required", "all") \
|
|
|
|
|
.input(7, "mul0_x", False, "required", "all") \
|
|
|
|
|
.input(8, "mul1_sub", False, "required", "all") \
|
|
|
|
|
.input(9, "mul2_x", False, "required", "all") \
|
|
|
|
|
.input(10, "mul3_sub1", False, "required", "all") \
|
|
|
|
|
.input(11, "mul4_x", False, "required", "all") \
|
|
|
|
|
.input(12, "add2_y", False, "required", "all") \
|
|
|
|
|
.output(0, "y1", True, "required", "all") \
|
|
|
|
|
.output(1, "y2", True, "required", "all") \
|
|
|
|
|
.output(2, "y3", True, "required", "all") \
|
|
|
|
|
.output(3, "y4", True, "required", "all") \
|
|
|
|
|
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
|
|
|
|
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
|
|
|
|
DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
|
|
|
|
@ -53,7 +53,7 @@ lamb_next_m_v_with_decay_v1_op_info = TBERegOp("LambNextMVWithDecayV1") \
|
|
|
|
|
.get_op_info()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@op_info_register(lamb_next_m_v_with_decay_v1_op_info)
|
|
|
|
|
def _lamb_next_mv_with_decay_v1_tbe():
|
|
|
|
|
"""LambNextMVWithDecayV1 TBE register"""
|
|
|
|
|
@op_info_register(lamb_next_m_v_with_decay_op_info)
|
|
|
|
|
def _lamb_next_mv_with_decay_tbe():
|
|
|
|
|
"""LambNextMVWithDecay TBE register"""
|
|
|
|
|
return
|