|
|
|
@ -1110,10 +1110,10 @@ def check_gnn_list_or_ndarray(param, param_name):
|
|
|
|
|
for m in param:
|
|
|
|
|
if not isinstance(m, int):
|
|
|
|
|
raise TypeError(
|
|
|
|
|
"Each membor in {0} should be of type int. Got {1}.".format(param_name, type(m)))
|
|
|
|
|
"Each member in {0} should be of type int. Got {1}.".format(param_name, type(m)))
|
|
|
|
|
elif isinstance(param, np.ndarray):
|
|
|
|
|
if not param.dtype == np.int32:
|
|
|
|
|
raise TypeError("Each membor in {0} should be of type int32. Got {1}.".format(
|
|
|
|
|
raise TypeError("Each member in {0} should be of type int32. Got {1}.".format(
|
|
|
|
|
param_name, param.dtype))
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError("Wrong input type for {0}, should be list or numpy.ndarray, got {1}".format(
|
|
|
|
@ -1196,15 +1196,15 @@ def check_gnn_get_sampled_neighbors(method):
|
|
|
|
|
# check neighbor_nums; required argument
|
|
|
|
|
neighbor_nums = param_dict.get("neighbor_nums")
|
|
|
|
|
check_gnn_list_or_ndarray(neighbor_nums, 'neighbor_nums')
|
|
|
|
|
if len(neighbor_nums) > 6:
|
|
|
|
|
raise ValueError("Wrong number of input members for {0}, should be less than or equal to 6, got {1}".format(
|
|
|
|
|
if not neighbor_nums or len(neighbor_nums) > 6:
|
|
|
|
|
raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}".format(
|
|
|
|
|
'neighbor_nums', len(neighbor_nums)))
|
|
|
|
|
|
|
|
|
|
# check neighbor_types; required argument
|
|
|
|
|
neighbor_types = param_dict.get("neighbor_types")
|
|
|
|
|
check_gnn_list_or_ndarray(neighbor_types, 'neighbor_types')
|
|
|
|
|
if len(neighbor_nums) > 6:
|
|
|
|
|
raise ValueError("Wrong number of input members for {0}, should be less than or equal to 6, got {1}".format(
|
|
|
|
|
if not neighbor_types or len(neighbor_types) > 6:
|
|
|
|
|
raise ValueError("Wrong number of input members for {0}, should be between 1 and 6, got {1}".format(
|
|
|
|
|
'neighbor_types', len(neighbor_types)))
|
|
|
|
|
|
|
|
|
|
if len(neighbor_nums) != len(neighbor_types):
|
|
|
|
@ -1256,7 +1256,7 @@ def check_gnn_random_walk(method):
|
|
|
|
|
return new_method
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_aligned_list(param, param_name, membor_type):
|
|
|
|
|
def check_aligned_list(param, param_name, member_type):
|
|
|
|
|
"""Check whether the structure of each member of the list is the same."""
|
|
|
|
|
|
|
|
|
|
if not isinstance(param, list):
|
|
|
|
@ -1264,27 +1264,27 @@ def check_aligned_list(param, param_name, membor_type):
|
|
|
|
|
if not param:
|
|
|
|
|
raise TypeError(
|
|
|
|
|
"Parameter {0} or its members are empty".format(param_name))
|
|
|
|
|
membor_have_list = None
|
|
|
|
|
member_have_list = None
|
|
|
|
|
list_len = None
|
|
|
|
|
for membor in param:
|
|
|
|
|
if isinstance(membor, list):
|
|
|
|
|
check_aligned_list(membor, param_name, membor_type)
|
|
|
|
|
if membor_have_list not in (None, True):
|
|
|
|
|
for member in param:
|
|
|
|
|
if isinstance(member, list):
|
|
|
|
|
check_aligned_list(member, param_name, member_type)
|
|
|
|
|
if member_have_list not in (None, True):
|
|
|
|
|
raise TypeError("The type of each member of the parameter {0} is inconsistent".format(
|
|
|
|
|
param_name))
|
|
|
|
|
if list_len is not None and len(membor) != list_len:
|
|
|
|
|
if list_len is not None and len(member) != list_len:
|
|
|
|
|
raise TypeError("The size of each member of parameter {0} is inconsistent".format(
|
|
|
|
|
param_name))
|
|
|
|
|
membor_have_list = True
|
|
|
|
|
list_len = len(membor)
|
|
|
|
|
member_have_list = True
|
|
|
|
|
list_len = len(member)
|
|
|
|
|
else:
|
|
|
|
|
if not isinstance(membor, membor_type):
|
|
|
|
|
raise TypeError("Each membor in {0} should be of type int. Got {1}.".format(
|
|
|
|
|
param_name, type(membor)))
|
|
|
|
|
if membor_have_list not in (None, False):
|
|
|
|
|
if not isinstance(member, member_type):
|
|
|
|
|
raise TypeError("Each member in {0} should be of type int. Got {1}.".format(
|
|
|
|
|
param_name, type(member)))
|
|
|
|
|
if member_have_list not in (None, False):
|
|
|
|
|
raise TypeError("The type of each member of the parameter {0} is inconsistent".format(
|
|
|
|
|
param_name))
|
|
|
|
|
membor_have_list = False
|
|
|
|
|
member_have_list = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_gnn_get_node_feature(method):
|
|
|
|
@ -1300,7 +1300,7 @@ def check_gnn_get_node_feature(method):
|
|
|
|
|
check_aligned_list(node_list, 'node_list', int)
|
|
|
|
|
elif isinstance(node_list, np.ndarray):
|
|
|
|
|
if not node_list.dtype == np.int32:
|
|
|
|
|
raise TypeError("Each membor in {0} should be of type int32. Got {1}.".format(
|
|
|
|
|
raise TypeError("Each member in {0} should be of type int32. Got {1}.".format(
|
|
|
|
|
node_list, node_list.dtype))
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError("Wrong input type for {0}, should be list or numpy.ndarray, got {1}".format(
|
|
|
|
|