1. Remove 'top 1'(or argmax) from CPU and GPU kernel 2. Add a new test case 3. Refine docadd_depthwiseConv_op_gpu
parent
579f684661
commit
281e93bcbb
@ -0,0 +1,62 @@
|
|||||||
|
import sys
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
from op_test import OpTest
|
||||||
|
from test_softmax_op import stable_softmax
|
||||||
|
|
||||||
|
|
||||||
|
def CTCDecode(input, lod, blank, merge_repeated):
|
||||||
|
lod0 = lod[0]
|
||||||
|
result = []
|
||||||
|
for i in range(len(lod0) - 1):
|
||||||
|
prev_token = -1
|
||||||
|
for j in range(lod0[i], lod0[i + 1]):
|
||||||
|
token = input[j][0]
|
||||||
|
if (token != blank) and not (merge_repeated and
|
||||||
|
token == prev_token):
|
||||||
|
result.append(token)
|
||||||
|
prev_token = token
|
||||||
|
result = np.array(result).reshape([len(result), 1]).astype("int32")
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class TestCTCDecodeOp(OpTest):
|
||||||
|
def config(self):
|
||||||
|
self.op_type = "ctc_greedy_decode"
|
||||||
|
self.input_lod = [[0, 11, 18]]
|
||||||
|
self.blank = 0
|
||||||
|
self.merge_repeated = False
|
||||||
|
self.input = np.array(
|
||||||
|
[0, 1, 2, 2, 0, 4, 0, 4, 5, 0, 6, 6, 0, 0, 7, 7, 7, 0]).reshape(
|
||||||
|
[18, 1]).astype("int32")
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.config()
|
||||||
|
output = CTCDecode(self.input, self.input_lod, self.blank,
|
||||||
|
self.merge_repeated)
|
||||||
|
|
||||||
|
self.inputs = {"Input": (self.input, self.input_lod), }
|
||||||
|
self.outputs = {"Output": output}
|
||||||
|
self.attrs = {
|
||||||
|
"blank": self.blank,
|
||||||
|
"merge_repeated": self.merge_repeated
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_check_output(self):
|
||||||
|
self.check_output()
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestCTCDecodeOpCase1(TestCTCDecodeOp):
|
||||||
|
def config(self):
|
||||||
|
self.op_type = "ctc_greedy_decode"
|
||||||
|
self.input_lod = [[0, 11, 18]]
|
||||||
|
self.blank = 0
|
||||||
|
self.merge_repeated = True
|
||||||
|
self.input = np.array(
|
||||||
|
[0, 1, 2, 2, 0, 4, 0, 4, 5, 0, 6, 6, 0, 0, 7, 7, 7, 0]).reshape(
|
||||||
|
[18, 1]).astype("int32")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
@ -1,56 +0,0 @@
|
|||||||
import sys
|
|
||||||
import unittest
|
|
||||||
import numpy as np
|
|
||||||
from op_test import OpTest
|
|
||||||
from test_softmax_op import stable_softmax
|
|
||||||
|
|
||||||
|
|
||||||
def CTCGreedyDecode(softmax, blank, merge_repeated):
|
|
||||||
prev_token = -1
|
|
||||||
result = []
|
|
||||||
for token in np.argmax(softmax, axis=1):
|
|
||||||
if (token != blank) and not (merge_repeated and token == prev_token):
|
|
||||||
result.append(token)
|
|
||||||
return np.array(result).reshape([len(result), 1])
|
|
||||||
|
|
||||||
|
|
||||||
class TestCTCGreedyDecodeOp(OpTest):
|
|
||||||
def config(self):
|
|
||||||
self.op_type = "ctc_greedy_decode"
|
|
||||||
self.batch_size = 4
|
|
||||||
self.num_classes = 8
|
|
||||||
self.input_lod = [[0, 4, 5, 8, 11]]
|
|
||||||
self.blank = 7
|
|
||||||
self.merge_repeated = True
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
self.config()
|
|
||||||
input = np.random.uniform(
|
|
||||||
0.1, 1.0,
|
|
||||||
[self.input_lod[0][-1], self.num_classes]).astype("float32")
|
|
||||||
softmax = np.apply_along_axis(stable_softmax, 1, input)
|
|
||||||
output = CTCGreedyDecode(softmax, self.blank, self.merge_repeated)
|
|
||||||
|
|
||||||
self.inputs = {"Input": (softmax, self.input_lod), }
|
|
||||||
self.outputs = {"Output": output}
|
|
||||||
self.attrs = {
|
|
||||||
"blank": self.blank,
|
|
||||||
"merge_repeated": self.merge_repeated
|
|
||||||
}
|
|
||||||
|
|
||||||
def test_check_output(self):
|
|
||||||
self.check_output()
|
|
||||||
|
|
||||||
|
|
||||||
class TestCTCGreedyDecodeOpCase1(TestCTCGreedyDecodeOp):
|
|
||||||
def config(self):
|
|
||||||
self.op_type = "ctc_greedy_decode"
|
|
||||||
self.batch_size = 4
|
|
||||||
self.num_classes = 1025
|
|
||||||
self.input_lod = [[0, 4, 5, 8, 11]]
|
|
||||||
self.blank = 0
|
|
||||||
self.merge_repeated = True
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
Loading…
Reference in new issue