|
|
|
@ -35,7 +35,10 @@ inline std::vector<int> get_new_data(
|
|
|
|
|
auto tensor = list_new_tensor[i];
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
tensor->dims(), framework::make_ddim({1}),
|
|
|
|
|
"The tensor's shape in list of Op(crop_tensor) should be [1].");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The tensor's shape in list of Op(crop_tensor) should be [1], "
|
|
|
|
|
"but the value received is %d.",
|
|
|
|
|
tensor->dims()));
|
|
|
|
|
if (platform::is_gpu_place(tensor->place())) {
|
|
|
|
|
framework::Tensor temp;
|
|
|
|
|
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
|
|
|
|
@ -56,18 +59,23 @@ static framework::DDim ValidateShape(const std::vector<int> shape,
|
|
|
|
|
auto shape_size = shape.size();
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
in_dim_size, shape_size,
|
|
|
|
|
"Attr(shape)'s size of Op(crop_tensor) should be equal "
|
|
|
|
|
"to that of input Tensor. "
|
|
|
|
|
"Please check the Attr(shape)'s size of Op(fluid.layers.crop_tensor).");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The number of elements (%d) for shape of Op(crop_tensor) should be "
|
|
|
|
|
"equal to the number of dimensions (%d) of the input tensor.",
|
|
|
|
|
shape_size, in_dim_size));
|
|
|
|
|
std::vector<int64_t> output_shape(shape.size(), 0);
|
|
|
|
|
for (size_t i = 0; i < shape.size(); ++i) {
|
|
|
|
|
if (shape[i] <= 0 && in_dims[i] > 0) {
|
|
|
|
|
PADDLE_ENFORCE_NE(
|
|
|
|
|
shape[i], 0,
|
|
|
|
|
"The element in Attr(shape) of Op(crop_tensor) should not be zero.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(shape[i], -1,
|
|
|
|
|
"When the element in Attr(shape) of Op(crop_tensor) is "
|
|
|
|
|
"negative, only -1 is supported.");
|
|
|
|
|
PADDLE_ENFORCE_NE(shape[i], 0,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The value (%d) of the %uth element for shape of "
|
|
|
|
|
"Op(crop_tensor) should not be zero.",
|
|
|
|
|
shape[i], i));
|
|
|
|
|
PADDLE_ENFORCE_EQ(shape[i], -1, platform::errors::InvalidArgument(
|
|
|
|
|
"When the value (%d) of the %uth "
|
|
|
|
|
"element for shape of Op(crop_tensor)"
|
|
|
|
|
" is negative, only -1 is supported.",
|
|
|
|
|
shape[i], i));
|
|
|
|
|
output_shape[i] = in_dims[i] - offsets[i];
|
|
|
|
|
} else {
|
|
|
|
|
output_shape[i] = static_cast<int64_t>(shape[i]);
|
|
|
|
@ -83,9 +91,13 @@ static std::vector<int> GetShape(const framework::ExecutionContext& ctx) {
|
|
|
|
|
auto list_new_shape_tensor = ctx.MultiInput<framework::Tensor>("ShapeTensor");
|
|
|
|
|
if (list_new_shape_tensor.size() > 0) {
|
|
|
|
|
// have offsets tensor list
|
|
|
|
|
PADDLE_ENFORCE_EQ(list_new_shape_tensor.size(), rank,
|
|
|
|
|
"Input(ShapeTensor)'s length of Op(crop_tensor) should "
|
|
|
|
|
"be equal to dimension size of input tensor.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
list_new_shape_tensor.size(), rank,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The number of tensors (%d) for the input ShapeTensor of "
|
|
|
|
|
"Op(crop_tensor) must be equal to the number of "
|
|
|
|
|
"dimensions (%d) of the input.",
|
|
|
|
|
list_new_shape_tensor.size(), rank));
|
|
|
|
|
res = get_new_data(list_new_shape_tensor);
|
|
|
|
|
|
|
|
|
|
return res;
|
|
|
|
@ -122,13 +134,21 @@ static std::vector<int> GetOffsets(const framework::ExecutionContext& ctx) {
|
|
|
|
|
if (ctx.HasInput("Offsets")) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx.Attr<std::vector<int>>("offsets").empty(), true,
|
|
|
|
|
"Input 'Offsets' and attribute 'offsets' should not be used "
|
|
|
|
|
"at the same time.");
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input 'Offsets' and attribute 'offsets' for Op(crop_tensor) "
|
|
|
|
|
"cannot be used at the same time."));
|
|
|
|
|
const auto* offsets_tensor = ctx.Input<Tensor>("Offsets");
|
|
|
|
|
PADDLE_ENFORCE_EQ(offsets_tensor->dims().size(), 1);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
rank, offsets_tensor->dims()[0],
|
|
|
|
|
"Offsets size should be equal to dimension size of input tensor.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(offsets_tensor->dims().size(), 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The number of dimensions of input 'Offsets' must "
|
|
|
|
|
"be 1, but the value received is: %d.",
|
|
|
|
|
offsets_tensor->dims().size()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(rank, offsets_tensor->dims()[0],
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The number of elements (%d) for "
|
|
|
|
|
"input 'Offsets' must be equal to "
|
|
|
|
|
"the number of dimensions (%d) of the input tensor.",
|
|
|
|
|
offsets_tensor->dims()[0], rank));
|
|
|
|
|
const int* offsets_data;
|
|
|
|
|
framework::Tensor cpu_tmp_tensor;
|
|
|
|
|
if (platform::is_cpu_place(offsets_tensor->place())) {
|
|
|
|
@ -143,7 +163,11 @@ static std::vector<int> GetOffsets(const framework::ExecutionContext& ctx) {
|
|
|
|
|
res = ctx.Attr<std::vector<int>>("offsets");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
rank, static_cast<int>(res.size()),
|
|
|
|
|
"Offsets size should be equal to dimension size of input tensor.");
|
|
|
|
|
platform::errors::InvalidArgument("The number of elements (%d) for "
|
|
|
|
|
"input 'Offsets' must be equal to "
|
|
|
|
|
"the number of dimensions (%d) "
|
|
|
|
|
"of the input tensor.",
|
|
|
|
|
static_cast<int>(res.size()), rank));
|
|
|
|
|
}
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
@ -168,10 +192,13 @@ void CropTensorFunction(const framework::ExecutionContext& context) {
|
|
|
|
|
out_dims = ValidateShape(shape, offsets, x->dims());
|
|
|
|
|
out->mutable_data<T>(out_dims, context.GetPlace());
|
|
|
|
|
for (size_t i = 0; i < offsets.size(); ++i) {
|
|
|
|
|
PADDLE_ENFORCE_LE(
|
|
|
|
|
offsets[i] + shape[i], x_dims[i],
|
|
|
|
|
"The sum of the Attr(offsets) and Attr(shape) of Op(crop_tensor) "
|
|
|
|
|
"should be less than or equal to corresponding input dimension size.");
|
|
|
|
|
PADDLE_ENFORCE_LE(offsets[i] + shape[i], x_dims[i],
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The sum of the %uth elements of "
|
|
|
|
|
"offsets (%d) and shape (%d) of Op(crop_tensor) "
|
|
|
|
|
"should be less than or "
|
|
|
|
|
"equal to the size of %uth dimension of the input.",
|
|
|
|
|
i, offsets[i], shape[i], i));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto x_tensor = EigenTensor<T, D>::From(*x);
|
|
|
|
@ -192,6 +219,19 @@ class CropTensorKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
int rank = context.Input<Tensor>("X")->dims().size();
|
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
rank, 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The number of dimensions of the input 'x' for "
|
|
|
|
|
"Op(crop_tensor) must be greater than or equal to 1, but the "
|
|
|
|
|
"value received is %d.",
|
|
|
|
|
rank));
|
|
|
|
|
PADDLE_ENFORCE_LE(
|
|
|
|
|
rank, 6, platform::errors::InvalidArgument(
|
|
|
|
|
"The number of dimensions of the input 'x' for "
|
|
|
|
|
"Op(crop_tensor) must be less than or equal to 6, but the "
|
|
|
|
|
"value received is %d.",
|
|
|
|
|
rank));
|
|
|
|
|
switch (rank) {
|
|
|
|
|
case 1:
|
|
|
|
|
CropTensorFunction<DeviceContext, T, 1>(context);
|
|
|
|
@ -211,10 +251,6 @@ class CropTensorKernel : public framework::OpKernel<T> {
|
|
|
|
|
case 6:
|
|
|
|
|
CropTensorFunction<DeviceContext, T, 6>(context);
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
"CropTensorOp only support tensors with no more than 6 "
|
|
|
|
|
"dimensions.");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -246,6 +282,20 @@ class CropTensorGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
size_t rank =
|
|
|
|
|
context.Input<Tensor>(framework::GradVarName("Out"))->dims().size();
|
|
|
|
|
PADDLE_ENFORCE_GE(
|
|
|
|
|
rank, 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The number of dimensions of the input 'Out@GRAD' for "
|
|
|
|
|
"Op(crop_tensor_grad) must be greater than or equal to 1, but the "
|
|
|
|
|
"value received is %d.",
|
|
|
|
|
rank));
|
|
|
|
|
PADDLE_ENFORCE_LE(
|
|
|
|
|
rank, 6,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The number of dimensions of the input 'Out@GRAD' for "
|
|
|
|
|
"Op(crop_tensor_grad) must be less than or equal to 6, but the "
|
|
|
|
|
"value received is %d.",
|
|
|
|
|
rank));
|
|
|
|
|
switch (rank) {
|
|
|
|
|
case 1:
|
|
|
|
|
CropTensorGradFunction<DeviceContext, T, 1>(context);
|
|
|
|
@ -265,10 +315,6 @@ class CropTensorGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
case 6:
|
|
|
|
|
CropTensorGradFunction<DeviceContext, T, 6>(context);
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
"CropTensorOp only support tensors with no more than 6 "
|
|
|
|
|
"dimensions.");
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|