1. Remove 'top 1'(or argmax) from CPU and GPU kernel 2. Add a new test case 3. Refine docadd_depthwiseConv_op_gpu
parent
579f684661
commit
281e93bcbb
@ -0,0 +1,62 @@
|
||||
import sys
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
from test_softmax_op import stable_softmax
|
||||
|
||||
|
||||
def CTCDecode(input, lod, blank, merge_repeated):
|
||||
lod0 = lod[0]
|
||||
result = []
|
||||
for i in range(len(lod0) - 1):
|
||||
prev_token = -1
|
||||
for j in range(lod0[i], lod0[i + 1]):
|
||||
token = input[j][0]
|
||||
if (token != blank) and not (merge_repeated and
|
||||
token == prev_token):
|
||||
result.append(token)
|
||||
prev_token = token
|
||||
result = np.array(result).reshape([len(result), 1]).astype("int32")
|
||||
return result
|
||||
|
||||
|
||||
class TestCTCDecodeOp(OpTest):
|
||||
def config(self):
|
||||
self.op_type = "ctc_greedy_decode"
|
||||
self.input_lod = [[0, 11, 18]]
|
||||
self.blank = 0
|
||||
self.merge_repeated = False
|
||||
self.input = np.array(
|
||||
[0, 1, 2, 2, 0, 4, 0, 4, 5, 0, 6, 6, 0, 0, 7, 7, 7, 0]).reshape(
|
||||
[18, 1]).astype("int32")
|
||||
|
||||
def setUp(self):
|
||||
self.config()
|
||||
output = CTCDecode(self.input, self.input_lod, self.blank,
|
||||
self.merge_repeated)
|
||||
|
||||
self.inputs = {"Input": (self.input, self.input_lod), }
|
||||
self.outputs = {"Output": output}
|
||||
self.attrs = {
|
||||
"blank": self.blank,
|
||||
"merge_repeated": self.merge_repeated
|
||||
}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
pass
|
||||
|
||||
|
||||
class TestCTCDecodeOpCase1(TestCTCDecodeOp):
|
||||
def config(self):
|
||||
self.op_type = "ctc_greedy_decode"
|
||||
self.input_lod = [[0, 11, 18]]
|
||||
self.blank = 0
|
||||
self.merge_repeated = True
|
||||
self.input = np.array(
|
||||
[0, 1, 2, 2, 0, 4, 0, 4, 5, 0, 6, 6, 0, 0, 7, 7, 7, 0]).reshape(
|
||||
[18, 1]).astype("int32")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -1,56 +0,0 @@
|
||||
import sys
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
from test_softmax_op import stable_softmax
|
||||
|
||||
|
||||
def CTCGreedyDecode(softmax, blank, merge_repeated):
|
||||
prev_token = -1
|
||||
result = []
|
||||
for token in np.argmax(softmax, axis=1):
|
||||
if (token != blank) and not (merge_repeated and token == prev_token):
|
||||
result.append(token)
|
||||
return np.array(result).reshape([len(result), 1])
|
||||
|
||||
|
||||
class TestCTCGreedyDecodeOp(OpTest):
|
||||
def config(self):
|
||||
self.op_type = "ctc_greedy_decode"
|
||||
self.batch_size = 4
|
||||
self.num_classes = 8
|
||||
self.input_lod = [[0, 4, 5, 8, 11]]
|
||||
self.blank = 7
|
||||
self.merge_repeated = True
|
||||
|
||||
def setUp(self):
|
||||
self.config()
|
||||
input = np.random.uniform(
|
||||
0.1, 1.0,
|
||||
[self.input_lod[0][-1], self.num_classes]).astype("float32")
|
||||
softmax = np.apply_along_axis(stable_softmax, 1, input)
|
||||
output = CTCGreedyDecode(softmax, self.blank, self.merge_repeated)
|
||||
|
||||
self.inputs = {"Input": (softmax, self.input_lod), }
|
||||
self.outputs = {"Output": output}
|
||||
self.attrs = {
|
||||
"blank": self.blank,
|
||||
"merge_repeated": self.merge_repeated
|
||||
}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
class TestCTCGreedyDecodeOpCase1(TestCTCGreedyDecodeOp):
|
||||
def config(self):
|
||||
self.op_type = "ctc_greedy_decode"
|
||||
self.batch_size = 4
|
||||
self.num_classes = 1025
|
||||
self.input_lod = [[0, 4, 5, 8, 11]]
|
||||
self.blank = 0
|
||||
self.merge_repeated = True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue