Merge branch 'develop' into fix_config_parsing_bug

gangliao-patch-1
caoying03 8 years ago
commit 892b1f9ff6

@ -30,7 +30,8 @@ RUN apt-get update && \
python-numpy python-matplotlib gcc g++ \
automake locales clang-format-3.8 swig doxygen cmake \
liblapack-dev liblapacke-dev libboost-dev \
clang-3.8 llvm-3.8 libclang-3.8-dev && \
clang-3.8 llvm-3.8 libclang-3.8-dev \
net-tools && \
apt-get clean -y
# Install Go

@ -59,7 +59,7 @@ macro(add_style_check_target TARGET_NAME)
"--filter=${STYLE_FILTER}"
"--write-success=${CUR_GEN}" ${filename}
DEPENDS ${filename}
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR})
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
endif()
endforeach()
endif()

@ -11,23 +11,16 @@ find_path(CUDNN_INCLUDE_DIR cudnn.h
get_filename_component(__libpath_hist ${CUDA_CUDART_LIBRARY} PATH)
if(NOT ${CMAKE_HOST_SYSTEM_PROCESSOR})
execute_process(
COMMAND uname -m COMMAND tr -d '\n'
OUTPUT_VARIABLE HOST_ARCH
RESULT_VARIABLE UNAME_RESULT)
if(${UNAME_RESULT})
set(HOST_ARCH "x86_64")
endif(${UNAME_RESULT})
else(NOT ${CMAKE_HOST_SYSTEM_PROCESSOR})
set(HOST_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR})
endif(NOT ${CMAKE_HOST_SYSTEM_PROCESSOR})
set(TARGET_ARCH "x86_64")
if(NOT ${CMAKE_SYSTEM_PROCESSOR})
set(TARGET_ARCH ${CMAKE_SYSTEM_PROCESSOR})
endif()
list(APPEND CUDNN_CHECK_LIBRARY_DIRS
${CUDNN_ROOT}
${CUDNN_ROOT}/lib64
${CUDNN_ROOT}/lib
${CUDNN_ROOT}/lib/${HOST_ARCH}-linux-gnu
${CUDNN_ROOT}/lib/${TARGET_ARCH}-linux-gnu
$ENV{CUDNN_ROOT}
$ENV{CUDNN_ROOT}/lib64
$ENV{CUDNN_ROOT}/lib

@ -24,20 +24,25 @@ IF(NOT ${CBLAS_FOUND})
SET(CBLAS_LIBRARIES "${CBLAS_INSTALL_DIR}/lib/${LIBRARY_PREFIX}openblas${STATIC_LIBRARY_SUFFIX}"
CACHE FILEPATH "openblas library." FORCE)
SET(COMMON_ARGS CC=${CMAKE_C_COMPILER} NO_SHARED=1 NO_LAPACK=1)
SET(COMMON_ARGS CC=${CMAKE_C_COMPILER} NO_SHARED=1 NO_LAPACK=1 libs)
IF(ANDROID)
# arm_soft_fp_abi branch of OpenBLAS to support softfp
# https://github.com/xianyi/OpenBLAS/tree/arm_soft_fp_abi
SET(OPENBLAS_COMMIT "b5c96fcfcdc82945502a2303116a64d89985daf5")
SET(OPTIONAL_ARGS HOSTCC=${HOST_C_COMPILER} TARGET=ARMV7 ARM_SOFTFP_ABI=1 USE_THREAD=0 libs)
ELSEIF(RPI)
# use hardfp
SET(OPENBLAS_COMMIT "v0.2.19")
SET(OPTIONAL_ARGS HOSTCC=${HOST_C_COMPILER} TARGET=ARMV7 USE_THREAD=0 libs)
IF(CMAKE_CROSSCOMPILING)
IF(ANDROID)
# arm_soft_fp_abi branch of OpenBLAS to support softfp
# https://github.com/xianyi/OpenBLAS/tree/arm_soft_fp_abi
SET(OPENBLAS_COMMIT "b5c96fcfcdc82945502a2303116a64d89985daf5")
SET(OPTIONAL_ARGS HOSTCC=${HOST_C_COMPILER} TARGET=ARMV7 ARM_SOFTFP_ABI=1 USE_THREAD=0)
ELSEIF(RPI)
# use hardfp
SET(OPENBLAS_COMMIT "v0.2.19")
SET(OPTIONAL_ARGS HOSTCC=${HOST_C_COMPILER} TARGET=ARMV7 USE_THREAD=0)
ENDIF()
ELSE()
SET(OPENBLAS_COMMIT "v0.2.19")
SET(OPTIONAL_ARGS DYNAMIC_ARCH=1 libs NUM_THREADS=64)
SET(OPTIONAL_ARGS "")
IF(CMAKE_SYSTEM_PROCESSOR MATCHES "^x86(_64)?$")
SET(OPTIONAL_ARGS DYNAMIC_ARCH=1 NUM_THREADS=64)
ENDIF()
ENDIF()
ExternalProject_Add(

@ -182,7 +182,7 @@ function(go_library TARGET_NAME)
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE}
-o "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}"
${go_library_SRCS}
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR})
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
add_custom_target(${TARGET_NAME}_lib ALL DEPENDS ${TARGET_NAME}_timestamp ${go_library_DEPS})
add_library(${TARGET_NAME} STATIC IMPORTED)
set_property(TARGET ${TARGET_NAME} PROPERTY
@ -199,7 +199,7 @@ function(go_binary TARGET_NAME)
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build
-o "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}"
${go_library_SRCS}
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR})
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
add_custom_target(${TARGET_NAME} ALL DEPENDS ${TARGET_NAME}_timestamp ${go_binary_DEPS})
install(PROGRAMS ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME} DESTINATION bin)
endfunction(go_binary)
@ -213,7 +213,7 @@ function(go_test TARGET_NAME)
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} test
-c -o "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}"
${go_test_SRCS}
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR})
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
add_custom_target(${TARGET_NAME} ALL DEPENDS ${TARGET_NAME}_timestamp ${go_test_DEPS})
add_test(${TARGET_NAME} ${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME})
endfunction(go_test)

@ -130,7 +130,7 @@ recurrent_group
---------------
.. autoclass:: paddle.v2.layer.recurrent_group
:noindex:
lstm_step
---------
.. autoclass:: paddle.v2.layer.lstm_step
@ -145,12 +145,12 @@ beam_search
------------
.. autoclass:: paddle.v2.layer.beam_search
:noindex:
get_output
----------
.. autoclass:: paddle.v2.layer.get_output
:noindex:
Mixed Layer
===========
@ -203,7 +203,7 @@ trans_full_matrix_projection
----------------------------
.. autoclass:: paddle.v2.layer.trans_full_matrix_projection
:noindex:
Aggregate Layers
================
@ -449,3 +449,11 @@ dropout
--------------
.. autoclass:: paddle.v2.layer.dropout
:noindex:
Activation with learnable parameter
===================================
prelu
--------
.. autoclass:: paddle.v2.layer.prelu
:noindex:

@ -7,4 +7,4 @@
build_and_install/index_cn.rst
concepts/use_concepts_cn.rst
- `深度学习入门课程 <http://book.paddlepaddle.org/>`_
- `深度学习入门课程 <http://book.paddlepaddle.org/index.cn.html>`_

@ -6,4 +6,4 @@ GET STARTED
build_and_install/index_en.rst
- `Deep Learning 101 <http://book.paddlepaddle.org/index.en.html>`_
- `Deep Learning 101 <http://book.paddlepaddle.org/index.html>`_

@ -39,7 +39,7 @@ function(GO_LIBRARY NAME BUILD_TYPE)
COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} build ${BUILD_MODE}
-o "${CMAKE_CURRENT_BINARY_DIR}/${LIB_NAME}"
${CMAKE_GO_FLAGS} ${GO_SOURCE}
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR})
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR})
add_custom_target(${NAME} ALL DEPENDS ${OUTPUT_DIR}/.timestamp ${ARGN})
add_dependencies(${NAME} goGet)

@ -8,6 +8,7 @@ add_subdirectory(gserver)
add_subdirectory(pserver)
add_subdirectory(trainer)
add_subdirectory(scripts)
add_subdirectory(strings)
# Do not build go directory until go cmake is working smoothly.
# if(CMAKE_Go_COMPILER)

@ -632,7 +632,7 @@ void Argument::printValueString(std::ostream& stream,
const std::string& prefix) const {
std::unordered_map<std::string, std::string> out;
getValueString(&out);
for (auto field : {"value", "id", "sequence pos", "sub-sequence pos"}) {
for (auto field : {"value", "ids", "sequence pos", "sub-sequence pos"}) {
auto it = out.find(field);
if (it != out.end()) {
stream << prefix << field << ":\n" << it->second;

@ -383,20 +383,23 @@ void SocketClient::TcpClient(const std::string &serverAddr, int serverPort) {
setOption(sockfd);
/// Now connect to the server
int retry_second = 0;
int error = 0;
int retry_count = 0;
do {
error = connect(sockfd, (sockaddr *)&serv_addr, sizeof(serv_addr));
if (error == ECONNREFUSED) {
if (connect(sockfd, (sockaddr *)&serv_addr, sizeof(serv_addr)) == 0) {
break;
}
if (errno == ECONNREFUSED) {
LOG(WARNING) << "connection refused by pserver, try again!";
if (retry_second++ >= 7) {
if (retry_count++ >= 7) {
LOG(FATAL) << "connection refused by pserver, maybe pserver failed!";
}
std::this_thread::sleep_for(std::chrono::seconds(1));
} else {
PCHECK(error >= 0) << "ERROR connecting to " << serverAddr;
PCHECK(errno != 0) << "ERROR connecting to " << serverAddr << ":"
<< serverPort << "errorno: " << errno;
}
} while (error == ECONNREFUSED);
} while (errno == ECONNREFUSED);
channel_.reset(new SocketChannel(sockfd, serverAddr));
tcpRdma_ = F_TCP;

@ -58,7 +58,7 @@ EOF
make -j `nproc`
if [ ${WITH_TESTING:-OFF} == "ON" ] && [ ${RUN_TEST:-OFF} == "ON" ] ; then
pip uninstall -y py-paddle paddle || true
ctest -V
ctest --output-on-failure
fi

@ -0,0 +1,2 @@
cc_library(stringpiece SRCS stringpiece.cc)
cc_test(stringpiece_test SRCS stringpiece_test.cc DEPS stringpiece glog gflags)

@ -0,0 +1,141 @@
/*
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "paddle/strings/stringpiece.h"
#include <string.h>
#include <algorithm>
#include <iosfwd>
#include <stdexcept>
namespace paddle {
StringPiece::StringPiece() : data_(NULL), size_(0) {}
StringPiece::StringPiece(const char* d, size_t n) : data_(d), size_(n) {
if (d == NULL && n != 0)
throw std::invalid_argument(
"StringPiece requires len to be 0 for NULL data");
}
StringPiece::StringPiece(const char* s) : data_(s) {
size_ = (s == NULL) ? 0 : strlen(s);
}
StringPiece::StringPiece(const std::string& s)
: data_(s.data()), size_(s.size()) {}
char StringPiece::operator[](size_t n) const {
if (n >= len())
throw std::invalid_argument("index out of StringPiece length");
return data_[n];
}
int Compare(StringPiece a, StringPiece b) {
const size_t min_len = (a.len() < b.len()) ? a.len() : b.len();
int r = memcmp(a.data(), b.data(), min_len);
if (r == 0) {
if (a.len() < b.len())
return -1;
else if (a.len() > b.len())
return 1;
}
return r;
}
bool operator==(StringPiece x, StringPiece y) {
return ((x.len() == y.len()) &&
(x.data() == y.data() || memcmp(x.data(), y.data(), x.len()) == 0));
}
bool operator!=(StringPiece x, StringPiece y) { return !(x == y); }
bool operator<(StringPiece x, StringPiece y) { return Compare(x, y) < 0; }
bool operator>(StringPiece x, StringPiece y) { return Compare(x, y) > 0; }
bool operator<=(StringPiece x, StringPiece y) { return Compare(x, y) <= 0; }
bool operator>=(StringPiece x, StringPiece y) { return Compare(x, y) >= 0; }
bool HasPrefix(StringPiece s, StringPiece x) {
return ((s.len() >= x.len()) && (memcmp(s.data(), x.data(), x.len()) == 0));
}
bool HasSuffix(StringPiece s, StringPiece x) {
return ((s.len() >= x.len()) &&
(memcmp(s.data() + (s.len() - x.len()), x.data(), x.len()) == 0));
}
StringPiece SkipPrefix(StringPiece s, size_t n) {
if (n > s.len())
throw std::invalid_argument("Skip distance larger than StringPiece length");
return StringPiece(s.data() + n, s.len() - n);
}
StringPiece SkipSuffix(StringPiece s, size_t n) {
if (n > s.len())
throw std::invalid_argument("Skip distance larger than StringPiece length");
return StringPiece(s.data(), s.len() - n);
}
StringPiece TrimPrefix(StringPiece s, StringPiece x) {
return HasPrefix(s, x) ? SkipPrefix(s, x.len()) : s;
}
StringPiece TrimSuffix(StringPiece s, StringPiece x) {
return HasSuffix(s, x) ? SkipSuffix(s, x.len()) : s;
}
bool Contains(StringPiece s, StringPiece sub) {
return std::search(s.begin(), s.end(), sub.begin(), sub.end()) != s.end();
}
size_t Index(StringPiece s, StringPiece sub) {
auto e = std::search(s.begin(), s.end(), sub.begin(), sub.end());
return e != s.end() ? e - s.data() : StringPiece::npos;
}
size_t Find(StringPiece s, char c, size_t pos) {
if (pos >= s.len()) {
return StringPiece::npos;
}
const char* result =
reinterpret_cast<const char*>(memchr(s.data() + pos, c, s.len() - pos));
return result != nullptr ? result - s.data() : StringPiece::npos;
}
size_t RFind(StringPiece s, char c, size_t pos) {
if (s.len() == 0) return StringPiece::npos;
for (const char* p = s.data() + std::min(pos, s.len() - 1); p >= s.data();
p--) {
if (*p == c) {
return p - s.data();
}
}
return StringPiece::npos;
}
StringPiece SubStr(StringPiece s, size_t pos, size_t n) {
if (pos > s.len()) pos = s.len();
if (n > s.len() - pos) n = s.len() - pos;
return StringPiece(s.data() + pos, n);
}
std::ostream& operator<<(std::ostream& o, StringPiece piece) {
return o << piece.ToString();
}
} // namespace paddle

@ -0,0 +1,104 @@
/*
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#pragma once
#include <string>
namespace paddle {
// StringPiece points into a std::string object but doesn't own the
// string. It is for efficient access to strings. Like Go's string
// type. Not that StringPiece doesn't mutate the underlying string,
// so it is thread-safe given that the underlying string doesn't
// change. Because StringPiece contains a little data members, and
// its syntax is simple as it doesn't own/manage the string, it is
// cheap to construct StringPieces and pass them around.
class StringPiece {
public:
static const size_t npos = static_cast<size_t>(-1);
// We provide non-explicit singleton constructors so users can
// pass in a "const char*" or a "string" wherever a "StringPiece"
// is expected. These contructors ensure that if data_ is NULL,
// size_ is 0.
StringPiece();
StringPiece(const char* d, size_t n);
StringPiece(const char* d);
StringPiece(const std::string& s);
const char* data() const { return data_; }
size_t len() const { return size_; }
char operator[](size_t n) const;
// StringPiece doesn't own the string, so both iterator and const
// iterator are const char* indeed.
typedef const char* const_iterator;
typedef const char* iterator;
iterator begin() const { return data_; }
iterator end() const { return data_ + size_; }
// Return a string that contains the copy of the referenced data.
std::string ToString() const { return std::string(data_, size_); }
private:
const char* data_;
size_t size_;
// Intentionally copyable
};
int Compare(StringPiece a, StringPiece b);
bool operator==(StringPiece x, StringPiece y);
bool operator!=(StringPiece x, StringPiece y);
bool operator<(StringPiece x, StringPiece y);
bool operator>(StringPiece x, StringPiece y);
bool operator<=(StringPiece x, StringPiece y);
bool operator>=(StringPiece x, StringPiece y);
bool HasPrefix(StringPiece s, StringPiece prefix);
bool HasSuffix(StringPiece s, StringPiece suffix);
StringPiece SkipPrefix(StringPiece s, size_t n);
StringPiece SkipSuffix(StringPiece s, size_t n);
// Skip the prefix (or suffix) if it matches with the string.
StringPiece TrimPrefix(StringPiece s, StringPiece prefix);
StringPiece TrimSuffix(StringPiece s, StringPiece suffix);
// Returns if s contains sub. Any s except for empty s contains an
// empty sub.
bool Contains(StringPiece s, StringPiece sub);
// Return the first occurrence of sub in s, or npos. If both s and
// sub is empty, it returns npos; otherwise, if only sub is empty, it
// returns 0.
size_t Index(StringPiece s, StringPiece sub);
// Return the first occurrence of c in s[pos:end], or npos.
size_t Find(StringPiece s, char c, size_t pos);
// Search range is [0..pos] inclusive. If pos == npos, search everything.
size_t RFind(StringPiece s, char c, size_t pos);
StringPiece SubStr(StringPiece s, size_t pos, size_t n);
// allow StringPiece to be logged
std::ostream& operator<<(std::ostream& o, StringPiece piece);
} // namespace paddle

File diff suppressed because it is too large Load Diff

@ -73,7 +73,6 @@ To use this from paddle_trainer, paddle_trainer should be called with
--config_args=extension_module_name=[MODULE_NAME]
'''
import copy
import logging
import os
@ -1731,9 +1730,10 @@ class ParameterReluLayer(LayerBase):
def __init__(self, name, inputs, partial_sum=1, **args):
super(ParameterReluLayer, self).__init__(
name, self.layer_type, 0, inputs=inputs, **args)
config_assert(len(self.inputs) == 1)
config_assert(self.input_layer.size % partial_sum == 0)
input_layer = self.get_input_layer(0)
config_assert(len(self.inputs) == 1, "prelu layer has only one input.")
config_assert(input_layer.size % partial_sum == 0,
"a wrong setting for partial_sum")
self.set_layer_size(input_layer.size)
self.create_input_parameter(0, input_layer.size / partial_sum)

File diff suppressed because it is too large Load Diff

@ -5,6 +5,7 @@ last_first_seq test_expand_layer test_ntm_layers test_hsigmoid
img_layers img_trans_layers util_layers simple_rnn_layers unused_layers test_cost_layers
test_rnn_group shared_fc shared_lstm shared_gru test_cost_layers_with_weight
test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops
test_seq_concat_reshape test_pad test_smooth_l1 test_multiplex_layer)
test_seq_concat_reshape test_pad test_smooth_l1 test_multiplex_layer
test_prelu_layer)
export whole_configs=(test_split_datasource)

@ -0,0 +1,36 @@
type: "nn"
layers {
name: "input"
type: "data"
size: 300
active_type: ""
}
layers {
name: "__prelu_layer_0__"
type: "prelu"
size: 300
active_type: ""
inputs {
input_layer_name: "input"
input_parameter_name: "___prelu_layer_0__.w0"
}
}
parameters {
name: "___prelu_layer_0__.w0"
size: 300
initial_mean: 0.0
initial_std: 0.057735026919
initial_strategy: 0
initial_smart: true
}
input_layer_names: "input"
output_layer_names: "__prelu_layer_0__"
sub_models {
name: "root"
layer_names: "input"
layer_names: "__prelu_layer_0__"
input_layer_names: "input"
output_layer_names: "__prelu_layer_0__"
is_recurrent_layer_group: false
}

@ -0,0 +1,6 @@
from paddle.trainer_config_helpers import *
data = data_layer(name='input', size=300)
prelu = prelu_layer(input=data)
outputs(prelu)

@ -12,7 +12,7 @@ from paddle.trainer.config_parser import logger
try:
import cv2
except ImportError:
logger.warning("OpenCV2 is not installed, using PIL to prcoess")
logger.warning("OpenCV2 is not installed, using PIL to process")
cv2 = None
__all__ = ["CvTransformer", "PILTransformer", "MultiProcessImageTransformer"]

@ -0,0 +1,184 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module will download dataset from
http://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html
and parse train/test set intopaddle reader creators.
This set contains images of flowers belonging to 102 different categories.
The images were acquired by searching the web and taking pictures. There are a
minimum of 40 images for each category.
The database was used in:
Nilsback, M-E. and Zisserman, A. Automated flower classification over a large
number of classes.Proceedings of the Indian Conference on Computer Vision,
Graphics and Image Processing (2008)
http://www.robots.ox.ac.uk/~vgg/publications/papers/nilsback08.{pdf,ps.gz}.
"""
import cPickle
import itertools
from common import download
import tarfile
import scipy.io as scio
from paddle.v2.image import *
import os
import numpy as np
import paddle.v2 as paddle
from multiprocessing import cpu_count
__all__ = ['train', 'test', 'valid']
DATA_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz'
LABEL_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat'
SETID_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat'
DATA_MD5 = '52808999861908f626f3c1f4e79d11fa'
LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d'
SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c'
def default_mapper(sample):
'''
map image bytes data to type needed by model input layer
'''
img, label = sample
img = paddle.image.load_image_bytes(img)
img = paddle.image.simple_transform(img, 256, 224, True)
return img.flatten().astype('float32'), label
def reader_creator(data_file,
label_file,
setid_file,
dataset_name,
mapper=default_mapper,
buffered_size=1024):
'''
1. read images from tar file and
merge images into batch files in 102flowers.tgz_batch/
2. get a reader to read sample from batch file
:param data_file: downloaded data file
:type data_file: string
:param label_file: downloaded label file
:type label_file: string
:param setid_file: downloaded setid file containing information
about how to split dataset
:type setid_file: string
:param dataset_name: data set name (tstid|trnid|valid)
:type dataset_name: string
:param mapper: a function to map image bytes data to type
needed by model input layer
:type mapper: callable
:param buffered_size: the size of buffer used to process images
:type buffered_size: int
:return: data reader
:rtype: callable
'''
labels = scio.loadmat(label_file)['labels'][0]
indexes = scio.loadmat(setid_file)[dataset_name][0]
img2label = {}
for i in indexes:
img = "jpg/image_%05d.jpg" % i
img2label[img] = labels[i - 1]
file_list = batch_images_from_tar(data_file, dataset_name, img2label)
def reader():
for file in open(file_list):
file = file.strip()
batch = None
with open(file, 'r') as f:
batch = cPickle.load(f)
data = batch['data']
labels = batch['label']
for sample, label in itertools.izip(data, batch['label']):
yield sample, int(label)
return paddle.reader.xmap_readers(mapper, reader,
cpu_count(), buffered_size)
def train(mapper=default_mapper, buffered_size=1024):
'''
Create flowers training set reader.
It returns a reader, each sample in the reader is
image pixels in [0, 1] and label in [1, 102]
translated from original color image by steps:
1. resize to 256*256
2. random crop to 224*224
3. flatten
:param mapper: a function to map sample.
:type mapper: callable
:param buffered_size: the size of buffer used to process images
:type buffered_size: int
:return: train data reader
:rtype: callable
'''
return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5), 'trnid', mapper,
buffered_size)
def test(mapper=default_mapper, buffered_size=1024):
'''
Create flowers test set reader.
It returns a reader, each sample in the reader is
image pixels in [0, 1] and label in [1, 102]
translated from original color image by steps:
1. resize to 256*256
2. random crop to 224*224
3. flatten
:param mapper: a function to map sample.
:type mapper: callable
:param buffered_size: the size of buffer used to process images
:type buffered_size: int
:return: test data reader
:rtype: callable
'''
return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5), 'tstid', mapper,
buffered_size)
def valid(mapper=default_mapper, buffered_size=1024):
'''
Create flowers validation set reader.
It returns a reader, each sample in the reader is
image pixels in [0, 1] and label in [1, 102]
translated from original color image by steps:
1. resize to 256*256
2. random crop to 224*224
3. flatten
:param mapper: a function to map sample.
:type mapper: callable
:param buffered_size: the size of buffer used to process images
:type buffered_size: int
:return: test data reader
:rtype: callable
'''
return reader_creator(
download(DATA_URL, 'flowers', DATA_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5),
download(SETID_URL, 'flowers', SETID_MD5), 'valid', mapper,
buffered_size)
def fetch():
download(DATA_URL, 'flowers', DATA_MD5)
download(LABEL_URL, 'flowers', LABEL_MD5)
download(SETID_URL, 'flowers', SETID_MD5)

@ -0,0 +1,51 @@
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.v2.dataset.flowers
import unittest
class TestFlowers(unittest.TestCase):
def check_reader(self, reader):
sum = 0
label = 0
size = 224 * 224 * 3
for l in reader():
self.assertEqual(l[0].size, size)
if l[1] > label:
label = l[1]
sum += 1
return sum, label
def test_train(self):
instances, max_label_value = self.check_reader(
paddle.v2.dataset.flowers.train())
self.assertEqual(instances, 1020)
self.assertEqual(max_label_value, 102)
def test_test(self):
instances, max_label_value = self.check_reader(
paddle.v2.dataset.flowers.test())
self.assertEqual(instances, 6149)
self.assertEqual(max_label_value, 102)
def test_valid(self):
instances, max_label_value = self.check_reader(
paddle.v2.dataset.flowers.valid())
self.assertEqual(instances, 1020)
self.assertEqual(max_label_value, 102)
if __name__ == '__main__':
unittest.main()

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save