parent
7923d7271f
commit
316636404f
@ -0,0 +1,118 @@
|
|||||||
|
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
from op_test import OpTest
|
||||||
|
from test_reorder_lod_tensor import convert_to_offset
|
||||||
|
from test_seq_pool import compute_seqpool_sum, compute_seqpool_avg, compute_seqpool_sqrt
|
||||||
|
|
||||||
|
|
||||||
|
class TestFusionSeqPoolConcatOp(OpTest):
|
||||||
|
def setUp(self):
|
||||||
|
self.w = 11
|
||||||
|
self.lods = [[[2, 3, 5]], [[1, 5, 2]]]
|
||||||
|
self.set_conf()
|
||||||
|
self.set_pooltype()
|
||||||
|
self.op_type = 'fusion_seqpool_concat'
|
||||||
|
self.axis = 1
|
||||||
|
bs = len(self.lods[0][0])
|
||||||
|
inputs = []
|
||||||
|
outs = []
|
||||||
|
i = 0
|
||||||
|
for lod in self.lods:
|
||||||
|
assert bs == len(lod[0]), 'All lod size should be equal'
|
||||||
|
x = np.random.uniform(0.1, 1,
|
||||||
|
[sum(lod[0]), self.w]).astype('float32')
|
||||||
|
offset = convert_to_offset(lod)
|
||||||
|
out = np.zeros((bs, self.w)).astype('float32')
|
||||||
|
if self.pooltype == "SUM":
|
||||||
|
compute_seqpool_sum(x, offset, out)
|
||||||
|
elif self.pooltype == "AVERAGE":
|
||||||
|
compute_seqpool_avg(x, offset, out)
|
||||||
|
elif self.pooltype == "SQRT":
|
||||||
|
compute_seqpool_sqrt(x, offset, out)
|
||||||
|
else:
|
||||||
|
raise Exception("Unsupported pool type!")
|
||||||
|
inputs.append(('x_{0}'.format(i), (x, lod)))
|
||||||
|
outs.append(out)
|
||||||
|
i = i + 1
|
||||||
|
|
||||||
|
self.inputs = {'X': inputs}
|
||||||
|
self.outputs = {'Out': np.concatenate(outs, axis=self.axis)}
|
||||||
|
self.attrs = {
|
||||||
|
'pooltype': self.pooltype,
|
||||||
|
'axis': self.axis,
|
||||||
|
}
|
||||||
|
|
||||||
|
def set_pooltype(self):
|
||||||
|
self.pooltype = "SUM"
|
||||||
|
|
||||||
|
def set_conf(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_check_output(self):
|
||||||
|
self.check_output()
|
||||||
|
|
||||||
|
|
||||||
|
class TestFusionSeqPoolConcatOpCase1(TestFusionSeqPoolConcatOp):
|
||||||
|
def set_conf(self):
|
||||||
|
self.lods = [[[1]]]
|
||||||
|
|
||||||
|
|
||||||
|
class TestFusionSeqPoolConcatOpCase2(TestFusionSeqPoolConcatOp):
|
||||||
|
def set_conf(self):
|
||||||
|
self.lods = [[[1]], [[1]], [[1]]]
|
||||||
|
|
||||||
|
|
||||||
|
class TestFusionSeqPoolConcatOpCase3(TestFusionSeqPoolConcatOp):
|
||||||
|
def set_conf(self):
|
||||||
|
self.lods = [[[1, 3, 4, 6]]]
|
||||||
|
self.w = 10
|
||||||
|
|
||||||
|
|
||||||
|
class TestFusionSeqPoolConcatOpCase4(TestFusionSeqPoolConcatOp):
|
||||||
|
def set_conf(self):
|
||||||
|
self.lods = [[[2, 13, 4]], [[1, 1, 1]], [[5, 3, 1]], [[9, 10, 3]]]
|
||||||
|
self.w = 3
|
||||||
|
|
||||||
|
|
||||||
|
## test avg pool and sqrt
|
||||||
|
def create_test_avg_sqrt_class(parent):
|
||||||
|
class TestSeqPoolAvgCase(parent):
|
||||||
|
def set_pooltype(self):
|
||||||
|
self.pooltype = "AVERAGE"
|
||||||
|
|
||||||
|
class TestSeqPoolSqrtCase(parent):
|
||||||
|
def set_pooltype(self):
|
||||||
|
self.pooltype = "SQRT"
|
||||||
|
|
||||||
|
cls_name_avg = "{0}_{1}".format(parent.__name__, "avg")
|
||||||
|
cls_name_sqrt = "{0}_{1}".format(parent.__name__, "sqrt")
|
||||||
|
TestSeqPoolAvgCase.__name__ = cls_name_avg
|
||||||
|
TestSeqPoolSqrtCase.__name__ = cls_name_sqrt
|
||||||
|
globals()[cls_name_avg] = TestSeqPoolAvgCase
|
||||||
|
globals()[cls_name_sqrt] = TestSeqPoolSqrtCase
|
||||||
|
|
||||||
|
|
||||||
|
create_test_avg_sqrt_class(TestFusionSeqPoolConcatOp)
|
||||||
|
create_test_avg_sqrt_class(TestFusionSeqPoolConcatOpCase1)
|
||||||
|
create_test_avg_sqrt_class(TestFusionSeqPoolConcatOpCase2)
|
||||||
|
create_test_avg_sqrt_class(TestFusionSeqPoolConcatOpCase3)
|
||||||
|
create_test_avg_sqrt_class(TestFusionSeqPoolConcatOpCase4)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue