|
|
@ -48,8 +48,8 @@ def test_get_parameter_layout():
|
|
|
|
net.set_auto_parallel()
|
|
|
|
net.set_auto_parallel()
|
|
|
|
exe = me._executor
|
|
|
|
exe = me._executor
|
|
|
|
exe.compile(net, x, auto_parallel_mode=True)
|
|
|
|
exe.compile(net, x, auto_parallel_mode=True)
|
|
|
|
x_layout = ([2, 4], [1, -1]) # device_arrangement = [2, 4], tensor_map = [1, -1]
|
|
|
|
x_layout = [[2, 4], [1, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [1, -1]
|
|
|
|
weight_layout = ([2, 4], [0, -1]) # device_arrangement = [2, 4], tensor_map = [0, -1]
|
|
|
|
weight_layout = [[2, 4], [0, -1], [16, 32]] # device_arrangement = [2, 4], tensor_map = [0, -1]
|
|
|
|
expect_dict = {'x': x_layout, 'w1': weight_layout}
|
|
|
|
expect_dict = {'x': x_layout, 'w1': weight_layout}
|
|
|
|
# to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut
|
|
|
|
# to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut
|
|
|
|
assert (net.parameter_layout_dict == expect_dict)
|
|
|
|
assert (net.parameter_layout_dict == expect_dict)
|
|
|
|