|
|
|
@ -157,7 +157,9 @@ class SplitLoDTensorInferShape : public framework::InferShapeBase {
|
|
|
|
|
|
|
|
|
|
auto mask_dim = context->GetInputDim("Mask");
|
|
|
|
|
PADDLE_ENFORCE_EQ(mask_dim.size(), 2);
|
|
|
|
|
PADDLE_ENFORCE_EQ(mask_dim[1], 1);
|
|
|
|
|
if (context->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(mask_dim[1], 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
context->SetOutputDim("OutTrue", context->GetInputDim("X"));
|
|
|
|
|
context->SetOutputDim("OutFalse", context->GetInputDim("X"));
|
|
|
|
|