|
|
|
@ -23,8 +23,7 @@ from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_
|
|
|
|
|
cancel_new_parameter
|
|
|
|
|
from mindspore.common.api import _generate_pip_args
|
|
|
|
|
from mindspore._c_expression import generate_key, Executor_
|
|
|
|
|
from mindspore.graph_utils.graph_pattern import IsIn, IsPrimTypeOf, CallWith, IsNot, AnyPattern, NewTensor,\
|
|
|
|
|
NewParameter, Imm
|
|
|
|
|
from mindspore.graph_utils.graph_pattern import OneOf, Prim, Call, NoneOf, Any, NewTensor, NewParameter, Imm
|
|
|
|
|
|
|
|
|
|
context.set_context(mode=context.GRAPH_MODE)
|
|
|
|
|
|
|
|
|
@ -50,11 +49,9 @@ def test_softmax_relu():
|
|
|
|
|
|
|
|
|
|
@registe_pass(run_only_once=True)
|
|
|
|
|
def softmax_relu_pass():
|
|
|
|
|
x = AnyPattern()
|
|
|
|
|
softmax_pattern = IsPrimTypeOf(P.Softmax())
|
|
|
|
|
pattern = CallWith(softmax_pattern, inputs=[x])
|
|
|
|
|
relu_pattern = IsPrimTypeOf(P.ReLU(), should_replace=False)
|
|
|
|
|
target = CallWith(relu_pattern, inputs=[x])
|
|
|
|
|
x = Any()
|
|
|
|
|
pattern = Call(P.Softmax(), inputs=[x])
|
|
|
|
|
target = Call(P.ReLU(), inputs=[x])
|
|
|
|
|
return pattern, target
|
|
|
|
|
|
|
|
|
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
|
|
|
@ -74,13 +71,13 @@ def test_softmax_relu_sigmoid():
|
|
|
|
|
|
|
|
|
|
@registe_pass(run_only_once=True)
|
|
|
|
|
def softmax_relu_pass():
|
|
|
|
|
x = AnyPattern()
|
|
|
|
|
softmax_pattern = IsPrimTypeOf(P.Softmax())
|
|
|
|
|
pattern = CallWith(softmax_pattern, inputs=[x])
|
|
|
|
|
sigmoid_pattern = IsPrimTypeOf(P.Sigmoid(), should_replace=False)
|
|
|
|
|
call_sigmoid = CallWith(sigmoid_pattern, [x])
|
|
|
|
|
relu_pattern = IsPrimTypeOf(P.ReLU(), should_replace=False)
|
|
|
|
|
target = CallWith(relu_pattern, inputs=[call_sigmoid])
|
|
|
|
|
x = Any()
|
|
|
|
|
softmax_pattern = Prim(P.Softmax())
|
|
|
|
|
pattern = Call(softmax_pattern, inputs=[x])
|
|
|
|
|
sigmoid_pattern = Prim(P.Sigmoid())
|
|
|
|
|
call_sigmoid = Call(sigmoid_pattern, [x])
|
|
|
|
|
relu_pattern = Prim(P.ReLU())
|
|
|
|
|
target = Call(relu_pattern, inputs=[call_sigmoid])
|
|
|
|
|
return pattern, target
|
|
|
|
|
|
|
|
|
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3)
|
|
|
|
@ -99,15 +96,15 @@ def test_isin_pattern_0():
|
|
|
|
|
|
|
|
|
|
@registe_pass(run_only_once=True)
|
|
|
|
|
def softmax_relu_pass():
|
|
|
|
|
x = AnyPattern()
|
|
|
|
|
softmax_pattern = IsPrimTypeOf(P.Softmax())
|
|
|
|
|
call_softmax = CallWith(softmax_pattern, inputs=[x])
|
|
|
|
|
relu_pattern = IsPrimTypeOf(P.ReLU())
|
|
|
|
|
call_relu = CallWith(relu_pattern, inputs=[x])
|
|
|
|
|
|
|
|
|
|
pattern = IsIn([call_softmax, call_relu])
|
|
|
|
|
relu6_pattern = IsPrimTypeOf(P.ReLU6(), should_replace=False)
|
|
|
|
|
target = CallWith(relu6_pattern, inputs=[x])
|
|
|
|
|
x = Any()
|
|
|
|
|
softmax_pattern = Prim(P.Softmax())
|
|
|
|
|
call_softmax = Call(softmax_pattern, inputs=[x])
|
|
|
|
|
relu_pattern = Prim(P.ReLU())
|
|
|
|
|
call_relu = Call(relu_pattern, inputs=[x])
|
|
|
|
|
|
|
|
|
|
pattern = OneOf([call_softmax, call_relu])
|
|
|
|
|
relu6_pattern = Prim(P.ReLU6())
|
|
|
|
|
target = Call(relu6_pattern, inputs=[x])
|
|
|
|
|
return pattern, target
|
|
|
|
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
|
|
|
|
unregiste_pass(softmax_relu_pass)
|
|
|
|
@ -123,18 +120,17 @@ def test_isin_pattern_1():
|
|
|
|
|
|
|
|
|
|
@registe_pass(run_only_once=True)
|
|
|
|
|
def softmax_neg_pass():
|
|
|
|
|
x = AnyPattern()
|
|
|
|
|
softmax_pattern = IsPrimTypeOf(P.Softmax())
|
|
|
|
|
call_softmax = CallWith(softmax_pattern, inputs=[x])
|
|
|
|
|
relu_pattern = IsPrimTypeOf(P.ReLU())
|
|
|
|
|
call_relu = CallWith(relu_pattern, inputs=[x])
|
|
|
|
|
|
|
|
|
|
pattern = IsIn([call_softmax, call_relu])
|
|
|
|
|
neg_ops = IsPrimTypeOf(P.Neg(), should_replace=False)
|
|
|
|
|
target = CallWith(neg_ops, inputs=[pattern])
|
|
|
|
|
x = Any()
|
|
|
|
|
softmax_pattern = Prim(P.Softmax())
|
|
|
|
|
call_softmax = Call(softmax_pattern, inputs=[x])
|
|
|
|
|
relu_pattern = Prim(P.ReLU())
|
|
|
|
|
call_relu = Call(relu_pattern, inputs=[x])
|
|
|
|
|
|
|
|
|
|
pattern = OneOf([call_softmax, call_relu])
|
|
|
|
|
neg_ops = Prim(P.Neg())
|
|
|
|
|
target = Call(neg_ops, inputs=[pattern])
|
|
|
|
|
return pattern, target
|
|
|
|
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(4)
|
|
|
|
|
print(transformed_repr)
|
|
|
|
|
unregiste_pass(softmax_neg_pass)
|
|
|
|
|
assert "Neg" in transformed_repr
|
|
|
|
|
assert "Softmax" in transformed_repr
|
|
|
|
@ -167,11 +163,11 @@ def test_isnot_pattern_0():
|
|
|
|
|
"""
|
|
|
|
|
Sub a BN which does NOT take Conv as inputs to ReLU6.
|
|
|
|
|
"""
|
|
|
|
|
conv2d_prim = IsPrimTypeOf("Conv2D")
|
|
|
|
|
conv2d = CallWith(conv2d_prim)
|
|
|
|
|
pattern_0 = IsNot(conv2d)
|
|
|
|
|
pattern = CallWith(P.BatchNorm(), inputs=[pattern_0])
|
|
|
|
|
target = CallWith(P.ReLU6(), inputs=[pattern_0])
|
|
|
|
|
conv2d_prim = Prim("Conv2D")
|
|
|
|
|
conv2d = Call(conv2d_prim)
|
|
|
|
|
pattern_0 = NoneOf(conv2d)
|
|
|
|
|
pattern = Call(P.BatchNorm(), inputs=[pattern_0])
|
|
|
|
|
target = Call(P.ReLU6(), inputs=[pattern_0])
|
|
|
|
|
return pattern, target
|
|
|
|
|
|
|
|
|
|
@registe_pass(run_only_once=True)
|
|
|
|
@ -179,10 +175,8 @@ def test_isnot_pattern_0():
|
|
|
|
|
"""
|
|
|
|
|
Sub a BN to Softmax.
|
|
|
|
|
"""
|
|
|
|
|
bn = P.BatchNorm()
|
|
|
|
|
pattern = CallWith(bn)
|
|
|
|
|
softmax = P.Softmax()
|
|
|
|
|
target = CallWith(softmax, should_replace=False)
|
|
|
|
|
pattern = Call(P.BatchNorm())
|
|
|
|
|
target = Call(P.Softmax())
|
|
|
|
|
return pattern, target
|
|
|
|
|
|
|
|
|
|
transformed_repr = get_func_graph(conv_bn_model, inputs).get_return().expanded_str(5)
|
|
|
|
@ -205,12 +199,12 @@ def test_isnot_pattern_1():
|
|
|
|
|
"""
|
|
|
|
|
Sub a BN which does NOT take MatMul as inputs to ReLU6.
|
|
|
|
|
"""
|
|
|
|
|
matmul = IsPrimTypeOf("MatMul")
|
|
|
|
|
pattern_0 = IsNot(matmul)
|
|
|
|
|
matmul = Prim("MatMul")
|
|
|
|
|
pattern_0 = NoneOf(matmul)
|
|
|
|
|
softmax = P.Softmax()
|
|
|
|
|
pattern = CallWith(softmax, inputs=[pattern_0])
|
|
|
|
|
pattern = Call(softmax, inputs=[pattern_0])
|
|
|
|
|
relu6 = P.ReLU6()
|
|
|
|
|
target = CallWith(relu6, inputs=[pattern_0], should_replace=False)
|
|
|
|
|
target = Call(relu6, inputs=[pattern_0])
|
|
|
|
|
return pattern, target
|
|
|
|
|
|
|
|
|
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
|
|
|
@ -228,14 +222,12 @@ def test_newtensor_pattern():
|
|
|
|
|
|
|
|
|
|
@registe_pass(run_only_once=True)
|
|
|
|
|
def softmax_addn_pass():
|
|
|
|
|
x = AnyPattern()
|
|
|
|
|
softmax = P.Softmax()
|
|
|
|
|
pattern = CallWith(softmax, inputs=[x])
|
|
|
|
|
x = Any()
|
|
|
|
|
pattern = Call(P.Softmax(), inputs=[x])
|
|
|
|
|
|
|
|
|
|
weight_tensor = Tensor(np.zeros([42]), mindspore.float16)
|
|
|
|
|
new_weight = NewTensor(weight_tensor)
|
|
|
|
|
addn_ops = P.AddN()
|
|
|
|
|
target = CallWith(addn_ops, inputs=[x, new_weight], should_replace=False)
|
|
|
|
|
target = Call(P.AddN(), inputs=[x, new_weight])
|
|
|
|
|
return pattern, target
|
|
|
|
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
|
|
|
|
|
unregiste_pass(softmax_addn_pass)
|
|
|
|
@ -252,25 +244,23 @@ def test_newparameter_pattern():
|
|
|
|
|
|
|
|
|
|
@registe_pass(run_only_once=True)
|
|
|
|
|
def softmax_addn_pass():
|
|
|
|
|
x = AnyPattern()
|
|
|
|
|
softmax = P.Softmax()
|
|
|
|
|
pattern = CallWith(softmax, inputs=[x])
|
|
|
|
|
x = Any()
|
|
|
|
|
pattern = Call(P.Softmax(), inputs=[x])
|
|
|
|
|
|
|
|
|
|
default_tensor0 = Tensor(np.ones((4, 4)), mindspore.float32)
|
|
|
|
|
default_tensor1 = Tensor(np.ones((4, 4)), mindspore.float32)
|
|
|
|
|
new_para_0 = NewParameter("Merlin", default_tensor0)
|
|
|
|
|
new_para_1 = NewParameter("Arthur", default_tensor1)
|
|
|
|
|
target_0 = CallWith(P.MatMul(), inputs=[new_para_0, new_para_1], should_replace=False)
|
|
|
|
|
target = CallWith("make_tuple", inputs=[target_0], should_replace=False)
|
|
|
|
|
target_0 = Call(P.MatMul(), inputs=[new_para_0, new_para_1])
|
|
|
|
|
target = Call("make_tuple", inputs=[target_0])
|
|
|
|
|
return pattern, target
|
|
|
|
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
|
|
|
|
print(transformed_repr)
|
|
|
|
|
unregiste_pass(softmax_addn_pass)
|
|
|
|
|
assert "MatMul" in transformed_repr
|
|
|
|
|
assert "make_tuple" in transformed_repr
|
|
|
|
|
assert "Softmax" not in transformed_repr
|
|
|
|
|
|
|
|
|
|
def test_imm_pattern():
|
|
|
|
|
def test_imm_target():
|
|
|
|
|
"""
|
|
|
|
|
Test NewParameter pattern in the target
|
|
|
|
|
"""
|
|
|
|
@ -278,17 +268,15 @@ def test_imm_pattern():
|
|
|
|
|
softmax_model = nn.Softmax()
|
|
|
|
|
|
|
|
|
|
@registe_pass(run_only_once=True)
|
|
|
|
|
def softmax_addn_pass():
|
|
|
|
|
x = AnyPattern()
|
|
|
|
|
softmax = P.Softmax()
|
|
|
|
|
pattern = CallWith(softmax, inputs=[x])
|
|
|
|
|
def softmax_pass():
|
|
|
|
|
x = Any()
|
|
|
|
|
pattern = Call(P.Softmax(), inputs=[x])
|
|
|
|
|
imm = Imm(0)
|
|
|
|
|
target_0 = CallWith("make_tuple", inputs=[pattern], should_replace=False)
|
|
|
|
|
target = CallWith("tuple_getitem", inputs=[target_0, imm], should_replace=False)
|
|
|
|
|
target_0 = Call("make_tuple", inputs=[pattern])
|
|
|
|
|
target = Call("tuple_getitem", inputs=[target_0, imm])
|
|
|
|
|
return pattern, target
|
|
|
|
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
|
|
|
|
print(transformed_repr)
|
|
|
|
|
unregiste_pass(softmax_addn_pass)
|
|
|
|
|
unregiste_pass(softmax_pass)
|
|
|
|
|
assert "make_tuple" in transformed_repr
|
|
|
|
|
assert "tuple_getitem" in transformed_repr
|
|
|
|
|
assert "Softmax" in transformed_repr
|
|
|
|
@ -301,21 +289,19 @@ def test_gen_new_parameter():
|
|
|
|
|
softmax_model = nn.Softmax()
|
|
|
|
|
|
|
|
|
|
default_tensor = Tensor(np.ones((4, 4)), mindspore.float32)
|
|
|
|
|
new_para = NewParameter("Merlin", default_tensor, should_replace=True)
|
|
|
|
|
new_para = NewParameter("Merlin", default_tensor)
|
|
|
|
|
gen_new_parameter(new_para)
|
|
|
|
|
@registe_pass(run_only_once=True)
|
|
|
|
|
def softmax_make_tuple_pass():
|
|
|
|
|
x = AnyPattern()
|
|
|
|
|
x = Any()
|
|
|
|
|
softmax = P.Softmax()
|
|
|
|
|
pattern = CallWith(softmax, inputs=[x])
|
|
|
|
|
pattern = Call(softmax, inputs=[x])
|
|
|
|
|
|
|
|
|
|
target = CallWith("make_tuple", inputs=[pattern, new_para], should_replace=False)
|
|
|
|
|
target = Call("make_tuple", inputs=[pattern, new_para])
|
|
|
|
|
return pattern, target
|
|
|
|
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
|
|
|
|
print(transformed_repr)
|
|
|
|
|
assert "Merlin" in transformed_repr
|
|
|
|
|
unregiste_pass(softmax_make_tuple_pass)
|
|
|
|
|
cancel_new_parameter(new_para)
|
|
|
|
|
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
|
|
|
|
|
print(transformed_repr)
|
|
|
|
|
assert "Merlin" not in transformed_repr
|
|
|
|
|