|
|
|
@ -831,12 +831,12 @@ def crf_decoding(input, param_attr, label=None):
|
|
|
|
|
return viterbi_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cos_sim(X, Y, **kwargs):
|
|
|
|
|
def cos_sim(X, Y):
|
|
|
|
|
"""
|
|
|
|
|
This function performs the cosine similarity between two tensors
|
|
|
|
|
X and Y and returns that as the output.
|
|
|
|
|
"""
|
|
|
|
|
helper = LayerHelper('cos_sim', **kwargs)
|
|
|
|
|
helper = LayerHelper('cos_sim', **locals())
|
|
|
|
|
out = helper.create_tmp_variable(dtype=X.dtype)
|
|
|
|
|
xnorm = helper.create_tmp_variable(dtype=X.dtype)
|
|
|
|
|
ynorm = helper.create_tmp_variable(dtype=X.dtype)
|
|
|
|
@ -850,7 +850,7 @@ def cos_sim(X, Y, **kwargs):
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def dropout(x, dropout_prob, is_test=False, seed=None, **kwargs):
|
|
|
|
|
def dropout(x, dropout_prob, is_test=False, seed=None):
|
|
|
|
|
"""
|
|
|
|
|
Computes dropout.
|
|
|
|
|
|
|
|
|
@ -879,7 +879,7 @@ def dropout(x, dropout_prob, is_test=False, seed=None, **kwargs):
|
|
|
|
|
droped = fluid.layers.dropout(input=x, dropout_rate=0.5)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
helper = LayerHelper('dropout', **kwargs)
|
|
|
|
|
helper = LayerHelper('dropout', **locals())
|
|
|
|
|
out = helper.create_tmp_variable(dtype=x.dtype)
|
|
|
|
|
mask = helper.create_tmp_variable(dtype=x.dtype, stop_gradient=True)
|
|
|
|
|
helper.append_op(
|
|
|
|
@ -896,7 +896,7 @@ def dropout(x, dropout_prob, is_test=False, seed=None, **kwargs):
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def cross_entropy(input, label, **kwargs):
|
|
|
|
|
def cross_entropy(input, label, soft_label=False):
|
|
|
|
|
"""
|
|
|
|
|
**Cross Entropy Layer**
|
|
|
|
|
|
|
|
|
@ -905,15 +905,15 @@ def cross_entropy(input, label, **kwargs):
|
|
|
|
|
computation.
|
|
|
|
|
|
|
|
|
|
1) One-hot cross-entropy:
|
|
|
|
|
`soft_label = False`, `Label[i, 0]` indicates the class index for sample i:
|
|
|
|
|
`soft_label = False`, `Label[i, 0]` indicates the class index for sample i:
|
|
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
|
|
|
|
|
Y[i] = -\log(X[i, Label[i]])
|
|
|
|
|
|
|
|
|
|
2) Soft-label cross-entropy:
|
|
|
|
|
`soft_label = True`, `Label[i, j]` indicates the soft label of class j
|
|
|
|
|
for sample i:
|
|
|
|
|
`soft_label = True`, `Label[i, j]` indicates the soft label of class j
|
|
|
|
|
for sample i:
|
|
|
|
|
|
|
|
|
|
.. math::
|
|
|
|
|
|
|
|
|
@ -923,8 +923,8 @@ def cross_entropy(input, label, **kwargs):
|
|
|
|
|
equals one.
|
|
|
|
|
|
|
|
|
|
3) One-hot cross-entropy with vecterized `label`:
|
|
|
|
|
As a special case of 2), when each row of 'label' has only one
|
|
|
|
|
non-zero element which is equal to 1, soft-label cross-entropy degenerates
|
|
|
|
|
As a special case of 2), when each row of 'label' has only one
|
|
|
|
|
non-zero element which is equal to 1, soft-label cross-entropy degenerates
|
|
|
|
|
to a one-hot cross-entropy with one-hot label representation.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
@ -938,7 +938,7 @@ def cross_entropy(input, label, **kwargs):
|
|
|
|
|
tensor<int64> with shape [N x 1]. When
|
|
|
|
|
`soft_label` is set to `True`, `label` is a
|
|
|
|
|
tensor<float/double> with shape [N x D].
|
|
|
|
|
soft_label (bool, via `**kwargs`): a flag indicating whether to
|
|
|
|
|
soft_label (bool): a flag indicating whether to
|
|
|
|
|
interpretate the given labels as soft
|
|
|
|
|
labels, default `False`.
|
|
|
|
|
|
|
|
|
@ -958,18 +958,18 @@ def cross_entropy(input, label, **kwargs):
|
|
|
|
|
predict = fluid.layers.fc(input=net, size=classdim, act='softmax')
|
|
|
|
|
cost = fluid.layers.cross_entropy(input=predict, label=label)
|
|
|
|
|
"""
|
|
|
|
|
helper = LayerHelper('cross_entropy', **kwargs)
|
|
|
|
|
helper = LayerHelper('cross_entropy', **locals())
|
|
|
|
|
out = helper.create_tmp_variable(dtype=input.dtype)
|
|
|
|
|
helper.append_op(
|
|
|
|
|
type='cross_entropy',
|
|
|
|
|
inputs={'X': [input],
|
|
|
|
|
'Label': [label]},
|
|
|
|
|
outputs={'Y': [out]},
|
|
|
|
|
attrs=kwargs)
|
|
|
|
|
attrs={"soft_label": soft_label})
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def square_error_cost(input, label, **kwargs):
|
|
|
|
|
def square_error_cost(input, label):
|
|
|
|
|
"""
|
|
|
|
|
**Square error cost layer**
|
|
|
|
|
|
|
|
|
@ -1004,7 +1004,7 @@ def square_error_cost(input, label, **kwargs):
|
|
|
|
|
cost = layers.square_error_cost(input=y_predict, label=y)
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
helper = LayerHelper('square_error_cost', **kwargs)
|
|
|
|
|
helper = LayerHelper('square_error_cost', **locals())
|
|
|
|
|
minus_out = helper.create_tmp_variable(dtype=input.dtype)
|
|
|
|
|
helper.append_op(
|
|
|
|
|
type='elementwise_sub',
|
|
|
|
@ -1019,12 +1019,12 @@ def square_error_cost(input, label, **kwargs):
|
|
|
|
|
return square_out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def accuracy(input, label, k=1, correct=None, total=None, **kwargs):
|
|
|
|
|
def accuracy(input, label, k=1, correct=None, total=None):
|
|
|
|
|
"""
|
|
|
|
|
This function computes the accuracy using the input and label.
|
|
|
|
|
The output is the top_k inputs and their indices.
|
|
|
|
|
"""
|
|
|
|
|
helper = LayerHelper("accuracy", **kwargs)
|
|
|
|
|
helper = LayerHelper("accuracy", **locals())
|
|
|
|
|
topk_out = helper.create_tmp_variable(dtype=input.dtype)
|
|
|
|
|
topk_indices = helper.create_tmp_variable(dtype="int64")
|
|
|
|
|
helper.append_op(
|
|
|
|
@ -1057,13 +1057,12 @@ def chunk_eval(input,
|
|
|
|
|
label,
|
|
|
|
|
chunk_scheme,
|
|
|
|
|
num_chunk_types,
|
|
|
|
|
excluded_chunk_types=None,
|
|
|
|
|
**kwargs):
|
|
|
|
|
excluded_chunk_types=None):
|
|
|
|
|
"""
|
|
|
|
|
This function computes and outputs the precision, recall and
|
|
|
|
|
F1-score of chunk detection.
|
|
|
|
|
"""
|
|
|
|
|
helper = LayerHelper("chunk_eval", **kwargs)
|
|
|
|
|
helper = LayerHelper("chunk_eval", **locals())
|
|
|
|
|
|
|
|
|
|
# prepare output
|
|
|
|
|
precision = helper.create_tmp_variable(dtype="float32")
|
|
|
|
@ -1295,7 +1294,7 @@ def conv2d(input,
|
|
|
|
|
return helper.append_activation(pre_act)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sequence_pool(input, pool_type, **kwargs):
|
|
|
|
|
def sequence_pool(input, pool_type):
|
|
|
|
|
"""
|
|
|
|
|
This function add the operator for sequence pooling.
|
|
|
|
|
It pools features of all time-steps of each instance, and is applied
|
|
|
|
@ -1345,7 +1344,7 @@ def sequence_pool(input, pool_type, **kwargs):
|
|
|
|
|
sqrt_x = fluid.layers.sequence_pool(input=x, pool_type='sqrt')
|
|
|
|
|
max_x = fluid.layers.sequence_pool(input=x, pool_type='max')
|
|
|
|
|
"""
|
|
|
|
|
helper = LayerHelper('sequence_pool', input=input, **kwargs)
|
|
|
|
|
helper = LayerHelper('sequence_pool', **locals())
|
|
|
|
|
dtype = helper.input_dtype()
|
|
|
|
|
pool_out = helper.create_tmp_variable(dtype)
|
|
|
|
|
max_index = helper.create_tmp_variable(dtype)
|
|
|
|
@ -1365,7 +1364,7 @@ def sequence_pool(input, pool_type, **kwargs):
|
|
|
|
|
return pool_out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sequence_first_step(input, **kwargs):
|
|
|
|
|
def sequence_first_step(input):
|
|
|
|
|
"""
|
|
|
|
|
This funciton get the first step of sequence.
|
|
|
|
|
|
|
|
|
@ -1398,7 +1397,7 @@ def sequence_first_step(input, **kwargs):
|
|
|
|
|
return sequence_pool(input=input, pool_type="first")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sequence_last_step(input, **kwargs):
|
|
|
|
|
def sequence_last_step(input):
|
|
|
|
|
"""
|
|
|
|
|
This funciton get the last step of sequence.
|
|
|
|
|
|
|
|
|
@ -2338,7 +2337,8 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None):
|
|
|
|
|
normed = fluid.layers.l2_normalize(x=data, axis=1)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if len(x.shape) == 1: axis = 0
|
|
|
|
|
if len(x.shape) == 1:
|
|
|
|
|
axis = 0
|
|
|
|
|
|
|
|
|
|
helper = LayerHelper("l2_normalize", **locals())
|
|
|
|
|
|
|
|
|
@ -2656,7 +2656,7 @@ def ctc_greedy_decoder(input, blank, name=None):
|
|
|
|
|
return ctc_out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def warpctc(input, label, blank=0, norm_by_times=False, **kwargs):
|
|
|
|
|
def warpctc(input, label, blank=0, norm_by_times=False):
|
|
|
|
|
"""
|
|
|
|
|
An operator integrating the open source Warp-CTC library
|
|
|
|
|
(https://github.com/baidu-research/warp-ctc)
|
|
|
|
@ -2697,7 +2697,7 @@ def warpctc(input, label, blank=0, norm_by_times=False, **kwargs):
|
|
|
|
|
cost = layers.warpctc(input=y_predict, label=y)
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
helper = LayerHelper('warpctc', **kwargs)
|
|
|
|
|
helper = LayerHelper('warpctc', **locals())
|
|
|
|
|
loss_out = helper.create_tmp_variable(dtype=input.dtype)
|
|
|
|
|
grad_out = helper.create_tmp_variable(dtype=input.dtype)
|
|
|
|
|
helper.append_op(
|
|
|
|
|