|
|
@ -31,15 +31,15 @@ void FusionSeqPoolCVMConcatOp::InferShape(
|
|
|
|
paddle::platform::errors::InvalidArgument(
|
|
|
|
paddle::platform::errors::InvalidArgument(
|
|
|
|
"Output(Out) of FusionSeqPoolCVMConcatOp should not be null."));
|
|
|
|
"Output(Out) of FusionSeqPoolCVMConcatOp should not be null."));
|
|
|
|
int axis = ctx->Attrs().Get<int>("axis");
|
|
|
|
int axis = ctx->Attrs().Get<int>("axis");
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
PADDLE_ENFORCE_EQ(axis, 1, paddle::platform::errors::InvalidArgument(
|
|
|
|
axis, 1,
|
|
|
|
"FusionSeqPoolCVMConcatOp only supports "
|
|
|
|
paddle::platform::errors::InvalidArgument(
|
|
|
|
"concat axis=1 yet, but received %d.",
|
|
|
|
"FusionSeqPoolCVMConcatOp only supports concat axis=1 yet."));
|
|
|
|
axis));
|
|
|
|
bool use_cvm = ctx->Attrs().Get<bool>("use_cvm");
|
|
|
|
bool use_cvm = ctx->Attrs().Get<bool>("use_cvm");
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
PADDLE_ENFORCE_EQ(use_cvm, true, paddle::platform::errors::InvalidArgument(
|
|
|
|
use_cvm, true,
|
|
|
|
"FusionSeqPoolCVMConcatOp only supports "
|
|
|
|
paddle::platform::errors::InvalidArgument(
|
|
|
|
"use_cvm is true yet, but received %d.",
|
|
|
|
"FusionSeqPoolCVMConcatOp only supports use_cvm is true yet."));
|
|
|
|
use_cvm));
|
|
|
|
|
|
|
|
|
|
|
|
auto ins_dims = ctx->GetInputsDim("X");
|
|
|
|
auto ins_dims = ctx->GetInputsDim("X");
|
|
|
|
const size_t n = ins_dims.size();
|
|
|
|
const size_t n = ins_dims.size();
|
|
|
|