Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into feature/add_prior_box_py

emailweixu-patch-1
chengduoZH 7 years ago
commit dff1bf33c9

@ -106,9 +106,11 @@ class Vector {
// std::vector iterator methods. Based on CPU data access method // std::vector iterator methods. Based on CPU data access method
size_t size() const { return size_; } size_t size() const { return size_; }
T* begin() { return &this->operator[](0); } T* begin() { return capacity() == 0 ? &EmptyDummy() : &this->operator[](0); }
T* end() { return &this->operator[](size()); } T* end() {
return capacity() == 0 ? &EmptyDummy() : &this->operator[](size());
}
T& front() { return *begin(); } T& front() { return *begin(); }
@ -118,8 +120,13 @@ class Vector {
return *it; return *it;
} }
const T* begin() const { return &this->operator[](0); } const T* begin() const {
const T* end() const { return &this->operator[](size()); } return capacity() == 0 ? &EmptyDummy() : &this->operator[](0);
}
const T* end() const {
return capacity() == 0 ? &EmptyDummy() : &this->operator[](size());
}
const T* cbegin() const { return begin(); } const T* cbegin() const { return begin(); }
@ -358,6 +365,11 @@ class Vector {
} }
} }
static T& EmptyDummy() {
static T dummy = T();
return dummy;
}
mutable int flag_; mutable int flag_;
mutable Tensor cpu_vec_; mutable Tensor cpu_vec_;
mutable Tensor cuda_vec_; mutable Tensor cuda_vec_;

@ -98,3 +98,9 @@ TEST(mixed_vector, InitWithCount) {
ASSERT_EQ(vec[i], 10); ASSERT_EQ(vec[i], 10);
} }
} }
TEST(mixed_vector, ForEach) {
vec<int> tmp;
for (auto& v : tmp) {
}
}

@ -29,6 +29,6 @@ inference_test(image_classification ARGS vgg resnet)
inference_test(label_semantic_roles) inference_test(label_semantic_roles)
inference_test(recognize_digits ARGS mlp) inference_test(recognize_digits ARGS mlp)
inference_test(recommender_system) inference_test(recommender_system)
inference_test(rnn_encoder_decoder) #inference_test(rnn_encoder_decoder)
inference_test(understand_sentiment) inference_test(understand_sentiment)
inference_test(word2vec) inference_test(word2vec)

