diff --git a/tools/check_api_approvals.sh b/tools/check_api_approvals.sh index b0237b3f50..1276fd9402 100644 --- a/tools/check_api_approvals.sh +++ b/tools/check_api_approvals.sh @@ -72,7 +72,7 @@ fi op_type_spec_diff=`python ${PADDLE_ROOT}/tools/check_op_register_type.py ${PADDLE_ROOT}/paddle/fluid/OP_TYPE_DEV.spec ${PADDLE_ROOT}/paddle/fluid/OP_TYPE_PR.spec` if [ "$op_type_spec_diff" != "" ]; then - echo_line="More data_type of new operator should be regitered in your PR. Please make sure that both float/double (or int/int64_t) have been regitered. You must have one RD (Aurelius84 or liym27 or zhhsplendid)approval for the data_type registration of new operator.\n" + echo_line="You must have one RD (Aurelius84 (Recommend) or liym27 or zhhsplendid)approval for the data_type registration of new operator. More data_type of new operator should be registered in your PR. Please make sure that both float/double (or int/int64_t) have been registered.\n For more details, please click [https://github.com/PaddlePaddle/Paddle/wiki/Data-types-of-generic-Op-must-be-fully-registered].\n" check_approval 1 9301846 33742067 7913861 fi diff --git a/tools/check_op_register_type.py b/tools/check_op_register_type.py index 8f45838bed..b32eff0573 100644 --- a/tools/check_op_register_type.py +++ b/tools/check_op_register_type.py @@ -25,6 +25,9 @@ import difflib import collections import paddle.fluid as fluid +INTS = set(['int', 'int64_t']) +FLOATS = set(['float', 'double']) + def get_all_kernels(): all_kernels_info = fluid.core._get_all_register_op_kernels() @@ -54,8 +57,15 @@ def read_file(file_path): return content -INTS = set(['int', 'int64_t']) -FLOATS = set(['float', 'double']) +def print_diff(op_type, register_types): + lack_types = set() + if len(INTS - register_types) == 1: + lack_types |= INTS - register_types + if len(FLOATS - register_types) == 1: + lack_types |= FLOATS - register_types + + print("{} only supports [{}] now, but lacks [{}].".format(op_type, " ".join( + register_types), " ".join(lack_types))) def check_add_op_valid(): @@ -72,7 +82,7 @@ def check_add_op_valid(): register_types = set(op_info[1:]) if len(FLOATS - register_types) == 1 or \ len(INTS - register_types) == 1: - print(each_diff) + print_diff(op_info[0], register_types) if len(sys.argv) == 1: