Merge pull request #12658 from velconia/port_pybind11

Port pybind11 and python code to support py3 CI test
revert-12469-sum_op_dim_fix
Qiyang Min 7 years ago committed by GitHub
commit 340a104c58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -202,6 +202,52 @@ std::vector<std::string> OpDesc::AttrNames() const {
} }
void OpDesc::SetAttr(const std::string &name, const Attribute &v) { void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
// NOTICE(minqiyang): pybind11 will take the empty list in python as
// the std::vector<int> type in C++; so we have to change the attr's type
// here if we meet this issue
proto::AttrType attr_type = static_cast<proto::AttrType>(v.which() - 1);
if (attr_type == proto::AttrType::INTS &&
boost::get<std::vector<int>>(v).size() == 0u) {
// Find current attr via attr name and set the correct attribute value
const proto::OpProto::Attr &attr = GetProtoAttr(name);
switch (attr.type()) {
case proto::AttrType::BOOLEANS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name
<< " from INTS to BOOLEANS";
this->attrs_[name] = std::vector<bool>();
break;
}
case proto::AttrType::INTS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name
<< " from INTS to INTS";
this->attrs_[name] = std::vector<int>();
break;
}
case proto::AttrType::FLOATS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name
<< " from INTS to FLOATS";
this->attrs_[name] = std::vector<float>();
break;
}
case proto::AttrType::STRINGS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name
<< " from INTS to STRINGS";
this->attrs_[name] = std::vector<std::string>();
break;
}
case proto::AttrType::BLOCKS: {
VLOG(11) << "SetAttr: " << Type() << ", " << name
<< " from INTS to BLOCKS";
this->SetBlocksAttr(name, std::vector<BlockDesc *>());
return;
}
default:
PADDLE_THROW("Wrong attr type %d", attr.type());
}
need_update_ = true;
return;
}
this->attrs_[name] = v; this->attrs_[name] = v;
need_update_ = true; need_update_ = true;
} }
@ -229,6 +275,19 @@ Attribute OpDesc::GetAttr(const std::string &name) const {
return it->second; return it->second;
} }
const proto::OpProto::Attr &OpDesc::GetProtoAttr(
const std::string &name) const {
const proto::OpProto &proto = OpInfoMap::Instance().Get(Type()).Proto();
for (int i = 0; i != proto.attrs_size(); ++i) {
const proto::OpProto::Attr &attr = proto.attrs(i);
if (attr.name() == name) {
return attr;
}
}
PADDLE_THROW("Attribute %s is not found in proto %s", name, proto.type());
}
Attribute OpDesc::GetNullableAttr(const std::string &name) const { Attribute OpDesc::GetNullableAttr(const std::string &name) const {
auto it = attrs_.find(name); auto it = attrs_.find(name);
if (it != attrs_.end()) { if (it != attrs_.end()) {

@ -81,6 +81,8 @@ class OpDesc {
Attribute GetAttr(const std::string &name) const; Attribute GetAttr(const std::string &name) const;
const proto::OpProto::Attr &GetProtoAttr(const std::string &name) const;
Attribute GetNullableAttr(const std::string &name) const; Attribute GetNullableAttr(const std::string &name) const;
int GetBlockAttrId(const std::string &name) const; int GetBlockAttrId(const std::string &name) const;

@ -205,12 +205,7 @@ void BindBlockDesc(pybind11::module *m) {
void BindVarDsec(pybind11::module *m) { void BindVarDsec(pybind11::module *m) {
pybind11::class_<pd::VarDesc> var_desc(*m, "VarDesc", ""); pybind11::class_<pd::VarDesc> var_desc(*m, "VarDesc", "");
var_desc var_desc
.def("name", .def("name", &pd::VarDesc::Name, pybind11::return_value_policy::reference)
[](pd::VarDesc &self) {
pybind11::bytes name = self.Name();
return name;
},
pybind11::return_value_policy::reference)
.def("set_name", &pd::VarDesc::SetName) .def("set_name", &pd::VarDesc::SetName)
.def("set_shape", &pd::VarDesc::SetShape) .def("set_shape", &pd::VarDesc::SetShape)
.def("set_shapes", &pd::VarDesc::SetShapes) .def("set_shapes", &pd::VarDesc::SetShapes)

@ -54,6 +54,8 @@ limitations under the License. */
#include "paddle/fluid/platform/gpu_info.h" #include "paddle/fluid/platform/gpu_info.h"
#endif #endif
#include "pybind11/stl.h"
// disable auto conversion to list in Python // disable auto conversion to list in Python
PYBIND11_MAKE_OPAQUE(paddle::framework::LoDTensorArray); PYBIND11_MAKE_OPAQUE(paddle::framework::LoDTensorArray);

@ -24,4 +24,5 @@ except ImportError:
import paddle.reader import paddle.reader
import paddle.dataset import paddle.dataset
import paddle.batch import paddle.batch
import paddle.compat
batch = batch.batch batch = batch.batch

@ -0,0 +1,237 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import six
import math
__all__ = [
'long_type',
'to_text',
'to_bytes',
'round',
'floor_division',
'get_exception_message',
]
if six.PY2:
int_type = int
long_type = long
else:
int_type = int
long_type = int
# str and bytes related functions
def to_text(obj, encoding='utf-8', inplace=False):
"""
All string in PaddlePaddle should be represented as a literal string.
This function will convert object to a literal string without any encoding.
Especially, if the object type is a list or set container, we will iterate
all items in the object and convert them to literal string.
In Python3:
Decode the bytes type object to str type with specific encoding
In Python2:
Decode the str type object to unicode type with specific encoding
Args:
obj(unicode|str|bytes|list|set) : The object to be decoded.
encoding(str) : The encoding format to decode a string
inplace(bool) : If we change the original object or we create a new one
Returns:
Decoded result of obj
"""
if obj is None:
return obj
if isinstance(obj, list):
if inplace:
for i in six.moves.xrange(len(obj)):
obj[i] = _to_text(obj[i], encoding)
return obj
else:
return [_to_text(item, encoding) for item in obj]
elif isinstance(obj, set):
if inplace:
for item in obj:
obj.remove(item)
obj.add(_to_text(item, encoding))
return obj
else:
return set([_to_text(item, encoding) for item in obj])
else:
return _to_text(obj, encoding)
def _to_text(obj, encoding):
"""
In Python3:
Decode the bytes type object to str type with specific encoding
In Python2:
Decode the str type object to unicode type with specific encoding,
or we just return the unicode string of object
Args:
obj(unicode|str|bytes) : The object to be decoded.
encoding(str) : The encoding format
Returns:
decoded result of obj
"""
if obj is None:
return obj
if isinstance(obj, six.binary_type):
return obj.decode(encoding)
elif isinstance(obj, six.text_type):
return obj
else:
return six.u(obj)
def to_bytes(obj, encoding='utf-8', inplace=False):
"""
All string in PaddlePaddle should be represented as a literal string.
This function will convert object to a bytes with specific encoding.
Especially, if the object type is a list or set container, we will iterate
all items in the object and convert them to bytes.
In Python3:
Encode the str type object to bytes type with specific encoding
In Python2:
Encode the unicode type object to str type with specific encoding,
or we just return the 8-bit string of object
Args:
obj(unicode|str|bytes|list|set) : The object to be encoded.
encoding(str) : The encoding format to encode a string
inplace(bool) : If we change the original object or we create a new one
Returns:
Decoded result of obj
"""
if obj is None:
return obj
if isinstance(obj, list):
if inplace:
for i in six.moves.xrange(len(obj)):
obj[i] = _to_bytes(obj[i], encoding)
return obj
else:
return [_to_bytes(item, encoding) for item in obj]
elif isinstance(obj, set):
if inplace:
for item in obj:
obj.remove(item)
obj.add(_to_bytes(item, encoding))
return obj
else:
return set([_to_bytes(item, encoding) for item in obj])
else:
return _to_bytes(obj, encoding)
def _to_bytes(obj, encoding):
"""
In Python3:
Encode the str type object to bytes type with specific encoding
In Python2:
Encode the unicode type object to str type with specific encoding,
or we just return the 8-bit string of object
Args:
obj(unicode|str|bytes) : The object to be encoded.
encoding(str) : The encoding format
Returns:
encoded result of obj
"""
if obj is None:
return obj
assert encoding is not None
if isinstance(obj, six.text_type):
return obj.encode(encoding)
elif isinstance(obj, six.binary_type):
return obj
else:
return six.b(obj)
# math related functions
def round(x, d=0):
"""
Compatible round which act the same behaviour in Python3.
Args:
x(float) : The number to round halfway.
Returns:
round result of x
"""
if six.PY3:
# The official walkaround of round in Python3 is incorrect
# we implement accroding this answer: https://www.techforgeek.info/round_python.html
if x > 0.0:
p = 10**d
return float(math.floor((x * p) + math.copysign(0.5, x))) / p
elif x < 0.0:
p = 10**d
return float(math.ceil((x * p) + math.copysign(0.5, x))) / p
else:
return math.copysign(0.0, x)
else:
import __builtin__
return __builtin__.round(x, d)
def floor_division(x, y):
"""
Compatible division which act the same behaviour in Python3 and Python2,
whose result will be a int value of floor(x / y) in Python3 and value of
(x / y) in Python2.
Args:
x(int|float) : The number to divide.
y(int|float) : The number to be divided
Returns:
division result of x // y
"""
return x // y
# exception related functions
def get_exception_message(exc):
"""
Get the error message of a specific exception
Args:
exec(Exception) : The exception to get error message.
Returns:
the error message of exec
"""
assert exc is not None
if six.PY2:
return exc.message
else:
return str(exc)

@ -32,7 +32,7 @@ import itertools
import numpy import numpy
import paddle.dataset.common import paddle.dataset.common
import tarfile import tarfile
from six.moves import zip import six
from six.moves import cPickle as pickle from six.moves import cPickle as pickle
__all__ = ['train100', 'test100', 'train10', 'test10', 'convert'] __all__ = ['train100', 'test100', 'train10', 'test10', 'convert']
@ -46,10 +46,11 @@ CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85'
def reader_creator(filename, sub_name, cycle=False): def reader_creator(filename, sub_name, cycle=False):
def read_batch(batch): def read_batch(batch):
data = batch['data'] data = batch[six.b('data')]
labels = batch.get('labels', batch.get('fine_labels', None)) labels = batch.get(
six.b('labels'), batch.get(six.b('fine_labels'), None))
assert labels is not None assert labels is not None
for sample, label in zip(data, labels): for sample, label in six.moves.zip(data, labels):
yield (sample / 255.0).astype(numpy.float32), int(label) yield (sample / 255.0).astype(numpy.float32), int(label)
def reader(): def reader():
@ -59,7 +60,11 @@ def reader_creator(filename, sub_name, cycle=False):
while True: while True:
for name in names: for name in names:
batch = pickle.load(f.extractfile(name)) if six.PY2:
batch = pickle.load(f.extractfile(name))
else:
batch = pickle.load(
f.extractfile(name), encoding='bytes')
for item in read_batch(batch): for item in read_batch(batch):
yield item yield item
if not cycle: if not cycle:

@ -85,10 +85,10 @@ def download(url, module_name, md5sum, save_name=None):
total_length = r.headers.get('content-length') total_length = r.headers.get('content-length')
if total_length is None: if total_length is None:
with open(filename, 'w') as f: with open(filename, 'wb') as f:
shutil.copyfileobj(r.raw, f) shutil.copyfileobj(r.raw, f)
else: else:
with open(filename, 'w') as f: with open(filename, 'wb') as f:
dl = 0 dl = 0
total_length = int(total_length) total_length = int(total_length)
for data in r.iter_content(chunk_size=4096): for data in r.iter_content(chunk_size=4096):

@ -24,11 +24,12 @@ import tarfile
import gzip import gzip
import itertools import itertools
import paddle.dataset.common import paddle.dataset.common
from six.moves import zip import paddle.compat as cpt
from six.moves import zip, range
__all__ = ['test, get_dict', 'get_embedding', 'convert'] __all__ = ['test, get_dict', 'get_embedding', 'convert']
DATA_URL = 'http://www.cs.upc.edu/~srlconll/conll05st-tests.tar.gz' DATA_URL = 'http://paddlemodels.bj.bcebos.com/conll05st/conll05st-tests.tar.gz'
DATA_MD5 = '387719152ae52d60422c016e92a742fc' DATA_MD5 = '387719152ae52d60422c016e92a742fc'
WORDDICT_URL = 'http://paddlemodels.bj.bcebos.com/conll05st%2FwordDict.txt' WORDDICT_URL = 'http://paddlemodels.bj.bcebos.com/conll05st%2FwordDict.txt'
WORDDICT_MD5 = 'ea7fb7d4c75cc6254716f0177a506baa' WORDDICT_MD5 = 'ea7fb7d4c75cc6254716f0177a506baa'
@ -89,8 +90,8 @@ def corpus_reader(data_path, words_name, props_name):
labels = [] labels = []
one_seg = [] one_seg = []
for word, label in zip(words_file, props_file): for word, label in zip(words_file, props_file):
word = word.strip() word = cpt.to_text(word.strip())
label = label.strip().split() label = cpt.to_text(label.strip().split())
if len(label) == 0: # end of sentence if len(label) == 0: # end of sentence
for i in range(len(one_seg[0])): for i in range(len(one_seg[0])):

@ -116,8 +116,8 @@ def reader_creator(data_file,
for file in open(file_list): for file in open(file_list):
file = file.strip() file = file.strip()
batch = None batch = None
with open(file, 'r') as f: with open(file, 'rb') as f:
batch = pickle.load(f) batch = pickle.loads(f.read())
data = batch['data'] data = batch['data']
labels = batch['label'] labels = batch['label']
for sample, label in zip(data, batch['label']): for sample, label in zip(data, batch['label']):

@ -33,6 +33,11 @@ import numpy as np
try: try:
import cv2 import cv2
except ImportError: except ImportError:
import sys
sys.stderr.write(
'''Warning with paddle image module: opencv-python should be imported,
or paddle image module could NOT work; please install opencv-python first.'''
)
cv2 = None cv2 = None
import os import os
import tarfile import tarfile
@ -56,7 +61,7 @@ def batch_images_from_tar(data_file,
:type data_file: string :type data_file: string
:param dataset_name: 'train','test' or 'valid' :param dataset_name: 'train','test' or 'valid'
:type dataset_name: string :type dataset_name: string
:param img2label: a dic with image file name as key :param img2label: a dic with image file name as key
and image's label as value and image's label as value
:type img2label: dic :type img2label: dic
:param num_per_batch: image number per batch file :param num_per_batch: image number per batch file
@ -88,7 +93,7 @@ def batch_images_from_tar(data_file,
output['data'] = data output['data'] = data
pickle.dump( pickle.dump(
output, output,
open('%s/batch_%d' % (out_path, file_id), 'w'), open('%s/batch_%d' % (out_path, file_id), 'wb'),
protocol=pickle.HIGHEST_PROTOCOL) protocol=pickle.HIGHEST_PROTOCOL)
file_id += 1 file_id += 1
data = [] data = []
@ -99,7 +104,7 @@ def batch_images_from_tar(data_file,
output['data'] = data output['data'] = data
pickle.dump( pickle.dump(
output, output,
open('%s/batch_%d' % (out_path, file_id), 'w'), open('%s/batch_%d' % (out_path, file_id), 'wb'),
protocol=pickle.HIGHEST_PROTOCOL) protocol=pickle.HIGHEST_PROTOCOL)
with open(meta_file, 'a') as meta: with open(meta_file, 'a') as meta:
@ -113,7 +118,7 @@ def load_image_bytes(bytes, is_color=True):
Load an color or gray image from bytes array. Load an color or gray image from bytes array.
Example usage: Example usage:
.. code-block:: python .. code-block:: python
with open('cat.jpg') as f: with open('cat.jpg') as f:
@ -126,6 +131,8 @@ def load_image_bytes(bytes, is_color=True):
load and return a gray image. load and return a gray image.
:type is_color: bool :type is_color: bool
""" """
assert cv2 is not None
flag = 1 if is_color else 0 flag = 1 if is_color else 0
file_bytes = np.asarray(bytearray(bytes), dtype=np.uint8) file_bytes = np.asarray(bytearray(bytes), dtype=np.uint8)
img = cv2.imdecode(file_bytes, flag) img = cv2.imdecode(file_bytes, flag)
@ -137,7 +144,7 @@ def load_image(file, is_color=True):
Load an color or gray image from the file path. Load an color or gray image from the file path.
Example usage: Example usage:
.. code-block:: python .. code-block:: python
im = load_image('cat.jpg') im = load_image('cat.jpg')
@ -149,6 +156,8 @@ def load_image(file, is_color=True):
load and return a gray image. load and return a gray image.
:type is_color: bool :type is_color: bool
""" """
assert cv2 is not None
# cv2.IMAGE_COLOR for OpenCV3 # cv2.IMAGE_COLOR for OpenCV3
# cv2.CV_LOAD_IMAGE_COLOR for older OpenCV Version # cv2.CV_LOAD_IMAGE_COLOR for older OpenCV Version
# cv2.IMAGE_GRAYSCALE for OpenCV3 # cv2.IMAGE_GRAYSCALE for OpenCV3
@ -161,27 +170,29 @@ def load_image(file, is_color=True):
def resize_short(im, size): def resize_short(im, size):
""" """
Resize an image so that the length of shorter edge is size. Resize an image so that the length of shorter edge is size.
Example usage: Example usage:
.. code-block:: python .. code-block:: python
im = load_image('cat.jpg') im = load_image('cat.jpg')
im = resize_short(im, 256) im = resize_short(im, 256)
:param im: the input image with HWC layout. :param im: the input image with HWC layout.
:type im: ndarray :type im: ndarray
:param size: the shorter edge size of image after resizing. :param size: the shorter edge size of image after resizing.
:type size: int :type size: int
""" """
assert cv2 is not None
h, w = im.shape[:2] h, w = im.shape[:2]
h_new, w_new = size, size h_new, w_new = size, size
if h > w: if h > w:
h_new = size * h / w h_new = size * h // w
else: else:
w_new = size * w / h w_new = size * w // h
im = cv2.resize(im, (h_new, w_new), interpolation=cv2.INTER_CUBIC) im = cv2.resize(im, (h_new, w_new), interpolation=cv2.INTER_CUBIC)
return im return im
@ -193,17 +204,17 @@ def to_chw(im, order=(2, 0, 1)):
according the order (2,0,1). according the order (2,0,1).
Example usage: Example usage:
.. code-block:: python .. code-block:: python
im = load_image('cat.jpg') im = load_image('cat.jpg')
im = resize_short(im, 256) im = resize_short(im, 256)
im = to_chw(im) im = to_chw(im)
:param im: the input image with HWC layout. :param im: the input image with HWC layout.
:type im: ndarray :type im: ndarray
:param order: the transposed order. :param order: the transposed order.
:type order: tuple|list :type order: tuple|list
""" """
assert len(im.shape) == len(order) assert len(im.shape) == len(order)
im = im.transpose(order) im = im.transpose(order)
@ -215,11 +226,11 @@ def center_crop(im, size, is_color=True):
Crop the center of image with size. Crop the center of image with size.
Example usage: Example usage:
.. code-block:: python .. code-block:: python
im = center_crop(im, 224) im = center_crop(im, 224)
:param im: the input image with HWC layout. :param im: the input image with HWC layout.
:type im: ndarray :type im: ndarray
:param size: the cropping size. :param size: the cropping size.
@ -228,8 +239,8 @@ def center_crop(im, size, is_color=True):
:type is_color: bool :type is_color: bool
""" """
h, w = im.shape[:2] h, w = im.shape[:2]
h_start = (h - size) / 2 h_start = (h - size) // 2
w_start = (w - size) / 2 w_start = (w - size) // 2
h_end, w_end = h_start + size, w_start + size h_end, w_end = h_start + size, w_start + size
if is_color: if is_color:
im = im[h_start:h_end, w_start:w_end, :] im = im[h_start:h_end, w_start:w_end, :]
@ -243,11 +254,11 @@ def random_crop(im, size, is_color=True):
Randomly crop input image with size. Randomly crop input image with size.
Example usage: Example usage:
.. code-block:: python .. code-block:: python
im = random_crop(im, 224) im = random_crop(im, 224)
:param im: the input image with HWC layout. :param im: the input image with HWC layout.
:type im: ndarray :type im: ndarray
:param size: the cropping size. :param size: the cropping size.
@ -272,11 +283,11 @@ def left_right_flip(im, is_color=True):
Return the flipped image. Return the flipped image.
Example usage: Example usage:
.. code-block:: python .. code-block:: python
im = left_right_flip(im) im = left_right_flip(im)
:param im: input image with HWC layout or HW layout for gray image :param im: input image with HWC layout or HW layout for gray image
:type im: ndarray :type im: ndarray
:param is_color: whether input image is color or not :param is_color: whether input image is color or not
@ -299,7 +310,7 @@ def simple_transform(im,
resizing, croping and flipping. resizing, croping and flipping.
Example usage: Example usage:
.. code-block:: python .. code-block:: python
im = simple_transform(im, 256, 224, True) im = simple_transform(im, 256, 224, True)
@ -314,7 +325,7 @@ def simple_transform(im,
:type is_train: bool :type is_train: bool
:param is_color: whether the image is color or not. :param is_color: whether the image is color or not.
:type is_color: bool :type is_color: bool
:param mean: the mean values, which can be element-wise mean values or :param mean: the mean values, which can be element-wise mean values or
mean values per channel. mean values per channel.
:type mean: numpy array | list :type mean: numpy array | list
""" """
@ -332,7 +343,7 @@ def simple_transform(im,
im = im.astype('float32') im = im.astype('float32')
if mean is not None: if mean is not None:
mean = np.array(mean, dtype=np.float32) mean = np.array(mean, dtype=np.float32)
# mean value, may be one value per channel # mean value, may be one value per channel
if mean.ndim == 1 and is_color: if mean.ndim == 1 and is_color:
mean = mean[:, np.newaxis, np.newaxis] mean = mean[:, np.newaxis, np.newaxis]
elif mean.ndim == 1: elif mean.ndim == 1:
@ -357,7 +368,7 @@ def load_and_transform(filename,
for the transform operations. for the transform operations.
Example usage: Example usage:
.. code-block:: python .. code-block:: python
im = load_and_transform('cat.jpg', 256, 224, True) im = load_and_transform('cat.jpg', 256, 224, True)
@ -372,7 +383,7 @@ def load_and_transform(filename,
:type is_train: bool :type is_train: bool
:param is_color: whether the image is color or not. :param is_color: whether the image is color or not.
:type is_color: bool :type is_color: bool
:param mean: the mean values, which can be element-wise mean values or :param mean: the mean values, which can be element-wise mean values or
mean values per channel. mean values per channel.
:type mean: numpy array | list :type mean: numpy array | list
""" """

@ -25,6 +25,7 @@ import collections
import tarfile import tarfile
import re import re
import string import string
import six
__all__ = ['build_dict', 'train', 'test', 'convert'] __all__ = ['build_dict', 'train', 'test', 'convert']
@ -42,13 +43,14 @@ def tokenize(pattern):
# sequential access of member files, other than # sequential access of member files, other than
# tarfile.extractfile, which does random access and might # tarfile.extractfile, which does random access and might
# destroy hard disks. # destroy hard disks.
tf = next(tarf) tf = tarf.next()
while tf != None: while tf != None:
if bool(pattern.match(tf.name)): if bool(pattern.match(tf.name)):
# newline and punctuations removal and ad-hoc tokenization. # newline and punctuations removal and ad-hoc tokenization.
yield tarf.extractfile(tf).read().rstrip("\n\r").translate( yield tarf.extractfile(tf).read().rstrip(six.b(
None, string.punctuation).lower().split() "\n\r")).translate(
tf = next(tarf) None, six.b(string.punctuation)).lower().split()
tf = tarf.next()
def build_dict(pattern, cutoff): def build_dict(pattern, cutoff):
@ -62,11 +64,11 @@ def build_dict(pattern, cutoff):
word_freq[word] += 1 word_freq[word] += 1
# Not sure if we should prune less-frequent words here. # Not sure if we should prune less-frequent words here.
word_freq = [x for x in list(word_freq.items()) if x[1] > cutoff] word_freq = [x for x in six.iteritems(word_freq) if x[1] > cutoff]
dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0])) dictionary = sorted(word_freq, key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*dictionary)) words, _ = list(zip(*dictionary))
word_idx = dict(list(zip(words, list(range(len(words)))))) word_idx = dict(list(zip(words, six.moves.range(len(words)))))
word_idx['<unk>'] = len(words) word_idx['<unk>'] = len(words)
return word_idx return word_idx

@ -14,13 +14,14 @@
""" """
imikolov's simple dataset. imikolov's simple dataset.
This module will download dataset from This module will download dataset from
http://www.fit.vutbr.cz/~imikolov/rnnlm/ and parse training set and test set http://www.fit.vutbr.cz/~imikolov/rnnlm/ and parse training set and test set
into paddle reader creators. into paddle reader creators.
""" """
import paddle.dataset.common import paddle.dataset.common
import collections import collections
import tarfile import tarfile
import six
__all__ = ['train', 'test', 'build_dict', 'convert'] __all__ = ['train', 'test', 'build_dict', 'convert']
@ -64,11 +65,13 @@ def build_dict(min_word_freq=50):
# remove <unk> for now, since we will set it as last index # remove <unk> for now, since we will set it as last index
del word_freq['<unk>'] del word_freq['<unk>']
word_freq = [x for x in list(word_freq.items()) if x[1] > min_word_freq] word_freq = [
x for x in six.iteritems(word_freq) if x[1] > min_word_freq
]
word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0])) word_freq_sorted = sorted(word_freq, key=lambda x: (-x[1], x[0]))
words, _ = list(zip(*word_freq_sorted)) words, _ = list(zip(*word_freq_sorted))
word_idx = dict(list(zip(words, list(range(len(words)))))) word_idx = dict(list(zip(words, six.moves.range(len(words)))))
word_idx['<unk>'] = len(words) word_idx['<unk>'] = len(words)
return word_idx return word_idx
@ -89,7 +92,7 @@ def reader_creator(filename, word_idx, n, data_type):
l = ['<s>'] + l.strip().split() + ['<e>'] l = ['<s>'] + l.strip().split() + ['<e>']
if len(l) >= n: if len(l) >= n:
l = [word_idx.get(w, UNK) for w in l] l = [word_idx.get(w, UNK) for w in l]
for i in range(n, len(l) + 1): for i in six.moves.range(n, len(l) + 1):
yield tuple(l[i - n:i]) yield tuple(l[i - n:i])
elif DataType.SEQ == data_type: elif DataType.SEQ == data_type:
l = l.strip().split() l = l.strip().split()

@ -21,6 +21,9 @@ import paddle.dataset.common
import subprocess import subprocess
import numpy import numpy
import platform import platform
import six
import tempfile
from six.moves import range
__all__ = ['train', 'test', 'convert'] __all__ = ['train', 'test', 'convert']
URL_PREFIX = 'http://yann.lecun.com/exdb/mnist/' URL_PREFIX = 'http://yann.lecun.com/exdb/mnist/'
@ -45,23 +48,28 @@ def reader_creator(image_filename, label_filename, buffer_size):
# According to http://stackoverflow.com/a/38061619/724872, we # According to http://stackoverflow.com/a/38061619/724872, we
# cannot use standard package gzip here. # cannot use standard package gzip here.
m = subprocess.Popen([zcat_cmd, image_filename], stdout=subprocess.PIPE) tmp_image_file = tempfile.TemporaryFile(prefix='paddle_dataset')
m.stdout.read(16) # skip some magic bytes m = subprocess.Popen(
[zcat_cmd, image_filename], stdout=tmp_image_file).communicate()
tmp_image_file.seek(16) # skip some magic bytes
l = subprocess.Popen([zcat_cmd, label_filename], stdout=subprocess.PIPE) # Python3 will not take stdout as file
l.stdout.read(8) # skip some magic bytes tmp_label_file = tempfile.TemporaryFile(prefix='paddle_dataset')
l = subprocess.Popen(
[zcat_cmd, label_filename], stdout=tmp_label_file).communicate()
tmp_label_file.seek(8) # skip some magic bytes
try: # reader could be break. try: # reader could be break.
while True: while True:
labels = numpy.fromfile( labels = numpy.fromfile(
l.stdout, 'ubyte', count=buffer_size).astype("int") tmp_label_file, 'ubyte', count=buffer_size).astype("int")
if labels.size != buffer_size: if labels.size != buffer_size:
break # numpy.fromfile returns empty slice after EOF. break # numpy.fromfile returns empty slice after EOF.
images = numpy.fromfile( images = numpy.fromfile(
m.stdout, 'ubyte', count=buffer_size * 28 * 28).reshape( tmp_image_file, 'ubyte', count=buffer_size * 28 *
(buffer_size, 28 * 28)).astype('float32') 28).reshape((buffer_size, 28 * 28)).astype('float32')
images = images / 255.0 * 2.0 - 1.0 images = images / 255.0 * 2.0 - 1.0

@ -27,6 +27,8 @@ import paddle.dataset.common
import re import re
import random import random
import functools import functools
import six
import paddle.compat as cpt
__all__ = [ __all__ = [
'train', 'test', 'get_movie_title_dict', 'max_movie_id', 'max_user_id', 'train', 'test', 'get_movie_title_dict', 'max_movie_id', 'max_user_id',
@ -112,6 +114,7 @@ def __initialize_meta_info__():
categories_set = set() categories_set = set()
with package.open('ml-1m/movies.dat') as movie_file: with package.open('ml-1m/movies.dat') as movie_file:
for i, line in enumerate(movie_file): for i, line in enumerate(movie_file):
line = cpt.to_text(line, encoding='latin')
movie_id, title, categories = line.strip().split('::') movie_id, title, categories = line.strip().split('::')
categories = categories.split('|') categories = categories.split('|')
for c in categories: for c in categories:
@ -136,6 +139,7 @@ def __initialize_meta_info__():
USER_INFO = dict() USER_INFO = dict()
with package.open('ml-1m/users.dat') as user_file: with package.open('ml-1m/users.dat') as user_file:
for line in user_file: for line in user_file:
line = cpt.to_text(line, encoding='latin')
uid, gender, age, job, _ = line.strip().split("::") uid, gender, age, job, _ = line.strip().split("::")
USER_INFO[int(uid)] = UserInfo( USER_INFO[int(uid)] = UserInfo(
index=uid, gender=gender, age=age, job_id=job) index=uid, gender=gender, age=age, job_id=job)
@ -148,6 +152,7 @@ def __reader__(rand_seed=0, test_ratio=0.1, is_test=False):
with zipfile.ZipFile(file=fn) as package: with zipfile.ZipFile(file=fn) as package:
with package.open('ml-1m/ratings.dat') as rating: with package.open('ml-1m/ratings.dat') as rating:
for line in rating: for line in rating:
line = cpt.to_text(line, encoding='latin')
if (rand.random() < test_ratio) == is_test: if (rand.random() < test_ratio) == is_test:
uid, mov_id, rating, _ = line.strip().split("::") uid, mov_id, rating, _ = line.strip().split("::")
uid = int(uid) uid = int(uid)
@ -187,7 +192,7 @@ def max_movie_id():
Get the maximum value of movie id. Get the maximum value of movie id.
""" """
__initialize_meta_info__() __initialize_meta_info__()
return reduce(__max_index_info__, list(MOVIE_INFO.values())).index return six.moves.reduce(__max_index_info__, list(MOVIE_INFO.values())).index
def max_user_id(): def max_user_id():
@ -195,7 +200,7 @@ def max_user_id():
Get the maximum value of user id. Get the maximum value of user id.
""" """
__initialize_meta_info__() __initialize_meta_info__()
return reduce(__max_index_info__, list(USER_INFO.values())).index return six.moves.reduce(__max_index_info__, list(USER_INFO.values())).index
def __max_job_id_impl__(a, b): def __max_job_id_impl__(a, b):
@ -210,7 +215,8 @@ def max_job_id():
Get the maximum value of job id. Get the maximum value of job id.
""" """
__initialize_meta_info__() __initialize_meta_info__()
return reduce(__max_job_id_impl__, list(USER_INFO.values())).job_id return six.moves.reduce(__max_job_id_impl__,
list(USER_INFO.values())).job_id
def movie_categories(): def movie_categories():

@ -20,6 +20,7 @@ The script fetch and preprocess movie_reviews data set that provided by NLTK
TODO(yuyang18): Complete dataset. TODO(yuyang18): Complete dataset.
""" """
import six
import collections import collections
from itertools import chain from itertools import chain
@ -64,7 +65,7 @@ def get_word_dict():
for field in movie_reviews.fileids(category): for field in movie_reviews.fileids(category):
for words in movie_reviews.words(field): for words in movie_reviews.words(field):
word_freq_dict[words] += 1 word_freq_dict[words] += 1
words_sort_list = list(word_freq_dict.items()) words_sort_list = six.iteritems(word_freq_dict)
words_sort_list.sort(cmp=lambda a, b: b[1] - a[1]) words_sort_list.sort(cmp=lambda a, b: b[1] - a[1])
for index, word in enumerate(words_sort_list): for index, word in enumerate(words_sort_list):
words_freq_sorted.append((word[0], index)) words_freq_sorted.append((word[0], index))

@ -16,6 +16,7 @@ import paddle.dataset.common
import unittest import unittest
import tempfile import tempfile
import glob import glob
from six.moves import range
class TestCommon(unittest.TestCase): class TestCommon(unittest.TestCase):

@ -22,6 +22,7 @@ parse training set and test set into paddle reader creators.
import os import os
import numpy as np import numpy as np
import six
import tempfile import tempfile
import tarfile import tarfile
import os import os
@ -70,11 +71,11 @@ def load_data(filename, feature_num=14, ratio=0.8):
return return
data = np.fromfile(filename, sep=' ') data = np.fromfile(filename, sep=' ')
data = data.reshape(data.shape[0] / feature_num, feature_num) data = data.reshape(data.shape[0] // feature_num, feature_num)
maximums, minimums, avgs = data.max(axis=0), data.min(axis=0), data.sum( maximums, minimums, avgs = data.max(axis=0), data.min(axis=0), data.sum(
axis=0) / data.shape[0] axis=0) / data.shape[0]
feature_range(maximums[:-1], minimums[:-1]) feature_range(maximums[:-1], minimums[:-1])
for i in range(feature_num - 1): for i in six.moves.range(feature_num - 1):
data[:, i] = (data[:, i] - avgs[i]) / (maximums[i] - minimums[i]) data[:, i] = (data[:, i] - avgs[i]) / (maximums[i] - minimums[i])
offset = int(data.shape[0] * ratio) offset = int(data.shape[0] * ratio)
UCI_TRAIN_DATA = data[:offset] UCI_TRAIN_DATA = data[:offset]
@ -137,7 +138,7 @@ def predict_reader():
It returns just one tuple data to do inference. It returns just one tuple data to do inference.
:return: one tuple data :return: one tuple data
:rtype: tuple :rtype: tuple
""" """
global UCI_TEST_DATA global UCI_TEST_DATA
load_data(paddle.dataset.common.download(URL, 'uci_housing', MD5)) load_data(paddle.dataset.common.download(URL, 'uci_housing', MD5))

@ -19,10 +19,12 @@ http://paddlepaddle.cdn.bcebos.com/demo/wmt_shrinked_data/wmt14.tgz and
parse training set and test set into paddle reader creators. parse training set and test set into paddle reader creators.
""" """
import six
import tarfile import tarfile
import gzip import gzip
import paddle.dataset.common import paddle.dataset.common
import paddle.compat as cpt
__all__ = [ __all__ = [
'train', 'train',
@ -53,7 +55,7 @@ def __read_to_dict(tar_file, dict_size):
out_dict = dict() out_dict = dict()
for line_count, line in enumerate(fd): for line_count, line in enumerate(fd):
if line_count < size: if line_count < size:
out_dict[line.strip()] = line_count out_dict[cpt.to_text(line.strip())] = line_count
else: else:
break break
return out_dict return out_dict
@ -84,7 +86,7 @@ def reader_creator(tar_file, file_name, dict_size):
] ]
for name in names: for name in names:
for line in f.extractfile(name): for line in f.extractfile(name):
line_split = line.strip().split('\t') line_split = line.strip().split(six.b('\t'))
if len(line_split) != 2: if len(line_split) != 2:
continue continue
src_seq = line_split[0] # one source sequence src_seq = line_split[0] # one source sequence
@ -153,8 +155,8 @@ def get_dict(dict_size, reverse=True):
tar_file = paddle.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN) tar_file = paddle.dataset.common.download(URL_TRAIN, 'wmt14', MD5_TRAIN)
src_dict, trg_dict = __read_to_dict(tar_file, dict_size) src_dict, trg_dict = __read_to_dict(tar_file, dict_size)
if reverse: if reverse:
src_dict = {v: k for k, v in list(src_dict.items())} src_dict = {v: k for k, v in six.iteritems(src_dict)}
trg_dict = {v: k for k, v in list(trg_dict.items())} trg_dict = {v: k for k, v in six.iteritems(trg_dict)}
return src_dict, trg_dict return src_dict, trg_dict

@ -29,11 +29,13 @@ Multi30K: Multilingual English-German Image Descriptions.
""" """
import os import os
import six
import tarfile import tarfile
import gzip import gzip
from collections import defaultdict from collections import defaultdict
import paddle.dataset.common import paddle.dataset.common
import paddle.compat as cpt
__all__ = [ __all__ = [
"train", "train",
@ -60,7 +62,7 @@ def __build_dict(tar_file, dict_size, save_path, lang):
word_dict = defaultdict(int) word_dict = defaultdict(int)
with tarfile.open(tar_file, mode="r") as f: with tarfile.open(tar_file, mode="r") as f:
for line in f.extractfile("wmt16/train"): for line in f.extractfile("wmt16/train"):
line_split = line.strip().split("\t") line_split = line.strip().split(six.b("\t"))
if len(line_split) != 2: continue if len(line_split) != 2: continue
sen = line_split[0] if lang == "en" else line_split[1] sen = line_split[0] if lang == "en" else line_split[1]
for w in sen.split(): for w in sen.split():
@ -70,8 +72,7 @@ def __build_dict(tar_file, dict_size, save_path, lang):
fout.write("%s\n%s\n%s\n" % (START_MARK, END_MARK, UNK_MARK)) fout.write("%s\n%s\n%s\n" % (START_MARK, END_MARK, UNK_MARK))
for idx, word in enumerate( for idx, word in enumerate(
sorted( sorted(
iter(list(word_dict.items())), six.iteritems(word_dict), key=lambda x: x[1],
key=lambda x: x[1],
reverse=True)): reverse=True)):
if idx + 3 == dict_size: break if idx + 3 == dict_size: break
fout.write("%s\n" % (word[0])) fout.write("%s\n" % (word[0]))
@ -81,16 +82,16 @@ def __load_dict(tar_file, dict_size, lang, reverse=False):
dict_path = os.path.join(paddle.dataset.common.DATA_HOME, dict_path = os.path.join(paddle.dataset.common.DATA_HOME,
"wmt16/%s_%d.dict" % (lang, dict_size)) "wmt16/%s_%d.dict" % (lang, dict_size))
if not os.path.exists(dict_path) or ( if not os.path.exists(dict_path) or (
len(open(dict_path, "r").readlines()) != dict_size): len(open(dict_path, "rb").readlines()) != dict_size):
__build_dict(tar_file, dict_size, dict_path, lang) __build_dict(tar_file, dict_size, dict_path, lang)
word_dict = {} word_dict = {}
with open(dict_path, "r") as fdict: with open(dict_path, "rb") as fdict:
for idx, line in enumerate(fdict): for idx, line in enumerate(fdict):
if reverse: if reverse:
word_dict[idx] = line.strip() word_dict[idx] = cpt.to_text(line.strip())
else: else:
word_dict[line.strip()] = idx word_dict[cpt.to_text(line.strip())] = idx
return word_dict return word_dict
@ -120,7 +121,7 @@ def reader_creator(tar_file, file_name, src_dict_size, trg_dict_size, src_lang):
with tarfile.open(tar_file, mode="r") as f: with tarfile.open(tar_file, mode="r") as f:
for line in f.extractfile(file_name): for line in f.extractfile(file_name):
line_split = line.strip().split("\t") line_split = line.strip().split(six.b("\t"))
if len(line_split) != 2: if len(line_split) != 2:
continue continue
src_words = line_split[src_col].split() src_words = line_split[src_col].split()

@ -17,6 +17,7 @@ from . import core
import collections import collections
import copy import copy
import six import six
from .. import compat as cpt
from . import unique_name from . import unique_name
__all__ = ['append_backward'] __all__ = ['append_backward']
@ -45,13 +46,13 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
""" """
op_desc = core.OpDesc() op_desc = core.OpDesc()
op_desc.set_type(op_type) op_desc.set_type(op_type)
for para, args in list(inputs.items()): for para, args in six.iteritems(inputs):
op_desc.set_input( op_desc.set_input(
para, para,
list( list(
map(lambda arg: arg.decode() if isinstance(arg, six.binary_type) else arg, map(lambda arg: arg.decode() if isinstance(arg, six.binary_type) else arg,
args))) args)))
for para, args in list(outputs.items()): for para, args in six.iteritems(outputs):
op_desc.set_output( op_desc.set_output(
para, para,
list( list(
@ -63,7 +64,7 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
if op_role_attr_name not in attrs: if op_role_attr_name not in attrs:
attrs[ attrs[
op_role_attr_name] = core.op_proto_and_checker_maker.OpRole.Backward op_role_attr_name] = core.op_proto_and_checker_maker.OpRole.Backward
for name, val in list(attrs.items()): for name, val in six.iteritems(attrs):
if isinstance(val, framework.Block): if isinstance(val, framework.Block):
op_desc.set_block_attr(name, val.desc) op_desc.set_block_attr(name, val.desc)
else: else:
@ -75,10 +76,10 @@ def _infer_var_data_type_(grad_var_name, block):
""" """
Infer the data type of given grad variable Infer the data type of given grad variable
""" """
grad_var = block.desc.find_var(grad_var_name.encode("ascii")) grad_var = block.desc.find_var(cpt.to_bytes(grad_var_name))
fwd_name = _strip_grad_suffix_(grad_var_name.encode("ascii")) fwd_name = _strip_grad_suffix_(grad_var_name)
if block.desc.has_var_recursive(fwd_name): if block.desc.has_var_recursive(cpt.to_bytes(fwd_name)):
fwd_var = block.desc.find_var_recursive(fwd_name.encode("ascii")) fwd_var = block.desc.find_var_recursive(cpt.to_bytes(fwd_name))
grad_var.set_dtype(fwd_var.dtype()) grad_var.set_dtype(fwd_var.dtype())
else: else:
grad_var.set_dtype(core.VarDesc.VarType.FP32) grad_var.set_dtype(core.VarDesc.VarType.FP32)
@ -102,8 +103,10 @@ def _some_in_set_(cands, s):
""" """
if len(cands) == 0: if len(cands) == 0:
return False return False
for c in cands: literal_set = cpt.to_text(s)
if c in s: literal_cands = cpt.to_text(cands)
for c in literal_cands:
if c in literal_set:
return True return True
return False return False
@ -114,9 +117,8 @@ def _strip_grad_suffix_(name):
e.g. x@GRAD ==> x e.g. x@GRAD ==> x
y@GRAD@RENAME@1 ==> y y@GRAD@RENAME@1 ==> y
""" """
if isinstance(name, six.text_type): name = cpt.to_text(name)
name = name.encode() pos = name.find(core.grad_var_suffix())
pos = name.find(six.b(core.grad_var_suffix()))
return name[:pos] if pos != -1 else name return name[:pos] if pos != -1 else name
@ -125,9 +127,7 @@ def _append_grad_suffix_(name):
Append grad suffix to the given variable name Append grad suffix to the given variable name
e.g. x ==> x@GRAD e.g. x ==> x@GRAD
""" """
if isinstance(name, six.text_type): return cpt.to_text(name) + core.grad_var_suffix()
name = name.encode()
return name + six.b(core.grad_var_suffix())
def _addup_repetitive_outputs_(op_descs): def _addup_repetitive_outputs_(op_descs):
@ -187,7 +187,7 @@ def _addup_repetitive_outputs_(op_descs):
op_desc.set_output(param_name, arg_names) op_desc.set_output(param_name, arg_names)
renamed_vars[var_name].append(new_name) renamed_vars[var_name].append(new_name)
for var_name, inputs in list(renamed_vars.items()): for var_name, inputs in six.iteritems(renamed_vars):
if len(inputs) > 1: if len(inputs) > 1:
pending_sum_ops.append( pending_sum_ops.append(
(_create_op_desc_("sum", {"X": inputs}, {"Out": [var_name]}, (_create_op_desc_("sum", {"X": inputs}, {"Out": [var_name]},
@ -243,7 +243,7 @@ from .proto import framework_pb2
def serialize_op_decs(op_desc): def serialize_op_decs(op_desc):
protostr = op_desc.serialize_to_string() protostr = op_desc.serialize_to_string()
proto = framework_pb2.OpDesc.FromString(str(protostr)) proto = framework_pb2.OpDesc.FromString(six.binary_type(protostr))
return proto.__str__() return proto.__str__()
@ -364,7 +364,7 @@ def _append_backward_ops_(block,
# Getting op's corresponding grad_op # Getting op's corresponding grad_op
grad_op_desc, op_grad_to_var = core.get_grad_op_desc( grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, no_grad_dict[block.idx], grad_sub_block_list) op.desc, cpt.to_text(no_grad_dict[block.idx]), grad_sub_block_list)
grad_op_descs.extend(grad_op_desc) grad_op_descs.extend(grad_op_desc)
grad_to_var.update(op_grad_to_var) grad_to_var.update(op_grad_to_var)
@ -411,11 +411,10 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
new_vars = set() new_vars = set()
# create new gradient variables # create new gradient variables
for grad_var_name in op_desc.output_arg_names(): for grad_var_name in op_desc.output_arg_names():
grad_var_name = grad_var_name.encode("ascii") if block.desc.has_var_recursive(cpt.to_bytes(
if block.desc.has_var_recursive( grad_var_name)) or grad_var_name == core.empty_var_name():
grad_var_name) or grad_var_name == core.empty_var_name():
continue continue
block.desc.var(grad_var_name) block.desc.var(cpt.to_bytes(grad_var_name))
new_vars.add(grad_var_name) new_vars.add(grad_var_name)
if grad_var_name not in grad_to_var: if grad_var_name not in grad_to_var:
continue continue
@ -445,7 +444,7 @@ def _rename_grad_(block, start_op_idx, grad_to_var, target_grad_map):
op_desc.rename_output(name, new_name) op_desc.rename_output(name, new_name)
var_map[name] = new_name var_map[name] = new_name
for g, ng in list(var_map.items()): for g, ng in six.iteritems(var_map):
if g in grad_to_var: if g in grad_to_var:
grad_to_var[ng] = grad_to_var[g] grad_to_var[ng] = grad_to_var[g]
grad_to_var.pop(g) grad_to_var.pop(g)
@ -595,11 +594,12 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
parameters = parameter_list parameters = parameter_list
else: else:
params = program.global_block().all_parameters() params = program.global_block().all_parameters()
program.global_block().iter_parameters()
parameters = [param.name for param in params] parameters = [param.name for param in params]
params_and_grads = [] params_and_grads = []
for param in parameters: for param in parameters:
if param not in grad_info_map: if cpt.to_text(param) not in grad_info_map:
continue continue
grad_info = grad_info_map[param] grad_info = grad_info_map[param]
grad_block = grad_info[1] grad_block = grad_info[1]

@ -14,12 +14,14 @@
""" """
This module privides a memory usage calculate function for user. This module privides a memory usage calculate function for user.
The purpose of this API is to allow users to estimate memory usage of The purpose of this API is to allow users to estimate memory usage of
a program under a special batch size, then user can set appropriate a program under a special batch size, then user can set appropriate
batch size to fully utilize a GPU. batch size to fully utilize a GPU.
This API is still under active development and may change drastically. This API is still under active development and may change drastically.
""" """
import six
from .. import core from .. import core
from ..framework import Program, Variable from ..framework import Program, Variable
@ -45,15 +47,15 @@ def memory_usage(program, batch_size):
Args: Args:
program(Program): The current Program. program(Program): The current Program.
batch_size(int): The current input data batch_size. batch_size(int): The current input data batch_size.
Returns: Returns:
min_total_memory(float): the estimate memory usage lower bound. min_total_memory(float): the estimate memory usage lower bound.
max_total_memory(float): the estimate memory usage upper bound. max_total_memory(float): the estimate memory usage upper bound.
unit_str(string): the unit of estimate usage result. unit_str(string): the unit of estimate usage result.
Examples: Examples:
>>> import paddle.fluid as fluid >>> import paddle.fluid as fluid
>>> lower_usage, upper_usage, unit = fluid.contrib.memory_usage( >>> lower_usage, upper_usage, unit = fluid.contrib.memory_usage(
fluid.default_main_program(), batch_size=10) fluid.default_main_program(), batch_size=10)
@ -72,7 +74,7 @@ def memory_usage(program, batch_size):
# Get the var_name list of first block and calculate # Get the var_name list of first block and calculate
total_memory = 0.0 total_memory = 0.0
for var in program.global_block().vars.itervalues(): for var in six.itervalues(program.global_block().vars):
data_count = 1 data_count = 1
for x in var.shape: for x in var.shape:
if x == -1: if x == -1:
@ -81,10 +83,10 @@ def memory_usage(program, batch_size):
data_count *= x data_count *= x
var_memory = data_count * dtype_to_size[var.dtype] var_memory = data_count * dtype_to_size[var.dtype]
if DEBUG: if DEBUG:
print "%s memory usage: %d" % (var.name, var_memory) print("%s memory usage: %d" % (var.name, var_memory))
total_memory += var_memory total_memory += var_memory
if DEBUG: if DEBUG:
print "total memory usage: %.2f" % (total_memory) print("total memory usage: %.2f" % (total_memory))
# Convert appropriate unit # Convert appropriate unit
unit_str = "B" unit_str = "B"

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import sys import sys
import six
import re import re
from .graphviz import GraphPreviewGenerator from .graphviz import GraphPreviewGenerator
from .proto import framework_pb2 from .proto import framework_pb2
@ -225,7 +226,7 @@ def draw_block_graphviz(block, highlights=None, path="./temp.dot"):
graph = GraphPreviewGenerator("some graph") graph = GraphPreviewGenerator("some graph")
# collect parameters and args # collect parameters and args
protostr = block.desc.serialize_to_string() protostr = block.desc.serialize_to_string()
desc = framework_pb2.BlockDesc.FromString(str(protostr)) desc = framework_pb2.BlockDesc.FromString(six.binary_type(protostr))
def need_highlight(name): def need_highlight(name):
if highlights is None: return False if highlights is None: return False

@ -320,8 +320,9 @@ class Executor(object):
# append fetch_operators # append fetch_operators
if not has_fetch_operators(global_block, fetch_list, fetch_var_name): if not has_fetch_operators(global_block, fetch_list, fetch_var_name):
for i, var in enumerate(fetch_list): for i, var in enumerate(fetch_list):
assert isinstance(var, Variable) or isinstance(var, str), ( assert isinstance(var, Variable) or isinstance(
"Wrong type for fetch_list[%s]: %s" % (i, type(var))) var, six.string_types), (
"Wrong type for fetch_list[%s]: %s" % (i, type(var)))
global_block.append_op( global_block.append_op(
type='fetch', type='fetch',
inputs={'X': [var]}, inputs={'X': [var]},
@ -346,7 +347,7 @@ class Executor(object):
def _fetch_data(self, fetch_list, fetch_var_name, scope): def _fetch_data(self, fetch_list, fetch_var_name, scope):
outs = [ outs = [
core.get_fetch_variable(scope, fetch_var_name, i) core.get_fetch_variable(scope, fetch_var_name, i)
for i in range(len(fetch_list)) for i in six.moves.range(len(fetch_list))
] ]
return outs return outs

@ -19,6 +19,7 @@ import six
import numpy as np import numpy as np
from .. import compat as cpt
from .proto import framework_pb2 from .proto import framework_pb2
try: try:
from . import core from . import core
@ -27,7 +28,7 @@ except ImportError as e:
"""NOTE: You may need to run \"export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH\" """NOTE: You may need to run \"export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH\"
if you encounters \"libmkldnn.so not found\" errors. If you have python if you encounters \"libmkldnn.so not found\" errors. If you have python
installed in other directory, replace \"/usr/local/lib\" with your own installed in other directory, replace \"/usr/local/lib\" with your own
directory. The original error is: \n""" + e.message) directory. The original error is: \n""" + cpt.get_exception_message(e))
except Exception as e: except Exception as e:
raise e raise e
from . import unique_name from . import unique_name
@ -87,7 +88,7 @@ def convert_np_dtype_to_dtype_(np_dtype):
elif dtype == np.uint8: elif dtype == np.uint8:
return core.VarDesc.VarType.UINT8 return core.VarDesc.VarType.UINT8
else: else:
raise ValueError("Not supported numpy dtype " + six.binary_type(dtype)) raise ValueError("Not supported numpy dtype %s" % dtype)
def dtype_is_floating(dtype): def dtype_is_floating(dtype):
@ -198,11 +199,11 @@ class Variable(object):
if name is None: if name is None:
name = unique_name.generate('_generated_var') name = unique_name.generate('_generated_var')
is_new_var = False is_new_var = False
name = name if isinstance(name, six.binary_type) else name.encode() name = cpt.to_text(name)
self.desc = self.block.desc.find_var(name) self.desc = self.block.desc.find_var(cpt.to_bytes(name))
if self.desc is None: if self.desc is None:
self.desc = self.block.desc.var(name) self.desc = self.block.desc.var(cpt.to_bytes(name))
is_new_var = True is_new_var = True
if is_new_var: if is_new_var:
@ -325,7 +326,7 @@ class Variable(object):
@property @property
def name(self): def name(self):
return self.desc.name() return cpt.to_text(self.desc.name())
@name.setter @name.setter
def name(self, new_name): def name(self, new_name):
@ -531,14 +532,7 @@ class Operator(object):
elif isinstance(arg, six.binary_type): elif isinstance(arg, six.binary_type):
in_arg_names.append(arg.decode()) in_arg_names.append(arg.decode())
else: else:
if isinstance(arg.name, six.string_types): in_arg_names.append(cpt.to_text(arg.name))
in_arg_names.append(arg.name)
elif isinstance(arg.name, six.binary_type):
in_arg_names.append(arg.name.decode())
else:
raise TypeError(
"arguments require unicode, str or bytes, but get %s instead."
% (type(arg.name)))
self.desc.set_input(in_proto.name, in_arg_names) self.desc.set_input(in_proto.name, in_arg_names)
else: else:
self.desc.set_input(in_proto.name, []) self.desc.set_input(in_proto.name, [])
@ -567,14 +561,7 @@ class Operator(object):
(out_proto.name, len(out_args))) (out_proto.name, len(out_args)))
out_arg_names = [] out_arg_names = []
for arg in out_args: for arg in out_args:
if isinstance(arg.name, six.string_types): out_arg_names.append(cpt.to_text(arg.name))
out_arg_names.append(arg.name)
elif isinstance(arg.name, six.binary_type):
out_arg_names.append(arg.name.decode())
else:
raise TypeError(
"arguments require unicode, str or bytes, but get %s instead."
% (type(arg.name)))
arg.op = self arg.op = self
self.desc.set_output(out_proto.name, out_arg_names) self.desc.set_output(out_proto.name, out_arg_names)
@ -970,10 +957,9 @@ class Block(object):
Variable: the Variable with the giving name. Variable: the Variable with the giving name.
""" """
if not isinstance(name, six.string_types): if not isinstance(name, six.string_types):
if not isinstance(name, six.binary_type): raise TypeError(
raise TypeError( "var require string as parameter, but get %s instead." %
"var require string as parameter, but get %s instead." % (type(name)))
(type(name)))
v = self.vars.get(name, None) v = self.vars.get(name, None)
if v is None: if v is None:
raise ValueError("var %s not in this block" % name) raise ValueError("var %s not in this block" % name)
@ -1024,7 +1010,7 @@ class Block(object):
return list(self.iter_parameters()) return list(self.iter_parameters())
def iter_parameters(self): def iter_parameters(self):
return (item[1] for item in list(self.vars.items()) return (item[1] for item in six.iteritems(self.vars)
if isinstance(item[1], Parameter)) if isinstance(item[1], Parameter))
def create_var(self, *args, **kwargs): def create_var(self, *args, **kwargs):
@ -1052,6 +1038,9 @@ class Block(object):
Returns: Returns:
Variable: the Variable with the giving name. Variable: the Variable with the giving name.
""" """
name = cpt.to_text(name)
new_name = cpt.to_text(new_name)
if not self.has_var(name): if not self.has_var(name):
raise ValueError("var %s is not in current block" % name) raise ValueError("var %s is not in current block" % name)
v = self.var(name) v = self.var(name)
@ -1070,9 +1059,9 @@ class Block(object):
else: else:
raise ValueError("unsupported var type: %s", type(v)) raise ValueError("unsupported var type: %s", type(v))
orig_var_type = v.type orig_var_type = v.type
self.desc._rename_var(name, new_name) self.desc._rename_var(cpt.to_bytes(name), cpt.to_bytes(new_name))
# NOTE: v is destroyed by C++ after calling _rename_var. # NOTE: v is destroyed by C++ after calling _rename_var.
d = self.desc.find_var(new_name) d = self.desc.find_var(cpt.to_bytes(new_name))
if var_type == "Parameter": if var_type == "Parameter":
var = Parameter( var = Parameter(
self, self,
@ -1103,7 +1092,7 @@ class Block(object):
def _remove_var(self, name): def _remove_var(self, name):
self._sync_with_cpp() self._sync_with_cpp()
self.desc._remove_var(name) self.desc._remove_var(cpt.to_bytes(name))
del self.vars[name] del self.vars[name]
def create_parameter(self, *args, **kwargs): def create_parameter(self, *args, **kwargs):
@ -1205,7 +1194,7 @@ class Block(object):
# sync variables removed from c++ end # sync variables removed from c++ end
for var in list(self.vars.keys()): for var in list(self.vars.keys()):
if not self.desc.find_var(var): if not self.desc.find_var(cpt.to_bytes(var)):
self.vars.pop(var) self.vars.pop(var)
# sync operators from cpp # sync operators from cpp
@ -1576,7 +1565,9 @@ class Program(object):
p.current_block_idx = self.current_block_idx p.current_block_idx = self.current_block_idx
p._seed = self._seed p._seed = self._seed
p.desc = core.ProgramDesc(self.desc) p.desc = core.ProgramDesc(self.desc)
p.blocks = [Block(p, i) for i in xrange(self.desc.num_blocks())] p.blocks = [
Block(p, i) for i in six.moves.range(self.desc.num_blocks())
]
p._current_role = self._current_role p._current_role = self._current_role
p._op_role_var = self._op_role_var p._op_role_var = self._op_role_var
@ -1632,7 +1623,9 @@ class Program(object):
targets_idx.append([t.block.idx, t.idx]) targets_idx.append([t.block.idx, t.idx])
res = Program() res = Program()
res.desc = core.prune(self.desc, targets_idx) res.desc = core.prune(self.desc, targets_idx)
res.blocks = [Block(res, i) for i in range(res.desc.num_blocks())] res.blocks = [
Block(res, i) for i in six.moves.range(res.desc.num_blocks())
]
res._sync_with_cpp() res._sync_with_cpp()
return res return res
@ -1675,16 +1668,18 @@ class Program(object):
root_block._remove_op(0, read_op_idx + 1) root_block._remove_op(0, read_op_idx + 1)
for var in root_block.all_vars(): for var in root_block.all_vars():
if var.type() == core.VarDesc.VarType.READER: if var.type() == core.VarDesc.VarType.READER:
root_block._remove_var(var.name()) root_block._remove_var(cpt.to_bytes(var.name()))
# change all `is_test` attributes to True # change all `is_test` attributes to True
for i in range(res.desc.num_blocks()): for i in six.moves.range(res.desc.num_blocks()):
block = res.desc.block(i) block = res.desc.block(i)
for j in range(block.op_size()): for j in six.moves.range(block.op_size()):
op = block.op(j) op = block.op(j)
if op.has_attr('is_test'): if op.has_attr('is_test'):
op.set_attr('is_test', True) op.set_attr('is_test', True)
res.blocks = [Block(res, i) for i in range(res.desc.num_blocks())] res.blocks = [
Block(res, i) for i in six.moves.range(res.desc.num_blocks())
]
res._sync_with_cpp() res._sync_with_cpp()
return res return res
@ -1704,7 +1699,7 @@ class Program(object):
""" """
p = Program() p = Program()
p.desc = core.ProgramDesc(binary_str) p.desc = core.ProgramDesc(binary_str)
p.blocks = [Block(p, i) for i in range(p.desc.num_blocks())] p.blocks = [Block(p, i) for i in six.moves.range(p.desc.num_blocks())]
p._sync_with_cpp() p._sync_with_cpp()
return p return p

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save