1. update test_split_var: replace split with slice

wangkuiyi-patch-1
minqiyang 7 years ago
parent b33ea7be2d
commit eacac49bcd

@ -14,14 +14,14 @@
import math
import unittest
from paddle.fluid.transpiler.distribute_transpiler import split_variable
from paddle.fluid.transpiler.distribute_transpiler import slice_variable
import paddle.fluid as fluid
import paddle.fluid.core as core
import random
class TestSplitVar(unittest.TestCase):
def check_split_output(self, shapes, expected_sizes, min_size):
class TestSliceVar(unittest.TestCase):
def check_slice_output(self, shapes, expected_sizes, min_size):
var_list = []
program = fluid.Program()
for shape in shapes:
@ -31,7 +31,7 @@ class TestSplitVar(unittest.TestCase):
# dtype=core.VarDesc.VarType.LOD_TENSOR,
shape=shape)
var_list.append(var)
blocks = split_variable(var_list, 10, min_size)
blocks = slice_variable(var_list, 10, min_size)
all_sizes = []
for s in expected_sizes:
for s2 in s:
@ -49,7 +49,7 @@ class TestSplitVar(unittest.TestCase):
[1150, 1150, 1150, 1150, 1150, 1150, 1100]
]
self.check_split_output(shapes, expected_sizes, 1024)
self.check_slice_output(shapes, expected_sizes, 1024)
def test_check_output_8k(self):
shapes = [[3, 5], [1024], [28, 784], [8, 1020], [800, 10],
@ -57,7 +57,7 @@ class TestSplitVar(unittest.TestCase):
expected_sizes = [[15], [1024], [10976, 10976], [8160], [8000],
[35937, 35937, 35937, 35937, 35937, 35937]]
self.check_split_output(shapes, expected_sizes, 8192)
self.check_slice_output(shapes, expected_sizes, 8192)
if __name__ == '__main__':
Loading…
Cancel
Save