fix tests warpctc (#27639)

revert-27356-init_low_level_gloo
Li Fuchen 4 years ago committed by GitHub
parent c9a8801325
commit 516d84b22a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -394,8 +394,7 @@ foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP}) py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach(TEST_OP) endforeach(TEST_OP)
py_test_modules(test_adam_op_multi_thread MODULES test_adam_op ENVS FLAGS_inner_op_parallelism=4) py_test_modules(test_adam_op_multi_thread MODULES test_adam_op ENVS FLAGS_inner_op_parallelism=4)
# disable test_warpctc_op py_test_modules(test_warpctc_op MODULES test_warpctc_op)
# py_test_modules(test_warpctc_op MODULES test_warpctc_op)
py_test_modules(test_bilinear_interp_op MODULES test_bilinear_interp_op ENVS ${GC_ENVS}) py_test_modules(test_bilinear_interp_op MODULES test_bilinear_interp_op ENVS ${GC_ENVS})
py_test_modules(test_nearest_interp_op MODULES test_nearest_interp_op ENVS ${GC_ENVS}) py_test_modules(test_nearest_interp_op MODULES test_nearest_interp_op ENVS ${GC_ENVS})
py_test_modules(test_imperative_resnet MODULES test_imperative_resnet ENVS py_test_modules(test_imperative_resnet MODULES test_imperative_resnet ENVS

@ -24,6 +24,8 @@ from paddle.fluid import Program, program_guard
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
paddle.enable_static()
CUDA_BLOCK_SIZE = 32 CUDA_BLOCK_SIZE = 32
@ -490,8 +492,8 @@ class TestWarpCTCOpError(unittest.TestCase):
logits = np.random.uniform(0.1, 1.0, [20, 15]).astype("float32") logits = np.random.uniform(0.1, 1.0, [20, 15]).astype("float32")
# labels should not be blank # labels should not be blank
labels = np.random.randint(0, 15 - 1, [15, 1], dtype="int32") labels = np.random.randint(0, 15 - 1, [15, 1], dtype="int32")
softmax = paddle.to_variable(logits) softmax = paddle.to_tensor(logits)
labels = paddle.to_variable(labels) labels = paddle.to_tensor(labels)
fluid.layers.warpctc(input=softmax, label=labels) fluid.layers.warpctc(input=softmax, label=labels)

Loading…
Cancel
Save