Fix bug for flowers dataset and row_conv.

gangliao-patch-1
dangqingqing 8 years ago
parent c5dc0b7329
commit 0c70f34c60

@ -2082,10 +2082,10 @@ class MaxOutLayer(LayerBase):
class RowConvLayer(LayerBase):
def __init__(self, name, inputs, context_length, **xargs):
super(RowConvLayer, self).__init__(
name, 'maxout', 0, inputs=inputs, **xargs)
name, 'row_conv', 0, inputs=inputs, **xargs)
config_assert(
len(self.inputs) == 1,
'TransLayer must have one and only one input')
'row convolution layer must have one and only one input.')
input_layer = self.get_input_layer(0)
row_conv_conf = self.config.inputs[0].row_conv_conf
row_conv_conf.context_length = context_length

@ -7,7 +7,7 @@ layers {
}
layers {
name: "__row_conv_layer_0__"
type: "maxout"
type: "row_conv"
size: 2560
active_type: "relu"
inputs {

@ -30,6 +30,7 @@ http://www.robots.ox.ac.uk/~vgg/publications/papers/nilsback08.{pdf,ps.gz}.
"""
import cPickle
import itertools
import functools
from common import download
import tarfile
import scipy.io as scio
@ -54,21 +55,25 @@ TEST_FLAG = 'trnid'
VALID_FLAG = 'valid'
def default_mapper(sample):
def default_mapper(is_train, sample):
'''
map image bytes data to type needed by model input layer
'''
img, label = sample
img = load_image_bytes(img)
img = simple_transform(img, 256, 224, True)
img = simple_transform(img, 256, 224, is_train)
return img.flatten().astype('float32'), label
train_mapper = functools.partial(default_mapper, True)
test_mapper = functools.partial(default_mapper, False)
def reader_creator(data_file,
label_file,
setid_file,
dataset_name,
mapper=default_mapper,
mapper,
buffered_size=1024,
use_xmap=True):
'''
@ -118,7 +123,7 @@ def reader_creator(data_file,
return map_readers(mapper, reader)
def train(mapper=default_mapper, buffered_size=1024, use_xmap=True):
def train(mapper=train_mapper, buffered_size=1024, use_xmap=True):
'''
Create flowers training set reader.
It returns a reader, each sample in the reader is
@ -141,7 +146,7 @@ def train(mapper=default_mapper, buffered_size=1024, use_xmap=True):
buffered_size, use_xmap)
def test(mapper=default_mapper, buffered_size=1024, use_xmap=True):
def test(mapper=test_mapper, buffered_size=1024, use_xmap=True):
'''
Create flowers test set reader.
It returns a reader, each sample in the reader is
@ -164,7 +169,7 @@ def test(mapper=default_mapper, buffered_size=1024, use_xmap=True):
buffered_size, use_xmap)
def valid(mapper=default_mapper, buffered_size=1024, use_xmap=True):
def valid(mapper=test_mapper, buffered_size=1024, use_xmap=True):
'''
Create flowers validation set reader.
It returns a reader, each sample in the reader is

Loading…
Cancel
Save