|
|
|
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
import json
|
|
|
|
import sys
|
|
|
|
from paddle.utils import OpLastCheckpointChecker
|
|
|
|
from paddle.fluid.core import OpUpdateType
|
|
|
|
|
|
|
|
SAME = 0
|
|
|
|
|
|
|
|
INPUTS = "Inputs"
|
|
|
|
OUTPUTS = "Outputs"
|
|
|
|
ATTRS = "Attrs"
|
|
|
|
|
|
|
|
# The constant `ADD` means that an item has been added. In particular,
|
|
|
|
# we use `ADD_WITH_DEFAULT` to mean adding attributes with default
|
|
|
|
# attributes, and `ADD_DISPENSABLE` to mean adding optional inputs or
|
|
|
|
# outputs.
|
|
|
|
ADD_WITH_DEFAULT = "Add_with_default"
|
|
|
|
ADD_DISPENSABLE = "Add_dispensable"
|
|
|
|
ADD = "Add"
|
|
|
|
|
|
|
|
DELETE = "Delete"
|
|
|
|
CHANGE = "Change"
|
|
|
|
|
|
|
|
DUPLICABLE = "duplicable"
|
|
|
|
INTERMEDIATE = "intermediate"
|
|
|
|
DISPENSABLE = "dispensable"
|
|
|
|
|
|
|
|
TYPE = "type"
|
|
|
|
GENERATED = "generated"
|
|
|
|
DEFAULT_VALUE = "default_value"
|
|
|
|
|
|
|
|
error = False
|
|
|
|
|
|
|
|
version_update_map = {
|
|
|
|
INPUTS: {
|
|
|
|
ADD: OpUpdateType.kNewInput,
|
|
|
|
},
|
|
|
|
OUTPUTS: {
|
|
|
|
ADD: OpUpdateType.kNewOutput,
|
|
|
|
},
|
|
|
|
ATTRS: {
|
|
|
|
ADD: OpUpdateType.kNewAttr,
|
|
|
|
CHANGE: OpUpdateType.kModifyAttr,
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def diff_vars(origin_vars, new_vars):
|
|
|
|
global error
|
|
|
|
var_error = False
|
|
|
|
var_changed_error_massage = {}
|
|
|
|
var_add_massage = []
|
|
|
|
var_add_dispensable_massage = []
|
|
|
|
var_deleted_error_massage = []
|
|
|
|
|
|
|
|
common_vars_name = set(origin_vars.keys()) & set(new_vars.keys())
|
|
|
|
vars_name_only_in_origin = set(origin_vars.keys()) - set(new_vars.keys())
|
|
|
|
vars_name_only_in_new = set(new_vars.keys()) - set(origin_vars.keys())
|
|
|
|
|
|
|
|
for var_name in common_vars_name:
|
|
|
|
if cmp(origin_vars.get(var_name), new_vars.get(var_name)) == SAME:
|
|
|
|
continue
|
|
|
|
else:
|
|
|
|
error, var_error = True, True
|
|
|
|
var_changed_error_massage[var_name] = {}
|
|
|
|
for arg_name in origin_vars.get(var_name):
|
|
|
|
new_arg_value = new_vars.get(var_name, {}).get(arg_name)
|
|
|
|
origin_arg_value = origin_vars.get(var_name, {}).get(arg_name)
|
|
|
|
if new_arg_value != origin_arg_value:
|
|
|
|
var_changed_error_massage[var_name][arg_name] = (
|
|
|
|
origin_arg_value, new_arg_value)
|
|
|
|
|
|
|
|
for var_name in vars_name_only_in_origin:
|
|
|
|
error, var_error = True, True
|
|
|
|
var_deleted_error_massage.append(var_name)
|
|
|
|
|
|
|
|
for var_name in vars_name_only_in_new:
|
|
|
|
var_add_massage.append(var_name)
|
|
|
|
if not new_vars.get(var_name).get(DISPENSABLE):
|
|
|
|
error, var_error = True, True
|
|
|
|
var_add_dispensable_massage.append(var_name)
|
|
|
|
|
|
|
|
var_diff_message = {}
|
|
|
|
if var_add_massage:
|
|
|
|
var_diff_message[ADD] = var_add_massage
|
|
|
|
if var_add_dispensable_massage:
|
|
|
|
var_diff_message[ADD_DISPENSABLE] = var_add_dispensable_massage
|
|
|
|
if var_changed_error_massage:
|
|
|
|
var_diff_message[CHANGE] = var_changed_error_massage
|
|
|
|
if var_deleted_error_massage:
|
|
|
|
var_diff_message[DELETE] = var_deleted_error_massage
|
|
|
|
|
|
|
|
return var_error, var_diff_message
|
|
|
|
|
|
|
|
|
|
|
|
def diff_attr(ori_attrs, new_attrs):
|
|
|
|
global error
|
|
|
|
attr_error = False
|
|
|
|
|
|
|
|
attr_changed_error_massage = {}
|
|
|
|
attr_added_error_massage = []
|
|
|
|
attr_added_def_error_massage = []
|
|
|
|
attr_deleted_error_massage = []
|
|
|
|
|
|
|
|
common_attrs = set(ori_attrs.keys()) & set(new_attrs.keys())
|
|
|
|
attrs_only_in_origin = set(ori_attrs.keys()) - set(new_attrs.keys())
|
|
|
|
attrs_only_in_new = set(new_attrs.keys()) - set(ori_attrs.keys())
|
|
|
|
|
|
|
|
for attr_name in common_attrs:
|
|
|
|
if cmp(ori_attrs.get(attr_name), new_attrs.get(attr_name)) == SAME:
|
|
|
|
continue
|
|
|
|
else:
|
|
|
|
error, attr_error = True, True
|
|
|
|
attr_changed_error_massage[attr_name] = {}
|
|
|
|
for arg_name in ori_attrs.get(attr_name):
|
|
|
|
new_arg_value = new_attrs.get(attr_name, {}).get(arg_name)
|
|
|
|
origin_arg_value = ori_attrs.get(attr_name, {}).get(arg_name)
|
|
|
|
if new_arg_value != origin_arg_value:
|
|
|
|
attr_changed_error_massage[attr_name][arg_name] = (
|
|
|
|
origin_arg_value, new_arg_value)
|
|
|
|
|
|
|
|
for attr_name in attrs_only_in_origin:
|
|
|
|
error, attr_error = True, True
|
|
|
|
attr_deleted_error_massage.append(attr_name)
|
|
|
|
|
|
|
|
for attr_name in attrs_only_in_new:
|
|
|
|
attr_added_error_massage.append(attr_name)
|
|
|
|
if new_attrs.get(attr_name).get(DEFAULT_VALUE) == None:
|
|
|
|
error, attr_error = True, True
|
|
|
|
attr_added_def_error_massage.append(attr_name)
|
|
|
|
|
|
|
|
attr_diff_message = {}
|
|
|
|
if attr_added_error_massage:
|
|
|
|
attr_diff_message[ADD] = attr_added_error_massage
|
|
|
|
if attr_added_def_error_massage:
|
|
|
|
attr_diff_message[ADD_WITH_DEFAULT] = attr_added_def_error_massage
|
|
|
|
if attr_changed_error_massage:
|
|
|
|
attr_diff_message[CHANGE] = attr_changed_error_massage
|
|
|
|
if attr_deleted_error_massage:
|
|
|
|
attr_diff_message[DELETE] = attr_deleted_error_massage
|
|
|
|
|
|
|
|
return attr_error, attr_diff_message
|
|
|
|
|
|
|
|
|
|
|
|
def check_io_registry(io_type, op, diff):
|
|
|
|
checker = OpLastCheckpointChecker()
|
|
|
|
results = {}
|
|
|
|
for update_type in [ADD]:
|
|
|
|
for item in diff.get(update_type, {}):
|
|
|
|
infos = checker.filter_updates(
|
|
|
|
op, version_update_map[io_type][update_type], item)
|
|
|
|
if not infos:
|
|
|
|
results[update_type] = (op, item, io_type)
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
def check_attr_registry(op, diff):
|
|
|
|
checker = OpLastCheckpointChecker()
|
|
|
|
results = {}
|
|
|
|
for update_type in [ADD, CHANGE]:
|
|
|
|
for item in diff.get(update_type, {}):
|
|
|
|
infos = checker.filter_updates(
|
|
|
|
op, version_update_map[ATTRS][update_type], item)
|
|
|
|
if not infos:
|
|
|
|
results[update_type] = (op, item)
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
def compare_op_desc(origin_op_desc, new_op_desc):
|
|
|
|
origin = json.loads(origin_op_desc)
|
|
|
|
new = json.loads(new_op_desc)
|
|
|
|
desc_error_message = {}
|
|
|
|
version_error_message = {}
|
|
|
|
if cmp(origin_op_desc, new_op_desc) == SAME:
|
|
|
|
return desc_error_message, version_error_message
|
|
|
|
|
|
|
|
for op_type in origin:
|
|
|
|
# no need to compare if the operator is deleted
|
|
|
|
if op_type not in new:
|
|
|
|
continue
|
|
|
|
|
|
|
|
origin_info = origin.get(op_type, {})
|
|
|
|
new_info = new.get(op_type, {})
|
|
|
|
|
|
|
|
origin_inputs = origin_info.get(INPUTS, {})
|
|
|
|
new_inputs = new_info.get(INPUTS, {})
|
|
|
|
ins_error, ins_diff = diff_vars(origin_inputs, new_inputs)
|
|
|
|
ins_version_errors = check_io_registry(INPUTS, op_type, ins_diff)
|
|
|
|
|
|
|
|
origin_outputs = origin_info.get(OUTPUTS, {})
|
|
|
|
new_outputs = new_info.get(OUTPUTS, {})
|
|
|
|
outs_error, outs_diff = diff_vars(origin_outputs, new_outputs)
|
|
|
|
outs_version_errors = check_io_registry(OUTPUTS, op_type, outs_diff)
|
|
|
|
|
|
|
|
origin_attrs = origin_info.get(ATTRS, {})
|
|
|
|
new_attrs = new_info.get(ATTRS, {})
|
|
|
|
attrs_error, attrs_diff = diff_attr(origin_attrs, new_attrs)
|
|
|
|
attrs_version_errors = check_attr_registry(op_type, attrs_diff)
|
|
|
|
|
|
|
|
if ins_error:
|
|
|
|
desc_error_message.setdefault(op_type, {})[INPUTS] = ins_diff
|
|
|
|
if outs_error:
|
|
|
|
desc_error_message.setdefault(op_type, {})[OUTPUTS] = outs_diff
|
|
|
|
if attrs_error:
|
|
|
|
desc_error_message.setdefault(op_type, {})[ATTRS] = attrs_diff
|
|
|
|
|
|
|
|
if ins_version_errors:
|
|
|
|
version_error_message.setdefault(op_type,
|
|
|
|
{})[INPUTS] = ins_version_errors
|
|
|
|
if outs_version_errors:
|
|
|
|
version_error_message.setdefault(op_type,
|
|
|
|
{})[OUTPUTS] = outs_version_errors
|
|
|
|
if attrs_version_errors:
|
|
|
|
version_error_message.setdefault(op_type,
|
|
|
|
{})[ATTRS] = attrs_version_errors
|
|
|
|
|
|
|
|
return desc_error_message, version_error_message
|
|
|
|
|
|
|
|
|
|
|
|
def print_desc_error_message(error_message):
|
|
|
|
print("\n======================= \n"
|
|
|
|
"Op desc error for the changes of Inputs/Outputs/Attrs of OPs:\n")
|
|
|
|
for op_name in error_message:
|
|
|
|
print("For OP '{}':".format(op_name))
|
|
|
|
|
|
|
|
# 1. print inputs error message
|
|
|
|
Inputs_error = error_message.get(op_name, {}).get(INPUTS, {})
|
|
|
|
for name in Inputs_error.get(ADD_DISPENSABLE, {}):
|
|
|
|
print(" * The added Input '{}' is not dispensable.".format(name))
|
|
|
|
|
|
|
|
for name in Inputs_error.get(DELETE, {}):
|
|
|
|
print(" * The Input '{}' is deleted.".format(name))
|
|
|
|
|
|
|
|
for name in Inputs_error.get(CHANGE, {}):
|
|
|
|
changed_args = Inputs_error.get(CHANGE, {}).get(name, {})
|
|
|
|
for arg in changed_args:
|
|
|
|
ori_value, new_value = changed_args.get(arg)
|
|
|
|
print(
|
|
|
|
" * The arg '{}' of Input '{}' is changed: from '{}' to '{}'.".
|
|
|
|
format(arg, name, ori_value, new_value))
|
|
|
|
|
|
|
|
# 2. print outputs error message
|
|
|
|
Outputs_error = error_message.get(op_name, {}).get(OUTPUTS, {})
|
|
|
|
for name in Outputs_error.get(ADD_DISPENSABLE, {}):
|
|
|
|
print(" * The added Output '{}' is not dispensable.".format(name))
|
|
|
|
|
|
|
|
for name in Outputs_error.get(DELETE, {}):
|
|
|
|
print(" * The Output '{}' is deleted.".format(name))
|
|
|
|
|
|
|
|
for name in Outputs_error.get(CHANGE, {}):
|
|
|
|
changed_args = Outputs_error.get(CHANGE, {}).get(name, {})
|
|
|
|
for arg in changed_args:
|
|
|
|
ori_value, new_value = changed_args.get(arg)
|
|
|
|
print(
|
|
|
|
" * The arg '{}' of Output '{}' is changed: from '{}' to '{}'.".
|
|
|
|
format(arg, name, ori_value, new_value))
|
|
|
|
|
|
|
|
# 3. print attrs error message
|
|
|
|
attrs_error = error_message.get(op_name, {}).get(ATTRS, {})
|
|
|
|
for name in attrs_error.get(ADD_WITH_DEFAULT, {}):
|
|
|
|
print(" * The added attr '{}' doesn't set default value.".format(
|
|
|
|
name))
|
|
|
|
|
|
|
|
for name in attrs_error.get(DELETE, {}):
|
|
|
|
print(" * The attr '{}' is deleted.".format(name))
|
|
|
|
|
|
|
|
for name in attrs_error.get(CHANGE, {}):
|
|
|
|
changed_args = attrs_error.get(CHANGE, {}).get(name, {})
|
|
|
|
for arg in changed_args:
|
|
|
|
ori_value, new_value = changed_args.get(arg)
|
|
|
|
print(
|
|
|
|
" * The arg '{}' of attr '{}' is changed: from '{}' to '{}'.".
|
|
|
|
format(arg, name, ori_value, new_value))
|
|
|
|
|
|
|
|
|
|
|
|
def print_version_error_message(error_message):
|
|
|
|
print(
|
|
|
|
"\n======================= \n"
|
|
|
|
"Operator registration error for the changes of Inputs/Outputs/Attrs of OPs:\n"
|
|
|
|
)
|
|
|
|
for op_name in error_message:
|
|
|
|
print("For OP '{}':".format(op_name))
|
|
|
|
|
|
|
|
# 1. print inputs error message
|
|
|
|
inputs_error = error_message.get(op_name, {}).get(INPUTS, {})
|
|
|
|
tuple = inputs_error.get(ADD, {})
|
|
|
|
if tuple:
|
|
|
|
print(" * The added input '{}' is not yet registered.".format(tuple[
|
|
|
|
1]))
|
|
|
|
|
|
|
|
# 2. print inputs error message
|
|
|
|
outputs_error = error_message.get(op_name, {}).get(OUTPUTS, {})
|
|
|
|
tuple = outputs_error.get(ADD, {})
|
|
|
|
if tuple:
|
|
|
|
print(" * The added output '{}' is not yet registered.".format(
|
|
|
|
tuple[1]))
|
|
|
|
|
|
|
|
#3. print attrs error message
|
|
|
|
attrs_error = error_message.get(op_name, {}).get(ATTRS, {})
|
|
|
|
tuple = attrs_error.get(ADD, {})
|
|
|
|
if tuple:
|
|
|
|
print(" * The added attribute '{}' is not yet registered.".format(
|
|
|
|
tuple[1]))
|
|
|
|
tuple = attrs_error.get(CHANGE, {})
|
|
|
|
if tuple:
|
|
|
|
print(" * The change of attribute '{}' is not yet registered.".
|
|
|
|
format(tuple[1]))
|
|
|
|
|
|
|
|
|
|
|
|
def print_repeat_process():
|
|
|
|
print(
|
|
|
|
"Tips:"
|
|
|
|
" If you want to repeat the process, please follow these steps:\n"
|
|
|
|
"\t1. Compile and install paddle from develop branch \n"
|
|
|
|
"\t2. Run: python tools/print_op_desc.py > OP_DESC_DEV.spec \n"
|
|
|
|
"\t3. Compile and install paddle from PR branch \n"
|
|
|
|
"\t4. Run: python tools/print_op_desc.py > OP_DESC_PR.spec \n"
|
|
|
|
"\t5. Run: python tools/check_op_desc.py OP_DESC_DEV.spec OP_DESC_PR.spec"
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
if len(sys.argv) == 3:
|
|
|
|
'''
|
|
|
|
Compare op_desc files generated by branch DEV and branch PR.
|
|
|
|
And print error message.
|
|
|
|
'''
|
|
|
|
with open(sys.argv[1], 'r') as f:
|
|
|
|
origin_op_desc = f.read()
|
|
|
|
|
|
|
|
with open(sys.argv[2], 'r') as f:
|
|
|
|
new_op_desc = f.read()
|
|
|
|
|
|
|
|
desc_error_message, version_error_message = compare_op_desc(origin_op_desc,
|
|
|
|
new_op_desc)
|
|
|
|
if error:
|
|
|
|
print("-" * 30)
|
|
|
|
print_desc_error_message(desc_error_message)
|
|
|
|
print_version_error_message(version_error_message)
|
|
|
|
print("-" * 30)
|
|
|
|
else:
|
|
|
|
print("Usage: python check_op_desc.py OP_DESC_DEV.spec OP_DESC_PR.spec")
|