|
|
@ -16,6 +16,8 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
|
|
#include <string.h>
|
|
|
|
#include <string.h>
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
|
|
|
|
#include "paddle/operators/math/math_function.h"
|
|
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
namespace paddle {
|
|
|
|
namespace operators {
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
|
@ -65,9 +67,14 @@ class CTCAlignKernel : public framework::OpKernel<T> {
|
|
|
|
framework::LoD output_lod;
|
|
|
|
framework::LoD output_lod;
|
|
|
|
output_lod.push_back(output_lod0);
|
|
|
|
output_lod.push_back(output_lod0);
|
|
|
|
output->set_lod(output_lod);
|
|
|
|
output->set_lod(output_lod);
|
|
|
|
|
|
|
|
|
|
|
|
// resize output dims
|
|
|
|
// resize output dims
|
|
|
|
output->Resize({static_cast<int64_t>(output_lod0.back()), 1});
|
|
|
|
output->Resize({static_cast<int64_t>(output_lod0.back()), 1});
|
|
|
|
|
|
|
|
// for empty sequence
|
|
|
|
|
|
|
|
if (output_lod0.back() == 0) {
|
|
|
|
|
|
|
|
output->Resize({1, 1});
|
|
|
|
|
|
|
|
output_data = output->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
output_data[0] = -1;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|