|
|
@ -14,7 +14,6 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
|
|
#include "paddle/framework/eigen.h"
|
|
|
|
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
#include "paddle/operators/math/math_function.h"
|
|
|
|
#include "paddle/operators/math/math_function.h"
|
|
|
|
#include "paddle/operators/math/sequence_padding.h"
|
|
|
|
#include "paddle/operators/math/sequence_padding.h"
|
|
|
@ -209,12 +208,6 @@ class WarpCTCGradKernel : public framework::OpKernel<T> {
|
|
|
|
auto* logits_grad = ctx.Output<LoDTensor>(framework::GradVarName("Logits"));
|
|
|
|
auto* logits_grad = ctx.Output<LoDTensor>(framework::GradVarName("Logits"));
|
|
|
|
const Tensor* loss_grad = ctx.Input<Tensor>(framework::GradVarName("Loss"));
|
|
|
|
const Tensor* loss_grad = ctx.Input<Tensor>(framework::GradVarName("Loss"));
|
|
|
|
|
|
|
|
|
|
|
|
// LOG(ERROR) << "loss_grad_dims: " << loss_grad_dims;
|
|
|
|
|
|
|
|
// for (int i=0; i<loss_grad->numel();i++) {
|
|
|
|
|
|
|
|
// LOG(ERROR) << "loss_grad: " << loss_grad_data[i];
|
|
|
|
|
|
|
|
//}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// T* logits_grad_data =
|
|
|
|
|
|
|
|
logits_grad->mutable_data<T>(ctx.GetPlace());
|
|
|
|
logits_grad->mutable_data<T>(ctx.GetPlace());
|
|
|
|
bool norm_by_times = ctx.Attr<bool>("norm_by_times");
|
|
|
|
bool norm_by_times = ctx.Attr<bool>("norm_by_times");
|
|
|
|
math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
|
|
|
|
math::UnpaddingLoDTensorFunctor<DeviceContext, T>()(
|
|
|
@ -226,18 +219,6 @@ class WarpCTCGradKernel : public framework::OpKernel<T> {
|
|
|
|
math::ScaleLoDTensorFunctor<DeviceContext, T>()(
|
|
|
|
math::ScaleLoDTensorFunctor<DeviceContext, T>()(
|
|
|
|
ctx.template device_context<DeviceContext>(), *logits_grad,
|
|
|
|
ctx.template device_context<DeviceContext>(), *logits_grad,
|
|
|
|
loss_grad_data, num_seq);
|
|
|
|
loss_grad_data, num_seq);
|
|
|
|
/*
|
|
|
|
|
|
|
|
int level = 0;
|
|
|
|
|
|
|
|
auto logits_grad_lod = framework::ToAbsOffset(logits_grad->lod());
|
|
|
|
|
|
|
|
const size_t num_sequences = logits_grad_lod[level].size() - 1;
|
|
|
|
|
|
|
|
for (int seq_index = 0; seq_index < num_sequences; ++seq_index) {
|
|
|
|
|
|
|
|
for (int token_index = logits_grad_lod[level][seq_index];
|
|
|
|
|
|
|
|
token_index < logits_grad_lod[level][seq_index + 1];
|
|
|
|
|
|
|
|
++token_index) {
|
|
|
|
|
|
|
|
logits_grad_data[token_index] *= loss_grad_data[seq_index];
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|