|
|
|
@ -21,10 +21,14 @@ class PullBoxSparseOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_GE(ctx->Inputs("Ids").size(), 1UL,
|
|
|
|
|
"Inputs(Ids) of PullBoxSparseOp should not be empty.");
|
|
|
|
|
PADDLE_ENFORCE_GE(ctx->Outputs("Out").size(), 1UL,
|
|
|
|
|
"Outputs(Out) of PullBoxSparseOp should not be empty.");
|
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
ctx->Inputs("Ids").size(), 1UL,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Inputs(Ids) of PullBoxSparseOp should not be empty."));
|
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
ctx->Outputs("Out").size(), 1UL,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Outputs(Out) of PullBoxSparseOp should not be empty."));
|
|
|
|
|
auto hidden_size = static_cast<int64_t>(ctx->Attrs().Get<int>("size"));
|
|
|
|
|
auto all_ids_dim = ctx->GetInputsDim("Ids");
|
|
|
|
|
const size_t n_ids = all_ids_dim.size();
|
|
|
|
@ -34,9 +38,10 @@ class PullBoxSparseOp : public framework::OperatorWithKernel {
|
|
|
|
|
const auto ids_dims = all_ids_dim[i];
|
|
|
|
|
int ids_rank = ids_dims.size();
|
|
|
|
|
PADDLE_ENFORCE_EQ(ids_dims[ids_rank - 1], 1,
|
|
|
|
|
"Shape error in %lu id, the last dimension of the "
|
|
|
|
|
"'Ids' tensor must be 1.",
|
|
|
|
|
i);
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Shape error in %lu id, the last dimension of the "
|
|
|
|
|
"'Ids' tensor must be 1.",
|
|
|
|
|
i));
|
|
|
|
|
auto out_dim = framework::vectorize(
|
|
|
|
|
framework::slice_ddim(ids_dims, 0, ids_rank - 1));
|
|
|
|
|
out_dim.push_back(hidden_size);
|
|
|
|
|