|
|
|
@ -26,28 +26,39 @@ class MultiplexOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Ids"), "Input(Ids) shouldn't be null.");
|
|
|
|
|
PADDLE_ENFORCE(!ctx->Inputs("X").empty(),
|
|
|
|
|
"MultiInput(X) shouldn't be empty.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) shouldn't be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Ids"), "Input", "Ids", "Multiplex");
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
ctx->Inputs("X").empty(), true,
|
|
|
|
|
platform::errors::InvalidArgument("MultiInput(X) shouldn't be empty."));
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Multiplex");
|
|
|
|
|
auto ids_dim = ctx->GetInputDim("Ids");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
ids_dim.size() == 2 && ids_dim[1] == 1,
|
|
|
|
|
"The index tensor must be a vector with size batchSize x 1.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ids_dim.size(), 2,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"The index tensor must be a vector with 2 dimensions"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ids_dim[1], 1,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"The index tensor must be a vector with batchSize x 1."));
|
|
|
|
|
|
|
|
|
|
auto ins_dims = ctx->GetInputsDim("X");
|
|
|
|
|
auto num_ins = ins_dims.size();
|
|
|
|
|
PADDLE_ENFORCE(num_ins > 1,
|
|
|
|
|
"multiplex operator should have more than "
|
|
|
|
|
"one candidate input tensors.");
|
|
|
|
|
PADDLE_ENFORCE_GT(num_ins, 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"multiplex operator should have more than "
|
|
|
|
|
"one candidate input tensors."));
|
|
|
|
|
|
|
|
|
|
auto in_dim = ins_dims[0];
|
|
|
|
|
PADDLE_ENFORCE(in_dim.size() >= 2,
|
|
|
|
|
"The rank of candidate tensors must be not less than 2.");
|
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
in_dim.size(), 2,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The rank of candidate tensors must be not less than 2."));
|
|
|
|
|
for (size_t i = 1; i < num_ins; i++) {
|
|
|
|
|
auto dim = ins_dims[i];
|
|
|
|
|
PADDLE_ENFORCE(in_dim == dim,
|
|
|
|
|
"All the candidate tensors must have the same size.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
in_dim, dim,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"All the candidate tensors must have the same size."));
|
|
|
|
|
}
|
|
|
|
|
ctx->SetOutputDim("Out", in_dim);
|
|
|
|
|
}
|
|
|
|
@ -115,9 +126,11 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
auto dxs = ctx->Outputs(framework::GradVarName("X"));
|
|
|
|
|
PADDLE_ENFORCE(!dxs.empty(), "Output(X@Grad) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@GRAD) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_NE(dxs.empty(), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Output(X@Grad) should not be null."));
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
|
|
|
|
|
framework::GradVarName("Out"), "MultiplexGrad");
|
|
|
|
|
auto dout_dim = ctx->GetInputDim(framework::GradVarName("Out"));
|
|
|
|
|
ctx->SetOutputsDim(framework::GradVarName("X"),
|
|
|
|
|
std::vector<framework::DDim>(dxs.size(), dout_dim));
|
|
|
|
|