From cca383cfba49fcf9b9a137922c4112623a80bc28 Mon Sep 17 00:00:00 2001 From: caoying03 Date: Fri, 27 Oct 2017 13:35:39 +0800 Subject: [PATCH] follow comments. --- paddle/operators/linear_chain_crf_op.cc | 324 +----------------------- paddle/operators/linear_chain_crf_op.h | 297 +++++++++++++++++++++- 2 files changed, 295 insertions(+), 326 deletions(-) diff --git a/paddle/operators/linear_chain_crf_op.cc b/paddle/operators/linear_chain_crf_op.cc index 9caa2dc742..65bbfff0f8 100644 --- a/paddle/operators/linear_chain_crf_op.cc +++ b/paddle/operators/linear_chain_crf_op.cc @@ -17,26 +17,6 @@ limitations under the License. */ namespace paddle { namespace operators { -namespace { -template -T NormalizeL1(T* x, size_t len) { - T sum = 0.; - for (size_t i = 0; i < len; ++i) sum += x[i]; - // (This comment is from the old LinearChainCRFLayer.) - // Right now, we just bet that sum won't be zero. If this really happens, we - // will figure out what should be done then. - PADDLE_ENFORCE(sum, - "The unnormalized probabilities of all possible unfinished " - "sequences must be greater than 0."); - T s = 1. / sum; - for (size_t i = 0; i < len; ++i) x[i] *= s; - return sum; -} -} // namespace - -using framework::LoDTensor; -using framework::LoD; - class LinearChainCRFOpMaker : public framework::OpProtoAndCheckerMaker { public: LinearChainCRFOpMaker(framework::OpProto* proto, @@ -206,145 +186,6 @@ class LinearChainCRFOp : public framework::OperatorWithKernel { } }; -template -class LinearChainCRFOpKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), - "This kernel only runs on CPU."); - auto* emission_weights = ctx.Input("Emission"); - auto* transition_weights = ctx.Input("Transition"); - auto* emission_exps = ctx.Output("EmissionExps"); - emission_exps->mutable_data(platform::CPUPlace()); - auto* transition_exps = ctx.Output("TransitionExps"); - transition_exps->mutable_data(platform::CPUPlace()); - auto* label = ctx.Input("Label"); - - auto in_lod = emission_weights->lod(); - PADDLE_ENFORCE(in_lod.size(), "Input(Emission) is not a sequence."); - - // TODO(caoying) The checks related to LoD information should be - // moved into InferShape once after the InferShape is refactored. - PADDLE_ENFORCE_EQ(emission_weights->NumLevels(), 1UL, - "The Input(Emission) should be a sequence."); - PADDLE_ENFORCE_EQ(label->NumLevels(), 1UL, - "The Input(Label) should be a sequence."); - const size_t level = 0; - - auto emission_dims = emission_weights->dims(); - const size_t batch_size = emission_dims[0]; - const size_t tag_num = emission_dims[1]; - const size_t seq_num = in_lod[level].size() - 1; - - Tensor emission_row_max; - emission_row_max.mutable_data( - framework::make_ddim({static_cast(batch_size), 1}), - platform::CPUPlace()); - - auto place = ctx.GetEigenDevice(); - auto x = EigenMatrix::From(*emission_weights); - auto x_row_max = EigenMatrix::From(emission_row_max); - x_row_max.device(place) = - x.maximum(Eigen::DSizes(1)) - .reshape(Eigen::DSizes(int(batch_size), 1)); - - auto x_exps = EigenMatrix::From(*emission_exps); - x_exps.device(place) = - (x - x_row_max.broadcast(Eigen::DSizes(1, tag_num))).exp(); - - auto w = EigenMatrix::From(*transition_weights); - auto w_exps = EigenMatrix::From(*transition_exps); - w_exps.device(place) = w.exp(); - - auto* alpha = ctx.Output("Alpha"); - alpha->mutable_data(platform::CPUPlace()); - auto* ll = ctx.Output("LogLikelihood"); - // resize the output tensor to the correct dimension. - ll->Resize({static_cast(seq_num), 1}); - T* log_likelihood = ll->mutable_data(platform::CPUPlace()); - for (size_t i = 0; i < seq_num; ++i) { - int start_pos = static_cast(in_lod[level][i]); - int end_pos = static_cast(in_lod[level][i + 1]); - if (end_pos == start_pos) { - // If an empty input sequence is given, pad 0 for its cost. - log_likelihood[i] = 0.; - continue; - } - - const Tensor one_seq = emission_weights->Slice(start_pos, end_pos); - Tensor one_seq_row_max = emission_row_max.Slice(start_pos, end_pos); - Tensor one_seq_exps = emission_exps->Slice(start_pos, end_pos); - const Tensor one_seq_label = label->Slice(start_pos, end_pos); - Tensor one_seq_alpha = alpha->Slice(start_pos, end_pos); - - log_likelihood[i] = ForwardOneSequence( - &one_seq, &one_seq_row_max, &one_seq_exps, transition_weights, - transition_exps, &one_seq_label, &one_seq_alpha); - } - } - - protected: - T ForwardOneSequence(const Tensor* emission, const Tensor* emission_row_max, - const Tensor* emission_exps, const Tensor* trans_weights, - const Tensor* trans_weight_exps, const Tensor* label, - Tensor* alpha) const { - const T* x = emission->data(); - const T* x_row_max = emission_row_max->data(); - const T* x_exps = emission_exps->data(); - const T* w = trans_weights->data(); - const T* w_exps = trans_weight_exps->data(); - T* alpha_value = alpha->data(); - - auto x_dims = emission->dims(); - const size_t seq_length = x_dims[0]; - const size_t tag_num = x_dims[1]; - // The 1st row of w are transition weights for start mask. - // The 2nd row of w are transition weights for end mask. - // Transition weights among other tags begin from the 3rd row of w. - const size_t state_trans_base_idx = 2; - - for (size_t i = 0; i < tag_num; ++i) { - alpha_value[i] = w_exps[i] * x_exps[i]; - } - T ll = -x_row_max[0] - std::log(NormalizeL1(alpha_value, tag_num)); - - for (size_t k = 1; k < seq_length; ++k) { - for (size_t i = 0; i < tag_num; ++i) { - T sum = 0.; - for (size_t j = 0; j < tag_num; ++j) { - sum += alpha_value[(k - 1) * tag_num + j] * - w_exps[(j + state_trans_base_idx) * tag_num + i]; - } - alpha_value[k * tag_num + i] = x_exps[k * tag_num + i] * sum; - } - // NormalizeL1 is to avoid underflow or overflow at (*). - ll -= x_row_max[k] + - std::log(NormalizeL1(alpha_value + k * tag_num, tag_num)); - } - T sum = 0.; - for (size_t i = 0; i < tag_num; ++i) { - sum += alpha_value[(seq_length - 1) * tag_num + i] * w_exps[tag_num + i]; - } - ll -= std::log(sum); - // Now ll is equal to -log(Z). - - const int* lbl = label->data(); - PADDLE_ENFORCE_LT( - *std::max_element(lbl, lbl + seq_length), tag_num, - "An invalid tag label that execesses the largest tag number."); - - // Calculate the nominator part, which depends on the label sequence. - ll += w[lbl[0]] /*start transition*/ + x[lbl[0]] + - w[tag_num + lbl[seq_length - 1]] /*end transition*/; - for (size_t k = 1; k < seq_length; ++k) { - ll += x[k * tag_num + lbl[k]] + - w[(lbl[k - 1] + state_trans_base_idx) * tag_num + lbl[k]]; - } - return -ll; - } -}; - class LinearChainCRFGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -357,11 +198,6 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("LogLikelihood")), "Input(LogLikelihood@GRAD) shoudl be not null."); - PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Emission")), - "Output(Emission@GRAD) should be not null."); - PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Transition")), - "Output(Transition@GRAD) should be not null."); - auto emission_exps_dims = ctx->GetInputDim("EmissionExps"); PADDLE_ENFORCE_EQ(emission_exps_dims.size(), 2UL, "The Input(EmissionExps) should be a 2-D tensor."); @@ -390,168 +226,24 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel { "The height of Input(EmissionExps) and the height of Input(Label) " "should be the same."); - ctx->SetOutputDim(framework::GradVarName("Emission"), emission_exps_dims); - ctx->SetOutputDim(framework::GradVarName("Transition"), - transition_exps_dims); + if (ctx->HasOutput(framework::GradVarName("Emission"))) { + ctx->SetOutputDim(framework::GradVarName("Emission"), emission_exps_dims); + } + if (ctx->HasOutput(framework::GradVarName("Transition"))) { + ctx->SetOutputDim(framework::GradVarName("Transition"), + transition_exps_dims); + } } protected: // Explicitly set that the data type of output of the linear_chain_crf_grad - // operator is determined by its input "EmissionExps". + // operator is determined by its input: graidents of LogLikelihood. framework::DataType IndicateDataType( const framework::ExecutionContext& ctx) const override { return framework::ToDataType(ctx.Input("LogLikelihood")->type()); } }; -template -class LinearChainCRFGradOpKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_ENFORCE(platform::is_cpu_place(platform::CPUPlace()), - "This kernel only runs on CPU."); - auto* label = ctx.Input("Label"); - auto* emission_exps = ctx.Input("EmissionExps"); - auto* transition_exps = ctx.Input("TransitionExps"); - auto* alpha = ctx.Input("Alpha"); - const T* ll_grad = - ctx.Input(framework::GradVarName("LogLikelihood"))->data(); - - auto* emission_grad = - ctx.Output(framework::GradVarName("Emission")); - emission_grad->mutable_data(platform::CPUPlace()); - - auto* trans_grad = ctx.Output(framework::GradVarName("Transition")); - if (trans_grad) trans_grad->mutable_data(platform::CPUPlace()); - - auto emission_dims = emission_exps->dims(); - - // Beta is the memo table used in dynamic programming to calculate the - // backwark vectors. For a backward vector i (the i-th row of beta), it - // captures the unnormalized probabilities of partial sequences starting at - // position i. - Tensor beta; - beta.mutable_data(emission_dims, platform::CPUPlace()); - - const size_t level = 0; // currently, only support sequence. - auto lod = label->lod(); - PADDLE_ENFORCE(lod.size(), "Input(Label) is not a sequence."); - - for (size_t i = 0; i < lod[level].size() - 1; ++i) { - int start_pos = static_cast(lod[level][i]); - int end_pos = static_cast(lod[level][i + 1]); - if (end_pos == start_pos) continue; - - const Tensor one_seq_emission_exps = - emission_exps->Slice(start_pos, end_pos); - const Tensor one_seq_label = label->Slice(start_pos, end_pos); - const Tensor one_seq_alpha = alpha->Slice(start_pos, end_pos); - Tensor one_seq_beta = beta.Slice(start_pos, end_pos); - Tensor one_seq_emission_grad = emission_grad->Slice(start_pos, end_pos); - - BackwardOneSequence(ctx.device_context(), ll_grad[i], - &one_seq_emission_exps, transition_exps, - &one_seq_alpha, &one_seq_label, &one_seq_beta, - trans_grad, &one_seq_emission_grad); - } - } - - protected: - void BackwardOneSequence(const platform::DeviceContext& ctx, const T ll_grad, - const Tensor* emission_exps, - const Tensor* transition_exps, const Tensor* alpha, - const Tensor* label, Tensor* beta, - Tensor* transition_grad, - Tensor* emission_grad) const { - const T* w_exps = transition_exps->data(); - const T* x_exps = emission_exps->data(); - const int* label_value = label->data(); - T* beta_value = beta->data(); - - auto x_dims = emission_exps->dims(); - const size_t seq_length = x_dims[0]; - const size_t tag_num = x_dims[1]; - const size_t state_trans_base_idx = 2; - - // Calculate the backward vectors: beta. - // First, calculate the initialition state. - for (size_t i = 0; i < tag_num; ++i) { - beta_value[(seq_length - 1) * tag_num + i] = w_exps[tag_num + i]; - } - NormalizeL1(beta_value + (seq_length - 1) * tag_num, tag_num); - - for (int k = static_cast(seq_length) - 2; k >= 0; --k) { - for (size_t i = 0; i < tag_num; ++i) { - T sum = 0.; - for (size_t j = 0; j < tag_num; ++j) { - sum += w_exps[(i + state_trans_base_idx) * tag_num + j] * - x_exps[(k + 1) * tag_num + j] * - beta_value[(k + 1) * tag_num + j]; - } - beta_value[k * tag_num + i] = sum; - } - // NormalizeL1 is to avoid underflow or overflow at (**). - NormalizeL1(beta_value + k * tag_num, tag_num); - } - - auto alpha_mat = EigenMatrix::From(*alpha); - auto beta_mat = EigenMatrix::From(*beta); - auto x_grad_mat = EigenMatrix::From(*emission_grad); - auto* place = ctx.GetEigenDevice(); - auto prob = alpha_mat * beta_mat; - auto row_sum = prob.sum(Eigen::DSizes(1)) - .reshape(Eigen::DSizes(seq_length, 1)) - .broadcast(Eigen::DSizes(1, tag_num)); - x_grad_mat.device(*place) = prob / row_sum; - - for (size_t k = 0; k < seq_length; ++k) { - x_grad_mat(k, label_value[k]) -= static_cast(1.); - } - - if (transition_grad) { - T* trans_grad = transition_grad->data(); - for (size_t k = 0; k < tag_num; ++k) { - trans_grad[k] += x_grad_mat(/*from start state*/ 0, k); - trans_grad[tag_num + k] += - x_grad_mat(/*to end state*/ seq_length - 1, k); - } - - auto x_exps_mat = EigenMatrix::From(*emission_exps); - - // TODO(caoying): Fix this to avoid using this local variable. - Tensor tmp; - tmp.mutable_data(beta->dims(), platform::CPUPlace()); - auto tmp_mat = EigenMatrix::From(tmp); - auto prob = beta_mat * x_exps_mat; - auto row_sum = prob.sum(Eigen::DSizes(1)) - .reshape(Eigen::DSizes(seq_length, 1)) - .broadcast(Eigen::DSizes(1, tag_num)); - tmp_mat.device(*place) = prob / row_sum; - - for (size_t k = 1; k < seq_length; ++k) { - T sum = 0.; - for (size_t i = 0; i < tag_num; ++i) { - for (size_t j = 0; j < tag_num; ++j) { - sum += w_exps[(i + state_trans_base_idx) * tag_num + j] * // (**) - alpha_mat(k - 1, i) * tmp_mat(k, j); - } - } - sum = 1. / sum; - for (size_t i = 0; i < tag_num; ++i) { - for (size_t j = 0; j < tag_num; ++j) { - trans_grad[(i + state_trans_base_idx) * tag_num + j] += - sum * w_exps[(i + state_trans_base_idx) * tag_num + j] * - alpha_mat(k - 1, i) * tmp_mat(k, j); - } - } - trans_grad[(label_value[k - 1] + state_trans_base_idx) * tag_num + - label_value[k]] -= static_cast(1.); - } - } - } -}; - } // namespace operators } // namespace paddle diff --git a/paddle/operators/linear_chain_crf_op.h b/paddle/operators/linear_chain_crf_op.h index 3175252c66..f028b6554e 100644 --- a/paddle/operators/linear_chain_crf_op.h +++ b/paddle/operators/linear_chain_crf_op.h @@ -19,6 +19,25 @@ limitations under the License. */ namespace paddle { namespace operators { +namespace { +template +T NormalizeL1(T* x, size_t len) { + T sum = 0.; + for (size_t i = 0; i < len; ++i) sum += x[i]; + // (This comment is from the old LinearChainCRFLayer.) + // Right now, we just bet that sum won't be zero. If this really happens, we + // will figure out what should be done then. + PADDLE_ENFORCE(sum, + "The unnormalized probabilities of all possible unfinished " + "sequences must be greater than 0."); + T s = 1. / sum; + for (size_t i = 0; i < len; ++i) x[i] *= s; + return sum; +} +} // namespace + +using framework::LoDTensor; +using framework::LoD; using framework::Tensor; template @@ -27,27 +46,285 @@ using EigenMatrix = framework::EigenMatrix; template class LinearChainCRFOpKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override; + void Compute(const framework::ExecutionContext& ctx) const override { + auto* emission_weights = ctx.Input("Emission"); + auto* transition_weights = ctx.Input("Transition"); + auto* emission_exps = ctx.Output("EmissionExps"); + emission_exps->mutable_data(ctx.GetPlace()); + auto* transition_exps = ctx.Output("TransitionExps"); + transition_exps->mutable_data(ctx.GetPlace()); + auto* label = ctx.Input("Label"); + + auto in_lod = emission_weights->lod(); + PADDLE_ENFORCE(in_lod.size(), "Input(Emission) is not a sequence."); + + // TODO(caoying) The checks related to LoD information should be + // moved into InferShape once after the InferShape is refactored. + PADDLE_ENFORCE_EQ(emission_weights->NumLevels(), 1UL, + "The Input(Emission) should be a sequence."); + PADDLE_ENFORCE_EQ(label->NumLevels(), 1UL, + "The Input(Label) should be a sequence."); + const size_t level = 0; + + auto emission_dims = emission_weights->dims(); + const size_t batch_size = emission_dims[0]; + const size_t tag_num = emission_dims[1]; + const size_t seq_num = in_lod[level].size() - 1; + + Tensor emission_row_max; + emission_row_max.mutable_data( + framework::make_ddim({static_cast(batch_size), 1}), + ctx.GetPlace()); + + auto place = ctx.GetEigenDevice(); + auto x = EigenMatrix::From(*emission_weights); + auto x_row_max = EigenMatrix::From(emission_row_max); + x_row_max.device(place) = + x.maximum(Eigen::DSizes(1)) + .reshape(Eigen::DSizes(int(batch_size), 1)); + + auto x_exps = EigenMatrix::From(*emission_exps); + x_exps.device(place) = + (x - x_row_max.broadcast(Eigen::DSizes(1, tag_num))).exp(); + + auto w = EigenMatrix::From(*transition_weights); + auto w_exps = EigenMatrix::From(*transition_exps); + w_exps.device(place) = w.exp(); + + auto* alpha = ctx.Output("Alpha"); + alpha->mutable_data(ctx.GetPlace()); + auto* ll = ctx.Output("LogLikelihood"); + // resize the output tensor to the correct dimension. + ll->Resize({static_cast(seq_num), 1}); + T* log_likelihood = ll->mutable_data(ctx.GetPlace()); + for (size_t i = 0; i < seq_num; ++i) { + int start_pos = static_cast(in_lod[level][i]); + int end_pos = static_cast(in_lod[level][i + 1]); + if (end_pos == start_pos) { + // If an empty input sequence is given, pad 0 for its cost. + log_likelihood[i] = 0.; + continue; + } + + const Tensor one_seq = emission_weights->Slice(start_pos, end_pos); + Tensor one_seq_row_max = emission_row_max.Slice(start_pos, end_pos); + Tensor one_seq_exps = emission_exps->Slice(start_pos, end_pos); + const Tensor one_seq_label = label->Slice(start_pos, end_pos); + Tensor one_seq_alpha = alpha->Slice(start_pos, end_pos); + + log_likelihood[i] = ForwardOneSequence( + one_seq, one_seq_row_max, one_seq_exps, *transition_weights, + *transition_exps, one_seq_label, &one_seq_alpha); + } + }; protected: - T ForwardOneSequence(const Tensor* emission, const Tensor* emission_row_max, - const Tensor* emission_exps, const Tensor* trans_weights, - const Tensor* trans_weight_exps, const Tensor* label, - Tensor* alpha) const; + T ForwardOneSequence(const Tensor& emission, const Tensor& emission_row_max, + const Tensor& emission_exps, const Tensor& trans_weights, + const Tensor& trans_weight_exps, const Tensor& label, + Tensor* alpha) const { + const T* x = emission.data(); + const T* x_row_max = emission_row_max.data(); + const T* x_exps = emission_exps.data(); + const T* w = trans_weights.data(); + const T* w_exps = trans_weight_exps.data(); + T* alpha_value = alpha->data(); + + auto x_dims = emission.dims(); + const size_t seq_length = x_dims[0]; + const size_t tag_num = x_dims[1]; + // The 1st row of w are transition weights for start mask. + // The 2nd row of w are transition weights for end mask. + // Transition weights between other tags begin from the 3rd row of w. + const size_t state_trans_base_idx = 2; + + for (size_t i = 0; i < tag_num; ++i) { + alpha_value[i] = w_exps[i] * x_exps[i]; + } + T ll = -x_row_max[0] - std::log(NormalizeL1(alpha_value, tag_num)); + + for (size_t k = 1; k < seq_length; ++k) { + for (size_t i = 0; i < tag_num; ++i) { + T sum = 0.; + for (size_t j = 0; j < tag_num; ++j) { + sum += alpha_value[(k - 1) * tag_num + j] * + w_exps[(j + state_trans_base_idx) * tag_num + i]; + } + alpha_value[k * tag_num + i] = x_exps[k * tag_num + i] * sum; + } + // NormalizeL1 is to avoid underflow or overflow at (*). + ll -= x_row_max[k] + + std::log(NormalizeL1(alpha_value + k * tag_num, tag_num)); + } + T sum = 0.; + for (size_t i = 0; i < tag_num; ++i) { + sum += alpha_value[(seq_length - 1) * tag_num + i] * w_exps[tag_num + i]; + } + ll -= std::log(sum); + // Now ll is equal to -log(Z). + + const int* lbl = label.data(); + PADDLE_ENFORCE_LT( + *std::max_element(lbl, lbl + seq_length), tag_num, + "An invalid tag label that execesses the largest tag number."); + + // Calculate the nominator part, which depends on the label sequence. + ll += w[lbl[0]] /*start transition*/ + x[lbl[0]] + + w[tag_num + lbl[seq_length - 1]] /*end transition*/; + for (size_t k = 1; k < seq_length; ++k) { + ll += x[k * tag_num + lbl[k]] + + w[(lbl[k - 1] + state_trans_base_idx) * tag_num + lbl[k]]; + } + return -ll; + }; }; template class LinearChainCRFGradOpKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override; + void Compute(const framework::ExecutionContext& ctx) const override { + auto* label = ctx.Input("Label"); + auto* emission_exps = ctx.Input("EmissionExps"); + auto* transition_exps = ctx.Input("TransitionExps"); + auto* alpha = ctx.Input("Alpha"); + const T* ll_grad = + ctx.Input(framework::GradVarName("LogLikelihood"))->data(); + + auto place = ctx.GetPlace(); + auto* emission_grad = + ctx.Output(framework::GradVarName("Emission")); + emission_grad->mutable_data(place); + + auto* trans_grad = ctx.Output(framework::GradVarName("Transition")); + if (trans_grad) { + trans_grad->mutable_data(place); + } + + auto emission_dims = emission_exps->dims(); + + // Beta is the memo table used in dynamic programming to calculate the + // backwark vectors. For a backward vector i (the i-th row of beta), it + // captures the unnormalized probabilities of partial sequences starting at + // position i. + Tensor beta; + beta.mutable_data(emission_dims, place); + + const size_t level = 0; // currently, only support sequence. + auto lod = label->lod(); + PADDLE_ENFORCE(lod.size(), "Input(Label) is not a sequence."); + + for (size_t i = 0; i < lod[level].size() - 1; ++i) { + int start_pos = static_cast(lod[level][i]); + int end_pos = static_cast(lod[level][i + 1]); + if (end_pos == start_pos) continue; + + const Tensor one_seq_emission_exps = + emission_exps->Slice(start_pos, end_pos); + const Tensor one_seq_label = label->Slice(start_pos, end_pos); + const Tensor one_seq_alpha = alpha->Slice(start_pos, end_pos); + Tensor one_seq_beta = beta.Slice(start_pos, end_pos); + Tensor one_seq_emission_grad = emission_grad->Slice(start_pos, end_pos); + + BackwardOneSequence(ctx.device_context(), ll_grad[i], + one_seq_emission_exps, *transition_exps, + one_seq_alpha, one_seq_label, &one_seq_beta, + trans_grad, &one_seq_emission_grad); + } + }; protected: void BackwardOneSequence(const platform::DeviceContext& ctx, const T ll_grad, - const Tensor* emission_exps, - const Tensor* transition_exps, const Tensor* alpha, - const Tensor* label, Tensor* beta, + const Tensor& emission_exps, + const Tensor& transition_exps, const Tensor& alpha, + const Tensor& label, Tensor* beta, Tensor* transition_grad, - Tensor* emission_grad) const; + Tensor* emission_grad) const { + const T* w_exps = transition_exps.data(); + const T* x_exps = emission_exps.data(); + const int* label_value = label.data(); + T* beta_value = beta->data(); + + auto x_dims = emission_exps.dims(); + const size_t seq_length = x_dims[0]; + const size_t tag_num = x_dims[1]; + const size_t state_trans_base_idx = 2; + + // Calculate the backward vectors: beta. + // First, calculate the initialition state. + for (size_t i = 0; i < tag_num; ++i) { + beta_value[(seq_length - 1) * tag_num + i] = w_exps[tag_num + i]; + } + NormalizeL1(beta_value + (seq_length - 1) * tag_num, tag_num); + + for (int k = static_cast(seq_length) - 2; k >= 0; --k) { + for (size_t i = 0; i < tag_num; ++i) { + T sum = 0.; + for (size_t j = 0; j < tag_num; ++j) { + sum += w_exps[(i + state_trans_base_idx) * tag_num + j] * + x_exps[(k + 1) * tag_num + j] * + beta_value[(k + 1) * tag_num + j]; + } + beta_value[k * tag_num + i] = sum; + } + // NormalizeL1 is to avoid underflow or overflow at (**). + NormalizeL1(beta_value + k * tag_num, tag_num); + } + + auto alpha_mat = EigenMatrix::From(alpha); + auto beta_mat = EigenMatrix::From(*beta); + auto x_grad_mat = EigenMatrix::From(*emission_grad); + auto* place = ctx.GetEigenDevice(); + auto prob = alpha_mat * beta_mat; + auto row_sum = prob.sum(Eigen::DSizes(1)) + .reshape(Eigen::DSizes(seq_length, 1)) + .broadcast(Eigen::DSizes(1, tag_num)); + x_grad_mat.device(*place) = prob / row_sum; + + for (size_t k = 0; k < seq_length; ++k) { + x_grad_mat(k, label_value[k]) -= static_cast(1.); + } + + if (transition_grad) { + T* trans_grad = transition_grad->data(); + for (size_t k = 0; k < tag_num; ++k) { + trans_grad[k] += x_grad_mat(/*from start state*/ 0, k); + trans_grad[tag_num + k] += + x_grad_mat(/*to end state*/ seq_length - 1, k); + } + + auto x_exps_mat = EigenMatrix::From(emission_exps); + + // TODO(caoying): Fix this to avoid using this local variable. + Tensor tmp; + tmp.mutable_data(beta->dims(), ctx.GetPlace()); + auto tmp_mat = EigenMatrix::From(tmp); + auto prob = beta_mat * x_exps_mat; + auto row_sum = prob.sum(Eigen::DSizes(1)) + .reshape(Eigen::DSizes(seq_length, 1)) + .broadcast(Eigen::DSizes(1, tag_num)); + tmp_mat.device(*place) = prob / row_sum; + + for (size_t k = 1; k < seq_length; ++k) { + T sum = 0.; + for (size_t i = 0; i < tag_num; ++i) { + for (size_t j = 0; j < tag_num; ++j) { + sum += w_exps[(i + state_trans_base_idx) * tag_num + j] * // (**) + alpha_mat(k - 1, i) * tmp_mat(k, j); + } + } + sum = 1. / sum; + for (size_t i = 0; i < tag_num; ++i) { + for (size_t j = 0; j < tag_num; ++j) { + trans_grad[(i + state_trans_base_idx) * tag_num + j] += + sum * w_exps[(i + state_trans_base_idx) * tag_num + j] * + alpha_mat(k - 1, i) * tmp_mat(k, j); + } + } + trans_grad[(label_value[k - 1] + state_trans_base_idx) * tag_num + + label_value[k]] -= static_cast(1.); + } + } + }; }; } // namespace operators