Add friendly Error message in save_inference_model (#25617)

fix_copy_if_different
Aurelius84 5 years ago committed by GitHub
parent ca1185d06b
commit dfe4e67e7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1064,6 +1064,13 @@ def prepend_feed_ops(inference_program,
persistable=True)
for i, name in enumerate(feed_target_names):
if not global_block.has_var(name):
raise ValueError(
"The feeded_var_names[{i}]: '{name}' doesn't exist in pruned inference program. "
"Please check whether '{name}' is a valid feed_var name, or remove it from feeded_var_names "
"if '{name}' is not involved in the target_vars calculation.".
format(
i=i, name=name))
out = global_block.var(name)
global_block._prepend_op(
type='feed',

@ -48,5 +48,26 @@ class TestSaveLoadAPIError(unittest.TestCase):
vars="vars")
class TestSaveInferenceModelAPIError(unittest.TestCase):
def test_useless_feeded_var_names(self):
start_prog = fluid.Program()
main_prog = fluid.Program()
with fluid.program_guard(main_prog, start_prog):
x = fluid.data(name='x', shape=[10, 16], dtype='float32')
y = fluid.data(name='y', shape=[10, 16], dtype='float32')
z = fluid.layers.fc(x, 4)
exe = fluid.Executor(fluid.CPUPlace())
exe.run(start_prog)
with self.assertRaisesRegexp(
ValueError, "not involved in the target_vars calculation"):
fluid.io.save_inference_model(
dirname='./model',
feeded_var_names=['x', 'y'],
target_vars=[z],
executor=exe,
main_program=main_prog)
if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save