|
|
@ -59,7 +59,9 @@ class RankAttentionOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ((rank_offset_dims[1] - 1) / 2, max_rank,
|
|
|
|
PADDLE_ENFORCE_EQ((rank_offset_dims[1] - 1) / 2, max_rank,
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
"Input(RankOffset) has wrong columns."));
|
|
|
|
"Input(RankOffset) has wrong columns, "
|
|
|
|
|
|
|
|
"except columns to be %d, but got %d",
|
|
|
|
|
|
|
|
max_rank, (rank_offset_dims[1] - 1) / 2));
|
|
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim("Out", {ins_num, para_col});
|
|
|
|
ctx->SetOutputDim("Out", {ins_num, para_col});
|
|
|
|
ctx->SetOutputDim("InputHelp", {ins_num, block_matrix_row});
|
|
|
|
ctx->SetOutputDim("InputHelp", {ins_num, block_matrix_row});
|
|
|
|