|
|
|
@ -19,6 +19,29 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
|
|
|
|
|
inline std::vector<int> get_new_shape(
|
|
|
|
|
const std::vector<const Tensor *> &list_new_shape_tensor) {
|
|
|
|
|
// get tensor from
|
|
|
|
|
std::vector<int> vec_new_shape;
|
|
|
|
|
for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) {
|
|
|
|
|
auto tensor = list_new_shape_tensor[i];
|
|
|
|
|
PADDLE_ENFORCE_EQ(tensor->dims(), framework::make_ddim({1}),
|
|
|
|
|
"shape of dim tensor should be [1]");
|
|
|
|
|
if (platform::is_gpu_place(tensor->place())) {
|
|
|
|
|
framework::Tensor temp;
|
|
|
|
|
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
|
|
|
|
|
|
|
|
|
|
vec_new_shape.push_back(static_cast<int32_t>(*temp.data<int32_t>()));
|
|
|
|
|
} else {
|
|
|
|
|
vec_new_shape.push_back(static_cast<int32_t>(*tensor->data<int32_t>()));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return vec_new_shape;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
class ReshapeOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
ReshapeOp(const std::string &type, const framework::VariableNameMap &inputs,
|
|
|
|
@ -32,17 +55,24 @@ class ReshapeOp : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of ReshapeOp should not be null.");
|
|
|
|
|
|
|
|
|
|
const std::vector<int> &shape = ctx->Attrs().Get<std::vector<int>>("shape");
|
|
|
|
|
PADDLE_ENFORCE(!shape.empty(),
|
|
|
|
|
"The shape information must be set by Attr(shape).");
|
|
|
|
|
if (ctx->HasInputs("ShapeTensor")) {
|
|
|
|
|
// top prority shape
|
|
|
|
|
auto inputs_name = ctx->Inputs("ShapeTensor");
|
|
|
|
|
PADDLE_ENFORCE(inputs_name.size() > 0, "shape tensor size can't be zero");
|
|
|
|
|
auto out_dims = std::vector<int>(inputs_name.size(), -1);
|
|
|
|
|
ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
|
|
|
|
|
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
if (ctx->HasInput("Shape") && ctx->IsRuntime()) {
|
|
|
|
|
// If true, set the shape of Output(Out) according to Input(Shape) in
|
|
|
|
|
// ReshapeKernel with ExecutionContext. Also check LoD in ReshapeKernel.
|
|
|
|
|
ctx->ShareLoD("X", /*->*/ "Out");
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const std::vector<int> &shape = ctx->Attrs().Get<std::vector<int>>("shape");
|
|
|
|
|
PADDLE_ENFORCE(!shape.empty(),
|
|
|
|
|
"The shape information must be set by Attr(shape).");
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
auto out_dims = ValidateShape(shape, x_dims);
|
|
|
|
|
ctx->SetOutputDim("Out", out_dims);
|
|
|
|
@ -114,6 +144,16 @@ class ReshapeOp : public framework::OperatorWithKernel {
|
|
|
|
|
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
|
|
|
|
|
ctx.device_context());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::OpKernelType GetKernelTypeForVar(
|
|
|
|
|
const std::string &var_name, const Tensor &tensor,
|
|
|
|
|
const framework::OpKernelType &expected_kernel_type) const override {
|
|
|
|
|
if (var_name == "ShapeTensor") {
|
|
|
|
|
return expected_kernel_type;
|
|
|
|
|
}
|
|
|
|
|
return framework::OpKernelType(expected_kernel_type.data_type_,
|
|
|
|
|
tensor.place(), tensor.layout());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
@ -126,9 +166,18 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
"the shape attribute, while the shape attribute still should be "
|
|
|
|
|
"set correctly to gurantee shape inference in compile time.")
|
|
|
|
|
.AsDispensable();
|
|
|
|
|
AddInput(
|
|
|
|
|
"ShapeTensor",
|
|
|
|
|
"(vector<Tensor<int32>>, optional). If provided, reshape will use this"
|
|
|
|
|
"The shape of the tensor in vector MUST BE [1]"
|
|
|
|
|
"it has the highest priority compare with Input(Shape) and "
|
|
|
|
|
"attr(shape).")
|
|
|
|
|
.AsDuplicable()
|
|
|
|
|
.AsDispensable();
|
|
|
|
|
AddOutput("Out", "(Tensor). The output tensor of reshape operator.");
|
|
|
|
|
AddAttr<std::vector<int>>(
|
|
|
|
|
"shape", "(std::vector<int>) Target shape of reshape operator.");
|
|
|
|
|
"shape", "(std::vector<int>) Target shape of reshape operator.")
|
|
|
|
|
.SetDefault({});
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
Reshape Operator.
|
|
|
|
|
|
|
|
|
@ -202,24 +251,35 @@ class ReshapeKernel {
|
|
|
|
|
auto *out = ctx.Output<framework::LoDTensor>("Out");
|
|
|
|
|
auto *in = ctx.Input<framework::LoDTensor>("X");
|
|
|
|
|
|
|
|
|
|
auto *shape_tensor = ctx.HasInput("Shape")
|
|
|
|
|
? ctx.Input<framework::LoDTensor>("Shape")
|
|
|
|
|
: nullptr;
|
|
|
|
|
|
|
|
|
|
framework::DDim out_dims = out->dims();
|
|
|
|
|
|
|
|
|
|
if (shape_tensor) {
|
|
|
|
|
auto *shape_data = shape_tensor->data<int>();
|
|
|
|
|
framework::Tensor cpu_shape_tensor;
|
|
|
|
|
if (platform::is_gpu_place(shape_tensor->place())) {
|
|
|
|
|
TensorCopySync(*shape_tensor, platform::CPUPlace(), &cpu_shape_tensor);
|
|
|
|
|
shape_data = cpu_shape_tensor.data<int>();
|
|
|
|
|
auto list_new_shape_tensor =
|
|
|
|
|
ctx.MultiInput<framework::Tensor>("ShapeTensor");
|
|
|
|
|
if (list_new_shape_tensor.size() > 0) {
|
|
|
|
|
// have shape tensor
|
|
|
|
|
auto new_shape = get_new_shape(list_new_shape_tensor);
|
|
|
|
|
out_dims = ReshapeOp::ValidateShape(new_shape, in->dims());
|
|
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
auto *shape_tensor = ctx.HasInput("Shape")
|
|
|
|
|
? ctx.Input<framework::LoDTensor>("Shape")
|
|
|
|
|
: nullptr;
|
|
|
|
|
|
|
|
|
|
if (shape_tensor) {
|
|
|
|
|
auto *shape_data = shape_tensor->data<int>();
|
|
|
|
|
framework::Tensor cpu_shape_tensor;
|
|
|
|
|
if (platform::is_gpu_place(shape_tensor->place())) {
|
|
|
|
|
TensorCopySync(*shape_tensor, platform::CPUPlace(),
|
|
|
|
|
&cpu_shape_tensor);
|
|
|
|
|
shape_data = cpu_shape_tensor.data<int>();
|
|
|
|
|
}
|
|
|
|
|
auto shape =
|
|
|
|
|
std::vector<int>(shape_data, shape_data + shape_tensor->numel());
|
|
|
|
|
out_dims = ReshapeOp::ValidateShape(shape, in->dims());
|
|
|
|
|
}
|
|
|
|
|
auto shape =
|
|
|
|
|
std::vector<int>(shape_data, shape_data + shape_tensor->numel());
|
|
|
|
|
out_dims = ReshapeOp::ValidateShape(shape, in->dims());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
out->Resize(out_dims);
|
|
|
|
|
out->mutable_data(ctx.GetPlace(), in->type());
|
|
|
|
|
framework::TensorCopy(
|
|
|
|
|
*in, ctx.GetPlace(),
|
|
|
|
@ -288,6 +348,7 @@ class Reshape2GradMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
auto *grad_op = new framework::OpDesc();
|
|
|
|
|
grad_op->SetType("reshape2_grad");
|
|
|
|
|
grad_op->SetInput("XShape", Output("XShape"));
|
|
|
|
|
grad_op->SetInput("ShapeTensor", Input("ShapeTensor"));
|
|
|
|
|
grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
|
|
|
|
|
grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
|
|
|
|
|
grad_op->SetAttrMap(Attrs());
|
|
|
|
@ -320,6 +381,16 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
|
|
|
|
|
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(),
|
|
|
|
|
ctx.device_context());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::OpKernelType GetKernelTypeForVar(
|
|
|
|
|
const std::string &var_name, const Tensor &tensor,
|
|
|
|
|
const framework::OpKernelType &expected_kernel_type) const override {
|
|
|
|
|
if (var_name == "ShapeTensor") {
|
|
|
|
|
return expected_kernel_type;
|
|
|
|
|
}
|
|
|
|
|
return framework::OpKernelType(expected_kernel_type.data_type_,
|
|
|
|
|
tensor.place(), tensor.layout());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class ReshapeOpInplaceInToOut : public framework::InplaceOpInference {
|
|
|
|
|