|
|
@ -94,7 +94,7 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
|
|
|
|
PADDLE_ENFORCE(!ctx.MultiOutputVar(framework::GradVarName("X")).empty(),
|
|
|
|
PADDLE_ENFORCE(!ctx.MultiOutputVar(framework::GradVarName("X")).empty(),
|
|
|
|
"Output(X@Grad) should not be null.");
|
|
|
|
"Output(X@Grad) should not be null.");
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
|
|
|
|
"Input(Out@GRAD) shouldn't be null.");
|
|
|
|
"Input(Out@GRAD) should not be null.");
|
|
|
|
auto d_ins = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
|
|
|
|
auto d_ins = ctx.MultiOutput<Tensor>(framework::GradVarName("X"));
|
|
|
|
auto ins = ctx.MultiInput<Tensor>("X");
|
|
|
|
auto ins = ctx.MultiInput<Tensor>("X");
|
|
|
|
// No need to compute gradient for Input(Ids)
|
|
|
|
// No need to compute gradient for Input(Ids)
|
|
|
|