|
|
|
@ -80,9 +80,9 @@ def test_double_star_graph():
|
|
|
|
|
|
|
|
|
|
_executor.compile(net, x, y, z, w, phase='train')
|
|
|
|
|
strategies = _executor._get_strategy(net)
|
|
|
|
|
expected_strategies = {'Default/network-Net/Cast-op1': [[8, 1]],
|
|
|
|
|
'Default/network-Net/Cast-op3': [[1, 8]],
|
|
|
|
|
'Default/network-Net/MatMul-op2': [[8, 1], [1, 1]],
|
|
|
|
|
expected_strategies = {'Default/network-Net/Cast-op0': [[8, 1]],
|
|
|
|
|
'Default/network-Net/Cast-op1': [[1, 8]],
|
|
|
|
|
'Default/network-Net/MatMul-op3': [[8, 1], [1, 1]],
|
|
|
|
|
'Default/network-Net/MatMul-op4': [[1, 1], [1, 8]],
|
|
|
|
|
'Default/network-Net/MatMul-op0': [[1, 8], [8, 1]]}
|
|
|
|
|
'Default/network-Net/MatMul-op2': [[1, 8], [8, 1]]}
|
|
|
|
|
assert strategies == expected_strategies
|
|
|
|
|