fix dygraph trace bug, test=develop (#21193)

revert-21172-masked_select_api
Zeng Jinle 5 years ago committed by GitHub
parent 7269ffe3cc
commit 0f30d3a213
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -376,7 +376,7 @@ class TracedLayer(object):
if partial_vars is None: if partial_vars is None:
return all_vars return all_vars
return [all_vars[idx] for idx in feed] return [all_vars[idx] for idx in partial_vars]
with scope_guard(self._scope): with scope_guard(self._scope):
feeded_var_names = get_feed_fetch(self._feed_names, feed) feeded_var_names = get_feed_fetch(self._feed_names, feed)

@ -988,8 +988,9 @@ class TestDygraphTransformerSortGradient(unittest.TestCase):
if i % 2 == 0: if i % 2 == 0:
outs, traced_layer = TracedLayer.trace( outs, traced_layer = TracedLayer.trace(
transformer, [enc_inputs, dec_inputs, label, weights]) transformer, [enc_inputs, dec_inputs, label, weights])
outs_static = traced_layer(enc_inputs + dec_inputs +
[label, weights]) ins_static = enc_inputs + dec_inputs + [label, weights]
outs_static = traced_layer(ins_static)
helper.assertEachVar(outs, outs_static) helper.assertEachVar(outs, outs_static)
if program is not None: if program is not None:
self.assertTrue( self.assertTrue(
@ -997,7 +998,9 @@ class TestDygraphTransformerSortGradient(unittest.TestCase):
program = traced_layer.program program = traced_layer.program
traced_layer.save_inference_model( traced_layer.save_inference_model(
'./infer_imperative_transformer') './infer_imperative_transformer',
feed=range(len(ins_static)),
fetch=range(len(outs_static)))
else: else:
outs = transformer(enc_inputs, dec_inputs, label, weights) outs = transformer(enc_inputs, dec_inputs, label, weights)

Loading…
Cancel
Save