|
|
|
@ -230,10 +230,12 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
if (ctx->HasOutput(framework::GradVarName("Emission"))) {
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("Emission"), emission_exps_dims);
|
|
|
|
|
ctx->ShareLoD("Emission", framework::GradVarName("Emission"));
|
|
|
|
|
}
|
|
|
|
|
if (ctx->HasOutput(framework::GradVarName("Transition"))) {
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("Transition"),
|
|
|
|
|
transition_exps_dims);
|
|
|
|
|
ctx->ShareLoD("Transition", framework::GradVarName("Transition"));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|