|
|
|
@ -31,31 +31,33 @@ class TransposeOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Transpose");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Transpose");
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
std::vector<int> axis = ctx->Attrs().Get<std::vector<int>>("axis");
|
|
|
|
|
size_t x_rank = x_dims.size();
|
|
|
|
|
size_t axis_size = axis.size();
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_rank, axis_size,
|
|
|
|
|
"ShapeError: The input tensor's dimension "
|
|
|
|
|
"should be equal to the axis's size. "
|
|
|
|
|
"But received input tensor's dimension is %d, "
|
|
|
|
|
"axis's size is %d",
|
|
|
|
|
x_rank, axis_size);
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The input tensor's dimension "
|
|
|
|
|
"should be equal to the axis's size. "
|
|
|
|
|
"But received input tensor's dimension is %d, "
|
|
|
|
|
"axis's size is %d",
|
|
|
|
|
x_rank, axis_size));
|
|
|
|
|
|
|
|
|
|
std::vector<int> count(axis_size, 0);
|
|
|
|
|
for (size_t i = 0; i < axis_size; i++) {
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
axis[i] < static_cast<int>(axis_size) && ++count[axis[i]] == 1,
|
|
|
|
|
"ValueError: Each element of Attribute axis should "
|
|
|
|
|
"be a unique value range from 0 to (dims - 1), "
|
|
|
|
|
"where the dims is the axis's size, "
|
|
|
|
|
"unique value means this axis value can appear only once. "
|
|
|
|
|
"But received axis[%d] is %d, axis_size is %d, "
|
|
|
|
|
"count[axis[%d]] is %d",
|
|
|
|
|
i, axis[i], axis_size, i, count[axis[i]]);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
axis[i] < static_cast<int>(axis_size) && ++count[axis[i]] == 1, true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Each element of Attribute axis should "
|
|
|
|
|
"be a unique value range from 0 to (dims - 1), "
|
|
|
|
|
"where the dims is the axis's size, "
|
|
|
|
|
"unique value means this axis value can appear only once. "
|
|
|
|
|
"But received axis[%d] is %d, axis_size is %d, "
|
|
|
|
|
"count[axis[%d]] is %d",
|
|
|
|
|
i, axis[i], axis_size, i, count[axis[i]]));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
framework::DDim out_dims(x_dims);
|
|
|
|
@ -149,9 +151,9 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@GRAD) should not be null");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "TransposeOpGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
|
|
|
|
|
framework::GradVarName("Out"), "TransposeOpGrad");
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
|
|
|
|
|
if (ctx->HasOutput(framework::GradVarName("X"))) {
|
|
|
|
@ -193,8 +195,7 @@ class Transpose2Op : public TransposeOp {
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
TransposeOp::InferShape(ctx);
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("XShape"),
|
|
|
|
|
"Output(XShape) should not be null");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("XShape"), "Output", "XShape", "Transpose2");
|
|
|
|
|
const auto &in_dims = ctx->GetInputDim("X");
|
|
|
|
|
std::vector<int64_t> x_shape_dim(in_dims.size() + 1);
|
|
|
|
|
x_shape_dim[0] = 0;
|
|
|
|
@ -259,9 +260,10 @@ class Transpose2OpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("XShape"), "Input(XShape) should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@GRAD) should not be null");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("XShape"), "Input", "XShape",
|
|
|
|
|
"Transpose2OpGrad");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
|
|
|
|
|
framework::GradVarName("Out"), "Transpose2OpGrad");
|
|
|
|
|
if (ctx->HasOutput(framework::GradVarName("X"))) {
|
|
|
|
|
auto xshape_dim = ctx->GetInputDim("XShape");
|
|
|
|
|
auto x_shape_dim =
|
|
|
|
|