|
|
|
@ -988,8 +988,9 @@ class TestDygraphTransformerSortGradient(unittest.TestCase):
|
|
|
|
|
if i % 2 == 0:
|
|
|
|
|
outs, traced_layer = TracedLayer.trace(
|
|
|
|
|
transformer, [enc_inputs, dec_inputs, label, weights])
|
|
|
|
|
outs_static = traced_layer(enc_inputs + dec_inputs +
|
|
|
|
|
[label, weights])
|
|
|
|
|
|
|
|
|
|
ins_static = enc_inputs + dec_inputs + [label, weights]
|
|
|
|
|
outs_static = traced_layer(ins_static)
|
|
|
|
|
helper.assertEachVar(outs, outs_static)
|
|
|
|
|
if program is not None:
|
|
|
|
|
self.assertTrue(
|
|
|
|
@ -997,7 +998,9 @@ class TestDygraphTransformerSortGradient(unittest.TestCase):
|
|
|
|
|
|
|
|
|
|
program = traced_layer.program
|
|
|
|
|
traced_layer.save_inference_model(
|
|
|
|
|
'./infer_imperative_transformer')
|
|
|
|
|
'./infer_imperative_transformer',
|
|
|
|
|
feed=range(len(ins_static)),
|
|
|
|
|
fetch=range(len(outs_static)))
|
|
|
|
|
else:
|
|
|
|
|
outs = transformer(enc_inputs, dec_inputs, label, weights)
|
|
|
|
|
|
|
|
|
|