diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 014c778eee..9f7fb01851 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -450,7 +450,7 @@ def interpolate(x, for i in range(len(x.shape) - 2): scale_list.append(scale) attrs['scale'] = list(map(float, scale_list)) - elif isinstance(scale, list) or isinstance(scale, float): + elif isinstance(scale, list) or isinstance(scale, tuple): if len(scale) != len(x.shape) - 2: raise ValueError("scale_shape length should be {} for " "input {}-D tensor.".format(