add crop_tensor_op, test=develop, test=document_preview (#19314)

add crop_tensor op. The main difference with crop is :

1. If the argument shape is a list, each element is an integer or a tensor variable with shape: [1]. This way is suitable for the case that the shape may be changed each iteration.

2. If the argument shape is a variable. Its rank must be 1. In crop op, the rank of shape must be the same as x

offsets can be a list, in which each element is an integer or a tensor variavle with shape: [1].
expand_as_op_1
Zhang Ting 6 years ago committed by Aurelius84
parent bf8367367e
commit b38889413d

@ -204,7 +204,8 @@ paddle.fluid.layers.mean_iou (ArgSpec(args=['input', 'label', 'num_classes'], va
paddle.fluid.layers.relu (ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '0942c174f4f6fb274976d4357356f6a2'))
paddle.fluid.layers.selu (ArgSpec(args=['x', 'scale', 'alpha', 'name'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', 'f93c61f5b0bf933cd425a64dca2c4fdd'))
paddle.fluid.layers.log (ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '02f668664e3bfc4df6c00d7363467140'))
paddle.fluid.layers.crop (ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', 'ddf9837ee83e549119210a3d714d5f44'))
paddle.fluid.layers.crop (ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', 'ba3621917d5beffd3d022b88fbf6dc46'))
paddle.fluid.layers.crop_tensor (ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None)), ('document', 'cb855453e3506bf54c5c013616ffddfb'))
paddle.fluid.layers.rank_loss (ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '8eb36596bb43d7a907d3397c7aedbdb3'))
paddle.fluid.layers.margin_rank_loss (ArgSpec(args=['label', 'left', 'right', 'margin', 'name'], varargs=None, keywords=None, defaults=(0.1, None)), ('document', '6fc86ed23b420c8a0f6c043563cf3937'))
paddle.fluid.layers.elu (ArgSpec(args=['x', 'alpha', 'name'], varargs=None, keywords=None, defaults=(1.0, None)), ('document', '9af1926c06711eacef9e82d7a9e4d308'))

File diff suppressed because it is too large Load Diff

@ -0,0 +1,24 @@
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/crop_tensor_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
crop_tensor,
ops::CropTensorKernel<paddle::platform::CUDADeviceContext, float>,
ops::CropTensorKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
crop_tensor_grad,
ops::CropTensorGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::CropTensorGradKernel<paddle::platform::CUDADeviceContext, double>);

File diff suppressed because it is too large Load Diff

@ -133,6 +133,7 @@ __all__ = [
'selu',
'log',
'crop',
'crop_tensor',
'rank_loss',
'margin_rank_loss',
'elu',
@ -9119,6 +9120,11 @@ def crop(x, shape=None, offsets=None, name=None):
"""
Crop input into output, as specified by offsets and shape.
**Warning:** THIS FUNCTION IS DEPRECATED. It will be removed in a future version.
Instructions for updating: Use `fluid.layers.crop_tensor
<https://www.paddlepaddle.org.cn/documentation/docs/en/api/layers/nn.html#crop_tensor>`_
instead.
.. code-block:: text
* Case 1:
@ -9150,16 +9156,16 @@ def crop(x, shape=None, offsets=None, name=None):
Args:
x (Variable): The input tensor variable.
shape (Variable|list/tuple of integer): The output shape is specified
by `shape`, which can a Variable or a list/tupe of integer.
by `shape`, which can be a Variable or a list/tuple of integer.
If a tensor Variable, it's rank must be the same as `x`. This way
is suitable for the case that the output shape may be changed each
iteration. If a list/tupe of integer, it's length must be the same
iteration. If a list/tuple of integer, it's length must be the same
as the rank of `x`
offsets (Variable|list/tuple of integer|None): Specifies the cropping
offsets at each dimension. It can be a Variable or or a list/tupe
offsets at each dimension. It can be a Variable or a list/tuple
of integers. If a tensor Variable, it's rank must be the same as `x`.
This way is suitable for the case that the offsets may be changed
each iteration. If a list/tupe of integer, it's length must be the
each iteration. If a list/tuple of integer, it's length must be the
same as the rank of `x`. If None, the offsets are 0 at each
dimension.
name(str|None): A name for this layer(optional). If set None, the layer
@ -9214,6 +9220,188 @@ def crop(x, shape=None, offsets=None, name=None):
return out
def crop_tensor(x, shape=None, offsets=None, name=None):
"""
Crop input into output, as specified by offsets and shape.
.. code-block:: text
* Case 1:
Given
X = [[0, 1, 2, 0, 0]
[0, 3, 4, 0, 0]
[0, 0, 0, 0, 0]],
and
shape = [2, 2],
offsets = [0, 1],
output is:
Out = [[1, 2],
[3, 4]].
* Case 2:
Given
X = [[[0, 1, 2, 3]
[0, 5, 6, 7]
[0, 0, 0, 0]],
[[0, 3, 4, 5]
[0, 6, 7, 8]
[0, 0, 0, 0]]].
and
shape = [2, 2, 3],
offsets = [0, 0, 1],
output is:
Out = [[[1, 2, 3]
[5, 6, 7]],
[[3, 4, 5]
[6, 7, 8]]].
Args:
x (Variable): The input tensor variable.
shape (Variable|list|tuple of integer): The output shape is specified
by `shape`. It can be a 1-D tensor Variable or a list/tuple. If a
1-D tensor Variable, it's rank must be the same as `x`. If a
list/tuple, it's length must be the same as the rank of `x`. Each
element of list can be an integer or a tensor Variable of shape: [1].
If Variable contained, it is suitable for the case that the shape may
be changed each iteration. Only the first element of list/tuple can be
set to -1, it means that the first dimension of the output is the same
as the input.
offsets (Variable|list|tuple of integer|None): Specifies the cropping
offsets at each dimension. It can be a 1-D tensor Variable or a list/tuple.
If a 1-D tensor Variable, it's rank must be the same as `x`. If a list/tuple,
it's length must be the same as the rank of `x`. Each element of list can be
an integer or a tensor Variable of shape: [1]. If Variable contained, it is
suitable for the case that the offsets may be changed each iteration. If None,
the offsets are 0 at each dimension.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns:
Variable: The cropped tensor variable.
Raises:
ValueError: If shape is not a list, tuple or Variable.
ValueError: If offsets is not None and not a list, tuple or Variable.
Examples:
.. code-block:: python
import paddle.fluid as fluid
x = fluid.layers.data(name="x", shape=[3, 5], dtype="float32")
# x.shape = [-1, 3, 5], where -1 indicates batch size, and it will get the exact value in runtime.
# shape is a 1-D tensor variable
crop_shape = fluid.layers.data(name="crop_shape", shape=[3], dtype="int32", append_batch_size=False)
crop0 = fluid.layers.crop_tensor(x, shape=crop_shape)
# crop0.shape = [-1, -1, -1], it means crop0.shape[0] = x.shape[0] in runtime.
# or shape is a list in which each element is a constant
crop1 = fluid.layers.crop_tensor(x, shape=[-1, 2, 3])
# crop1.shape = [-1, 2, 3]
# or shape is a list in which each element is a constant or variable
y = fluid.layers.data(name="y", shape=[3, 8, 8], dtype="float32")
dim1 = fluid.layers.data(name="dim1", shape=[1], dtype="int32", append_batch_size=False)
crop2 = fluid.layers.crop_tensor(y, shape=[-1, 3, dim1, 4])
# crop2.shape = [-1, 3, -1, 4]
# offsets is a 1-D tensor variable
crop_offsets = fluid.layers.data(name="crop_offsets", shape=[3], dtype="int32", append_batch_size=False)
crop3 = fluid.layers.crop_tensor(x, shape=[-1, 2, 3], offsets=crop_offsets)
# crop3.shape = [-1, 2, 3]
# offsets is a list in which each element is a constant or variable
offsets_var = fluid.layers.data(name="dim1", shape=[1], dtype="int32", append_batch_size=False)
crop4 = fluid.layers.crop_tensor(x, shape=[-1, 2, 3], offsets=[0, 1, offsets_var])
# crop4.shape = [-1, 2, 3]
"""
helper = LayerHelper('crop_tensor', **locals())
if not (isinstance(shape, list) or isinstance(shape, tuple) or \
isinstance(shape, Variable)):
raise ValueError("The shape should be a list, tuple or Variable.")
if offsets is None:
offsets = [0] * len(x.shape)
if not (isinstance(offsets, list) or isinstance(offsets, tuple) or \
isinstance(offsets, Variable)):
raise ValueError("The offsets should be a list, tuple or Variable.")
out = helper.create_variable_for_type_inference(x.dtype)
ipts = {'X': x}
attrs = {}
def contain_var(input_list):
for ele in input_list:
if isinstance(ele, Variable):
return True
return False
if isinstance(offsets, Variable):
offsets.stop_gradient = True
ipts['Offsets'] = offsets
elif contain_var(offsets):
new_offsets_tensor = []
for dim in offsets:
if isinstance(dim, Variable):
dim.stop_gradient = True
new_offsets_tensor.append(dim)
else:
assert (isinstance(dim, int))
assert dim >= 0, ("offsets should be greater or equal to zero.")
temp_out = helper.create_variable_for_type_inference('int32')
fill_constant([1], 'int32', dim, force_cpu=True, out=temp_out)
new_offsets_tensor.append(temp_out)
ipts['OffsetsTensor'] = new_offsets_tensor
else:
attrs['offsets'] = offsets
unk_dim_idx = -1
if isinstance(shape, Variable):
shape.stop_gradient = True
ipts['Shape'] = shape
elif contain_var(shape):
new_shape_tensor = []
shape_attr = []
for dim_idx, dim_size in enumerate(shape):
if isinstance(dim_size, Variable):
dim_size.stop_gradient = True
new_shape_tensor.append(dim_size)
shape_attr.append(-1)
else:
assert (isinstance(dim_size, int))
if dim_size == -1:
assert unk_dim_idx == -1, (
"Only one element in shape can be unknown.")
assert dim_idx == 0, (
"Only the first element in shape can be -1.")
unk_dim_idx = dim_idx
else:
assert dim_size > 0, (
"Each dimension size given in shape must be greater than zero."
)
temp_out = helper.create_variable_for_type_inference('int32')
fill_constant(
[1], 'int32', dim_size, force_cpu=True, out=temp_out)
new_shape_tensor.append(temp_out)
shape_attr.append(dim_size)
ipts['ShapeTensor'] = new_shape_tensor
attrs['shape'] = shape_attr
else:
attrs['shape'] = shape
helper.append_op(
type='crop_tensor',
inputs=ipts,
outputs={'Out': out},
attrs=None if len(attrs) == 0 else attrs)
return out
def affine_grid(theta, out_shape, name=None):
"""
It generates a grid of (x,y) coordinates using the parameters of

@ -0,0 +1,218 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
def crop(data, offsets, crop_shape):
def indexOf(shape, index):
result = []
for dim in reversed(shape):
result.append(index % dim)
index = index / dim
return result[::-1]
result = []
for i, value in enumerate(data.flatten()):
index = indexOf(data.shape, i)
selected = True
if len(index) == len(offsets):
for j, offset in enumerate(offsets):
selected = selected and index[j] >= offset and index[
j] < crop_shape[j] + offset
if selected:
result.append(value)
return np.array(result).reshape(crop_shape)
class TestCropTensorOp(OpTest):
def setUp(self):
self.op_type = "crop_tensor"
self.crop_by_1D_shape = False
self.offset_by_input = False
self.unk_dim_idx = -1
self.attrs = {}
self.initTestCase()
if self.crop_by_1D_shape:
self.inputs = {
'X': np.random.random(self.x_shape).astype("float32"),
'Shape': np.array(self.crop_shape).astype("int32")
}
else:
self.attrs['shape'] = self.crop_shape
self.inputs = {
'X': np.random.random(self.x_shape).astype("float32"),
}
if self.offset_by_input:
self.inputs['Offsets'] = np.array(self.offsets).astype('int32')
else:
self.attrs['offsets'] = self.offsets
if self.unk_dim_idx != -1:
self.crop_shape[self.unk_dim_idx] = self.x_shape[self.unk_dim_idx]
self.outputs = {
'Out': crop(self.inputs['X'], self.offsets, self.crop_shape)
}
def initTestCase(self):
self.x_shape = (8, 8)
self.crop_shape = [2, 2]
self.offsets = [1, 2]
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X'], 'Out', max_relative_error=0.006)
class TestCase1(TestCropTensorOp):
def initTestCase(self):
self.x_shape = (100)
self.crop_shape = [64]
self.offsets = [13]
class TestCase2(TestCropTensorOp):
def initTestCase(self):
self.x_shape = (12, 24)
self.crop_shape = [-1, 8] #only the first dimension (batch) can be -1
self.offsets = [0, 0]
self.unk_dim_idx = 0
class TestCase3(TestCropTensorOp):
def initTestCase(self):
self.x_shape = (4, 8, 16)
self.crop_shape = [2, 2, 3]
self.offsets = [1, 5, 3]
self.crop_by_1D_shape = True
class TestCase4(TestCropTensorOp):
def initTestCase(self):
self.x_shape = (8, 3, 6, 6)
self.crop_shape = [-1, 3, 4, 4]
self.offsets = [0, 0, 0, 0]
self.crop_by_1D_shape = True
self.unk_dim_idx = 0
class TestCase5(TestCropTensorOp):
def initTestCase(self):
self.x_shape = (2, 4, 5, 8, 8)
self.crop_shape = [1, 1, 2, 4, 4]
self.offsets = [1, 0, 0, 2, 2]
self.offset_by_input = True
class TestCase6(TestCropTensorOp):
def initTestCase(self):
self.x_shape = (2, 2, 4, 4, 4, 2)
self.crop_shape = [1, 1, 4, 2, 2, 2]
self.offsets = [0, 0, 0, 0, 0, 0]
self.crop_by_1D_shape = True
self.offset_by_input = True
class TestCropTensorOp_attr_tensor(OpTest):
def setUp(self):
self.op_type = "crop_tensor"
self.mixed_type = False
self.OffsetsTensor = False
self.ShapeTensor = True
self.attrs = {}
self.initTestCase()
if self.ShapeTensor:
shape_tensor = []
for index, ele in enumerate(self.crop_shape):
shape_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs = {
'X': np.random.random(self.x_shape).astype("float32"),
'ShapeTensor': shape_tensor
}
if self.mixed_type:
self.attrs['shape'] = self.shape_attr
if self.OffsetsTensor:
offsets_tensor = []
for index, ele in enumerate(self.offsets):
offsets_tensor.append(("x" + str(index), np.ones(
(1)).astype('int32') * ele))
self.inputs = {
'X': np.random.random(self.x_shape).astype("float32"),
'OffsetsTensor': offsets_tensor
}
else:
self.attrs['offsets'] = self.offsets
self.outputs = {
'Out': crop(self.inputs['X'], self.offsets, self.crop_shape)
}
def initTestCase(self):
self.x_shape = (8, 8)
self.crop_shape = (2, 2)
self.offsets = [1, 2]
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(["X"], "Out", max_relative_error=0.006)
class TestCropTensorOp_attr_tensor_case1(TestCropTensorOp_attr_tensor):
def init_data(self):
self.x_shape = (16, 8, 32)
self.crop_shape = [2, 2, 3]
self.offsets = [1, 5, 3]
class TestCropTensorOp_attr_tensor_case2(TestCropTensorOp_attr_tensor):
def init_data(self):
self.x_shape = (4, 8, 16, 8)
self.crop_shape = [2, 2, 3, 4]
self.offsets = [1, 5, 3, 0]
self.shape_attr = [-1, -1, 3, 4]
self.mixed_type = True
class TestCropTensorOp_attr_tensor_case3(TestCropTensorOp_attr_tensor):
def init_data(self):
self.x_shape = (16, 8, 32)
self.crop_shape = [2, 2, 3]
self.offsets = [1, 5, 3]
self.ShapeTensor = False
self.OffsetsTensor = True
class TestCropTensorOp_attr_tensor_case4(TestCropTensorOp_attr_tensor):
def init_data(self):
self.x_shape = (16, 8, 32)
self.crop_shape = [2, 2, 3]
self.offsets = [1, 5, 3]
self.OffsetsTensor = True
if __name__ == '__main__':
unittest.main()

@ -1100,6 +1100,34 @@ class TestLayer(LayerTest):
for i in range(len(static_ret5)):
self.assertTrue(dcond5.numpy()[i] == static_ret5[i])
def test_crop_tensor(self):
with self.static_graph():
x = fluid.layers.data(name="x1", shape=[6, 5, 8])
dim1 = fluid.layers.data(
name="dim1", shape=[1], append_batch_size=False)
dim2 = fluid.layers.data(
name="dim2", shape=[1], append_batch_size=False)
crop_shape1 = (1, 2, 4, 4)
crop_shape2 = fluid.layers.data(
name="crop_shape", shape=[4], append_batch_size=False)
crop_shape3 = [-1, dim1, dim2, 4]
crop_offsets1 = [0, 0, 1, 0]
crop_offsets2 = fluid.layers.data(
name="crop_offset", shape=[4], append_batch_size=False)
crop_offsets3 = [0, dim1, dim2, 0]
out1 = fluid.layers.crop_tensor(
x, shape=crop_shape1, offsets=crop_offsets1)
out2 = fluid.layers.crop_tensor(
x, shape=crop_shape2, offsets=crop_offsets2)
out3 = fluid.layers.crop_tensor(
x, shape=crop_shape3, offsets=crop_offsets3)
self.assertIsNotNone(out1)
self.assertIsNotNone(out2)
self.assertIsNotNone(out3)
class TestBook(LayerTest):
def test_all_layers(self):

Loading…
Cancel
Save