|
|
|
@ -63,28 +63,23 @@ class TextTensorOperation(TensorOperation):
|
|
|
|
|
"""
|
|
|
|
|
Base class of Text Tensor Ops
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __call__(self, input_tensor):
|
|
|
|
|
if not isinstance(input_tensor, list):
|
|
|
|
|
input_list = [input_tensor]
|
|
|
|
|
else:
|
|
|
|
|
input_list = input_tensor
|
|
|
|
|
tensor_list = []
|
|
|
|
|
for tensor in input_list:
|
|
|
|
|
if not isinstance(tensor, str):
|
|
|
|
|
raise TypeError("Input should be string or list of strings, got {}.".format(type(tensor)))
|
|
|
|
|
tensor_list.append(cde.Tensor(np.asarray(tensor)))
|
|
|
|
|
def __call__(self, *tensor_list):
|
|
|
|
|
tensor_array = []
|
|
|
|
|
output_list = []
|
|
|
|
|
# Combine input tensor_list to a TensorRow
|
|
|
|
|
for input_tensor in tensor_list:
|
|
|
|
|
if not isinstance(input_tensor, (str, list)):
|
|
|
|
|
raise TypeError("Input should be string or list of strings, got {}.".format(type(input_tensor)))
|
|
|
|
|
tensor_array.append(cde.Tensor(np.asarray(input_tensor)))
|
|
|
|
|
callable_op = cde.Execute(self.parse())
|
|
|
|
|
output_list = callable_op(tensor_list)
|
|
|
|
|
output_list = callable_op(tensor_array)
|
|
|
|
|
for i, element in enumerate(output_list):
|
|
|
|
|
arr = element.as_array()
|
|
|
|
|
if arr.dtype.char == 'S':
|
|
|
|
|
output_list[i] = to_str(arr)
|
|
|
|
|
output_list[i] = np.char.decode(arr)
|
|
|
|
|
else:
|
|
|
|
|
output_list[i] = arr
|
|
|
|
|
if not isinstance(input_tensor, list) and len(output_list) == 1:
|
|
|
|
|
output_list = output_list[0]
|
|
|
|
|
return output_list
|
|
|
|
|
return output_list[0] if len(output_list) == 1 else output_list
|
|
|
|
|
|
|
|
|
|
def parse(self):
|
|
|
|
|
raise NotImplementedError("TextTensorOperation has to implement parse() method.")
|
|
|
|
|