|
|
|
@ -109,21 +109,14 @@ class DeformableConvOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Input"),
|
|
|
|
|
"Input(Input) of DeformableConvOp "
|
|
|
|
|
"should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Offset"),
|
|
|
|
|
"Input(Offset) of DeformableConvOp "
|
|
|
|
|
"should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Mask"),
|
|
|
|
|
"Input(Mask) of DeformableConvOp "
|
|
|
|
|
"should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Filter"),
|
|
|
|
|
"Input(Filter) of DeformableConvOp "
|
|
|
|
|
"should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Output"),
|
|
|
|
|
"Output(Output) of DeformableConvOp "
|
|
|
|
|
"should not be null.");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "deformable_conv");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Offset"), "Input", "Offset",
|
|
|
|
|
"deformable_conv)");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Mask"), "Input", "Mask", "deformable_conv");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Filter"), "Input", "Filter",
|
|
|
|
|
"deformable_conv");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Output"), "Output", "Output",
|
|
|
|
|
"deformable_conv");
|
|
|
|
|
|
|
|
|
|
auto in_dims = ctx->GetInputDim("Input");
|
|
|
|
|
auto filter_dims = ctx->GetInputDim("Filter");
|
|
|
|
@ -138,39 +131,56 @@ class DeformableConvOp : public framework::OperatorWithKernel {
|
|
|
|
|
int deformable_groups = ctx->Attrs().Get<int>("deformable_groups");
|
|
|
|
|
int im2col_step = ctx->Attrs().Get<int>("im2col_step");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(in_dims.size() == 4,
|
|
|
|
|
"Conv input should be 4-D tensor, get %u", in_dims.size());
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
in_dims.size(), filter_dims.size(),
|
|
|
|
|
"Conv input dimension and filter dimension should be the same.");
|
|
|
|
|
in_dims.size(), 4,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Conv input should be 4-D tensor, get %u", in_dims.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_dims.size(), filter_dims.size(),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Conv input dimension and filter dimension should be "
|
|
|
|
|
"the same. The diff is [%d] vs [%d]",
|
|
|
|
|
in_dims.size(), filter_dims.size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
in_dims.size() - strides.size(), 2U,
|
|
|
|
|
"Conv input dimension and strides dimension should be consistent.");
|
|
|
|
|
platform::errors::InvalidArgument("Conv input dimension and strides "
|
|
|
|
|
"dimension should be consistent."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(paddings.size(), strides.size(),
|
|
|
|
|
"Conv paddings dimension and Conv strides dimension "
|
|
|
|
|
"should be the same.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Conv paddings dimension and Conv strides dimension "
|
|
|
|
|
"should be the same. The diff is [%d] vs [%d]",
|
|
|
|
|
paddings.size(), strides.size()));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[1] * groups,
|
|
|
|
|
"The number of input channels should be equal to filter "
|
|
|
|
|
"channels * groups.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
in_dims[1], filter_dims[1] * groups,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The number of input channels should be equal to filter "
|
|
|
|
|
"channels * groups. The diff is [%d] vs [%d]",
|
|
|
|
|
in_dims[1], filter_dims[1] * groups));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
filter_dims[0] % groups, 0,
|
|
|
|
|
"The number of output channels should be divided by groups.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(filter_dims[0] % deformable_groups, 0,
|
|
|
|
|
"The number of output channels should be "
|
|
|
|
|
"divided by deformable groups.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The number of output channels should be divided by groups."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
filter_dims[0] % deformable_groups, 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The number of output channels should be "
|
|
|
|
|
"divided by deformable groups. The diff is [%d] vs [%d]",
|
|
|
|
|
filter_dims[0] % groups, 0));
|
|
|
|
|
|
|
|
|
|
if (in_dims[0] > im2col_step) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
in_dims[0] % im2col_step, 0U,
|
|
|
|
|
"Input batchsize must be smaller than or divide im2col_step");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input batchsize must be smaller than or divide im2col_step"));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < strides.size(); ++i) {
|
|
|
|
|
PADDLE_ENFORCE_GT(strides[i], 0U, "stride %d size incorrect", i);
|
|
|
|
|
PADDLE_ENFORCE_GT(strides[i], 0U, platform::errors::InvalidArgument(
|
|
|
|
|
"stride %d size incorrect", i));
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < dilations.size(); ++i) {
|
|
|
|
|
PADDLE_ENFORCE_GT(dilations[i], 0U, "dilation %d size incorrect", i);
|
|
|
|
|
PADDLE_ENFORCE_GT(dilations[i], 0U, platform::errors::InvalidArgument(
|
|
|
|
|
"dilation %d size incorrect", i));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
|
|
|
|
@ -185,29 +195,49 @@ class DeformableConvOp : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(output_shape[1] % deformable_groups, 0U,
|
|
|
|
|
"output num_filter must divide deformable group size.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
output_shape[1] % deformable_groups, 0U,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"output num_filter must divide deformable group size."));
|
|
|
|
|
|
|
|
|
|
if (ctx->IsRuntime()) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(output_shape[2], offset_dims[2],
|
|
|
|
|
"output height must equal to offset map height.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"output height must equal to offset map height. "
|
|
|
|
|
"The diff is [%d] vs [%d]",
|
|
|
|
|
output_shape[2], offset_dims[2]));
|
|
|
|
|
PADDLE_ENFORCE_EQ(output_shape[3], offset_dims[3],
|
|
|
|
|
"output width must equal to offset map width.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(offset_dims[1] % (filter_dims[2] * filter_dims[3]), 0U,
|
|
|
|
|
"offset filter must divide deformable group size.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(offset_dims[1] / (2 * filter_dims[2] * filter_dims[3]),
|
|
|
|
|
deformable_groups,
|
|
|
|
|
"offset filter must divide deformable group size.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"output width must equal to offset map width. The "
|
|
|
|
|
"diff is [%d] vs [%d]",
|
|
|
|
|
output_shape[3], offset_dims[3]));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
offset_dims[1] % (filter_dims[2] * filter_dims[3]), 0U,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"offset filter must divide deformable group size."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
offset_dims[1] / (2 * filter_dims[2] * filter_dims[3]),
|
|
|
|
|
deformable_groups,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"offset filter must divide deformable group size."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(output_shape[2], mask_dims[2],
|
|
|
|
|
"output height must equal to mask map height.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"output height must equal to mask map height. The "
|
|
|
|
|
"diff is [%d] vs [%d]",
|
|
|
|
|
output_shape[2], mask_dims[2]));
|
|
|
|
|
PADDLE_ENFORCE_EQ(output_shape[3], mask_dims[3],
|
|
|
|
|
"output width must equal to mask map width.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"output width must equal to mask map width. The "
|
|
|
|
|
"diff is [%d] vs [%d]",
|
|
|
|
|
output_shape[3], mask_dims[3]));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(mask_dims[1] % (filter_dims[2] * filter_dims[3]), 0U,
|
|
|
|
|
"mask filter must divide deformable group size.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"mask filter must divide deformable group size."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(mask_dims[1] / (filter_dims[2] * filter_dims[3]),
|
|
|
|
|
deformable_groups,
|
|
|
|
|
"mask filter must divide deformable group size.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"mask filter must divide deformable group size."));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
|
|
|
|
@ -255,8 +285,8 @@ class DeformableConvGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
auto offset_dims = ctx->GetInputDim("Offset");
|
|
|
|
|
auto mask_dims = ctx->GetInputDim("Mask");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Output")),
|
|
|
|
|
"the gradient of output(Out) must not be null");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Output")), "Input",
|
|
|
|
|
"Output@Grad", "deformable_conv_grad");
|
|
|
|
|
if (ctx->HasOutput(framework::GradVarName("Input"))) {
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("Input"), in_dims);
|
|
|
|
|
}
|
|
|
|
|