Add a high-level API with traning and inference into Paddle. (#24293)
* Merge hapi into Paddle Hapi is a high level API for training and inference. The main modules include Model, Loss, Metrics, Dataset. Also includes common modules and models in NLP and computer vision, such as BERT, ResNet. These modules are developed by: 0YuanZhang0, guoshengCS heavengate, LielinJiang, qingqing01, xyzhou-puck huangjun12, wangxiao1021, zhangyang.release/2.0-alpha
parent
4af3ec0f8a
commit
43625bdabd
@ -0,0 +1,38 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import logger
|
||||
from . import progressbar
|
||||
from . import callbacks
|
||||
from . import download
|
||||
from . import model
|
||||
from . import metrics
|
||||
from . import loss
|
||||
from . import datasets
|
||||
from . import distributed
|
||||
from . import vision
|
||||
|
||||
logger.setup_logger()
|
||||
|
||||
__all__ = [
|
||||
'callbacks',
|
||||
'datasets',
|
||||
'distributed',
|
||||
'download',
|
||||
'metrics',
|
||||
'loss',
|
||||
'vision',
|
||||
]
|
||||
|
||||
__all__ += model.__all__
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,25 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import folder
|
||||
from . import mnist
|
||||
from . import flowers
|
||||
|
||||
from .folder import *
|
||||
from .mnist import *
|
||||
from .flowers import *
|
||||
|
||||
__all__ = folder.__all__ \
|
||||
+ mnist.__all__ \
|
||||
+ flowers.__all__
|
@ -0,0 +1,129 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import io
|
||||
import tarfile
|
||||
import numpy as np
|
||||
import scipy.io as scio
|
||||
from PIL import Image
|
||||
|
||||
from paddle.io import Dataset
|
||||
from .utils import _check_exists_and_download
|
||||
|
||||
__all__ = ["Flowers"]
|
||||
|
||||
DATA_URL = 'http://paddlemodels.bj.bcebos.com/flowers/102flowers.tgz'
|
||||
LABEL_URL = 'http://paddlemodels.bj.bcebos.com/flowers/imagelabels.mat'
|
||||
SETID_URL = 'http://paddlemodels.bj.bcebos.com/flowers/setid.mat'
|
||||
DATA_MD5 = '52808999861908f626f3c1f4e79d11fa'
|
||||
LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d'
|
||||
SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c'
|
||||
|
||||
# In official 'readme', tstid is the flag of test data
|
||||
# and trnid is the flag of train data. But test data is more than train data.
|
||||
# So we exchange the train data and test data.
|
||||
MODE_FLAG_MAP = {'train': 'tstid', 'test': 'trnid', 'valid': "valid"}
|
||||
|
||||
|
||||
class Flowers(Dataset):
|
||||
"""
|
||||
Implement of flowers dataset
|
||||
|
||||
Args:
|
||||
data_file(str): path to data file, can be set None if
|
||||
:attr:`download` is True. Default None
|
||||
label_file(str): path to label file, can be set None if
|
||||
:attr:`download` is True. Default None
|
||||
setid_file(str): path to subset index file, can be set
|
||||
None if :attr:`download` is True. Default None
|
||||
mode(str): 'train', 'valid' or 'test' mode. Default 'train'.
|
||||
download(bool): whether auto download mnist dataset if
|
||||
:attr:`image_path`/:attr:`label_path` unset. Default
|
||||
True
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from paddle.incubate.hapi.datasets import Flowers
|
||||
|
||||
flowers = Flowers(mode='test')
|
||||
|
||||
for i in range(len(flowers)):
|
||||
sample = flowers[i]
|
||||
print(sample[0].shape, sample[1])
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
data_file=None,
|
||||
label_file=None,
|
||||
setid_file=None,
|
||||
mode='train',
|
||||
transform=None,
|
||||
download=True):
|
||||
assert mode.lower() in ['train', 'valid', 'test'], \
|
||||
"mode should be 'train', 'valid' or 'test', but got {}".format(mode)
|
||||
self.flag = MODE_FLAG_MAP[mode.lower()]
|
||||
|
||||
self.data_file = data_file
|
||||
if self.data_file is None:
|
||||
assert download, "data_file not set and auto download disabled"
|
||||
self.data_file = _check_exists_and_download(
|
||||
data_file, DATA_URL, DATA_MD5, 'flowers', download)
|
||||
|
||||
self.label_file = label_file
|
||||
if self.label_file is None:
|
||||
assert download, "label_file not set and auto download disabled"
|
||||
self.label_file = _check_exists_and_download(
|
||||
label_file, LABEL_URL, LABEL_MD5, 'flowers', download)
|
||||
|
||||
self.setid_file = setid_file
|
||||
if self.setid_file is None:
|
||||
assert download, "setid_file not set and auto download disabled"
|
||||
self.setid_file = _check_exists_and_download(
|
||||
setid_file, SETID_URL, SETID_MD5, 'flowers', download)
|
||||
|
||||
self.transform = transform
|
||||
|
||||
# read dataset into memory
|
||||
self._load_anno()
|
||||
|
||||
def _load_anno(self):
|
||||
self.name2mem = {}
|
||||
self.data_tar = tarfile.open(self.data_file)
|
||||
for ele in self.data_tar.getmembers():
|
||||
self.name2mem[ele.name] = ele
|
||||
|
||||
self.labels = scio.loadmat(self.label_file)['labels'][0]
|
||||
self.indexes = scio.loadmat(self.setid_file)[self.flag][0]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
index = self.indexes[idx]
|
||||
label = np.array([self.labels[index - 1]])
|
||||
img_name = "jpg/image_%05d.jpg" % index
|
||||
img_ele = self.name2mem[img_name]
|
||||
image = self.data_tar.extractfile(img_ele).read()
|
||||
image = np.array(Image.open(io.BytesIO(image)))
|
||||
|
||||
if self.transform is not None:
|
||||
image = self.transform(image)
|
||||
|
||||
return image, label.astype('int64')
|
||||
|
||||
def __len__(self):
|
||||
return len(self.indexes)
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,162 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import gzip
|
||||
import struct
|
||||
import numpy as np
|
||||
|
||||
import paddle.dataset.common
|
||||
from paddle.io import Dataset
|
||||
from .utils import _check_exists_and_download
|
||||
|
||||
__all__ = ["MNIST"]
|
||||
|
||||
URL_PREFIX = 'https://dataset.bj.bcebos.com/mnist/'
|
||||
TEST_IMAGE_URL = URL_PREFIX + 't10k-images-idx3-ubyte.gz'
|
||||
TEST_IMAGE_MD5 = '9fb629c4189551a2d022fa330f9573f3'
|
||||
TEST_LABEL_URL = URL_PREFIX + 't10k-labels-idx1-ubyte.gz'
|
||||
TEST_LABEL_MD5 = 'ec29112dd5afa0611ce80d1b7f02629c'
|
||||
TRAIN_IMAGE_URL = URL_PREFIX + 'train-images-idx3-ubyte.gz'
|
||||
TRAIN_IMAGE_MD5 = 'f68b3c2dcbeaaa9fbdd348bbdeb94873'
|
||||
TRAIN_LABEL_URL = URL_PREFIX + 'train-labels-idx1-ubyte.gz'
|
||||
TRAIN_LABEL_MD5 = 'd53e105ee54ea40749a09fcbcd1e9432'
|
||||
|
||||
|
||||
class MNIST(Dataset):
|
||||
"""
|
||||
Implement of MNIST dataset
|
||||
|
||||
Args:
|
||||
image_path(str): path to image file, can be set None if
|
||||
:attr:`download` is True. Default None
|
||||
label_path(str): path to label file, can be set None if
|
||||
:attr:`download` is True. Default None
|
||||
chw_format(bool): If set True, the output shape is [1, 28, 28],
|
||||
otherwise, output shape is [1, 784]. Default True.
|
||||
mode(str): 'train' or 'test' mode. Default 'train'.
|
||||
download(bool): whether auto download mnist dataset if
|
||||
:attr:`image_path`/:attr:`label_path` unset. Default
|
||||
True
|
||||
|
||||
Returns:
|
||||
Dataset: MNIST Dataset.
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from paddle.incubate.hapi.datasets import MNIST
|
||||
|
||||
mnist = MNIST(mode='test')
|
||||
|
||||
for i in range(len(mnist)):
|
||||
sample = mnist[i]
|
||||
print(sample[0].shape, sample[1])
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
image_path=None,
|
||||
label_path=None,
|
||||
chw_format=True,
|
||||
mode='train',
|
||||
transform=None,
|
||||
download=True):
|
||||
assert mode.lower() in ['train', 'test'], \
|
||||
"mode should be 'train' or 'test', but got {}".format(mode)
|
||||
self.mode = mode.lower()
|
||||
self.chw_format = chw_format
|
||||
self.image_path = image_path
|
||||
if self.image_path is None:
|
||||
assert download, "image_path not set and auto download disabled"
|
||||
image_url = TRAIN_IMAGE_URL if mode == 'train' else TEST_IMAGE_URL
|
||||
image_md5 = TRAIN_IMAGE_MD5 if mode == 'train' else TEST_IMAGE_MD5
|
||||
self.image_path = _check_exists_and_download(
|
||||
image_path, image_url, image_md5, 'mnist', download)
|
||||
|
||||
self.label_path = label_path
|
||||
if self.label_path is None:
|
||||
assert download, "label_path not set and auto download disabled"
|
||||
label_url = TRAIN_LABEL_URL if mode == 'train' else TEST_LABEL_URL
|
||||
label_md5 = TRAIN_LABEL_MD5 if mode == 'train' else TEST_LABEL_MD5
|
||||
self.label_path = _check_exists_and_download(
|
||||
label_path, label_url, label_md5, 'mnist', download)
|
||||
|
||||
self.transform = transform
|
||||
|
||||
# read dataset into memory
|
||||
self._parse_dataset()
|
||||
|
||||
def _parse_dataset(self, buffer_size=100):
|
||||
self.images = []
|
||||
self.labels = []
|
||||
with gzip.GzipFile(self.image_path, 'rb') as image_file:
|
||||
img_buf = image_file.read()
|
||||
with gzip.GzipFile(self.label_path, 'rb') as label_file:
|
||||
lab_buf = label_file.read()
|
||||
|
||||
step_label = 0
|
||||
offset_img = 0
|
||||
# read from Big-endian
|
||||
# get file info from magic byte
|
||||
# image file : 16B
|
||||
magic_byte_img = '>IIII'
|
||||
magic_img, image_num, rows, cols = struct.unpack_from(
|
||||
magic_byte_img, img_buf, offset_img)
|
||||
offset_img += struct.calcsize(magic_byte_img)
|
||||
|
||||
offset_lab = 0
|
||||
# label file : 8B
|
||||
magic_byte_lab = '>II'
|
||||
magic_lab, label_num = struct.unpack_from(magic_byte_lab,
|
||||
lab_buf, offset_lab)
|
||||
offset_lab += struct.calcsize(magic_byte_lab)
|
||||
|
||||
while True:
|
||||
if step_label >= label_num:
|
||||
break
|
||||
fmt_label = '>' + str(buffer_size) + 'B'
|
||||
labels = struct.unpack_from(fmt_label, lab_buf, offset_lab)
|
||||
offset_lab += struct.calcsize(fmt_label)
|
||||
step_label += buffer_size
|
||||
|
||||
fmt_images = '>' + str(buffer_size * rows * cols) + 'B'
|
||||
images_temp = struct.unpack_from(fmt_images, img_buf,
|
||||
offset_img)
|
||||
images = np.reshape(images_temp, (buffer_size, rows *
|
||||
cols)).astype('float32')
|
||||
offset_img += struct.calcsize(fmt_images)
|
||||
|
||||
images = images / 255.0
|
||||
images = images * 2.0
|
||||
images = images - 1.0
|
||||
|
||||
for i in range(buffer_size):
|
||||
self.images.append(images[i, :])
|
||||
self.labels.append(
|
||||
np.array([labels[i]]).astype('int64'))
|
||||
|
||||
def __getitem__(self, idx):
|
||||
image, label = self.images[idx], self.labels[idx]
|
||||
if self.chw_format:
|
||||
image = np.reshape(image, [1, 28, 28])
|
||||
if self.transform is not None:
|
||||
image = self.transform(image)
|
||||
return image, label
|
||||
|
||||
def __len__(self):
|
||||
return len(self.labels)
|
@ -0,0 +1,29 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import paddle.dataset.common
|
||||
|
||||
|
||||
def _check_exists_and_download(path, url, md5, module_name, download=True):
|
||||
if path and os.path.exists(path):
|
||||
return path
|
||||
|
||||
if download:
|
||||
return paddle.dataset.common.download(url, module_name, md5)
|
||||
else:
|
||||
raise ValueError('{} not exists and auto download disabled'.format(
|
||||
path))
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,235 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
import os.path as osp
|
||||
import shutil
|
||||
import requests
|
||||
import hashlib
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from paddle.fluid.dygraph.parallel import ParallelEnv
|
||||
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
except:
|
||||
|
||||
class tqdm(object):
|
||||
def __init__(self, total=None):
|
||||
self.total = total
|
||||
self.n = 0
|
||||
|
||||
def update(self, n):
|
||||
self.n += n
|
||||
if self.total is None:
|
||||
sys.stderr.write("\r{0:.1f} bytes".format(self.n))
|
||||
else:
|
||||
sys.stderr.write("\r{0:.1f}%".format(100 * self.n / float(
|
||||
self.total)))
|
||||
sys.stderr.flush()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
sys.stderr.write('\n')
|
||||
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ['get_weights_path_from_url']
|
||||
|
||||
WEIGHTS_HOME = osp.expanduser("~/.cache/paddle/hapi/weights")
|
||||
|
||||
DOWNLOAD_RETRY_LIMIT = 3
|
||||
|
||||
nlp_models = OrderedDict((
|
||||
('RoBERTa-zh-base',
|
||||
'https://bert-models.bj.bcebos.com/chinese_roberta_wwm_ext_L-12_H-768_A-12.tar.gz'
|
||||
),
|
||||
('RoBERTa-zh-large',
|
||||
'https://bert-models.bj.bcebos.com/chinese_roberta_wwm_large_ext_L-24_H-1024_A-16.tar.gz'
|
||||
),
|
||||
('ERNIE-v2-en-base',
|
||||
'https://ernie.bj.bcebos.com/ERNIE_Base_en_stable-2.0.0.tar.gz'),
|
||||
('ERNIE-v2-en-large',
|
||||
'https://ernie.bj.bcebos.com/ERNIE_Large_en_stable-2.0.0.tar.gz'),
|
||||
('XLNet-cased-base',
|
||||
'https://xlnet.bj.bcebos.com/xlnet_cased_L-12_H-768_A-12.tgz'),
|
||||
('XLNet-cased-large',
|
||||
'https://xlnet.bj.bcebos.com/xlnet_cased_L-24_H-1024_A-16.tgz'),
|
||||
('ERNIE-v1-zh-base',
|
||||
'https://baidu-nlp.bj.bcebos.com/ERNIE_stable-1.0.1.tar.gz'),
|
||||
('ERNIE-v1-zh-base-max-len-512',
|
||||
'https://ernie.bj.bcebos.com/ERNIE_1.0_max-len-512.tar.gz'),
|
||||
('BERT-en-uncased-large-whole-word-masking',
|
||||
'https://bert-models.bj.bcebos.com/wwm_uncased_L-24_H-1024_A-16.tar.gz'),
|
||||
('BERT-en-cased-large-whole-word-masking',
|
||||
'https://bert-models.bj.bcebos.com/wwm_cased_L-24_H-1024_A-16.tar.gz'),
|
||||
('BERT-en-uncased-base',
|
||||
'https://bert-models.bj.bcebos.com/uncased_L-12_H-768_A-12.tar.gz'),
|
||||
('BERT-en-uncased-large',
|
||||
'https://bert-models.bj.bcebos.com/uncased_L-24_H-1024_A-16.tar.gz'),
|
||||
('BERT-en-cased-base',
|
||||
'https://bert-models.bj.bcebos.com/cased_L-12_H-768_A-12.tar.gz'),
|
||||
('BERT-en-cased-large',
|
||||
'https://bert-models.bj.bcebos.com/cased_L-24_H-1024_A-16.tar.gz'),
|
||||
('BERT-multilingual-uncased-base',
|
||||
'https://bert-models.bj.bcebos.com/multilingual_L-12_H-768_A-12.tar.gz'),
|
||||
('BERT-multilingual-cased-base',
|
||||
'https://bert-models.bj.bcebos.com/multi_cased_L-12_H-768_A-12.tar.gz'),
|
||||
('BERT-zh-base',
|
||||
'https://bert-models.bj.bcebos.com/chinese_L-12_H-768_A-12.tar.gz'), ))
|
||||
|
||||
|
||||
def is_url(path):
|
||||
"""
|
||||
Whether path is URL.
|
||||
Args:
|
||||
path (string): URL string or not.
|
||||
"""
|
||||
return path.startswith('http://') or path.startswith('https://')
|
||||
|
||||
|
||||
def get_weights_path_from_url(url, md5sum=None):
|
||||
"""Get weights path from WEIGHT_HOME, if not exists,
|
||||
download it from url.
|
||||
|
||||
Args:
|
||||
url (str): download url
|
||||
md5sum (str): md5 sum of download package
|
||||
|
||||
Returns:
|
||||
str: a local path to save downloaded weights.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
from paddle.incubate.hapi.download import get_weights_path_from_url
|
||||
|
||||
resnet18_pretrained_weight_url = 'https://paddle-hapi.bj.bcebos.com/models/resnet18.pdparams'
|
||||
local_weight_path = get_weights_path_from_url(resnet18_pretrained_weight_url)
|
||||
|
||||
"""
|
||||
path = get_path_from_url(url, WEIGHTS_HOME, md5sum)
|
||||
return path
|
||||
|
||||
|
||||
def _map_path(url, root_dir):
|
||||
# parse path after download under root_dir
|
||||
fname = osp.split(url)[-1]
|
||||
fpath = fname
|
||||
return osp.join(root_dir, fpath)
|
||||
|
||||
|
||||
def get_path_from_url(url, root_dir, md5sum=None, check_exist=True):
|
||||
""" Download from given url to root_dir.
|
||||
if file or directory specified by url is exists under
|
||||
root_dir, return the path directly, otherwise download
|
||||
from url and decompress it, return the path.
|
||||
|
||||
Args:
|
||||
url (str): download url
|
||||
root_dir (str): root dir for downloading, it should be
|
||||
WEIGHTS_HOME or DATASET_HOME
|
||||
md5sum (str): md5 sum of download package
|
||||
|
||||
Returns:
|
||||
str: a local path to save downloaded models & weights & datasets.
|
||||
"""
|
||||
assert is_url(url), "downloading from {} not a url".format(url)
|
||||
# parse path after download to decompress under root_dir
|
||||
fullpath = _map_path(url, root_dir)
|
||||
|
||||
if osp.exists(fullpath) and check_exist and _md5check(fullpath, md5sum):
|
||||
logger.info("Found {}".format(fullpath))
|
||||
else:
|
||||
if ParallelEnv().local_rank == 0:
|
||||
fullpath = _download(url, root_dir, md5sum)
|
||||
else:
|
||||
while not os.path.exists(fullpath):
|
||||
time.sleep(1)
|
||||
return fullpath
|
||||
|
||||
|
||||
def _download(url, path, md5sum=None):
|
||||
"""
|
||||
Download from url, save to path.
|
||||
|
||||
url (str): download url
|
||||
path (str): download to given path
|
||||
"""
|
||||
if not osp.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
fname = osp.split(url)[-1]
|
||||
fullname = osp.join(path, fname)
|
||||
retry_cnt = 0
|
||||
|
||||
while not (osp.exists(fullname) and _md5check(fullname, md5sum)):
|
||||
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
|
||||
retry_cnt += 1
|
||||
else:
|
||||
raise RuntimeError("Download from {} failed. "
|
||||
"Retry limit reached".format(url))
|
||||
|
||||
logger.info("Downloading {} from {}".format(fname, url))
|
||||
|
||||
req = requests.get(url, stream=True)
|
||||
if req.status_code != 200:
|
||||
raise RuntimeError("Downloading from {} failed with code "
|
||||
"{}!".format(url, req.status_code))
|
||||
|
||||
# For protecting download interupted, download to
|
||||
# tmp_fullname firstly, move tmp_fullname to fullname
|
||||
# after download finished
|
||||
tmp_fullname = fullname + "_tmp"
|
||||
total_size = req.headers.get('content-length')
|
||||
with open(tmp_fullname, 'wb') as f:
|
||||
if total_size:
|
||||
with tqdm(total=(int(total_size) + 1023) // 1024) as pbar:
|
||||
for chunk in req.iter_content(chunk_size=1024):
|
||||
f.write(chunk)
|
||||
pbar.update(1)
|
||||
else:
|
||||
for chunk in req.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
shutil.move(tmp_fullname, fullname)
|
||||
|
||||
return fullname
|
||||
|
||||
|
||||
def _md5check(fullname, md5sum=None):
|
||||
if md5sum is None:
|
||||
return True
|
||||
|
||||
logger.info("File {} md5 checking...".format(fullname))
|
||||
md5 = hashlib.md5()
|
||||
with open(fullname, 'rb') as f:
|
||||
for chunk in iter(lambda: f.read(4096), b""):
|
||||
md5.update(chunk)
|
||||
calc_md5sum = md5.hexdigest()
|
||||
|
||||
if calc_md5sum != md5sum:
|
||||
logger.info("File {} md5 check failed, {}(calc) != "
|
||||
"{}(base)".format(fullname, calc_md5sum, md5sum))
|
||||
return False
|
||||
return True
|
@ -0,0 +1,71 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
|
||||
from paddle.fluid.dygraph.parallel import ParallelEnv
|
||||
|
||||
|
||||
def setup_logger(output=None, name="hapi", log_level=logging.INFO):
|
||||
"""
|
||||
Initialize logger of hapi and set its verbosity level to "INFO".
|
||||
|
||||
Args:
|
||||
output (str): a file name or a directory to save log. If None, will not save log file.
|
||||
If ends with ".txt" or ".log", assumed to be a file name.
|
||||
Otherwise, logs will be saved to `output/log.txt`.
|
||||
name (str): the root module name of this logger. Default: 'hapi'.
|
||||
log_level (enum): log level. eg.'INFO', 'DEBUG', 'ERROR'. Default: logging.INFO.
|
||||
Returns:
|
||||
logging.Logger: a logger
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
logger.propagate = False
|
||||
logger.setLevel(log_level)
|
||||
|
||||
format_str = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
# stdout logging: only local rank==0
|
||||
local_rank = ParallelEnv().local_rank
|
||||
if local_rank == 0 and len(logger.handlers) == 0:
|
||||
ch = logging.StreamHandler(stream=sys.stdout)
|
||||
ch.setLevel(log_level)
|
||||
|
||||
ch.setFormatter(logging.Formatter(format_str))
|
||||
logger.addHandler(ch)
|
||||
|
||||
# file logging if output is not None: all workers
|
||||
if output is not None:
|
||||
if output.endswith(".txt") or output.endswith(".log"):
|
||||
filename = output
|
||||
else:
|
||||
filename = os.path.join(output, "log.txt")
|
||||
|
||||
if local_rank > 0:
|
||||
filename = filename + ".rank{}".format(local_rank)
|
||||
|
||||
if not os.path.exists(os.path.dirname(filename)):
|
||||
os.makedirs(os.path.dirname(filename))
|
||||
|
||||
fh = logging.StreamHandler(filename)
|
||||
fh.setLevel(log_level)
|
||||
fh.setFormatter(logging.Formatter(format_str))
|
||||
logger.addHandler(fh)
|
||||
|
||||
return logger
|
@ -0,0 +1,145 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from paddle import fluid
|
||||
from paddle.fluid.framework import in_dygraph_mode, Variable
|
||||
from paddle.fluid.dygraph.base import to_variable
|
||||
|
||||
from .utils import to_list
|
||||
|
||||
__all__ = ['Loss', 'CrossEntropy', 'SoftmaxWithCrossEntropy']
|
||||
|
||||
|
||||
class Loss(object):
|
||||
"""
|
||||
Base class for loss, encapsulates loss logic and APIs
|
||||
|
||||
Usage:
|
||||
custom_loss = CustomLoss()
|
||||
loss = custom_loss(inputs, labels)
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
from paddle.incubate.hapi.loss import Loss
|
||||
from paddle import fluid
|
||||
|
||||
class SoftmaxWithCrossEntropy(Loss):
|
||||
def __init__(self, average=True):
|
||||
super(SoftmaxWithCrossEntropy, self).__init__(average)
|
||||
|
||||
def forward(self, outputs, labels):
|
||||
return [
|
||||
fluid.layers.softmax_with_cross_entropy(
|
||||
o, l, return_softmax=False) for o, l in zip(outputs, labels)
|
||||
]
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, average=True):
|
||||
super(Loss, self).__init__()
|
||||
self.average = average
|
||||
|
||||
def forward(self, outputs, labels):
|
||||
raise NotImplementedError()
|
||||
|
||||
def __call__(self, outputs, labels=None):
|
||||
labels = to_list(labels)
|
||||
if in_dygraph_mode() and labels:
|
||||
labels = [to_variable(l) for l in labels]
|
||||
losses = to_list(self.forward(to_list(outputs), labels))
|
||||
if self.average:
|
||||
losses = [fluid.layers.reduce_mean(l) for l in losses]
|
||||
else:
|
||||
losses = [fluid.layers.reduce_sum(l) for l in losses]
|
||||
return losses
|
||||
|
||||
|
||||
class CrossEntropy(Loss):
|
||||
"""
|
||||
Args:
|
||||
input (list[Variable]): Input tensor, the data type is float32,
|
||||
float64, int32, int64.
|
||||
label (list[Variable]): Label tensor, the data type is float32,
|
||||
float64, int32, int64.
|
||||
average (bool, optional): Indicate whether to average the loss, Default: True.
|
||||
Returns:
|
||||
list[Variable]: The tensor variable storing the cross_entropy_loss of inputs and labels.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
from paddle.incubate.hapi.model import Input
|
||||
from paddle.incubate.hapi.vision.models import LeNet
|
||||
from paddle.incubate.hapi.loss import CrossEntropy
|
||||
|
||||
inputs = [Input([-1, 1, 28, 28], 'float32', name='image')]
|
||||
labels = [Input([None, 1], 'int64', name='label')]
|
||||
|
||||
model = LeNet()
|
||||
loss = CrossEntropy()
|
||||
model.prepare(loss_function=loss, inputs=inputs, labels=labels)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, average=True):
|
||||
super(CrossEntropy, self).__init__(average)
|
||||
|
||||
def forward(self, outputs, labels):
|
||||
return [
|
||||
fluid.layers.cross_entropy(o, l) for o, l in zip(outputs, labels)
|
||||
]
|
||||
|
||||
|
||||
class SoftmaxWithCrossEntropy(Loss):
|
||||
"""
|
||||
this op combined softmax and cross entropy.
|
||||
Args:
|
||||
input (list[Variable]): Input tensor, the data type is float32,
|
||||
float64, int32, int64.
|
||||
label (list[Variable]): Label tensor, the data type is float32,
|
||||
float64, int32, int64.
|
||||
average (bool, optional): Indicate whether to average the loss, Default: True.
|
||||
Returns:
|
||||
list[Variable]: The tensor variable storing the cross_entropy_loss of inputs and labels.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
from paddle.incubate.hapi.model import Input
|
||||
from paddle.incubate.hapi.vision.models import LeNet
|
||||
from paddle.incubate.hapi.loss import SoftmaxWithCrossEntropy
|
||||
|
||||
inputs = [Input([-1, 1, 28, 28], 'float32', name='image')]
|
||||
labels = [Input([None, 1], 'int64', name='label')]
|
||||
|
||||
model = LeNet(classifier_activation=None)
|
||||
loss = SoftmaxWithCrossEntropy()
|
||||
model.prepare(loss_function=loss, inputs=inputs, labels=labels)
|
||||
"""
|
||||
|
||||
def __init__(self, average=True):
|
||||
super(SoftmaxWithCrossEntropy, self).__init__(average)
|
||||
|
||||
def forward(self, outputs, labels):
|
||||
return [
|
||||
fluid.layers.softmax_with_cross_entropy(
|
||||
o, l, return_softmax=False) for o, l in zip(outputs, labels)
|
||||
]
|
@ -0,0 +1,242 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import six
|
||||
import abc
|
||||
import numpy as np
|
||||
import paddle.fluid as fluid
|
||||
|
||||
import logging
|
||||
|
||||
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
|
||||
logging.basicConfig(level=logging.INFO, format=FORMAT)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ['Metric', 'Accuracy']
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class Metric(object):
|
||||
"""
|
||||
Base class for metric, encapsulates metric logic and APIs
|
||||
Usage:
|
||||
|
||||
m = SomeMetric()
|
||||
for prediction, label in ...:
|
||||
m.update(prediction, label)
|
||||
m.accumulate()
|
||||
|
||||
Advanced usage for :code:`add_metric_op`
|
||||
Metric calculating con be accelerate by calucateing metric states
|
||||
from model outputs and labels by Paddle OPs in :code:`add_metric_op`,
|
||||
metric states will be fetch as numpy array and call :code:`update`
|
||||
with states in numpy format.
|
||||
Metric calculated as follows (operations in Model and Metric are
|
||||
indicated with curly brackets, while data nodes not):
|
||||
inputs & labels || ------------------
|
||||
| ||
|
||||
{model} ||
|
||||
| ||
|
||||
outputs & labels ||
|
||||
| || tensor data
|
||||
{Metric.add_metric_op} ||
|
||||
| ||
|
||||
metric states(tensor) ||
|
||||
| ||
|
||||
{fetch as numpy} || ------------------
|
||||
| ||
|
||||
metric states(numpy) || numpy data
|
||||
| ||
|
||||
{Metric.update} \/ ------------------
|
||||
Examples:
|
||||
|
||||
For :code:`Accuracy` metric, which takes :code:`pred` and :code:`label`
|
||||
as inputs, we can calculate the correct prediction matrix between
|
||||
:code:`pred` and :code:`label` in :code:`add_metric_op`.
|
||||
For examples, prediction results contains 10 classes, while :code:`pred`
|
||||
shape is [N, 10], :code:`label` shape is [N, 1], N is mini-batch size,
|
||||
and we only need to calculate accurary of top-1 and top-5, we could
|
||||
calculated the correct prediction matrix of the top-5 scores of the
|
||||
prediction of each sample like follows, while the correct prediction
|
||||
matrix shape is [N, 5].
|
||||
.. code-block:: python
|
||||
def add_metric_op(pred, label):
|
||||
# sort prediction and slice the top-5 scores
|
||||
pred = fluid.layers.argsort(pred, descending=True)[1][:, :5]
|
||||
# calculate whether the predictions are correct
|
||||
correct = pred == label
|
||||
return fluid.layers.cast(correct, dtype='float32')
|
||||
With the :code:`add_metric_op`, we split some calculations to OPs(which
|
||||
may run on GPU devices, will be faster), and only fetch 1 tensor with
|
||||
shape as [N, 5] instead of 2 tensors with shapes as [N, 10] and [N, 1].
|
||||
:code:`update` can be define as follows:
|
||||
.. code-block:: python
|
||||
def update(self, correct):
|
||||
accs = []
|
||||
for i, k in enumerate(self.topk):
|
||||
num_corrects = correct[:, :k].sum()
|
||||
num_samples = len(correct)
|
||||
accs.append(float(num_corrects) / num_samples)
|
||||
self.total[i] += num_corrects
|
||||
self.count[i] += num_samples
|
||||
return accs
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def reset(self):
|
||||
"""
|
||||
Reset states and result
|
||||
"""
|
||||
raise NotImplementedError("function 'reset' not implemented in {}.".
|
||||
format(self.__class__.__name__))
|
||||
|
||||
@abc.abstractmethod
|
||||
def update(self, *args):
|
||||
"""
|
||||
Update states for metric
|
||||
|
||||
Inputs of :code:`update` is the outputs of :code:`Metric.add_metric_op`,
|
||||
if :code:`add_metric_op` is not defined, the inputs of :code:`update`
|
||||
will be flatten arguments of **output** of mode and **label** from data:
|
||||
:code:`update(output1, output2, ..., label1, label2,...)`
|
||||
|
||||
see :code:`Metric.add_metric_op`
|
||||
"""
|
||||
raise NotImplementedError("function 'update' not implemented in {}.".
|
||||
format(self.__class__.__name__))
|
||||
|
||||
@abc.abstractmethod
|
||||
def accumulate(self):
|
||||
"""
|
||||
Accumulates statistics, computes and returns the metric value
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"function 'accumulate' not implemented in {}.".format(
|
||||
self.__class__.__name__))
|
||||
|
||||
@abc.abstractmethod
|
||||
def name(self):
|
||||
"""
|
||||
Returns metric name
|
||||
"""
|
||||
raise NotImplementedError("function 'name' not implemented in {}.".
|
||||
format(self.__class__.__name__))
|
||||
|
||||
def add_metric_op(self, *args):
|
||||
"""
|
||||
This API is advanced usage to accelerate metric calculating, calulations
|
||||
from outputs of model to the states which should be updated by Metric can
|
||||
be defined here, where Paddle OPs is also supported. Outputs of this API
|
||||
will be the inputs of "Metric.update".
|
||||
|
||||
If :code:`add_metric_op` is defined, it will be called with **outputs**
|
||||
of model and **labels** from data as arguments, all outputs and labels
|
||||
will be concatenated and flatten and each filed as a separate argument
|
||||
as follows:
|
||||
:code:`add_metric_op(output1, output2, ..., label1, label2,...)`
|
||||
|
||||
If :code:`add_metric_op` is not defined, default behaviour is to pass
|
||||
input to output, so output format will be:
|
||||
:code:`return output1, output2, ..., label1, label2,...`
|
||||
|
||||
see :code:`Metric.update`
|
||||
"""
|
||||
return args
|
||||
|
||||
|
||||
class Accuracy(Metric):
|
||||
"""
|
||||
Encapsulates accuracy metric logic
|
||||
|
||||
Examples:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from paddle import fluid
|
||||
from paddle.incubate.hapi.metrics import Accuracy
|
||||
from paddle.incubate.hapi.loss import CrossEntropy
|
||||
from paddle.incubate.hapi.datasets import MNIST
|
||||
from paddle.incubate.hapi.model import Input
|
||||
from paddle.incubate.hapi.vision.models import LeNet
|
||||
|
||||
fluid.enable_dygraph()
|
||||
|
||||
train_dataset = MNIST(mode='train')
|
||||
|
||||
model = LeNet()
|
||||
optim = fluid.optimizer.Adam(
|
||||
learning_rate=0.001, parameter_list=model.parameters())
|
||||
|
||||
inputs = [Input([-1, 1, 28, 28], 'float32', name='image')]
|
||||
labels = [Input([None, 1], 'int64', name='label')]
|
||||
|
||||
model.prepare(
|
||||
optim,
|
||||
loss_function=CrossEntropy(average=False),
|
||||
metrics=Accuracy(),
|
||||
inputs=inputs,
|
||||
labels=labels)
|
||||
|
||||
model.fit(train_dataset, batch_size=64)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, topk=(1, ), name=None, *args, **kwargs):
|
||||
super(Accuracy, self).__init__(*args, **kwargs)
|
||||
self.topk = topk
|
||||
self.maxk = max(topk)
|
||||
self._init_name(name)
|
||||
self.reset()
|
||||
|
||||
def add_metric_op(self, pred, label, *args):
|
||||
pred = fluid.layers.argsort(pred, descending=True)[1][:, :self.maxk]
|
||||
correct = pred == label
|
||||
return fluid.layers.cast(correct, dtype='float32')
|
||||
|
||||
def update(self, correct, *args):
|
||||
accs = []
|
||||
for i, k in enumerate(self.topk):
|
||||
num_corrects = correct[:, :k].sum()
|
||||
num_samples = len(correct)
|
||||
accs.append(float(num_corrects) / num_samples)
|
||||
self.total[i] += num_corrects
|
||||
self.count[i] += num_samples
|
||||
return accs
|
||||
|
||||
def reset(self):
|
||||
self.total = [0.] * len(self.topk)
|
||||
self.count = [0] * len(self.topk)
|
||||
|
||||
def accumulate(self):
|
||||
res = []
|
||||
for t, c in zip(self.total, self.count):
|
||||
res.append(float(t) / c)
|
||||
return res
|
||||
|
||||
def _init_name(self, name):
|
||||
name = name or 'acc'
|
||||
if self.maxk != 1:
|
||||
self._name = ['{}_top{}'.format(name, k) for k in self.topk]
|
||||
else:
|
||||
self._name = [name]
|
||||
|
||||
def name(self):
|
||||
return self._name
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,192 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import numpy as np
|
||||
from collections import namedtuple
|
||||
|
||||
__all__ = ['ProgressBar']
|
||||
|
||||
|
||||
class ProgressBar(object):
|
||||
"""progress bar """
|
||||
|
||||
def __init__(self,
|
||||
num=None,
|
||||
width=30,
|
||||
verbose=1,
|
||||
start=True,
|
||||
file=sys.stdout):
|
||||
self._num = num
|
||||
if isinstance(num, int) and num <= 0:
|
||||
raise TypeError('num should be None or integer (> 0)')
|
||||
max_width = self._get_max_width()
|
||||
self._width = width if width <= max_width else max_width
|
||||
self._total_width = 0
|
||||
self._verbose = verbose
|
||||
self.file = file
|
||||
self._values = {}
|
||||
self._values_order = []
|
||||
if start:
|
||||
self._start = time.time()
|
||||
self._last_update = 0
|
||||
|
||||
self._dynamic_display = (
|
||||
(hasattr(self.file, 'isatty') and
|
||||
self.file.isatty()) or 'ipykernel' in sys.modules or
|
||||
'posix' in sys.modules or 'PYCHARM_HOSTED' in os.environ)
|
||||
|
||||
def _get_max_width(self):
|
||||
if sys.version_info > (3, 3):
|
||||
from shutil import get_terminal_size
|
||||
else:
|
||||
try:
|
||||
from backports.shutil_get_terminal_size import get_terminal_size
|
||||
except:
|
||||
|
||||
def get_terminal_size():
|
||||
terminal_size = namedtuple("terminal_size", "columns lines")
|
||||
return terminal_size(80, 24)
|
||||
|
||||
terminal_width, _ = get_terminal_size()
|
||||
max_width = min(int(terminal_width * 0.6), terminal_width - 50)
|
||||
return max_width
|
||||
|
||||
def start(self):
|
||||
self.file.flush()
|
||||
self._start = time.time()
|
||||
|
||||
def update(self, current_num, values=None):
|
||||
now = time.time()
|
||||
|
||||
if current_num:
|
||||
time_per_unit = (now - self._start) / current_num
|
||||
else:
|
||||
time_per_unit = 0
|
||||
|
||||
if time_per_unit >= 1 or time_per_unit == 0:
|
||||
fps = ' - %.0fs/%s' % (time_per_unit, 'step')
|
||||
elif time_per_unit >= 1e-3:
|
||||
fps = ' - %.0fms/%s' % (time_per_unit * 1e3, 'step')
|
||||
else:
|
||||
fps = ' - %.0fus/%s' % (time_per_unit * 1e6, 'step')
|
||||
|
||||
info = ''
|
||||
if self._verbose == 1:
|
||||
prev_total_width = self._total_width
|
||||
|
||||
if self._dynamic_display:
|
||||
sys.stdout.write('\b' * prev_total_width)
|
||||
sys.stdout.write('\r')
|
||||
else:
|
||||
sys.stdout.write('\n')
|
||||
|
||||
if self._num is not None:
|
||||
numdigits = int(np.log10(self._num)) + 1
|
||||
|
||||
bar_chars = ('step %' + str(numdigits) + 'd/%d [') % (
|
||||
current_num, self._num)
|
||||
prog = float(current_num) / self._num
|
||||
prog_width = int(self._width * prog)
|
||||
|
||||
if prog_width > 0:
|
||||
bar_chars += ('=' * (prog_width - 1))
|
||||
if current_num < self._num:
|
||||
bar_chars += '>'
|
||||
else:
|
||||
bar_chars += '='
|
||||
bar_chars += ('.' * (self._width - prog_width))
|
||||
bar_chars += ']'
|
||||
else:
|
||||
bar_chars = 'step %3d' % current_num
|
||||
|
||||
self._total_width = len(bar_chars)
|
||||
sys.stdout.write(bar_chars)
|
||||
|
||||
for k, val in values:
|
||||
info += ' - %s:' % k
|
||||
val = val if isinstance(val, list) else [val]
|
||||
for i, v in enumerate(val):
|
||||
if isinstance(v, (float, np.float32, np.float64)):
|
||||
if abs(v) > 1e-3:
|
||||
info += ' %.4f' % v
|
||||
else:
|
||||
info += ' %.4e' % v
|
||||
else:
|
||||
info += ' %s' % v
|
||||
|
||||
if self._num is not None and current_num < self._num:
|
||||
eta = time_per_unit * (self._num - current_num)
|
||||
if eta > 3600:
|
||||
eta_format = '%d:%02d:%02d' % (eta // 3600,
|
||||
(eta % 3600) // 60, eta % 60)
|
||||
elif eta > 60:
|
||||
eta_format = '%d:%02d' % (eta // 60, eta % 60)
|
||||
else:
|
||||
eta_format = '%ds' % eta
|
||||
|
||||
info += ' - ETA: %s' % eta_format
|
||||
|
||||
info += fps
|
||||
self._total_width += len(info)
|
||||
if prev_total_width > self._total_width:
|
||||
info += (' ' * (prev_total_width - self._total_width))
|
||||
|
||||
# newline for another epoch
|
||||
if self._num is not None and current_num >= self._num:
|
||||
info += '\n'
|
||||
if self._num is None:
|
||||
info += '\n'
|
||||
|
||||
sys.stdout.write(info)
|
||||
sys.stdout.flush()
|
||||
self._last_update = now
|
||||
elif self._verbose == 2:
|
||||
if self._num:
|
||||
numdigits = int(np.log10(self._num)) + 1
|
||||
count = ('step %' + str(numdigits) + 'd/%d') % (current_num,
|
||||
self._num)
|
||||
else:
|
||||
count = 'step %3d' % current_num
|
||||
info = count + info
|
||||
|
||||
for k, val in values:
|
||||
info += ' - %s:' % k
|
||||
val = val if isinstance(val, list) else [val]
|
||||
for v in val:
|
||||
if isinstance(v, (float, np.float32, np.float64)):
|
||||
if abs(v) > 1e-3:
|
||||
info += ' %.4f' % v
|
||||
else:
|
||||
info += ' %.4e' % v
|
||||
elif isinstance(v, np.ndarray) and \
|
||||
v.size == 1 and \
|
||||
v.dtype in [np.float32, np.float64]:
|
||||
if abs(v[0]) > 1e-3:
|
||||
info += ' %.4f' % v[0]
|
||||
else:
|
||||
info += ' %.4e' % v[0]
|
||||
else:
|
||||
info += ' %s' % v
|
||||
|
||||
info += fps
|
||||
info += '\n'
|
||||
sys.stdout.write(info)
|
||||
sys.stdout.flush()
|
@ -0,0 +1,45 @@
|
||||
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
|
||||
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
|
||||
|
||||
file(GLOB DIST_TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_dist_*.py")
|
||||
string(REPLACE ".py" "" DIST_TEST_OPS "${DIST_TEST_OPS}")
|
||||
|
||||
|
||||
foreach(TEST_OP ${DIST_TEST_OPS})
|
||||
list(REMOVE_ITEM TEST_OPS ${TEST_OP})
|
||||
endforeach()
|
||||
|
||||
foreach(src ${TEST_OPS})
|
||||
py_test(${src} SRCS ${src}.py)
|
||||
endforeach()
|
||||
|
||||
|
||||
function(py_dist_test TARGET_NAME)
|
||||
if(WITH_TESTING)
|
||||
set(options "")
|
||||
set(oneValueArgs "")
|
||||
set(multiValueArgs SRCS DEPS ARGS ENVS)
|
||||
cmake_parse_arguments(py_dist_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
|
||||
|
||||
if(WITH_COVERAGE AND WITH_GPU AND WITH_NCCL AND NOT WIN32)
|
||||
add_test(NAME ${TARGET_NAME}
|
||||
COMMAND ${CMAKE_COMMAND} -E env FLAGS_init_allocated_mem=true FLAGS_cudnn_deterministic=true
|
||||
FLAGS_cpu_deterministic=true NCCL_P2P_DISABLE=1 NCCL_SHM_DISABLE=1
|
||||
PYTHONPATH=${PADDLE_BINARY_DIR}/python ${py_dist_test_ENVS}
|
||||
COVERAGE_FILE=${PADDLE_BINARY_DIR}/python-coverage.data
|
||||
${PYTHON_EXECUTABLE} -u ${py_dist_test_SRCS} ${py_dist_test_ARGS}
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
|
||||
# No unit test should exceed 10 minutes.
|
||||
set_tests_properties(${TARGET_NAME} PROPERTIES TIMEOUT 600 LABELS "RUN_TYPE=DIST" RUN_SERIAL TRUE)
|
||||
endif()
|
||||
|
||||
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
|
||||
|
||||
foreach(src ${DIST_TEST_OPS})
|
||||
message(STATUS ${src})
|
||||
py_dist_test(${src} SRCS ${src}.py)
|
||||
endforeach()
|
@ -0,0 +1,100 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import contextlib
|
||||
|
||||
from paddle import fluid
|
||||
|
||||
from paddle.incubate.hapi.model import Model, Input, set_device
|
||||
from paddle.incubate.hapi.loss import CrossEntropy
|
||||
from paddle.incubate.hapi.vision.models import LeNet
|
||||
from paddle.incubate.hapi.metrics import Accuracy
|
||||
from paddle.incubate.hapi.callbacks import ProgBarLogger
|
||||
from paddle.incubate.hapi.datasets import MNIST
|
||||
|
||||
|
||||
class MnistDataset(MNIST):
|
||||
def __init__(self, mode, return_label=True):
|
||||
super(MnistDataset, self).__init__(mode=mode)
|
||||
self.return_label = return_label
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img = np.reshape(self.images[idx], [1, 28, 28])
|
||||
if self.return_label:
|
||||
return img, np.array(self.labels[idx]).astype('int64')
|
||||
return img,
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
|
||||
def compute_accuracy(pred, gt):
|
||||
pred = np.argmax(pred, -1)
|
||||
gt = np.array(gt)
|
||||
|
||||
correct = pred[:, np.newaxis] == gt
|
||||
|
||||
return np.sum(correct) / correct.shape[0]
|
||||
|
||||
|
||||
@unittest.skipIf(not fluid.is_compiled_with_cuda(),
|
||||
'CPU testing is not supported')
|
||||
class TestDistTraning(unittest.TestCase):
|
||||
def test_static_multiple_gpus(self):
|
||||
device = set_device('gpu')
|
||||
|
||||
fluid.enable_dygraph(device)
|
||||
im_shape = (-1, 1, 28, 28)
|
||||
batch_size = 128
|
||||
|
||||
inputs = [Input(im_shape, 'float32', name='image')]
|
||||
labels = [Input([None, 1], 'int64', name='label')]
|
||||
|
||||
train_dataset = MnistDataset(mode='train')
|
||||
val_dataset = MnistDataset(mode='test')
|
||||
test_dataset = MnistDataset(mode='test', return_label=False)
|
||||
|
||||
model = LeNet()
|
||||
optim = fluid.optimizer.Momentum(
|
||||
learning_rate=0.001, momentum=.9, parameter_list=model.parameters())
|
||||
loss = CrossEntropy()
|
||||
model.prepare(optim, loss, Accuracy(), inputs, labels, device=device)
|
||||
cbk = ProgBarLogger(50)
|
||||
|
||||
model.fit(train_dataset,
|
||||
val_dataset,
|
||||
epochs=2,
|
||||
batch_size=batch_size,
|
||||
callbacks=cbk)
|
||||
|
||||
eval_result = model.evaluate(val_dataset, batch_size=batch_size)
|
||||
|
||||
output = model.predict(
|
||||
test_dataset, batch_size=batch_size, stack_outputs=True)
|
||||
|
||||
np.testing.assert_equal(output[0].shape[0], len(test_dataset))
|
||||
|
||||
acc = compute_accuracy(output[0], val_dataset.labels)
|
||||
|
||||
np.testing.assert_allclose(acc, eval_result['acc'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,99 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import contextlib
|
||||
|
||||
from paddle import fluid
|
||||
|
||||
from paddle.incubate.hapi.model import Model, Input, set_device
|
||||
from paddle.incubate.hapi.loss import CrossEntropy
|
||||
from paddle.incubate.hapi.vision.models import LeNet
|
||||
from paddle.incubate.hapi.metrics import Accuracy
|
||||
from paddle.incubate.hapi.callbacks import ProgBarLogger
|
||||
from paddle.incubate.hapi.datasets import MNIST
|
||||
|
||||
|
||||
class MnistDataset(MNIST):
|
||||
def __init__(self, mode, return_label=True):
|
||||
super(MnistDataset, self).__init__(mode=mode)
|
||||
self.return_label = return_label
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img = np.reshape(self.images[idx], [1, 28, 28])
|
||||
if self.return_label:
|
||||
return img, np.array(self.labels[idx]).astype('int64')
|
||||
return img,
|
||||
|
||||
def __len__(self):
|
||||
return len(self.images)
|
||||
|
||||
|
||||
def compute_accuracy(pred, gt):
|
||||
pred = np.argmax(pred, -1)
|
||||
gt = np.array(gt)
|
||||
|
||||
correct = pred[:, np.newaxis] == gt
|
||||
|
||||
return np.sum(correct) / correct.shape[0]
|
||||
|
||||
|
||||
@unittest.skipIf(not fluid.is_compiled_with_cuda(),
|
||||
'CPU testing is not supported')
|
||||
class TestDistTraning(unittest.TestCase):
|
||||
def test_static_multiple_gpus(self):
|
||||
device = set_device('gpu')
|
||||
|
||||
im_shape = (-1, 1, 28, 28)
|
||||
batch_size = 128
|
||||
|
||||
inputs = [Input(im_shape, 'float32', name='image')]
|
||||
labels = [Input([None, 1], 'int64', name='label')]
|
||||
|
||||
train_dataset = MnistDataset(mode='train')
|
||||
val_dataset = MnistDataset(mode='test')
|
||||
test_dataset = MnistDataset(mode='test', return_label=False)
|
||||
|
||||
model = LeNet()
|
||||
optim = fluid.optimizer.Momentum(
|
||||
learning_rate=0.001, momentum=.9, parameter_list=model.parameters())
|
||||
loss = CrossEntropy()
|
||||
model.prepare(optim, loss, Accuracy(), inputs, labels, device=device)
|
||||
cbk = ProgBarLogger(50)
|
||||
|
||||
model.fit(train_dataset,
|
||||
val_dataset,
|
||||
epochs=2,
|
||||
batch_size=batch_size,
|
||||
callbacks=cbk)
|
||||
|
||||
eval_result = model.evaluate(val_dataset, batch_size=batch_size)
|
||||
|
||||
output = model.predict(
|
||||
test_dataset, batch_size=batch_size, stack_outputs=True)
|
||||
|
||||
np.testing.assert_equal(output[0].shape[0], len(test_dataset))
|
||||
|
||||
acc = compute_accuracy(output[0], val_dataset.labels)
|
||||
|
||||
np.testing.assert_allclose(acc, eval_result['acc'])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,106 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import time
|
||||
import random
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
from paddle.incubate.hapi.model import Input
|
||||
from paddle.incubate.hapi.vision.models import LeNet
|
||||
from paddle.incubate.hapi.callbacks import config_callbacks
|
||||
|
||||
|
||||
class TestCallbacks(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.save_dir = tempfile.mkdtemp()
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.save_dir)
|
||||
|
||||
def run_callback(self):
|
||||
epochs = 2
|
||||
steps = 50
|
||||
freq = 2
|
||||
eval_steps = 20
|
||||
|
||||
lenet = LeNet()
|
||||
inputs = [Input([None, 1, 28, 28], 'float32', name='image')]
|
||||
lenet.prepare(inputs=inputs)
|
||||
|
||||
cbks = config_callbacks(
|
||||
model=lenet,
|
||||
batch_size=128,
|
||||
epochs=epochs,
|
||||
steps=steps,
|
||||
log_freq=freq,
|
||||
verbose=self.verbose,
|
||||
metrics=['loss', 'acc'],
|
||||
save_dir=self.save_dir)
|
||||
cbks.on_begin('train')
|
||||
|
||||
logs = {'loss': 50.341673, 'acc': 0.00256}
|
||||
for epoch in range(epochs):
|
||||
cbks.on_epoch_begin(epoch)
|
||||
for step in range(steps):
|
||||
cbks.on_batch_begin('train', step, logs)
|
||||
logs['loss'] -= random.random() * 0.1
|
||||
logs['acc'] += random.random() * 0.1
|
||||
time.sleep(0.005)
|
||||
cbks.on_batch_end('train', step, logs)
|
||||
cbks.on_epoch_end(epoch, logs)
|
||||
|
||||
eval_logs = {'eval_loss': 20.341673, 'eval_acc': 0.256}
|
||||
params = {
|
||||
'steps': eval_steps,
|
||||
'metrics': ['eval_loss', 'eval_acc'],
|
||||
}
|
||||
cbks.on_begin('eval', params)
|
||||
for step in range(eval_steps):
|
||||
cbks.on_batch_begin('eval', step, eval_logs)
|
||||
eval_logs['eval_loss'] -= random.random() * 0.1
|
||||
eval_logs['eval_acc'] += random.random() * 0.1
|
||||
eval_logs['batch_size'] = 2
|
||||
time.sleep(0.005)
|
||||
cbks.on_batch_end('eval', step, eval_logs)
|
||||
cbks.on_end('eval', eval_logs)
|
||||
|
||||
test_logs = {}
|
||||
params = {'steps': eval_steps}
|
||||
cbks.on_begin('test', params)
|
||||
for step in range(eval_steps):
|
||||
cbks.on_batch_begin('test', step, test_logs)
|
||||
test_logs['batch_size'] = 2
|
||||
time.sleep(0.005)
|
||||
cbks.on_batch_end('test', step, test_logs)
|
||||
cbks.on_end('test', test_logs)
|
||||
|
||||
cbks.on_end('train')
|
||||
|
||||
def test_callback_verbose_0(self):
|
||||
self.verbose = 0
|
||||
self.run_callback()
|
||||
|
||||
def test_callback_verbose_1(self):
|
||||
self.verbose = 1
|
||||
self.run_callback()
|
||||
|
||||
def test_callback_verbose_2(self):
|
||||
self.verbose = 2
|
||||
self.run_callback()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,159 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import os
|
||||
import numpy as np
|
||||
import tempfile
|
||||
import shutil
|
||||
import cv2
|
||||
|
||||
from paddle.incubate.hapi.datasets import *
|
||||
from paddle.incubate.hapi.datasets.utils import _check_exists_and_download
|
||||
|
||||
|
||||
class TestFolderDatasets(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.data_dir = tempfile.mkdtemp()
|
||||
self.empty_dir = tempfile.mkdtemp()
|
||||
for i in range(2):
|
||||
sub_dir = os.path.join(self.data_dir, 'class_' + str(i))
|
||||
if not os.path.exists(sub_dir):
|
||||
os.makedirs(sub_dir)
|
||||
for j in range(2):
|
||||
fake_img = (np.random.random((32, 32, 3)) * 255).astype('uint8')
|
||||
cv2.imwrite(os.path.join(sub_dir, str(j) + '.jpg'), fake_img)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.data_dir)
|
||||
|
||||
def test_dataset(self):
|
||||
dataset_folder = DatasetFolder(self.data_dir)
|
||||
|
||||
for _ in dataset_folder:
|
||||
pass
|
||||
|
||||
assert len(dataset_folder) == 4
|
||||
assert len(dataset_folder.classes) == 2
|
||||
|
||||
dataset_folder = DatasetFolder(self.data_dir)
|
||||
for _ in dataset_folder:
|
||||
pass
|
||||
|
||||
def test_folder(self):
|
||||
loader = ImageFolder(self.data_dir)
|
||||
|
||||
for _ in loader:
|
||||
pass
|
||||
|
||||
loader = ImageFolder(self.data_dir)
|
||||
for _ in loader:
|
||||
pass
|
||||
|
||||
assert len(loader) == 4
|
||||
|
||||
def test_transform(self):
|
||||
def fake_transform(img):
|
||||
return img
|
||||
|
||||
transfrom = fake_transform
|
||||
dataset_folder = DatasetFolder(self.data_dir, transform=transfrom)
|
||||
|
||||
for _ in dataset_folder:
|
||||
pass
|
||||
|
||||
loader = ImageFolder(self.data_dir, transform=transfrom)
|
||||
for _ in loader:
|
||||
pass
|
||||
|
||||
def test_errors(self):
|
||||
with self.assertRaises(RuntimeError):
|
||||
ImageFolder(self.empty_dir)
|
||||
with self.assertRaises(RuntimeError):
|
||||
DatasetFolder(self.empty_dir)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
_check_exists_and_download('temp_paddle', None, None, None, False)
|
||||
|
||||
|
||||
class TestMNISTTest(unittest.TestCase):
|
||||
def test_main(self):
|
||||
mnist = MNIST(mode='test')
|
||||
self.assertTrue(len(mnist) == 10000)
|
||||
|
||||
for i in range(len(mnist)):
|
||||
image, label = mnist[i]
|
||||
self.assertTrue(image.shape[0] == 1)
|
||||
self.assertTrue(image.shape[1] == 28)
|
||||
self.assertTrue(image.shape[2] == 28)
|
||||
self.assertTrue(label.shape[0] == 1)
|
||||
self.assertTrue(0 <= int(label) <= 9)
|
||||
|
||||
|
||||
class TestMNISTTrain(unittest.TestCase):
|
||||
def test_main(self):
|
||||
mnist = MNIST(mode='train', chw_format=False)
|
||||
self.assertTrue(len(mnist) == 60000)
|
||||
|
||||
for i in range(len(mnist)):
|
||||
image, label = mnist[i]
|
||||
self.assertTrue(image.shape[0] == 784)
|
||||
self.assertTrue(label.shape[0] == 1)
|
||||
self.assertTrue(0 <= int(label) <= 9)
|
||||
|
||||
|
||||
class TestFlowersTrain(unittest.TestCase):
|
||||
def test_main(self):
|
||||
flowers = Flowers(mode='train')
|
||||
self.assertTrue(len(flowers) == 6149)
|
||||
|
||||
# traversal whole dataset may cost a
|
||||
# long time, randomly check 1 sample
|
||||
idx = np.random.randint(0, 6149)
|
||||
image, label = flowers[idx]
|
||||
self.assertTrue(len(image.shape) == 3)
|
||||
self.assertTrue(image.shape[2] == 3)
|
||||
self.assertTrue(label.shape[0] == 1)
|
||||
|
||||
|
||||
class TestFlowersValid(unittest.TestCase):
|
||||
def test_main(self):
|
||||
flowers = Flowers(mode='valid')
|
||||
self.assertTrue(len(flowers) == 1020)
|
||||
|
||||
# traversal whole dataset may cost a
|
||||
# long time, randomly check 1 sample
|
||||
idx = np.random.randint(0, 1020)
|
||||
image, label = flowers[idx]
|
||||
self.assertTrue(len(image.shape) == 3)
|
||||
self.assertTrue(image.shape[2] == 3)
|
||||
self.assertTrue(label.shape[0] == 1)
|
||||
|
||||
|
||||
class TestFlowersTest(unittest.TestCase):
|
||||
def test_main(self):
|
||||
flowers = Flowers(mode='test')
|
||||
self.assertTrue(len(flowers) == 1020)
|
||||
|
||||
# traversal whole dataset may cost a
|
||||
# long time, randomly check 1 sample
|
||||
idx = np.random.randint(0, 1020)
|
||||
image, label = flowers[idx]
|
||||
self.assertTrue(len(image.shape) == 3)
|
||||
self.assertTrue(image.shape[2] == 3)
|
||||
self.assertTrue(label.shape[0] == 1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,130 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import os
|
||||
import time
|
||||
import copy
|
||||
import subprocess
|
||||
import paddle.fluid as fluid
|
||||
|
||||
from paddle.distributed.utils import find_free_ports, watch_local_trainers, get_cluster, TrainerProc
|
||||
|
||||
|
||||
def get_cluster_from_args(selected_gpus):
|
||||
cluster_node_ips = '127.0.0.1'
|
||||
node_ip = '127.0.0.1'
|
||||
|
||||
node_ips = [x.strip() for x in cluster_node_ips.split(',')]
|
||||
|
||||
node_ips.index(node_ip)
|
||||
|
||||
free_ports = None
|
||||
|
||||
free_ports = find_free_ports(len(selected_gpus))
|
||||
if free_ports is not None:
|
||||
free_ports = list(free_ports)
|
||||
return get_cluster(node_ips, node_ip, free_ports, selected_gpus)
|
||||
|
||||
|
||||
def get_gpus(selected_gpus):
|
||||
selected_gpus = [x.strip() for x in selected_gpus.split(',')]
|
||||
return selected_gpus
|
||||
|
||||
|
||||
def start_local_trainers(cluster,
|
||||
pod,
|
||||
training_script,
|
||||
training_script_args,
|
||||
log_dir=None):
|
||||
current_env = copy.copy(os.environ.copy())
|
||||
#paddle broadcast ncclUniqueId use socket, and
|
||||
#proxy maybe make trainers unreachable, so delete them.
|
||||
#if we set them to "", grpc will log error message "bad uri"
|
||||
#so just delete them.
|
||||
current_env.pop("http_proxy", None)
|
||||
current_env.pop("https_proxy", None)
|
||||
|
||||
procs = []
|
||||
for t in pod.trainers:
|
||||
proc_env = {
|
||||
"FLAGS_selected_gpus": "%s" % ",".join([str(g) for g in t.gpus]),
|
||||
"PADDLE_TRAINER_ID": "%d" % t.rank,
|
||||
"PADDLE_CURRENT_ENDPOINT": "%s" % t.endpoint,
|
||||
"PADDLE_TRAINERS_NUM": "%d" % cluster.trainers_nranks(),
|
||||
"PADDLE_TRAINER_ENDPOINTS": ",".join(cluster.trainers_endpoints())
|
||||
}
|
||||
|
||||
current_env.update(proc_env)
|
||||
|
||||
print("trainer proc env:{}".format(current_env))
|
||||
|
||||
if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
|
||||
cmd = "python -m coverage run --branch -p " + training_script
|
||||
else:
|
||||
cmd = "python -u " + training_script
|
||||
|
||||
print("start trainer proc:{} env:{}".format(cmd, proc_env))
|
||||
|
||||
fn = None
|
||||
|
||||
proc = subprocess.Popen(cmd.split(" "), env=current_env)
|
||||
|
||||
tp = TrainerProc()
|
||||
tp.proc = proc
|
||||
tp.rank = t.rank
|
||||
tp.log_fn = fn
|
||||
tp.cmd = cmd
|
||||
|
||||
procs.append(tp)
|
||||
|
||||
return procs
|
||||
|
||||
|
||||
class TestMultipleGpus(unittest.TestCase):
|
||||
def run_mnist_2gpu(self, target_file_name):
|
||||
if fluid.core.get_cuda_device_count() == 0:
|
||||
return
|
||||
|
||||
selected_gpus = get_gpus('0,1')
|
||||
cluster = None
|
||||
pod = None
|
||||
|
||||
cluster, pod = get_cluster_from_args(selected_gpus)
|
||||
|
||||
procs = start_local_trainers(
|
||||
cluster,
|
||||
pod,
|
||||
training_script=target_file_name,
|
||||
training_script_args=[])
|
||||
|
||||
while True:
|
||||
alive = watch_local_trainers(procs, cluster.trainers_nranks())
|
||||
|
||||
if not alive:
|
||||
print("Local procs complete, POD info:{}".format(pod))
|
||||
break
|
||||
time.sleep(3)
|
||||
|
||||
def test_hapi_multiple_gpus_static(self):
|
||||
self.run_mnist_2gpu('dist_hapi_mnist_static.py')
|
||||
|
||||
def test_hapi_multiple_gpus_dynamic(self):
|
||||
self.run_mnist_2gpu('dist_hapi_mnist_dynamic.py')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -0,0 +1,50 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from paddle.incubate.hapi.download import get_weights_path_from_url
|
||||
|
||||
|
||||
class TestDownload(unittest.TestCase):
|
||||
def download(self, url, md5sum):
|
||||
get_weights_path_from_url(url, md5sum)
|
||||
|
||||
def test_download_model(self):
|
||||
url = 'https://paddle-hapi.bj.bcebos.com/models/mobilenet_v2_x1.0.pdparams'
|
||||
md5sum = '8ff74f291f72533f2a7956a4efff9d88'
|
||||
self.download(url, md5sum)
|
||||
|
||||
def test_exist_download(self):
|
||||
url = 'https://paddle-hapi.bj.bcebos.com/models/mobilenet_v2_x1.0.pdparams'
|
||||
md5sum = '8ff74f291f72533f2a7956a4efff9d88'
|
||||
self.download(url, md5sum)
|
||||
|
||||
def test_download_without_md5sum(self):
|
||||
url = 'https://paddle-hapi.bj.bcebos.com/models/mobilenet_v2_x1.0.pdparams'
|
||||
self.download(url, None)
|
||||
|
||||
def test_download_errors(self):
|
||||
with self.assertRaises(RuntimeError):
|
||||
url = 'https://paddle-hapi.bj.bcebos.com/models/mobilenet_v2_x1.0t.pdparams'
|
||||
md5sum = '8ff74f291f72533f2a7956a4eftttttt'
|
||||
self.download(url, md5sum)
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
url = 'https://paddle-hapi.bj.bcebos.com/models/mobilenet_v2_x1.0t.pdparams'
|
||||
self.download(url, None)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue