|
|
|
@ -80,9 +80,9 @@ def test_common_parameter():
|
|
|
|
|
|
|
|
|
|
_executor.compile(net, x, y, z, w, phase='train')
|
|
|
|
|
strategies = _executor._get_strategy(net)
|
|
|
|
|
expected_strategies = {'Default/network-Net/MatMul-op8': [[1, 1], [1, 8]],
|
|
|
|
|
'Default/network-Net/MatMul-op9': [[1, 1], [1, 8]],
|
|
|
|
|
'Default/network-Net/Cast-op10': [[1, 8]],
|
|
|
|
|
'Default/network-Net/MatMul-op0': [[1, 1], [1, 8]],
|
|
|
|
|
'Default/network-Net/Cast-op11': [[1, 8]]}
|
|
|
|
|
assert strategies == expected_strategies
|
|
|
|
|
expected_strategies = {'Default/network-Net/MatMul-op6': [[8, 1], [1, 1]],
|
|
|
|
|
'Default/network-Net/MatMul-op8': [[8, 1], [1, 1]],
|
|
|
|
|
'Default/network-Net/Cast-op7': [[1, 1]],
|
|
|
|
|
'Default/network-Net/MatMul-op0': [[8, 1], [1, 1]],
|
|
|
|
|
'Default/network-Net/Cast-op9': [[1, 1]]}
|
|
|
|
|
assert strategies == expected_strategies
|
|
|
|
|