parent
7923d7271f
commit
316636404f
@ -0,0 +1,118 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
from test_reorder_lod_tensor import convert_to_offset
|
||||
from test_seq_pool import compute_seqpool_sum, compute_seqpool_avg, compute_seqpool_sqrt
|
||||
|
||||
|
||||
class TestFusionSeqPoolConcatOp(OpTest):
|
||||
def setUp(self):
|
||||
self.w = 11
|
||||
self.lods = [[[2, 3, 5]], [[1, 5, 2]]]
|
||||
self.set_conf()
|
||||
self.set_pooltype()
|
||||
self.op_type = 'fusion_seqpool_concat'
|
||||
self.axis = 1
|
||||
bs = len(self.lods[0][0])
|
||||
inputs = []
|
||||
outs = []
|
||||
i = 0
|
||||
for lod in self.lods:
|
||||
assert bs == len(lod[0]), 'All lod size should be equal'
|
||||
x = np.random.uniform(0.1, 1,
|
||||
[sum(lod[0]), self.w]).astype('float32')
|
||||
offset = convert_to_offset(lod)
|
||||
out = np.zeros((bs, self.w)).astype('float32')
|
||||
if self.pooltype == "SUM":
|
||||
compute_seqpool_sum(x, offset, out)
|
||||
elif self.pooltype == "AVERAGE":
|
||||
compute_seqpool_avg(x, offset, out)
|
||||
elif self.pooltype == "SQRT":
|
||||
compute_seqpool_sqrt(x, offset, out)
|
||||
else:
|
||||
raise Exception("Unsupported pool type!")
|
||||
inputs.append(('x_{0}'.format(i), (x, lod)))
|
||||
outs.append(out)
|
||||
i = i + 1
|
||||
|
||||
self.inputs = {'X': inputs}
|
||||
self.outputs = {'Out': np.concatenate(outs, axis=self.axis)}
|
||||
self.attrs = {
|
||||
'pooltype': self.pooltype,
|
||||
'axis': self.axis,
|
||||
}
|
||||
|
||||
def set_pooltype(self):
|
||||
self.pooltype = "SUM"
|
||||
|
||||
def set_conf(self):
|
||||
pass
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
class TestFusionSeqPoolConcatOpCase1(TestFusionSeqPoolConcatOp):
|
||||
def set_conf(self):
|
||||
self.lods = [[[1]]]
|
||||
|
||||
|
||||
class TestFusionSeqPoolConcatOpCase2(TestFusionSeqPoolConcatOp):
|
||||
def set_conf(self):
|
||||
self.lods = [[[1]], [[1]], [[1]]]
|
||||
|
||||
|
||||
class TestFusionSeqPoolConcatOpCase3(TestFusionSeqPoolConcatOp):
|
||||
def set_conf(self):
|
||||
self.lods = [[[1, 3, 4, 6]]]
|
||||
self.w = 10
|
||||
|
||||
|
||||
class TestFusionSeqPoolConcatOpCase4(TestFusionSeqPoolConcatOp):
|
||||
def set_conf(self):
|
||||
self.lods = [[[2, 13, 4]], [[1, 1, 1]], [[5, 3, 1]], [[9, 10, 3]]]
|
||||
self.w = 3
|
||||
|
||||
|
||||
## test avg pool and sqrt
|
||||
def create_test_avg_sqrt_class(parent):
|
||||
class TestSeqPoolAvgCase(parent):
|
||||
def set_pooltype(self):
|
||||
self.pooltype = "AVERAGE"
|
||||
|
||||
class TestSeqPoolSqrtCase(parent):
|
||||
def set_pooltype(self):
|
||||
self.pooltype = "SQRT"
|
||||
|
||||
cls_name_avg = "{0}_{1}".format(parent.__name__, "avg")
|
||||
cls_name_sqrt = "{0}_{1}".format(parent.__name__, "sqrt")
|
||||
TestSeqPoolAvgCase.__name__ = cls_name_avg
|
||||
TestSeqPoolSqrtCase.__name__ = cls_name_sqrt
|
||||
globals()[cls_name_avg] = TestSeqPoolAvgCase
|
||||
globals()[cls_name_sqrt] = TestSeqPoolSqrtCase
|
||||
|
||||
|
||||
create_test_avg_sqrt_class(TestFusionSeqPoolConcatOp)
|
||||
create_test_avg_sqrt_class(TestFusionSeqPoolConcatOpCase1)
|
||||
create_test_avg_sqrt_class(TestFusionSeqPoolConcatOpCase2)
|
||||
create_test_avg_sqrt_class(TestFusionSeqPoolConcatOpCase3)
|
||||
create_test_avg_sqrt_class(TestFusionSeqPoolConcatOpCase4)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue