From 0dd3919a21ee28942821504bb3b8ee2b205bb3ec Mon Sep 17 00:00:00 2001
From: wanghaoshuang <wanghaoshuang@baidu.com>
Date: Thu, 18 Jan 2018 10:58:07 +0800
Subject: [PATCH 1/9] Add python wrapper for ctc_evaluator

---
 python/paddle/v2/fluid/layers/nn.py | 49 +++++++++++++++++++++++++----
 1 file changed, 43 insertions(+), 6 deletions(-)

diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py
index 4e8fd407c9..8572b422e5 100644
--- a/python/paddle/v2/fluid/layers/nn.py
+++ b/python/paddle/v2/fluid/layers/nn.py
@@ -50,6 +50,7 @@ __all__ = [
     'sequence_last_step',
     'dropout',
     'split',
+    'greedy_ctc_evaluator',
 ]
 
 
@@ -1547,13 +1548,13 @@ def split(input, num_or_sections, dim=-1):
 
     Args:
         input (Variable): The input variable which is a Tensor or LoDTensor.
-        num_or_sections (int|list): If :attr:`num_or_sections` is an integer, 
-            then the integer indicates the number of equal sized sub-tensors 
-            that the tensor will be divided into. If :attr:`num_or_sections` 
-            is a list of integers, the length of list indicates the number of 
-            sub-tensors and the integers indicate the sizes of sub-tensors' 
+        num_or_sections (int|list): If :attr:`num_or_sections` is an integer,
+            then the integer indicates the number of equal sized sub-tensors
+            that the tensor will be divided into. If :attr:`num_or_sections`
+            is a list of integers, the length of list indicates the number of
+            sub-tensors and the integers indicate the sizes of sub-tensors'
             :attr:`dim` dimension orderly.
-        dim (int): The dimension along which to split. If :math:`dim < 0`, the 
+        dim (int): The dimension along which to split. If :math:`dim < 0`, the
             dimension to split along is :math:`rank(input) + dim`.
 
     Returns:
@@ -1597,3 +1598,39 @@ def split(input, num_or_sections, dim=-1):
             'axis': dim
         })
     return outs
+
+
+def greedy_ctc_evaluator(input, label, blank, normalized=False, name=None):
+    """
+    """
+
+    helper = LayerHelper("greedy_ctc_evalutor", **locals())
+    # top 1 op
+    topk_out = helper.create_tmp_variable(dtype=input.dtype)
+    topk_indices = helper.create_tmp_variable(dtype="int64")
+    helper.append_op(
+        type="top_k",
+        inputs={"X": [input]},
+        outputs={"Out": [topk_out],
+                 "Indices": [topk_indices]},
+        attrs={"k": 1})
+
+    # ctc align op
+    ctc_out = helper.create_tmp_variable(dtype="int64")
+    helper.append_op(
+        type="ctc_align",
+        inputs={"Input": [topk_indices]},
+        outputs={"Out": [ctc_out]},
+        attrs={"merge_repeated": True,
+               "blank": blank})
+
+    # edit distance op
+    edit_distance_out = helper.create_tmp_variable(dtype="int64")
+    helper.append_op(
+        type="edit_distance",
+        inputs={"Hyps": [ctc_out],
+                "Refs": [label]},
+        outputs={"Out": [edit_distance_out]},
+        attrs={"normalized": normalized})
+
+    return edit_distance_out

From 082c302c3f1a2e289808829bcdd3db0a8eb5a853 Mon Sep 17 00:00:00 2001
From: wanghaoshuang <wanghaoshuang@baidu.com>
Date: Thu, 18 Jan 2018 16:58:54 +0800
Subject: [PATCH 2/9] Add comments

---
 doc/api/v2/fluid/layers.rst         |  5 +++++
 python/paddle/v2/fluid/layers/nn.py | 31 ++++++++++++++++++++++++-----
 2 files changed, 31 insertions(+), 5 deletions(-)

diff --git a/doc/api/v2/fluid/layers.rst b/doc/api/v2/fluid/layers.rst
index 62c154e65d..1b40a495d6 100644
--- a/doc/api/v2/fluid/layers.rst
+++ b/doc/api/v2/fluid/layers.rst
@@ -493,3 +493,8 @@ swish
 ------
 ..  autofunction:: paddle.v2.fluid.layers.swish
     :noindex:
+
+greedy_ctc_error
+------
+..  autofunction:: paddle.v2.fluid.layers.greedy_ctc_error
+    :noindex:
diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py
index 8572b422e5..c786f3128b 100644
--- a/python/paddle/v2/fluid/layers/nn.py
+++ b/python/paddle/v2/fluid/layers/nn.py
@@ -50,7 +50,7 @@ __all__ = [
     'sequence_last_step',
     'dropout',
     'split',
-    'greedy_ctc_evaluator',
+    'greedy_ctc_error',
 ]
 
 
@@ -1600,11 +1600,32 @@ def split(input, num_or_sections, dim=-1):
     return outs
 
 
-def greedy_ctc_evaluator(input, label, blank, normalized=False, name=None):
-    """
+def greedy_ctc_error(input, label, blank, normalized=False, name=None):
     """
+    This evaluator is to calculate sequence-to-sequence edit distance.
+
+    Args:
+
+        input(Variable): (LodTensor, default: LoDTensor<float>), the unscaled probabilities of variable-length sequences, which is a 2-D Tensor with LoD information. It's shape is [Lp, num_classes + 1], where Lp is the sum of all input sequences' length and num_classes is the true number of classes. (not including the blank label).
+
+        label(Variable): (LodTensor, default: LoDTensor<int>), the ground truth of variable-length sequence, which is a 2-D Tensor with LoD information. It is of the shape [Lg, 1], where Lg is th sum of all labels' length.
+
+        blank(int): the blank label index of Connectionist Temporal Classification (CTC) loss, which is in thehalf-opened interval [0, num_classes + 1).
+
+        normalized(bool): Indicated whether to normalize the edit distance by the length of reference string.
 
-    helper = LayerHelper("greedy_ctc_evalutor", **locals())
+    Returns:
+        Variable: sequence-to-sequence edit distance loss in shape [batch_size, 1].
+
+    Examples:
+        .. code-block:: python
+
+            x = fluid.layers.data(name='x', shape=[8], dtype='float32')
+            y = fluid.layers.data(name='y', shape=[1], dtype='float32')
+
+            cost = fluid.layers.greedy_ctc_error(input=x,label=y, blank=0)
+    """
+    helper = LayerHelper("greedy_ctc_error", **locals())
     # top 1 op
     topk_out = helper.create_tmp_variable(dtype=input.dtype)
     topk_indices = helper.create_tmp_variable(dtype="int64")
@@ -1620,7 +1641,7 @@ def greedy_ctc_evaluator(input, label, blank, normalized=False, name=None):
     helper.append_op(
         type="ctc_align",
         inputs={"Input": [topk_indices]},
-        outputs={"Out": [ctc_out]},
+        outputs={"Output": [ctc_out]},
         attrs={"merge_repeated": True,
                "blank": blank})
 

From 4673a4a9aa2c4c3d2cf487cacc841d59e817dfac Mon Sep 17 00:00:00 2001
From: wanghaoshuang <wanghaoshuang@baidu.com>
Date: Thu, 18 Jan 2018 20:40:47 +0800
Subject: [PATCH 3/9] divide this operator into ctc_greedy_decoder and
 edit_distance_error.

---
 doc/api/v2/fluid/layers.rst         |  9 ++-
 python/paddle/v2/fluid/layers/nn.py | 99 +++++++++++++++++++++++------
 2 files changed, 85 insertions(+), 23 deletions(-)

diff --git a/doc/api/v2/fluid/layers.rst b/doc/api/v2/fluid/layers.rst
index aae63a9ad0..f1e4e753c5 100644
--- a/doc/api/v2/fluid/layers.rst
+++ b/doc/api/v2/fluid/layers.rst
@@ -500,9 +500,14 @@ swish
 ..  autofunction:: paddle.v2.fluid.layers.swish
     :noindex:
 
-greedy_ctc_error
+edit_distance_error
 ---------------
-..  autofunction:: paddle.v2.fluid.layers.greedy_ctc_error
+..  autofunction:: paddle.v2.fluid.layers.edit_distance_error
+    :noindex:
+
+ctc_greedy_decoder
+---------------
+..  autofunction:: paddle.v2.fluid.layers.ctc_greedy_decoder
     :noindex:
 
 l2_normalize
diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py
index 60f2fd8e9d..72246304be 100644
--- a/python/paddle/v2/fluid/layers/nn.py
+++ b/python/paddle/v2/fluid/layers/nn.py
@@ -50,7 +50,8 @@ __all__ = [
     'sequence_last_step',
     'dropout',
     'split',
-    'greedy_ctc_error',
+    'ctc_greedy_decoder',
+    'edit_distance_error',
     'l2_normalize',
     'matmul',
 ]
@@ -1791,17 +1792,21 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None):
     return out
 
 
-def greedy_ctc_error(input, label, blank, normalized=False, name=None):
+def edit_distance_error(input, label, normalized=False, name=None):
     """
-    This evaluator is to calculate sequence-to-sequence edit distance.
+    EditDistance operator computes the edit distances between a batch of hypothesis strings and their references.Edit distance, also called Levenshtein distance, measures how dissimilar two strings are by counting the minimum number of operations to transform one string into anthor. Here the operations include insertion, deletion, and substitution. For example, given hypothesis string A = "kitten" and reference B = "sitting", the edit distance is 3 for A will be transformed into B at least after two substitutions and one insertion:
 
-    Args:
+       "kitten" -> "sitten" -> "sittin" -> "sitting"
 
-        input(Variable): (LodTensor, default: LoDTensor<float>), the unscaled probabilities of variable-length sequences, which is a 2-D Tensor with LoD information. It's shape is [Lp, num_classes + 1], where Lp is the sum of all input sequences' length and num_classes is the true number of classes. (not including the blank label).
+    Input(Hyps) is a LoDTensor consisting of all the hypothesis strings with the total number denoted by `batch_size`, and the separation is specified by the LoD information. And the `batch_size` reference strings are arranged in order in the same way in the LoDTensor Input(Refs).
 
-        label(Variable): (LodTensor, default: LoDTensor<int>), the ground truth of variable-length sequence, which is a 2-D Tensor with LoD information. It is of the shape [Lg, 1], where Lg is th sum of all labels' length.
+    Output(Out) contains the `batch_size` results and each stands for the edit stance for a pair of strings respectively. If Attr(normalized) is true, the edit distance will be divided by the length of reference string.
 
-        blank(int): the blank label index of Connectionist Temporal Classification (CTC) loss, which is in thehalf-opened interval [0, num_classes + 1).
+    Args:
+
+        input(Variable): The indices for hypothesis strings.
+
+        label(Variable): The indices for reference strings.
 
         normalized(bool): Indicated whether to normalize the edit distance by the length of reference string.
 
@@ -1812,11 +1817,73 @@ def greedy_ctc_error(input, label, blank, normalized=False, name=None):
         .. code-block:: python
 
             x = fluid.layers.data(name='x', shape=[8], dtype='float32')
-            y = fluid.layers.data(name='y', shape=[1], dtype='float32')
+            y = fluid.layers.data(name='y', shape=[7], dtype='float32')
+
+            cost = fluid.layers.edit_distance_error(input=x,label=y)
+    """
+    helper = LayerHelper("edit_distance_error", **locals())
+
+    # edit distance op
+    edit_distance_out = helper.create_tmp_variable(dtype="int64")
+    helper.append_op(
+        type="edit_distance",
+        inputs={"Hyps": [input],
+                "Refs": [label]},
+        outputs={"Out": [edit_distance_out]},
+        attrs={"normalized": normalized})
+
+    return edit_distance_out
+
+
+def ctc_greedy_decoder(input, blank, name=None):
+    """
+    This op is used to decode sequences by greedy policy by below steps:
+    1. Get the indexes of max value for each row in input. a.k.a. numpy.argmax(input, axis=0).
+    2. For each sequence in result of step1, merge repeated tokens between two blanks and delete all blanks.
+
+    A simple example as below:
+
+    .. code-block:: text
+
+        Given:
+
+        input.data = [[0.6, 0.1, 0.3, 0.1],
+                      [0.3, 0.2, 0.4, 0.1],
+                      [0.1, 0.5, 0.1, 0.3],
+                      [0.5, 0.1, 0.3, 0.1],
+
+                      [0.5, 0.1, 0.3, 0.1],
+                      [0.2, 0.2, 0.2, 0.4],
+                      [0.2, 0.2, 0.1, 0.5],
+                      [0.5, 0.1, 0.3, 0.1]]
+
+        input.lod = [[0, 4, 8]]
+
+        Then:
+
+        output.data = [[2],
+                       [1],
+                       [3]]
+
+        output.lod = [[0, 2, 3]]
+
+    Args:
+
+        input(Variable): (LoDTensor<float>), the probabilities of variable-length sequences, which is a 2-D Tensor with LoD information. It's shape is [Lp, num_classes + 1], where Lp is the sum of all input sequences' length and num_classes is the true number of classes. (not including the blank label).
+
+        blank(int): the blank label index of Connectionist Temporal Classification (CTC) loss, which is in thehalf-opened interval [0, num_classes + 1).
+
+    Returns:
+        Variable: CTC greedy decode result.
+
+    Examples:
+        .. code-block:: python
+
+            x = fluid.layers.data(name='x', shape=[8], dtype='float32')
 
-            cost = fluid.layers.greedy_ctc_error(input=x,label=y, blank=0)
+            cost = fluid.layers.ctc_greedy_decoder(input=x, blank=0)
     """
-    helper = LayerHelper("greedy_ctc_error", **locals())
+    helper = LayerHelper("ctc_greedy_decoder", **locals())
     # top 1 op
     topk_out = helper.create_tmp_variable(dtype=input.dtype)
     topk_indices = helper.create_tmp_variable(dtype="int64")
@@ -1835,14 +1902,4 @@ def greedy_ctc_error(input, label, blank, normalized=False, name=None):
         outputs={"Output": [ctc_out]},
         attrs={"merge_repeated": True,
                "blank": blank})
-
-    # edit distance op
-    edit_distance_out = helper.create_tmp_variable(dtype="int64")
-    helper.append_op(
-        type="edit_distance",
-        inputs={"Hyps": [ctc_out],
-                "Refs": [label]},
-        outputs={"Out": [edit_distance_out]},
-        attrs={"normalized": normalized})
-
-    return edit_distance_out
+    return ctc_out

From 5846aab31730cb595f6210bed0758954529fc0f0 Mon Sep 17 00:00:00 2001
From: wanghaoshuang <wanghaoshuang@baidu.com>
Date: Fri, 19 Jan 2018 14:53:46 +0800
Subject: [PATCH 4/9] 1. Rename 'edit_distance_error' to 'edit_distance' 2. Add
 edit distance evaluator to evaluator.py

---
 doc/api/v2/fluid/layers.rst         |  2 +-
 python/paddle/v2/fluid/evaluator.py | 32 +++++++++++++++++++++++++++++
 python/paddle/v2/fluid/layers/nn.py |  9 ++++----
 3 files changed, 37 insertions(+), 6 deletions(-)

diff --git a/doc/api/v2/fluid/layers.rst b/doc/api/v2/fluid/layers.rst
index f1e4e753c5..2ae68d01d3 100644
--- a/doc/api/v2/fluid/layers.rst
+++ b/doc/api/v2/fluid/layers.rst
@@ -500,7 +500,7 @@ swish
 ..  autofunction:: paddle.v2.fluid.layers.swish
     :noindex:
 
-edit_distance_error
+edit_distance
 ---------------
 ..  autofunction:: paddle.v2.fluid.layers.edit_distance_error
     :noindex:
diff --git a/python/paddle/v2/fluid/evaluator.py b/python/paddle/v2/fluid/evaluator.py
index adf174a07d..336d25929e 100644
--- a/python/paddle/v2/fluid/evaluator.py
+++ b/python/paddle/v2/fluid/evaluator.py
@@ -204,3 +204,35 @@ class ChunkEvaluator(Evaluator):
             [precision], dtype='float32'), np.array(
                 [recall], dtype='float32'), np.array(
                     [f1_score], dtype='float32')
+
+
+class EditDistance(Evaluator):
+    """
+    Average edit distance error for multiple mini-batches.
+    """
+
+    def __init__(self, input, label, k=1, **kwargs):
+        super(EditDistance, self).__init__("edit_distance", **kwargs)
+        main_program = self.helper.main_program
+        if main_program.current_block().idx != 0:
+            raise ValueError("You can only invoke Evaluator in root block")
+
+        self.total_error = self.create_state(
+            dtype='int64', shape=[1], suffix='total')
+        self.batch_num = 0
+        error = layers.edit_distance(input=input, label=label)
+        mean_error = layers.mean(input=error)
+        layers.sums(input=[self.total_error, mean_error], out=self.total_error)
+        self.metrics.append(mean_error)
+
+    def eval(self, executor, eval_program=None):
+        self.batch_num += 1
+        if eval_program is None:
+            eval_program = Program()
+        block = eval_program.current_block()
+        with program_guard(main_program=eval_program):
+            total_error = _clone_var_(block, self.total_error)
+            batch_num = layers.fill_constant(
+                shape=[1], value=self.batch_num, dtype="float32")
+            out = layers.elementwise_div(x=total_error, y=batch_num)
+        return np.array(executor.run(eval_program, fetch_list=[out])[0])
diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py
index 0c77b89065..8383e43dea 100644
--- a/python/paddle/v2/fluid/layers/nn.py
+++ b/python/paddle/v2/fluid/layers/nn.py
@@ -28,8 +28,7 @@ __all__ = [
     'batch_norm', 'beam_search_decode', 'conv2d_transpose', 'sequence_expand',
     'lstm_unit', 'reduce_sum', 'reduce_mean', 'reduce_max', 'reduce_min',
     'sequence_first_step', 'sequence_last_step', 'dropout', 'split',
-    'ctc_greedy_decoder', 'edit_distance_error', 'l2_normalize', 'matmul',
-    'warpctc'
+    'ctc_greedy_decoder', 'edit_distance', 'l2_normalize', 'matmul', 'warpctc'
 ]
 
 
@@ -1768,7 +1767,7 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None):
     return out
 
 
-def edit_distance_error(input, label, normalized=False, name=None):
+def edit_distance(input, label, normalized=False, name=None):
     """
     EditDistance operator computes the edit distances between a batch of hypothesis strings and their references.Edit distance, also called Levenshtein distance, measures how dissimilar two strings are by counting the minimum number of operations to transform one string into anthor. Here the operations include insertion, deletion, and substitution. For example, given hypothesis string A = "kitten" and reference B = "sitting", the edit distance is 3 for A will be transformed into B at least after two substitutions and one insertion:
 
@@ -1795,9 +1794,9 @@ def edit_distance_error(input, label, normalized=False, name=None):
             x = fluid.layers.data(name='x', shape=[8], dtype='float32')
             y = fluid.layers.data(name='y', shape=[7], dtype='float32')
 
-            cost = fluid.layers.edit_distance_error(input=x,label=y)
+            cost = fluid.layers.edit_distance(input=x,label=y)
     """
-    helper = LayerHelper("edit_distance_error", **locals())
+    helper = LayerHelper("edit_distance", **locals())
 
     # edit distance op
     edit_distance_out = helper.create_tmp_variable(dtype="int64")

From a8f118ca839a03d84aead834759679948e41f6f5 Mon Sep 17 00:00:00 2001
From: wanghaoshuang <wanghaoshuang@baidu.com>
Date: Sat, 20 Jan 2018 09:57:34 +0800
Subject: [PATCH 5/9] Add EditDistance to evaluator.py

---
 python/paddle/v2/fluid/evaluator.py | 14 ++++++++------
 1 file changed, 8 insertions(+), 6 deletions(-)

diff --git a/python/paddle/v2/fluid/evaluator.py b/python/paddle/v2/fluid/evaluator.py
index 336d25929e..351db4f12d 100644
--- a/python/paddle/v2/fluid/evaluator.py
+++ b/python/paddle/v2/fluid/evaluator.py
@@ -218,21 +218,23 @@ class EditDistance(Evaluator):
             raise ValueError("You can only invoke Evaluator in root block")
 
         self.total_error = self.create_state(
-            dtype='int64', shape=[1], suffix='total')
-        self.batch_num = 0
+            dtype='float32', shape=[1], suffix='total')
+        self.batch_num = self.create_state(
+            dtype='float32', shape=[1], suffix='total')
         error = layers.edit_distance(input=input, label=label)
-        mean_error = layers.mean(input=error)
+        error = layers.cast(x=error, dtype='float32')
+        mean_error = layers.mean(x=error)
         layers.sums(input=[self.total_error, mean_error], out=self.total_error)
+        const1 = layers.fill_constant(shape=[1], value=1.0, dtype="float32")
+        layers.sums(input=[self.batch_num, const1], out=self.batch_num)
         self.metrics.append(mean_error)
 
     def eval(self, executor, eval_program=None):
-        self.batch_num += 1
         if eval_program is None:
             eval_program = Program()
         block = eval_program.current_block()
         with program_guard(main_program=eval_program):
             total_error = _clone_var_(block, self.total_error)
-            batch_num = layers.fill_constant(
-                shape=[1], value=self.batch_num, dtype="float32")
+            batch_num = _clone_var_(block, self.batch_num)
             out = layers.elementwise_div(x=total_error, y=batch_num)
         return np.array(executor.run(eval_program, fetch_list=[out])[0])

From 0b854bdb8b0aad6360cf2c15b1ca40b52a94d40c Mon Sep 17 00:00:00 2001
From: wanghaoshuang <wanghaoshuang@baidu.com>
Date: Mon, 22 Jan 2018 09:37:23 +0800
Subject: [PATCH 6/9] Add sequence_erase option into edit distance python API

---
 python/paddle/v2/fluid/layers/nn.py | 23 ++++++++++++++++++++++-
 1 file changed, 22 insertions(+), 1 deletion(-)

diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py
index 5d05046bba..c57811df1d 100644
--- a/python/paddle/v2/fluid/layers/nn.py
+++ b/python/paddle/v2/fluid/layers/nn.py
@@ -1864,7 +1864,7 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None):
     return out
 
 
-def edit_distance(input, label, normalized=False, name=None):
+def edit_distance(input, label, normalized=False, tokens=None, name=None):
     """
     EditDistance operator computes the edit distances between a batch of hypothesis strings and their references.Edit distance, also called Levenshtein distance, measures how dissimilar two strings are by counting the minimum number of operations to transform one string into anthor. Here the operations include insertion, deletion, and substitution. For example, given hypothesis string A = "kitten" and reference B = "sitting", the edit distance is 3 for A will be transformed into B at least after two substitutions and one insertion:
 
@@ -1882,6 +1882,8 @@ def edit_distance(input, label, normalized=False, name=None):
 
         normalized(bool): Indicated whether to normalize the edit distance by the length of reference string.
 
+        tokens(list): Tokens that should be removed before calculating edit distance.
+
     Returns:
         Variable: sequence-to-sequence edit distance loss in shape [batch_size, 1].
 
@@ -1895,6 +1897,25 @@ def edit_distance(input, label, normalized=False, name=None):
     """
     helper = LayerHelper("edit_distance", **locals())
 
+    # remove some tokens from input and labels
+    if tokens is not None and len(tokens) > 0:
+        erased_input = helper.create_tmp_variable(dtype="int64")
+        erased_label = helper.create_tmp_variable(dtype="int64")
+
+        helper.append_op(
+            type="sequence_erase",
+            inputs={"X": [input]},
+            outputs={"Out": [erased_input]},
+            attrs={"tokens": tokens})
+        input = erased_input
+
+        helper.append_op(
+            type="sequence_erase",
+            inputs={"X": [label]},
+            outputs={"Out": [erase_label]},
+            attrs={"tokens": tokens})
+        label = erased_label
+
     # edit distance op
     edit_distance_out = helper.create_tmp_variable(dtype="int64")
     helper.append_op(

From 1bc8de32091d11a57cda7af0b38b1766d51a06d5 Mon Sep 17 00:00:00 2001
From: wanghaoshuang <wanghaoshuang@baidu.com>
Date: Mon, 22 Jan 2018 16:59:54 +0800
Subject: [PATCH 7/9] 1. Add sequence_num as edit distance op's output 2. Fix
 evaluator using 'reduce_sum' op instead of 'mean' op

---
 paddle/operators/CMakeLists.txt               |  1 +
 paddle/operators/edit_distance_op.cc          |  4 ++++
 paddle/operators/edit_distance_op.cu          |  9 +++++++-
 paddle/operators/edit_distance_op.h           |  4 +++-
 python/paddle/v2/fluid/evaluator.py           | 22 +++++++++----------
 python/paddle/v2/fluid/layers/nn.py           |  6 +++--
 .../v2/fluid/tests/test_edit_distance_op.py   |  6 +++--
 7 files changed, 35 insertions(+), 17 deletions(-)

diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt
index 6745a8da17..15f7cb6b56 100644
--- a/paddle/operators/CMakeLists.txt
+++ b/paddle/operators/CMakeLists.txt
@@ -156,6 +156,7 @@ op_library(parallel_do_op DEPS executor)
 # Regist multiple Kernel to pybind
 if (WITH_GPU)
 op_library(conv_op SRCS conv_op.cc conv_op.cu.cc conv_cudnn_op.cu.cc DEPS vol2col)
+op_library(edit_distance_op SRCS edit_distance_op.cc edit_distance_op.cu DEPS math_function)
 op_library(pool_op SRCS pool_op.cc pool_op.cu.cc pool_cudnn_op.cu.cc DEPS pooling)
 op_library(conv_transpose_op SRCS conv_transpose_op.cc conv_transpose_op.cu.cc
   conv_transpose_cudnn_op.cu.cc DEPS vol2col)
diff --git a/paddle/operators/edit_distance_op.cc b/paddle/operators/edit_distance_op.cc
index 62a1fcebe7..7e7dfc79eb 100644
--- a/paddle/operators/edit_distance_op.cc
+++ b/paddle/operators/edit_distance_op.cc
@@ -25,6 +25,8 @@ class EditDistanceOp : public framework::OperatorWithKernel {
     PADDLE_ENFORCE(ctx->HasInput("Hyps"), "Input(Hyps) shouldn't be null.");
     PADDLE_ENFORCE(ctx->HasInput("Refs"), "Input(Refs) shouldn't be null.");
     PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) shouldn't be null.");
+    PADDLE_ENFORCE(ctx->HasOutput("SequenceNum"),
+                   "Output(SequenceNum) shouldn't be null.");
     auto hyp_dims = ctx->GetInputDim("Hyps");
     auto ref_dims = ctx->GetInputDim("Refs");
     PADDLE_ENFORCE(hyp_dims.size() == 2 && hyp_dims[1] == 1,
@@ -34,6 +36,7 @@ class EditDistanceOp : public framework::OperatorWithKernel {
                    "Input(Refs) must be a 2-D LoDTensor with the 2nd dimension "
                    "equal to 1.");
     ctx->SetOutputDim("Out", ctx->GetInputDim("Refs"));
+    ctx->SetOutputDim("SequenceNum", {1});
   }
 
  protected:
@@ -54,6 +57,7 @@ class EditDistanceOpMaker : public framework::OpProtoAndCheckerMaker {
     AddInput("Refs",
              "(2-D LoDTensor<int64_t>, 2nd dim. equal to 1) "
              "The indices for reference strings.");
+    AddOutput("SequenceNum", "The sequence count of current batch");
     AddAttr<bool>("normalized",
                   "(bool, default false) Indicated whether to normalize "
                   "the edit distance by the length of reference string.")
diff --git a/paddle/operators/edit_distance_op.cu b/paddle/operators/edit_distance_op.cu
index 338fd79bcc..c3e116af08 100644
--- a/paddle/operators/edit_distance_op.cu
+++ b/paddle/operators/edit_distance_op.cu
@@ -14,6 +14,7 @@ limitations under the License. */
 
 #include <algorithm>
 #include "paddle/framework/op_registry.h"
+#include "paddle/operators/math/math_function.h"
 #include "paddle/platform/cuda_helper.h"
 #include "paddle/platform/gpu_info.h"
 
@@ -72,6 +73,8 @@ class EditDistanceGPUKernel : public framework::OpKernel<T> {
 
     auto* x1_t = ctx.Input<framework::LoDTensor>("Hyps");
     auto* x2_t = ctx.Input<framework::LoDTensor>("Refs");
+    auto* sequence_num = ctx.Output<framework::Tensor>("SequenceNum");
+    sequence_num->mutable_data<int64_t>(ctx.GetPlace());
 
     auto normalized = ctx.Attr<bool>("normalized");
     auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
@@ -88,7 +91,11 @@ class EditDistanceGPUKernel : public framework::OpKernel<T> {
                      "Reference string %d is empty.", i);
     }
 
-    auto num_strs = hyp_lod.size() - 1;
+    const size_t num_strs = hyp_lod.size() - 1;
+    math::SetConstant<platform::CUDADeviceContext, int64_t> set_constant;
+    set_constant(ctx.template device_context<platform::CUDADeviceContext>(),
+                 sequence_num, static_cast<int64_t>(num_strs));
+
     out_t->Resize({static_cast<int64_t>(num_strs), 1});
     out_t->mutable_data<T>(ctx.GetPlace());
     auto out = out_t->data<T>();
diff --git a/paddle/operators/edit_distance_op.h b/paddle/operators/edit_distance_op.h
index 4c5a29813c..974299e604 100644
--- a/paddle/operators/edit_distance_op.h
+++ b/paddle/operators/edit_distance_op.h
@@ -16,7 +16,6 @@ limitations under the License. */
 #include <algorithm>
 #include "paddle/framework/eigen.h"
 #include "paddle/framework/op_registry.h"
-
 namespace paddle {
 namespace operators {
 
@@ -28,6 +27,8 @@ class EditDistanceKernel : public framework::OpKernel<T> {
 
     auto* x1_t = ctx.Input<framework::LoDTensor>("Hyps");
     auto* x2_t = ctx.Input<framework::LoDTensor>("Refs");
+    auto* sequence_num = ctx.Output<framework::Tensor>("SequenceNum");
+    int64_t* seq_num_data = sequence_num->mutable_data<int64_t>(ctx.GetPlace());
 
     auto normalized = ctx.Attr<bool>("normalized");
 
@@ -41,6 +42,7 @@ class EditDistanceKernel : public framework::OpKernel<T> {
                      "Reference string %d is empty.", i);
     }
     auto num_strs = hyp_lod.size() - 1;
+    *seq_num_data = static_cast<int64_t>(num_strs);
 
     out_t->Resize({static_cast<int64_t>(num_strs), 1});
     out_t->mutable_data<float>(ctx.GetPlace());
diff --git a/python/paddle/v2/fluid/evaluator.py b/python/paddle/v2/fluid/evaluator.py
index 351db4f12d..67e99a70ad 100644
--- a/python/paddle/v2/fluid/evaluator.py
+++ b/python/paddle/v2/fluid/evaluator.py
@@ -219,15 +219,14 @@ class EditDistance(Evaluator):
 
         self.total_error = self.create_state(
             dtype='float32', shape=[1], suffix='total')
-        self.batch_num = self.create_state(
-            dtype='float32', shape=[1], suffix='total')
-        error = layers.edit_distance(input=input, label=label)
-        error = layers.cast(x=error, dtype='float32')
-        mean_error = layers.mean(x=error)
-        layers.sums(input=[self.total_error, mean_error], out=self.total_error)
-        const1 = layers.fill_constant(shape=[1], value=1.0, dtype="float32")
-        layers.sums(input=[self.batch_num, const1], out=self.batch_num)
-        self.metrics.append(mean_error)
+        self.seq_num = self.create_state(
+            dtype='int64', shape=[1], suffix='total')
+        error, seq_num = layers.edit_distance(input=input, label=label)
+        #error = layers.cast(x=error, dtype='float32')
+        sum_error = layers.reduce_sum(error)
+        layers.sums(input=[self.total_error, sum_error], out=self.total_error)
+        layers.sums(input=[self.seq_num, seq_num], out=self.seq_num)
+        self.metrics.append(sum_error)
 
     def eval(self, executor, eval_program=None):
         if eval_program is None:
@@ -235,6 +234,7 @@ class EditDistance(Evaluator):
         block = eval_program.current_block()
         with program_guard(main_program=eval_program):
             total_error = _clone_var_(block, self.total_error)
-            batch_num = _clone_var_(block, self.batch_num)
-            out = layers.elementwise_div(x=total_error, y=batch_num)
+            seq_num = _clone_var_(block, self.seq_num)
+            seq_num = layers.cast(x=seq_num, dtype='float32')
+            out = layers.elementwise_div(x=total_error, y=seq_num)
         return np.array(executor.run(eval_program, fetch_list=[out])[0])
diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py
index c57811df1d..9a1fc2f120 100644
--- a/python/paddle/v2/fluid/layers/nn.py
+++ b/python/paddle/v2/fluid/layers/nn.py
@@ -1918,14 +1918,16 @@ def edit_distance(input, label, normalized=False, tokens=None, name=None):
 
     # edit distance op
     edit_distance_out = helper.create_tmp_variable(dtype="int64")
+    sequence_num = helper.create_tmp_variable(dtype="int64")
     helper.append_op(
         type="edit_distance",
         inputs={"Hyps": [input],
                 "Refs": [label]},
-        outputs={"Out": [edit_distance_out]},
+        outputs={"Out": [edit_distance_out],
+                 "SequenceNum": [sequence_num]},
         attrs={"normalized": normalized})
 
-    return edit_distance_out
+    return edit_distance_out, sequence_num
 
 
 def ctc_greedy_decoder(input, blank, name=None):
diff --git a/python/paddle/v2/fluid/tests/test_edit_distance_op.py b/python/paddle/v2/fluid/tests/test_edit_distance_op.py
index 5f5634e297..01e7e64d05 100644
--- a/python/paddle/v2/fluid/tests/test_edit_distance_op.py
+++ b/python/paddle/v2/fluid/tests/test_edit_distance_op.py
@@ -60,6 +60,7 @@ class TestEditDistanceOp(OpTest):
 
         num_strs = len(x1_lod) - 1
         distance = np.zeros((num_strs, 1)).astype("float32")
+        sequence_num = np.array(2).astype("int64")
         for i in range(0, num_strs):
             distance[i] = Levenshtein(
                 hyp=x1[x1_lod[i]:x1_lod[i + 1]],
@@ -69,7 +70,7 @@ class TestEditDistanceOp(OpTest):
                 distance[i] = distance[i] / len_ref
         self.attrs = {'normalized': normalized}
         self.inputs = {'Hyps': (x1, [x1_lod]), 'Refs': (x2, [x2_lod])}
-        self.outputs = {'Out': distance}
+        self.outputs = {'Out': distance, 'SequenceNum': sequence_num}
 
     def test_check_output(self):
         self.check_output()
@@ -88,6 +89,7 @@ class TestEditDistanceOpNormalized(OpTest):
 
         num_strs = len(x1_lod) - 1
         distance = np.zeros((num_strs, 1)).astype("float32")
+        sequence_num = np.array(3).astype("int64")
         for i in range(0, num_strs):
             distance[i] = Levenshtein(
                 hyp=x1[x1_lod[i]:x1_lod[i + 1]],
@@ -97,7 +99,7 @@ class TestEditDistanceOpNormalized(OpTest):
                 distance[i] = distance[i] / len_ref
         self.attrs = {'normalized': normalized}
         self.inputs = {'Hyps': (x1, [x1_lod]), 'Refs': (x2, [x2_lod])}
-        self.outputs = {'Out': distance}
+        self.outputs = {'Out': distance, 'SequenceNum': sequence_num}
 
     def test_check_output(self):
         self.check_output()

From 8143a42667d3dd158a464449e3492b7b0acf55c7 Mon Sep 17 00:00:00 2001
From: wanghaoshuang <wanghaoshuang@baidu.com>
Date: Mon, 22 Jan 2018 17:34:45 +0800
Subject: [PATCH 8/9] 1. Add more comments

---
 python/paddle/v2/fluid/evaluator.py | 36 +++++++++++++++++++++++++----
 python/paddle/v2/fluid/layers/nn.py | 16 ++++++++-----
 2 files changed, 41 insertions(+), 11 deletions(-)

diff --git a/python/paddle/v2/fluid/evaluator.py b/python/paddle/v2/fluid/evaluator.py
index 67e99a70ad..5dde8d623a 100644
--- a/python/paddle/v2/fluid/evaluator.py
+++ b/python/paddle/v2/fluid/evaluator.py
@@ -208,20 +208,46 @@ class ChunkEvaluator(Evaluator):
 
 class EditDistance(Evaluator):
     """
-    Average edit distance error for multiple mini-batches.
+    Accumulate edit distance sum and sequence number from mini-batches and
+    compute the average edit_distance of all batches.
+
+    Args:
+        input: the sequences predicted by network
+        label: the target sequences which must has same sequence count
+        with input.
+        ignored_tokens(list of int): Tokens that should be removed before
+        calculating edit distance.
+
+    Example:
+
+        exe = fluid.executor(place)
+        distance_evaluator = fluid.Evaluator.EditDistance(input, label)
+        for epoch in PASS_NUM:
+            distance_evaluator.reset(exe)
+            for data in batches:
+                loss, sum_distance = exe.run(fetch_list=[cost] + distance_evaluator.metrics)
+                avg_distance = distance_evaluator.eval(exe)
+            pass_distance = distance_evaluator.eval(exe)
+
+        In the above example:
+        'sum_distance' is the sum of the batch's edit distance.
+        'avg_distance' is the average of edit distance from the firt batch to the current batch.
+        'pass_distance' is the average of edit distance from all the pass.
+
     """
 
-    def __init__(self, input, label, k=1, **kwargs):
+    def __init__(self, input, label, ignored_tokens=None, **kwargs):
         super(EditDistance, self).__init__("edit_distance", **kwargs)
         main_program = self.helper.main_program
         if main_program.current_block().idx != 0:
             raise ValueError("You can only invoke Evaluator in root block")
 
         self.total_error = self.create_state(
-            dtype='float32', shape=[1], suffix='total')
+            dtype='float32', shape=[1], suffix='total_error')
         self.seq_num = self.create_state(
-            dtype='int64', shape=[1], suffix='total')
-        error, seq_num = layers.edit_distance(input=input, label=label)
+            dtype='int64', shape=[1], suffix='seq_num')
+        error, seq_num = layers.edit_distance(
+            input=input, label=label, ignored_tokens=ignored_tokens)
         #error = layers.cast(x=error, dtype='float32')
         sum_error = layers.reduce_sum(error)
         layers.sums(input=[self.total_error, sum_error], out=self.total_error)
diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py
index 9a1fc2f120..7dd77aca95 100644
--- a/python/paddle/v2/fluid/layers/nn.py
+++ b/python/paddle/v2/fluid/layers/nn.py
@@ -1864,7 +1864,11 @@ def matmul(x, y, transpose_x=False, transpose_y=False, name=None):
     return out
 
 
-def edit_distance(input, label, normalized=False, tokens=None, name=None):
+def edit_distance(input,
+                  label,
+                  normalized=False,
+                  ignored_tokens=None,
+                  name=None):
     """
     EditDistance operator computes the edit distances between a batch of hypothesis strings and their references.Edit distance, also called Levenshtein distance, measures how dissimilar two strings are by counting the minimum number of operations to transform one string into anthor. Here the operations include insertion, deletion, and substitution. For example, given hypothesis string A = "kitten" and reference B = "sitting", the edit distance is 3 for A will be transformed into B at least after two substitutions and one insertion:
 
@@ -1882,10 +1886,10 @@ def edit_distance(input, label, normalized=False, tokens=None, name=None):
 
         normalized(bool): Indicated whether to normalize the edit distance by the length of reference string.
 
-        tokens(list): Tokens that should be removed before calculating edit distance.
+        ignored_tokens(list of int): Tokens that should be removed before calculating edit distance.
 
     Returns:
-        Variable: sequence-to-sequence edit distance loss in shape [batch_size, 1].
+        Variable: sequence-to-sequence edit distance in shape [batch_size, 1].
 
     Examples:
         .. code-block:: python
@@ -1898,7 +1902,7 @@ def edit_distance(input, label, normalized=False, tokens=None, name=None):
     helper = LayerHelper("edit_distance", **locals())
 
     # remove some tokens from input and labels
-    if tokens is not None and len(tokens) > 0:
+    if ignored_tokens is not None and len(ignored_tokens) > 0:
         erased_input = helper.create_tmp_variable(dtype="int64")
         erased_label = helper.create_tmp_variable(dtype="int64")
 
@@ -1906,14 +1910,14 @@ def edit_distance(input, label, normalized=False, tokens=None, name=None):
             type="sequence_erase",
             inputs={"X": [input]},
             outputs={"Out": [erased_input]},
-            attrs={"tokens": tokens})
+            attrs={"tokens": ignored_tokens})
         input = erased_input
 
         helper.append_op(
             type="sequence_erase",
             inputs={"X": [label]},
             outputs={"Out": [erase_label]},
-            attrs={"tokens": tokens})
+            attrs={"tokens": ignored_tokens})
         label = erased_label
 
     # edit distance op

