|
|
|
@ -25,28 +25,22 @@ class ExpandAsV2Op : public framework::OperatorWithKernel {
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "ExpandAsV2");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("target_tensor"), "Input", "target_tensor",
|
|
|
|
|
"ExpandAsV2");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "ExpandAsV2");
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
auto target_tensor_dims = ctx->GetInputDim("target_tensor");
|
|
|
|
|
auto target_shape = ctx->Attrs().Get<std::vector<int>>("target_shape");
|
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
target_tensor_dims.size(), static_cast<size_t>(x_dims.size()),
|
|
|
|
|
target_shape.size(), static_cast<size_t>(x_dims.size()),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The rank of Input(target_tensor) must be greater than or equal "
|
|
|
|
|
"The rank of target_shape must be greater than or equal "
|
|
|
|
|
"to the rank of Input(X). But received Input(X): input "
|
|
|
|
|
"rank %u, input shape [%s]; received Input(target_tensor): "
|
|
|
|
|
"input rank %u, input shape [%s].",
|
|
|
|
|
x_dims.size(), x_dims, target_tensor_dims.size(),
|
|
|
|
|
target_tensor_dims));
|
|
|
|
|
PADDLE_ENFORCE_LE(
|
|
|
|
|
target_tensor_dims.size(), MAX_RANK_SUPPORTED,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The rank of Input(target_tensor) must not be less than or equal "
|
|
|
|
|
"to %d. But received: input rank %u, input shape [%s].",
|
|
|
|
|
MAX_RANK_SUPPORTED, x_dims.size(), x_dims));
|
|
|
|
|
std::vector<int> out_shape = framework::vectorize<int>(target_tensor_dims);
|
|
|
|
|
ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
|
|
|
|
|
"rank %u; received target_shape: rank %u.",
|
|
|
|
|
x_dims.size(), target_shape.size()));
|
|
|
|
|
PADDLE_ENFORCE_LE(target_shape.size(), MAX_RANK_SUPPORTED,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The rank of target_shape must be less than or equal "
|
|
|
|
|
"to %d. But received: rank %u.",
|
|
|
|
|
MAX_RANK_SUPPORTED, target_shape.size()));
|
|
|
|
|
ctx->SetOutputDim("Out", framework::make_ddim(target_shape));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -62,23 +56,11 @@ class ExpandAsV2OpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
"After expanding, size of each dimension of Output(Out) is equal "
|
|
|
|
|
"to size of the corresponding dimension of Input(X) multiplying "
|
|
|
|
|
"the corresponding value given by Attr(expand_times).");
|
|
|
|
|
AddInput("target_tensor", "Expand tensor's shape for each dimension.");
|
|
|
|
|
AddAttr<std::vector<int>>("target_shape",
|
|
|
|
|
"Expand shape for each dimension.")
|
|
|
|
|
.SetDefault({});
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
Expand the input by given times number. You should set times
|
|
|
|
|
number for each dimension by providing tensor 'expend_tensor'. The rank of X
|
|
|
|
|
should be in [1, 6]. Please note that size of 'expend_tensor' must be the same
|
|
|
|
|
with X's rank. Following is a using case:
|
|
|
|
|
Input(X) is a 3-D tensor with shape [2, 3, 1]:
|
|
|
|
|
[
|
|
|
|
|
[[1], [2], [3]],
|
|
|
|
|
[[4], [5], [6]]
|
|
|
|
|
]
|
|
|
|
|
target_tensors'shape: [2, 6, 2]
|
|
|
|
|
Output(Out) is a 3-D tensor with shape [2, 6, 2]:
|
|
|
|
|
[
|
|
|
|
|
[[1, 1], [2, 2], [3, 3], [1, 1], [2, 2], [3, 3]],
|
|
|
|
|
[[4, 4], [5, 5], [6, 6], [4, 4], [5, 5], [6, 6]]
|
|
|
|
|
]
|
|
|
|
|
Expand the input to the given shape.
|
|
|
|
|
)DOC");
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -117,7 +99,6 @@ class ExpandAsV2GradOpMaker : public framework::SingleGradOpMaker<T> {
|
|
|
|
|
void Apply(GradOpPtr<T> op) const override {
|
|
|
|
|
op->SetType("expand_as_v2_grad");
|
|
|
|
|
op->SetInput("X", this->Input("X"));
|
|
|
|
|
op->SetInput("target_tensor", this->Input("target_tensor"));
|
|
|
|
|
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
|
|
|
|
|
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
|
|
|
|
|
op->SetAttrMap(this->Attrs());
|
|
|
|
|