fix dygraph trace bug, test=develop (#21193)

revert-21172-masked_select_api
Zeng Jinle 5 years ago committed by GitHub
parent 7269ffe3cc
commit 0f30d3a213
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -376,7 +376,7 @@ class TracedLayer(object):
if partial_vars is None:
return all_vars
return [all_vars[idx] for idx in feed]
return [all_vars[idx] for idx in partial_vars]
with scope_guard(self._scope):
feeded_var_names = get_feed_fetch(self._feed_names, feed)

@ -988,8 +988,9 @@ class TestDygraphTransformerSortGradient(unittest.TestCase):
if i % 2 == 0:
outs, traced_layer = TracedLayer.trace(
transformer, [enc_inputs, dec_inputs, label, weights])
outs_static = traced_layer(enc_inputs + dec_inputs +
[label, weights])
ins_static = enc_inputs + dec_inputs + [label, weights]
outs_static = traced_layer(ins_static)
helper.assertEachVar(outs, outs_static)
if program is not None:
self.assertTrue(
@ -997,7 +998,9 @@ class TestDygraphTransformerSortGradient(unittest.TestCase):
program = traced_layer.program
traced_layer.save_inference_model(
'./infer_imperative_transformer')
'./infer_imperative_transformer',
feed=range(len(ins_static)),
fetch=range(len(outs_static)))
else:
outs = transformer(enc_inputs, dec_inputs, label, weights)

Loading…
Cancel
Save