From d9d9be1bac627d5314accdf89a4367bc3a2f0294 Mon Sep 17 00:00:00 2001
From: wanghaoshuang <wanghaoshuang@baidu.com>
Date: Mon, 22 Jan 2018 19:14:47 +0800
Subject: [PATCH 9/9] Fix white space in comments.

---
 python/paddle/v2/fluid/evaluator.py | 2 +-
 python/paddle/v2/fluid/layers/nn.py | 4 ++--
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/python/paddle/v2/fluid/evaluator.py b/python/paddle/v2/fluid/evaluator.py
index 5dde8d623a..933f91dcfe 100644
--- a/python/paddle/v2/fluid/evaluator.py
+++ b/python/paddle/v2/fluid/evaluator.py
@@ -212,7 +212,7 @@ class EditDistance(Evaluator):
     compute the average edit_distance of all batches.
 
     Args:
-        input: the sequences predicted by network
+        input: the sequences predicted by network.
         label: the target sequences which must has same sequence count
         with input.
         ignored_tokens(list of int): Tokens that should be removed before
diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py
index 7dd77aca95..5b53f5d64e 100644
--- a/python/paddle/v2/fluid/layers/nn.py
+++ b/python/paddle/v2/fluid/layers/nn.py
@@ -1870,7 +1870,7 @@ def edit_distance(input,
                   ignored_tokens=None,
                   name=None):
     """
-    EditDistance operator computes the edit distances between a batch of hypothesis strings and their references.Edit distance, also called Levenshtein distance, measures how dissimilar two strings are by counting the minimum number of operations to transform one string into anthor. Here the operations include insertion, deletion, and substitution. For example, given hypothesis string A = "kitten" and reference B = "sitting", the edit distance is 3 for A will be transformed into B at least after two substitutions and one insertion:
+    EditDistance operator computes the edit distances between a batch of hypothesis strings and their references. Edit distance, also called Levenshtein distance, measures how dissimilar two strings are by counting the minimum number of operations to transform one string into anthor. Here the operations include insertion, deletion, and substitution. For example, given hypothesis string A = "kitten" and reference B = "sitting", the edit distance is 3 for A will be transformed into B at least after two substitutions and one insertion:
 
        "kitten" -> "sitten" -> "sittin" -> "sitting"
 
@@ -2028,7 +2028,7 @@ def warpctc(input, label, blank=0, norm_by_times=False, **kwargs):
          Temporal Classification (CTC) loss, which is in the
          half-opened interval [0, num_classes + 1).
        norm_by_times: (bool, default: false), whether to normalize
-       the gradients by the number of time-step,which is also the
+       the gradients by the number of time-step, which is also the
        sequence's length. There is no need to normalize the gradients
        if warpctc layer was follewed by a mean_op.