diff --git a/mindspore/ops/_op_impl/tbe/assign.py b/mindspore/ops/_op_impl/tbe/assign.py index 2fbd152c78..ff673a03c4 100644 --- a/mindspore/ops/_op_impl/tbe/assign.py +++ b/mindspore/ops/_op_impl/tbe/assign.py @@ -23,31 +23,53 @@ assign_op_info = TBERegOp("Assign") \ .compute_cost(10) \ .kernel_name("assign") \ .partial_flag(True) \ - .input(0, "resource", False, "required", "all") \ + .input(0, "ref", False, "required", "all") \ .input(1, "value", False, "required", "all") \ - .output(0, "y", False, "required", "all") \ - .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ + .output(0, "ref", False, "required", "all") \ .dtype_format(DataType.BOOL_Default, DataType.BOOL_Default, DataType.BOOL_Default) \ + .dtype_format(DataType.BOOL_5HD, DataType.BOOL_5HD, DataType.BOOL_5HD) \ + .dtype_format(DataType.BOOL_C1HWNCoC0, DataType.BOOL_C1HWNCoC0, DataType.BOOL_C1HWNCoC0) \ + .dtype_format(DataType.BOOL_FracZ, DataType.BOOL_FracZ, DataType.BOOL_FracZ) \ + .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ .dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \ + .dtype_format(DataType.I8_C1HWNCoC0, DataType.I8_C1HWNCoC0, DataType.I8_C1HWNCoC0) \ + .dtype_format(DataType.I8_FracZ, DataType.I8_FracZ, DataType.I8_FracZ) \ .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ .dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \ + .dtype_format(DataType.U8_C1HWNCoC0, DataType.U8_C1HWNCoC0, DataType.U8_C1HWNCoC0) \ + .dtype_format(DataType.U8_FracZ, DataType.U8_FracZ, DataType.U8_FracZ) \ .dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I16_Default) \ .dtype_format(DataType.I16_5HD, DataType.I16_5HD, DataType.I16_5HD) \ + .dtype_format(DataType.I16_C1HWNCoC0, DataType.I16_C1HWNCoC0, DataType.I16_C1HWNCoC0) \ + .dtype_format(DataType.I16_FracZ, DataType.I16_FracZ, DataType.I16_FracZ) \ .dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.U16_Default) \ .dtype_format(DataType.U16_5HD, DataType.U16_5HD, DataType.U16_5HD) \ + .dtype_format(DataType.U16_C1HWNCoC0, DataType.U16_C1HWNCoC0, DataType.U16_C1HWNCoC0) \ + .dtype_format(DataType.U16_FracZ, DataType.U16_FracZ, DataType.U16_FracZ) \ .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ + .dtype_format(DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0) \ + .dtype_format(DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ) \ .dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \ .dtype_format(DataType.U32_5HD, DataType.U32_5HD, DataType.U32_5HD) \ + .dtype_format(DataType.U32_C1HWNCoC0, DataType.U32_C1HWNCoC0, DataType.U32_C1HWNCoC0) \ + .dtype_format(DataType.U32_FracZ, DataType.U32_FracZ, DataType.U32_FracZ) \ .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ .dtype_format(DataType.I64_5HD, DataType.I64_5HD, DataType.I64_5HD) \ + .dtype_format(DataType.I64_C1HWNCoC0, DataType.I64_C1HWNCoC0, DataType.I64_C1HWNCoC0) \ + .dtype_format(DataType.I64_FracZ, DataType.I64_FracZ, DataType.I64_FracZ) \ .dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \ .dtype_format(DataType.U64_5HD, DataType.U64_5HD, DataType.U64_5HD) \ + .dtype_format(DataType.U64_C1HWNCoC0, DataType.U64_C1HWNCoC0, DataType.U64_C1HWNCoC0) \ + .dtype_format(DataType.U64_FracZ, DataType.U64_FracZ, DataType.U64_FracZ) \ .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ + .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \ .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ - .dtype_format(DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \ + .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ + .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \ .get_op_info() diff --git a/mindspore/ops/_op_impl/tbe/assign_add.py b/mindspore/ops/_op_impl/tbe/assign_add.py index 2b20a7781d..7ad23ff3bc 100644 --- a/mindspore/ops/_op_impl/tbe/assign_add.py +++ b/mindspore/ops/_op_impl/tbe/assign_add.py @@ -28,16 +28,28 @@ assign_add_op_info = TBERegOp("AssignAdd") \ .output(0, "ref", False, "required", "all") \ .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ .dtype_format(DataType.I8_5HD, DataType.I8_5HD, DataType.I8_5HD) \ + .dtype_format(DataType.I8_C1HWNCoC0, DataType.I8_C1HWNCoC0, DataType.I8_C1HWNCoC0) \ + .dtype_format(DataType.I8_FracZ, DataType.I8_FracZ, DataType.I8_FracZ) \ .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ .dtype_format(DataType.U8_5HD, DataType.U8_5HD, DataType.U8_5HD) \ + .dtype_format(DataType.U8_C1HWNCoC0, DataType.U8_C1HWNCoC0, DataType.U8_C1HWNCoC0) \ + .dtype_format(DataType.U8_FracZ, DataType.U8_FracZ, DataType.U8_FracZ) \ .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ .dtype_format(DataType.I32_5HD, DataType.I32_5HD, DataType.I32_5HD) \ + .dtype_format(DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0, DataType.I32_C1HWNCoC0) \ + .dtype_format(DataType.I32_FracZ, DataType.I32_FracZ, DataType.I32_FracZ) \ .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ .dtype_format(DataType.I64_5HD, DataType.I64_5HD, DataType.I64_5HD) \ + .dtype_format(DataType.I64_C1HWNCoC0, DataType.I64_C1HWNCoC0, DataType.I64_C1HWNCoC0) \ + .dtype_format(DataType.I64_FracZ, DataType.I64_FracZ, DataType.I64_FracZ) \ .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ + .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \ .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ + .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \ .get_op_info() diff --git a/mindspore/ops/_op_impl/tbe/relu6.py b/mindspore/ops/_op_impl/tbe/relu6.py index bbedfdeb0f..d9bd7f9f8e 100644 --- a/mindspore/ops/_op_impl/tbe/relu6.py +++ b/mindspore/ops/_op_impl/tbe/relu6.py @@ -23,8 +23,8 @@ relu6_op_info = TBERegOp("ReLU6") \ .compute_cost(10) \ .kernel_name("relu6") \ .partial_flag(True) \ - .input(0, "features", False, "required", "all") \ - .output(0, "activations", False, "required", "all") \ + .input(0, "x", False, "required", "all") \ + .output(0, "y", False, "required", "all") \ .dtype_format(DataType.F16_Default, DataType.F16_Default) \ .dtype_format(DataType.F16_5HD, DataType.F16_5HD) \ .dtype_format(DataType.F32_Default, DataType.F32_Default) \