1. update test_split_var: replace split with slice

wangkuiyi-patch-1
minqiyang 7 years ago
parent b33ea7be2d
commit eacac49bcd

@ -14,14 +14,14 @@
import math import math
import unittest import unittest
from paddle.fluid.transpiler.distribute_transpiler import split_variable from paddle.fluid.transpiler.distribute_transpiler import slice_variable
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
import random import random
class TestSplitVar(unittest.TestCase): class TestSliceVar(unittest.TestCase):
def check_split_output(self, shapes, expected_sizes, min_size): def check_slice_output(self, shapes, expected_sizes, min_size):
var_list = [] var_list = []
program = fluid.Program() program = fluid.Program()
for shape in shapes: for shape in shapes:
@ -31,7 +31,7 @@ class TestSplitVar(unittest.TestCase):
# dtype=core.VarDesc.VarType.LOD_TENSOR, # dtype=core.VarDesc.VarType.LOD_TENSOR,
shape=shape) shape=shape)
var_list.append(var) var_list.append(var)
blocks = split_variable(var_list, 10, min_size) blocks = slice_variable(var_list, 10, min_size)
all_sizes = [] all_sizes = []
for s in expected_sizes: for s in expected_sizes:
for s2 in s: for s2 in s:
@ -49,7 +49,7 @@ class TestSplitVar(unittest.TestCase):
[1150, 1150, 1150, 1150, 1150, 1150, 1100] [1150, 1150, 1150, 1150, 1150, 1150, 1100]
] ]
self.check_split_output(shapes, expected_sizes, 1024) self.check_slice_output(shapes, expected_sizes, 1024)
def test_check_output_8k(self): def test_check_output_8k(self):
shapes = [[3, 5], [1024], [28, 784], [8, 1020], [800, 10], shapes = [[3, 5], [1024], [28, 784], [8, 1020], [800, 10],
@ -57,7 +57,7 @@ class TestSplitVar(unittest.TestCase):
expected_sizes = [[15], [1024], [10976, 10976], [8160], [8000], expected_sizes = [[15], [1024], [10976, 10976], [8160], [8000],
[35937, 35937, 35937, 35937, 35937, 35937]] [35937, 35937, 35937, 35937, 35937, 35937]]
self.check_split_output(shapes, expected_sizes, 8192) self.check_slice_output(shapes, expected_sizes, 8192)
if __name__ == '__main__': if __name__ == '__main__':
Loading…
Cancel
Save