|
|
@ -26,7 +26,7 @@ import numpy
|
|
|
|
import warnings
|
|
|
|
import warnings
|
|
|
|
import six
|
|
|
|
import six
|
|
|
|
from functools import reduce, partial
|
|
|
|
from functools import reduce, partial
|
|
|
|
from ..data_feeder import convert_dtype, check_variable_and_dtype
|
|
|
|
from ..data_feeder import convert_dtype, check_variable_and_dtype, check_type
|
|
|
|
from ... import compat as cpt
|
|
|
|
from ... import compat as cpt
|
|
|
|
from ..backward import _infer_var_data_type_shape_
|
|
|
|
from ..backward import _infer_var_data_type_shape_
|
|
|
|
|
|
|
|
|
|
|
@ -2251,16 +2251,13 @@ def case(pred_fn_pairs, default=None, name=None):
|
|
|
|
'''
|
|
|
|
'''
|
|
|
|
Check arguments pred_fn_pairs and default. Return canonical pre_fn_pairs and default.
|
|
|
|
Check arguments pred_fn_pairs and default. Return canonical pre_fn_pairs and default.
|
|
|
|
'''
|
|
|
|
'''
|
|
|
|
if not isinstance(pred_fn_pairs, (list, tuple)):
|
|
|
|
check_type(pred_fn_pairs, 'pred_fn_pairs', (list, tuple), 'case')
|
|
|
|
raise TypeError(
|
|
|
|
|
|
|
|
_error_message("The type", "pred_fn_pairs", "case",
|
|
|
|
|
|
|
|
"list or tuple", type(pred_fn_pairs)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for pred_fn in pred_fn_pairs:
|
|
|
|
for pred_fn in pred_fn_pairs:
|
|
|
|
if not isinstance(pred_fn, tuple):
|
|
|
|
if not isinstance(pred_fn, tuple):
|
|
|
|
raise TypeError(
|
|
|
|
raise TypeError(
|
|
|
|
_error_message("The elements' type", "pred_fn_pairs",
|
|
|
|
_error_message("The elements' type", "pred_fn_pairs",
|
|
|
|
"case", "tuple", type(pred_fn)))
|
|
|
|
"case", tuple, type(pred_fn)))
|
|
|
|
if len(pred_fn) != 2:
|
|
|
|
if len(pred_fn) != 2:
|
|
|
|
raise TypeError(
|
|
|
|
raise TypeError(
|
|
|
|
_error_message("The tuple's size", "pred_fn_pairs", "case",
|
|
|
|
_error_message("The tuple's size", "pred_fn_pairs", "case",
|
|
|
|