|
|
|
@ -24,6 +24,7 @@ make_tuple = Primitive('make_tuple')
|
|
|
|
|
tuple_getitem = Primitive('tuple_getitem')
|
|
|
|
|
depend = Primitive('depend')
|
|
|
|
|
BatchNorm = P.BatchNorm()
|
|
|
|
|
Cast = P.Cast()
|
|
|
|
|
BNTrainingReduce = Primitive('BNTrainingReduce')
|
|
|
|
|
BNTrainingUpdate = Primitive('BNTrainingUpdate')
|
|
|
|
|
constant0 = Tensor(0.1, mstype.float32)
|
|
|
|
@ -59,6 +60,21 @@ def test_fused_batch_norm_fusion(tag):
|
|
|
|
|
output = tuple_getitem(outputs, 0)
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
@fns
|
|
|
|
|
def before_mix_precision(input0, input1, input2, input3, input4, var0, var1):
|
|
|
|
|
batch_norm = BatchNorm(input0, input1, input2, input3, input4)
|
|
|
|
|
sub0 = Sub(Cast(var0, mstype.float32), tuple_getitem(batch_norm, 1))
|
|
|
|
|
sub1 = Sub(Cast(var1, mstype.float32), tuple_getitem(batch_norm, 2))
|
|
|
|
|
mul0 = Mul(sub0, constant0)
|
|
|
|
|
mul1 = Mul(sub1, constant1)
|
|
|
|
|
assign_sub0 = AssignSub(var0, Cast(mul0, mstype.float32))
|
|
|
|
|
assign_sub1 = AssignSub(var1, Cast(mul1, mstype.float32))
|
|
|
|
|
depend0 = depend(tuple_getitem(batch_norm, 0), assign_sub0)
|
|
|
|
|
depend1 = depend(depend0, assign_sub1)
|
|
|
|
|
outputs = make_tuple(depend1, tuple_getitem(batch_norm, 3), tuple_getitem(batch_norm, 4))
|
|
|
|
|
output = tuple_getitem(outputs, 0)
|
|
|
|
|
return output
|
|
|
|
|
|
|
|
|
|
@fns
|
|
|
|
|
def after(input0, input1, input2, input3, input4, var0, var1):
|
|
|
|
|
bn_training_reduce = BNTrainingReduce(input0)
|
|
|
|
|