@ -19,7 +19,6 @@ from ..layer_helper import LayerHelper
from ..framework import Variable from ..framework import Variable
from tensor import concat from tensor import concat
from ops import reshape from ops import reshape
from operator import mul
import math import math
__all__ = [ __all__ = [
@ -143,43 +142,50 @@ def prior_box(inputs,
""" """
**Prior_boxes** **Prior_boxes**
Generate prior boxes for SSD(Single Shot MultiBox Detector) algorithm. Generate prior boxes for SSD(Single Shot MultiBox Detector)
The details of this algorithm, please refer the section 2.2 of SSD paper algorithm. The details of this algorithm, please refer the
(SSD: Single Shot MultiBox Detector)<https://arxiv.org/abs/1512.02325>`_ . section 2.2 of SSD paper (SSD: Single Shot MultiBox Detector)
<https://arxiv.org/abs/1512.02325>`_ .
Args: Args:
inputs(list): The list of input Variables, the format of all Variables is NCHW. inputs(list): The list of input Variables, the format
image(Variable): The input image data of PriorBoxOp, the layout is NCHW. of all Variables is NCHW.
image(Variable): The input image data of PriorBoxOp,
the layout is NCHW.
min_ratio(int): the min ratio of generated prior boxes. min_ratio(int): the min ratio of generated prior boxes.
max_ratio(int): the max ratio of generated prior boxes. max_ratio(int): the max ratio of generated prior boxes.
aspect_ratios(list): the aspect ratios of generated prior boxes. aspect_ratios(list): the aspect ratios of generated prior
The length of input and aspect_ratios must be equal. boxes. The length of input and aspect_ratios must be equal.
base_size(int): the base_size is used to get min_size and max_size base_size(int): the base_size is used to get min_size
according to min_ratio and max_ratio. and max_size according to min_ratio and max_ratio.
step_w(list, optional, default=None): Prior boxes step across width. step_w(list, optional, default=None): Prior boxes step
If step_w[i] == 0.0, the prior boxes step across width of the inputs[i] across width. If step_w[i] == 0.0, the prior boxes step
will be automatically calculated. across width of the inputs[i] will be automatically calculated.
step_h(list, optional, default=None): Prior boxes step across height, step_h(list, optional, default=None): Prior boxes step
If step_h[i] == 0.0, the prior boxes step across height of the inputs[i] across height, If step_h[i] == 0.0, the prior boxes
will be automatically calculated. step across height of the inputs[i] will be automatically calculated.
offset(float, optional, default=0.5): Prior boxes center offset. offset(float, optional, default=0.5): Prior boxes center offset.
variance(list, optional, default=[0.1, 0.1, 0.1, 0.1]): the variances variance(list, optional, default=[0.1, 0.1, 0.1, 0.1]): the variances
to be encoded in prior boxes. to be encoded in prior boxes.
flip(bool, optional, default=False): Whether to flip aspect ratios. flip(bool, optional, default=False): Whether to flip
clip(bool, optional, default=False): Whether to clip out-of-boundary boxes. aspect ratios.
min_sizes(list, optional, default=None): If `len(inputs) <=2`, min_sizes must clip(bool, optional, default=False): Whether to clip
be set up, and the length of min_sizes should equal to the length of inputs. out-of-boundary boxes.
max_sizes(list, optional, default=None): If `len(inputs) <=2`, max_sizes must min_sizes(list, optional, default=None): If `len(inputs) <=2`,
be set up, and the length of min_sizes should equal to the length of inputs. min_sizes must be set up, and the length of min_sizes
should equal to the length of inputs.
max_sizes(list, optional, default=None): If `len(inputs) <=2`,
max_sizes must be set up, and the length of min_sizes
should equal to the length of inputs.
name(str, optional, None): Name of the prior box layer. name(str, optional, None): Name of the prior box layer.
Returns: Returns:
boxes(Variable): the output prior boxes of PriorBoxOp. The layout is boxes(Variable): the output prior boxes of PriorBoxOp.
[num_priors, 4]. num_priors is the total box count of each The layout is [num_priors, 4]. num_priors is the total
position of inputs. box count of each position of inputs.
Variances(Variable): the expanded variances of PriorBoxOp. The layout Variances(Variable): the expanded variances of PriorBoxOp.
is [num_priors, 4]. num_priors is the total box count of each The layout is [num_priors, 4]. num_priors is the total
position of inputs box count of each position of inputs
Examples: Examples:
.. code-block:: python .. code-block:: python
@ -235,10 +241,11 @@ def prior_box(inputs,
def _reshape_with_axis_(input, axis=1): def _reshape_with_axis_(input, axis=1):
if not (axis > 0 and axis < len(input.shape)): if not (axis > 0 and axis < len(input.shape)):
raise ValueError( raise ValueError("The axis should be smaller than "
"The axis should be smaller than the arity of input and bigger than 0." "the arity of input and bigger than 0.")
) new_shape = [
new_shape = [-1, reduce(mul, input.shape[axis:len(input.shape)], 1)] -1, reduce(lambda x, y: x * y, input.shape[axis:len(input.shape)])
]
out = reshape(x=input, shape=new_shape) out = reshape(x=input, shape=new_shape)
return out return out

@ -54,8 +54,12 @@ class TestBook(unittest.TestCase):
class TestPriorBox(unittest.TestCase): class TestPriorBox(unittest.TestCase):
def test_prior_box(self): def test_prior_box(self):
self.check_prior_box(use_cuda=False) data_shape = [3, 224, 224]
self.check_prior_box(use_cuda=True) box, var = self.prior_box_output(data_shape)
assert len(box.shape) == 2
assert box.shape == var.shape
assert box.shape[1] == 4
def prior_box_output(self, data_shape): def prior_box_output(self, data_shape):
images = fluid.layers.data( images = fluid.layers.data(
@ -104,32 +108,6 @@ class TestPriorBox(unittest.TestCase):
clip=True) clip=True)
return box, var return box, var
def check_prior_box(self, use_cuda):
if use_cuda: # prior_box only support CPU.
return
data_shape = [3, 224, 224]
box, var = self.prior_box_output(data_shape)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
batch = [4] # batch is not used in the prior_box.
assert box.shape[1] == 4
assert var.shape[1] == 4
assert box.shape == var.shape
assert len(box.shape) == 2
x = np.random.random(batch + data_shape).astype("float32")
tensor_x = core.LoDTensor()
tensor_x.set(x, place)
boxes, vars = exe.run(fluid.default_main_program(),
feed={'pixel': tensor_x},
fetch_list=[box, var])
assert vars.shape == var.shape
assert boxes.shape == box.shape
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

Loading…
Cancel
Save