|
|
|
|
@ -1971,13 +1971,14 @@ class DetectionOutputLayer(LayerBase):
|
|
|
|
|
|
|
|
|
|
@config_layer('roi_pool')
|
|
|
|
|
class ROIPoolLayer(LayerBase):
|
|
|
|
|
def __init__(self, name, inputs, pooled_width, pooled_height,
|
|
|
|
|
spatial_scale):
|
|
|
|
|
def __init__(self, name, inputs, pooled_width, pooled_height, spatial_scale,
|
|
|
|
|
num_channels, **xargs):
|
|
|
|
|
super(ROIPoolLayer, self).__init__(name, 'roi_pool', 0, inputs)
|
|
|
|
|
config_assert(len(inputs) == 2, 'ROIPoolLayer must have 2 inputs')
|
|
|
|
|
self.config.inputs[0].roi_pool_conf.pooled_width = pooled_width
|
|
|
|
|
self.config.inputs[0].roi_pool_conf.pooled_height = pooled_height
|
|
|
|
|
self.config.inputs[0].roi_pool_conf.spatial_scale = spatial_scale
|
|
|
|
|
self.set_cnn_layer(name, pooled_height, pooled_width, num_channels)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@config_layer('data')
|
|
|
|
|
|