|
|
|
@ -14,6 +14,7 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/expand_op.h"
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
@ -30,9 +31,12 @@ class ExpandOp : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
|
|
|
|
|
|
|
|
|
|
std::vector<int> expand_times =
|
|
|
|
|
ctx->Attrs().Get<std::vector<int>>("expand_times");
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
std::vector<int> expand_times(x_dims.size(), -1);
|
|
|
|
|
|
|
|
|
|
if (!ctx->HasInputs("expand_times_tensor")) {
|
|
|
|
|
expand_times = ctx->Attrs().Get<std::vector<int>>("expand_times");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(static_cast<size_t>(x_dims.size()), expand_times.size(),
|
|
|
|
|
"The number of Attr(expand_times)'s value must be equal "
|
|
|
|
@ -42,15 +46,11 @@ class ExpandOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
std::vector<int64_t> out_shape(x_dims.size());
|
|
|
|
|
for (size_t i = 0; i < expand_times.size(); ++i) {
|
|
|
|
|
PADDLE_ENFORCE_GE(expand_times[i], 1,
|
|
|
|
|
"Each value of Attr(expand_times) should not be "
|
|
|
|
|
"less than 1.");
|
|
|
|
|
out_shape[i] = x_dims[i] * expand_times[i];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// set the first dim to -1 in compile time
|
|
|
|
|
if (!ctx->IsRuntime() && x_dims[0] < 0) {
|
|
|
|
|
out_shape[0] = x_dims[0];
|
|
|
|
|
if (x_dims[i] == -1 || expand_times[i] == -1) {
|
|
|
|
|
out_shape[i] = -1;
|
|
|
|
|
} else {
|
|
|
|
|
out_shape[i] = x_dims[i] * expand_times[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
|
|
|
|
@ -58,6 +58,23 @@ class ExpandOp : public framework::OperatorWithKernel {
|
|
|
|
|
ctx->ShareLoD("X", "Out");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
|
|
|
|
|
ctx.device_context());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::OpKernelType GetKernelTypeForVar(
|
|
|
|
|
const std::string& var_name, const Tensor& tensor,
|
|
|
|
|
const framework::OpKernelType& expected_kernel_type) const override {
|
|
|
|
|
if (var_name == "expand_times_tensor") {
|
|
|
|
|
return expected_kernel_type;
|
|
|
|
|
}
|
|
|
|
|
return framework::OpKernelType(expected_kernel_type.data_type_,
|
|
|
|
|
tensor.place(), tensor.layout());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class ExpandOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
@ -66,6 +83,9 @@ class ExpandOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
AddInput("X",
|
|
|
|
|
"(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
|
|
|
|
|
"X is the input to be expanded.");
|
|
|
|
|
AddInput("expand_times_tensor", "(Tensor Tensor<int>), epxand times for X")
|
|
|
|
|
.AsDuplicable()
|
|
|
|
|
.AsDispensable();
|
|
|
|
|
AddOutput("Out",
|
|
|
|
|
"(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
|
|
|
|
|
"The rank of Output(Out) have the same with Input(X). "
|
|
|
|
@ -73,7 +93,8 @@ class ExpandOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
"to size of the corresponding dimension of Input(X) multiplying "
|
|
|
|
|
"the corresponding value given by Attr(expand_times).");
|
|
|
|
|
AddAttr<std::vector<int>>("expand_times",
|
|
|
|
|
"Expand times number for each dimension.");
|
|
|
|
|
"Expand times number for each dimension.")
|
|
|
|
|
.SetDefault({});
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
Expand operator tiles the input by given times number. You should set times
|
|
|
|
|
number for each dimension by providing attribute 'expand_times'. The rank of X
|
|
|
|
@ -113,6 +134,7 @@ class ExpandGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
std::vector<int> expand_times =
|
|
|
|
|
ctx->Attrs().Get<std::vector<int>>("expand_times");
|
|
|
|
|
|
|
|
|
|
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
|
|
|
|
|
|
|
|
|
|
size_t start_pos = 0u;
|
|
|
|
@ -137,6 +159,23 @@ class ExpandGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
ctx->SetOutputDim(x_grad_name, x_dims);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
|
|
|
|
|
ctx.device_context());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::OpKernelType GetKernelTypeForVar(
|
|
|
|
|
const std::string& var_name, const Tensor& tensor,
|
|
|
|
|
const framework::OpKernelType& expected_kernel_type) const override {
|
|
|
|
|
if (var_name == "expand_times_tensor") {
|
|
|
|
|
return expected_kernel_type;
|
|
|
|
|
}
|
|
|
|
|
return framework::OpKernelType(expected_kernel_type.data_type_,
|
|
|
|
|
tensor.place(), tensor.layout());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class ExpandGradOpDescMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
@ -150,6 +189,7 @@ class ExpandGradOpDescMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
op->SetInput("X", Input("X"));
|
|
|
|
|
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
|
|
|
|
|
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
|
|
|
|
|
op->SetInput("expand_times_tensor", Input("expand_times_tensor"));
|
|
|
|
|
op->SetAttrMap(Attrs());
|
|
|
|
|
return op;
|
|
|
|
|
}
|
|
|
|
|