|
|
|
@ -19,27 +19,31 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
using framework::Tensor;
|
|
|
|
|
template <typename T, int MajorType = Eigen::RowMajor,
|
|
|
|
|
typename IndexType = Eigen::DenseIndex>
|
|
|
|
|
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
|
class LinearChainCrfOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
|
|
|
|
|
"This kernel only runs on CPU.");
|
|
|
|
|
}
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
T ForwardOneSequence(const platform::DeviceContext& ctx,
|
|
|
|
|
const Tensor& emission, Tensor& emission_row_max,
|
|
|
|
|
Tensor& emission_exps, const Tensor& trans_weights,
|
|
|
|
|
Tensor& trans_weight_exps, const Tensor& label,
|
|
|
|
|
Tensor& a) const;
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
T NormalizeL1(T* x, size_t len) const;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
|
class LinearChainCrfGradOpKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
|
|
|
|
|
"This kernel only runs on CPU.");
|
|
|
|
|
}
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|