You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/python/paddle/fluid/tests/unittests/test_switch_case.py

310 lines
12 KiB

# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.layers as layers
from paddle.fluid.framework import Program, program_guard
from functools import partial
class TestAPISwitchCase(unittest.TestCase):
def test_return_single_var(self):
def fn_1():
return layers.fill_constant(shape=[4, 2], dtype='int32', value=1)
def fn_2():
return layers.fill_constant(shape=[4, 2], dtype='int32', value=2)
def fn_3():
return layers.fill_constant(shape=[4, 3], dtype='int32', value=3)
main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
index_1 = layers.fill_constant(shape=[1], dtype='int32', value=1)
index_2 = layers.fill_constant(shape=[1], dtype='int32', value=2)
index_5 = layers.fill_constant(shape=[1], dtype='int32', value=5)
# call fn_1
out_0 = layers.switch_case(
branch_index=index_1, branch_fns={1: fn_1,
2: fn_2,
3: fn_3})
# call fn_2 : branch_fns={0: fn_1, 1:fn_2, 2:fn_3}
out_1 = layers.switch_case(
branch_index=index_1, branch_fns=(fn_1, fn_2, fn_3))
# call default fn_3
out_2 = layers.switch_case(
branch_index=index_5,
branch_fns=((1, fn_1), (2, fn_2)),
default=fn_3)
# no default, call fn_2
out_3 = layers.switch_case(
branch_index=index_2, branch_fns=[(1, fn_1), (2, fn_2)])
# no default, call fn_2 but branch_index is 5
out_4 = layers.switch_case(
branch_index=index_5,
branch_fns=[(1, fn_1), (3, fn_2), (2, fn_3)])
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
res = exe.run(main_program,
fetch_list=[out_0, out_1, out_2, out_3, out_4])
self.assertTrue(
np.allclose(res[0], 1),
"result is {} but answer is {}".format(res[0], 1))
self.assertTrue(
np.allclose(res[1], 2),
"result is {} but answer is {}".format(res[0], 2))
self.assertTrue(
np.allclose(res[2], 3),
"result is {} but answer is {}".format(res[0], 3))
self.assertTrue(
np.allclose(res[3], 2),
"result is {} but answer is {}".format(res[0], 2))
self.assertTrue(
np.allclose(res[4], 2),
"result is {} but answer is {}".format(res[0], 2))
def test_return_var_tuple(self):
def fn_1():
return layers.fill_constant(
shape=[1, 2], dtype='int32', value=1), layers.fill_constant(
shape=[2, 3], dtype='float32', value=2)
def fn_2():
return layers.fill_constant(
shape=[3, 4], dtype='int32', value=3), layers.fill_constant(
shape=[4, 5], dtype='float32', value=4)
def fn_3():
return layers.fill_constant(
shape=[5], dtype='int32', value=5), layers.fill_constant(
shape=[5, 6], dtype='float32', value=6)
main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
index_1 = layers.fill_constant(shape=[1], dtype='int32', value=1)
out = layers.switch_case(index_1, ((1, fn_1), (2, fn_2)), fn_3)
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
ret = exe.run(main_program, fetch_list=out)
self.assertTrue(
np.allclose(np.asarray(ret[0]), np.full((1, 2), 1, np.int32)))
self.assertTrue(
np.allclose(
np.asarray(ret[1]), np.full((2, 3), 2, np.float32)))
class TestAPISwitchCase_Nested(unittest.TestCase):
def test_nested_switch_case(self):
def fn_1(x=1):
out = layers.switch_case(
branch_index=layers.fill_constant(
shape=[1], dtype='int32', value=x),
branch_fns={
1: partial(
layers.fill_constant, shape=[1], dtype='int32',
value=1),
x: partial(
layers.fill_constant, shape=[2], dtype='int32', value=x)
})
return out
def fn_2(x=2):
out = layers.switch_case(
branch_index=layers.fill_constant(
shape=[1], dtype='int32', value=2),
branch_fns={
1: partial(
layers.fill_constant,
shape=[4, 3],
dtype='int32',
value=1),
2: partial(
fn_1, x=x)
})
return out
def fn_3():
out = layers.switch_case(
branch_index=layers.fill_constant(
shape=[1], dtype='int32', value=3),
branch_fns={
1: partial(
layers.fill_constant,
shape=[4, 3],
dtype='int32',
value=1),
3: partial(
fn_2, x=3)
})
return out
main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
index_1 = fluid.data(name="index_1", shape=[1], dtype='uint8')
index_2 = layers.fill_constant(shape=[1], dtype='int32', value=2)
index_3 = layers.fill_constant(shape=[1], dtype='int64', value=3)
out_1 = layers.switch_case(
branch_index=index_1, branch_fns={1: fn_1,
2: fn_2,
3: fn_3})
out_2 = layers.switch_case(
branch_index=index_2, branch_fns={1: fn_1,
2: fn_2,
3: fn_3})
out_3 = layers.switch_case(
branch_index=index_3, branch_fns={1: fn_1,
2: fn_2,
3: fn_3})
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
res = exe.run(main_program,
feed={"index_1": np.array(
[1], dtype="uint8")},
fetch_list=[out_1, out_2, out_3])
self.assertTrue(
np.allclose(res[0], 1),
"result is {} but answer is {}".format(res[0], 1))
self.assertTrue(
np.allclose(res[1], 2),
"result is {} but answer is {}".format(res[1], 2))
self.assertTrue(
np.allclose(res[2], 3),
"result is {} but answer is {}".format(res[2], 3))
# test TypeError and ValueError of api switch_case
class TestAPISwitchCase_Error(unittest.TestCase):
def test_error(self):
def fn_1():
return layers.fill_constant(shape=[4, 2], dtype='int32', value=1)
def fn_2():
return layers.fill_constant(shape=[4, 2], dtype='int32', value=2)
def fn_3():
return layers.fill_constant(shape=[4, 3], dtype='int32', value=3)
main_program = Program()
startup_program = Program()
with program_guard(main_program, startup_program):
key_float32 = layers.fill_constant(
shape=[1], dtype='float32', value=0.23)
key_int32 = layers.fill_constant(
shape=[1], dtype='int32', value=0.23)
# The type of 'branch_index' in Op(switch_case) must be Variable
def type_error_branch_index():
layers.switch_case(
branch_index=1, branch_fns=[(1, fn_1)], default=fn_3)
self.assertRaises(TypeError, type_error_branch_index)
# The data type of 'branch_index' in Op(switch_case) must be int32, int64 or uint8
def dtype_error_branch_index():
layers.switch_case(
branch_index=key_float32,
branch_fns=[(1, fn_1)],
default=fn_3)
self.assertRaises(TypeError, dtype_error_branch_index)
# The type of 'branch_fns' in Op(switch_case) must be list, tuple or dict
def type_error_branch_fns():
layers.switch_case(
branch_index=key_int32, branch_fns=1, default=fn_3)
self.assertRaises(TypeError, type_error_branch_fns)
# The elements' type of 'branch_fns' in Op(switch_case) must be tuple
def type_error_index_fn_pair_1():
layers.switch_case(
branch_index=key_int32, branch_fns=[1], default=fn_3)
self.assertRaises(TypeError, type_error_index_fn_pair_1)
# The tuple's size of 'branch_fns' in Op(switch_case) must be 2
def type_error_index_fn_pair_2():
layers.switch_case(
branch_index=key_int32,
branch_fns=[(1, 2, 3)],
default=fn_3)
self.assertRaises(TypeError, type_error_index_fn_pair_2)
# The key's type of 'branch_fns' in Op(switch_case) must be int
def type_error_key():
layers.switch_case(
branch_index=key_int32, branch_fns=[(2.3, 2)], default=fn_3)
self.assertRaises(TypeError, type_error_key)
# The key in 'branch_fns' must be unique
def value_error_key():
layers.switch_case(
branch_index=key_int32,
branch_fns=[(2, fn_1), (2, fn_2)],
default=fn_3)
self.assertRaises(ValueError, value_error_key)
# The type of function in 'branch_fns' must be callable
def type_error_fn():
layers.switch_case(
branch_index=key_int32,
branch_fns=[(1, 1), (2, fn_2)],
default=fn_3)
self.assertRaises(TypeError, type_error_fn)
# The default in Op(case) must be callable
def type_error_default():
layers.switch_case(
branch_index=key_int32,
branch_fns=[(1, fn_1), (2, fn_2)],
default=1)
self.assertRaises(TypeError, type_error_default)
if __name__ == '__main__':
unittest.main()