|
|
|
@ -122,7 +122,7 @@ class FusionGroupPassTestFP64(FusionGroupPassTest):
|
|
|
|
|
self.fused_op_type = "fusion_group"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FusionGroupPassTestFP16(FusionGroupPassTest):
|
|
|
|
|
class FusionGroupPassTestCastAndFP16(FusionGroupPassTest):
|
|
|
|
|
def build_program(self, dtype):
|
|
|
|
|
with fluid.program_guard(self.main_program, self.startup_program):
|
|
|
|
|
self.feed_vars = self._prepare_feed_vars([32, 128], dtype, 2)
|
|
|
|
@ -132,7 +132,7 @@ class FusionGroupPassTestFP16(FusionGroupPassTest):
|
|
|
|
|
|
|
|
|
|
# subgraph with 2 op nodes
|
|
|
|
|
tmp_0 = self.feed_vars[0] * self.feed_vars[1]
|
|
|
|
|
tmp_1 = layers.cast(tmp_0, dtype="float16")
|
|
|
|
|
tmp_1 = layers.softmax(layers.cast(tmp_0, dtype="float16"))
|
|
|
|
|
tmp_2 = layers.mul(tmp_0, self.feed_vars[2])
|
|
|
|
|
# subgraph with 4 op nodes
|
|
|
|
|
tmp_3 = layers.cast(tmp_2, dtype="float16")
|
|
|
|
@ -141,7 +141,7 @@ class FusionGroupPassTestFP16(FusionGroupPassTest):
|
|
|
|
|
|
|
|
|
|
self.append_gradients(tmp_5)
|
|
|
|
|
|
|
|
|
|
self.num_fused_ops = 3
|
|
|
|
|
self.num_fused_ops = 4
|
|
|
|
|
self.fetch_list = [tmp_5, self.grad(tmp_0)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|