|
|
|
@ -179,15 +179,15 @@ class MergeLoDTensorInferShape : public framework::InferShapeBase {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(framework::InferShapeContext *context) const override {
|
|
|
|
|
PADDLE_ENFORCE(context->HasInput("X"),
|
|
|
|
|
"MergeLoDTensorOp must has input X.");
|
|
|
|
|
"MergeLoDTensorOp must have input X.");
|
|
|
|
|
PADDLE_ENFORCE(context->HasInput("Mask"),
|
|
|
|
|
"MergeLoDTensorOp must has input Mask.");
|
|
|
|
|
"MergeLoDTensorOp must have input Mask.");
|
|
|
|
|
PADDLE_ENFORCE(context->HasInput("InTrue"),
|
|
|
|
|
"MergeLoDTensorOp must has input InTrue.");
|
|
|
|
|
"MergeLoDTensorOp must have input InTrue.");
|
|
|
|
|
PADDLE_ENFORCE(context->HasInput("InFalse"),
|
|
|
|
|
"MergeLoDTensorOp must has input InFalse.");
|
|
|
|
|
"MergeLoDTensorOp must have input InFalse.");
|
|
|
|
|
PADDLE_ENFORCE(context->HasOutput("Out"),
|
|
|
|
|
"MergeLoDTensorOp must has output Out");
|
|
|
|
|
"MergeLoDTensorOp must have output Out");
|
|
|
|
|
|
|
|
|
|
auto mask_dim = context->GetInputDim("Mask");
|
|
|
|
|
PADDLE_ENFORCE_EQ(mask_dim.size(), 2,
|
|
|
|
|