|
|
|
@ -3361,24 +3361,14 @@ def switch_case(branch_index, branch_fns, default=None, name=None):
|
|
|
|
|
helper = LayerHelper('switch_case', **locals())
|
|
|
|
|
|
|
|
|
|
def _check_args(branch_index, branch_fns, default):
|
|
|
|
|
if not isinstance(branch_index, Variable):
|
|
|
|
|
raise TypeError(
|
|
|
|
|
_error_message("The type", "branch_index", "switch_case",
|
|
|
|
|
"Variable", type(branch_index)))
|
|
|
|
|
|
|
|
|
|
if convert_dtype(branch_index.dtype) not in ["uint8", "int32", "int64"]:
|
|
|
|
|
raise TypeError(
|
|
|
|
|
_error_message("The data type", "branch_index", "switch_case",
|
|
|
|
|
"uint8, int32 or int64",
|
|
|
|
|
convert_dtype(branch_index.dtype)))
|
|
|
|
|
check_variable_and_dtype(branch_index, 'branch_index',
|
|
|
|
|
['uint8', 'int32', 'int64'], 'switch_case')
|
|
|
|
|
|
|
|
|
|
if convert_dtype(branch_index.dtype) != "int64":
|
|
|
|
|
branch_index = cast(branch_index, "int64")
|
|
|
|
|
|
|
|
|
|
if not isinstance(branch_fns, (list, tuple, dict)):
|
|
|
|
|
raise TypeError(
|
|
|
|
|
_error_message("The type", "branch_fns", "switch_case",
|
|
|
|
|
"dict, tuple or list", type(branch_fns)))
|
|
|
|
|
check_type(branch_fns, 'branch_fns', (list, tuple, dict), 'switch_case')
|
|
|
|
|
|
|
|
|
|
branch_fns = branch_fns.items() if isinstance(branch_fns,
|
|
|
|
|
dict) else branch_fns
|
|
|
|
@ -3391,7 +3381,7 @@ def switch_case(branch_index, branch_fns, default=None, name=None):
|
|
|
|
|
if not isinstance(index_fn_pair, tuple):
|
|
|
|
|
raise TypeError(
|
|
|
|
|
_error_message("The elements' type", "branch_fns",
|
|
|
|
|
"switch_case", "tuple", type(branch_fns)))
|
|
|
|
|
"switch_case", tuple, type(branch_fns)))
|
|
|
|
|
|
|
|
|
|
if len(index_fn_pair) != 2:
|
|
|
|
|
raise TypeError(
|
|
|
|
@ -3404,7 +3394,7 @@ def switch_case(branch_index, branch_fns, default=None, name=None):
|
|
|
|
|
if not isinstance(key, int):
|
|
|
|
|
raise TypeError(
|
|
|
|
|
_error_message("The key's type", "branch_fns",
|
|
|
|
|
"switch_case", "int", type(key)))
|
|
|
|
|
"switch_case", int, type(key)))
|
|
|
|
|
|
|
|
|
|
if key in keys_of_fns:
|
|
|
|
|
raise ValueError(
|
|
|
|
|