|
|
|
@ -164,7 +164,7 @@ dimension value will be copied from Input(X) at runtime. Note that the index of
|
|
|
|
|
[2, 3, 4], Attr(shape) = [2, 3, 2, 0] is an invalid input.
|
|
|
|
|
|
|
|
|
|
3. Input(Shape) has a higher priority than Attr(shape) if it is provided, while
|
|
|
|
|
Attr(shape) still should be set correctly to gurantee shape inference in
|
|
|
|
|
Attr(shape) still should be set correctly to gurantee shape inference in
|
|
|
|
|
compile-time.
|
|
|
|
|
|
|
|
|
|
)DOC");
|
|
|
|
@ -195,6 +195,7 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class ReshapeKernel {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const framework::ExecutionContext &ctx) const {
|
|
|
|
@ -227,12 +228,15 @@ class ReshapeKernel {
|
|
|
|
|
"sequence_reshape op.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
out->mutable_data(ctx.GetPlace(), in->type());
|
|
|
|
|
framework::TensorCopySync(*in, ctx.GetPlace(), out);
|
|
|
|
|
if (in->data<T>() !=
|
|
|
|
|
reinterpret_cast<T *>(out->mutable_data(ctx.GetPlace(), in->type()))) {
|
|
|
|
|
framework::TensorCopySync(*in, ctx.GetPlace(), out);
|
|
|
|
|
}
|
|
|
|
|
out->Resize(out_dims);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class ReshapeGradKernel {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const framework::ExecutionContext &ctx) const {
|
|
|
|
@ -240,8 +244,9 @@ class ReshapeGradKernel {
|
|
|
|
|
auto *d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto in_dims = d_x->dims();
|
|
|
|
|
|
|
|
|
|
d_x->mutable_data(ctx.GetPlace(), d_out->type());
|
|
|
|
|
framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x);
|
|
|
|
|
if (d_out->data<T>() != d_x->mutable_data(ctx.GetPlace(), d_out->type())) {
|
|
|
|
|
framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x);
|
|
|
|
|
}
|
|
|
|
|
d_x->Resize(in_dims);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -259,7 +264,6 @@ class Reshape2Op : public ReshapeOp {
|
|
|
|
|
: ReshapeOp(type, inputs, outputs, attrs) {}
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
ReshapeOp::InferShape(ctx);
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("XShape"),
|
|
|
|
|
"Output(XShape) of ReshapeOp should not be null.");
|
|
|
|
|
const auto &x_dims = ctx->GetInputDim("X");
|
|
|
|
@ -270,6 +274,8 @@ class Reshape2Op : public ReshapeOp {
|
|
|
|
|
}
|
|
|
|
|
ctx->SetOutputDim("XShape", framework::make_ddim(xshape_dims));
|
|
|
|
|
ctx->ShareLoD("X", /*->*/ "XShape");
|
|
|
|
|
|
|
|
|
|
ReshapeOp::InferShape(ctx);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -335,38 +341,46 @@ namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OPERATOR(reshape, ops::ReshapeOp, ops::ReshapeOpMaker,
|
|
|
|
|
paddle::framework::DefaultGradOpDescMaker<true>);
|
|
|
|
|
REGISTER_OPERATOR(reshape_grad, ops::ReshapeGradOp);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
|
|
|
|
|
ops::ReshapeKernel, int, ops::ReshapeKernel,
|
|
|
|
|
int64_t, ops::ReshapeKernel);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
|
|
|
|
|
double, ops::ReshapeGradKernel, int,
|
|
|
|
|
ops::ReshapeGradKernel, int64_t,
|
|
|
|
|
ops::ReshapeGradKernel);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel<float>,
|
|
|
|
|
double, ops::ReshapeKernel<double>, int,
|
|
|
|
|
ops::ReshapeKernel<int>, int64_t,
|
|
|
|
|
ops::ReshapeKernel<int64_t>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape_grad, float,
|
|
|
|
|
ops::ReshapeGradKernel<float>, double,
|
|
|
|
|
ops::ReshapeGradKernel<double>, int,
|
|
|
|
|
ops::ReshapeGradKernel<int>, int64_t,
|
|
|
|
|
ops::ReshapeGradKernel<int64_t>);
|
|
|
|
|
|
|
|
|
|
REGISTER_OPERATOR(reshape2, ops::Reshape2Op, ops::Reshape2OpMaker,
|
|
|
|
|
ops::Reshape2GradMaker);
|
|
|
|
|
REGISTER_OPERATOR(reshape2_grad, ops::Reshape2GradOp);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
|
|
|
|
|
ops::ReshapeKernel, int, ops::ReshapeKernel,
|
|
|
|
|
int64_t, ops::ReshapeKernel);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
|
|
|
|
|
double, ops::ReshapeGradKernel, int,
|
|
|
|
|
ops::ReshapeGradKernel, int64_t,
|
|
|
|
|
ops::ReshapeGradKernel);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel<float>,
|
|
|
|
|
double, ops::ReshapeKernel<double>, int,
|
|
|
|
|
ops::ReshapeKernel<int>, int64_t,
|
|
|
|
|
ops::ReshapeKernel<int64_t>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL_FUNCTOR(reshape2_grad, float,
|
|
|
|
|
ops::ReshapeGradKernel<float>, double,
|
|
|
|
|
ops::ReshapeGradKernel<double>, int,
|
|
|
|
|
ops::ReshapeGradKernel<int>, int64_t,
|
|
|
|
|
ops::ReshapeGradKernel<int64_t>);
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel, double,
|
|
|
|
|
ops::ReshapeKernel, int, ops::ReshapeKernel,
|
|
|
|
|
int64_t, ops::ReshapeKernel);
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float, ops::ReshapeGradKernel,
|
|
|
|
|
double, ops::ReshapeGradKernel, int,
|
|
|
|
|
ops::ReshapeGradKernel, int64_t,
|
|
|
|
|
ops::ReshapeGradKernel);
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel, double,
|
|
|
|
|
ops::ReshapeKernel, int, ops::ReshapeKernel,
|
|
|
|
|
int64_t, ops::ReshapeKernel);
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad, float, ops::ReshapeGradKernel,
|
|
|
|
|
double, ops::ReshapeGradKernel, int,
|
|
|
|
|
ops::ReshapeGradKernel, int64_t,
|
|
|
|
|
ops::ReshapeGradKernel);
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape, float, ops::ReshapeKernel<float>,
|
|
|
|
|
double, ops::ReshapeKernel<double>, int,
|
|
|
|
|
ops::ReshapeKernel<int>, int64_t,
|
|
|
|
|
ops::ReshapeKernel<int64_t>);
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape_grad, float,
|
|
|
|
|
ops::ReshapeGradKernel<float>, double,
|
|
|
|
|
ops::ReshapeGradKernel<double>, int,
|
|
|
|
|
ops::ReshapeGradKernel<int>, int64_t,
|
|
|
|
|
ops::ReshapeGradKernel<int64_t>);
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2, float, ops::ReshapeKernel<float>,
|
|
|
|
|
double, ops::ReshapeKernel<double>, int,
|
|
|
|
|
ops::ReshapeKernel<int>, int64_t,
|
|
|
|
|
ops::ReshapeKernel<int64_t>);
|
|
|
|
|
REGISTER_OP_CUDA_KERNEL_FUNCTOR(reshape2_grad, float,
|
|
|
|
|
ops::ReshapeGradKernel<float>, double,
|
|
|
|
|
ops::ReshapeGradKernel<double>, int,
|
|
|
|
|
ops::ReshapeGradKernel<int>, int64_t,
|
|
|
|
|
ops::ReshapeGradKernel<int64_t>);
|
|
|
|
|
#endif
|
|
|
|
|