DataLoader supprot dict str (#31481)
* add dict/str/list supprot for DataLoader. test=developpull/1/head
parent
30a627aaf3
commit
a32e8bf1e7
@ -0,0 +1,87 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
import numbers
|
||||
import numpy as np
|
||||
from ..framework import in_dygraph_mode
|
||||
from .. import core, layers
|
||||
|
||||
try:
|
||||
from collections.abc import Sequence, Mapping
|
||||
except:
|
||||
from collections import Sequence, Mapping
|
||||
|
||||
|
||||
def default_collate_fn(batch):
|
||||
"""
|
||||
Default batch collating function for :code:`paddle.io.DataLoader`,
|
||||
batch should be a list of samples, and each sample should be a list
|
||||
of fields as follows:
|
||||
|
||||
[[filed1, filed2, ...], [filed1, filed2, ...], ...]
|
||||
|
||||
This default collate function zipped each filed together and stack
|
||||
each filed as the batch field as follows:
|
||||
|
||||
[batch_filed1, batch_filed2, ...]
|
||||
|
||||
Args:
|
||||
batch(list of list of numpy array|paddle.Tensor): the batch data, each fields
|
||||
should be a numpy array, each sample should be a list of
|
||||
fileds, and batch should be a list of sample.
|
||||
|
||||
Returns:
|
||||
a list of numpy array|Paddle.Tensor: collated batch of input batch data,
|
||||
fields data type as same as fields in each sample.
|
||||
"""
|
||||
sample = batch[0]
|
||||
if isinstance(sample, np.ndarray):
|
||||
batch = np.stack(batch, axis=0)
|
||||
return batch
|
||||
elif isinstance(sample, paddle.Tensor):
|
||||
return layers.stack(batch, axis=0)
|
||||
elif isinstance(sample, numbers.Number):
|
||||
batch = np.array(batch)
|
||||
return batch
|
||||
elif isinstance(sample, (str, bytes)):
|
||||
return batch
|
||||
elif isinstance(sample, Mapping):
|
||||
return {
|
||||
key: default_collate_fn([d[key] for d in batch])
|
||||
for key in sample
|
||||
}
|
||||
elif isinstance(sample, Sequence):
|
||||
sample_fields_num = len(sample)
|
||||
if not all(len(sample) == sample_fields_num for sample in iter(batch)):
|
||||
raise RuntimeError(
|
||||
"fileds number not same among samples in a batch")
|
||||
return [default_collate_fn(fields) for fields in zip(*batch)]
|
||||
|
||||
raise TypeError("batch data con only contains: tensor, numpy.ndarray, "
|
||||
"dict, list, number, but got {}".format(type(sample)))
|
||||
return outputs
|
||||
|
||||
|
||||
def default_convert_fn(batch):
|
||||
if isinstance(batch, (paddle.Tensor, np.ndarray)):
|
||||
return batch
|
||||
elif isinstance(batch, (str, bytes)):
|
||||
return batch
|
||||
elif isinstance(batch, Mapping):
|
||||
return {key: default_convert_fn(batch[key]) for key in batch}
|
||||
elif isinstance(batch, Sequence):
|
||||
return [default_convert_fn(d) for d in batch]
|
||||
else:
|
||||
return batch
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,150 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
import numbers
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from collections.abc import Sequence, Mapping
|
||||
except:
|
||||
from collections import Sequence, Mapping
|
||||
|
||||
FIELD_PREFIX = "_paddle_field_"
|
||||
|
||||
|
||||
def _flatten_batch(batch):
|
||||
"""
|
||||
For lod_blocking_queue only receive tensor array, flatten batch
|
||||
data, extract numpy.array data out as a list of numpy.array to
|
||||
send to lod_blocking_queue, and save the batch data structure
|
||||
such as fields in other types (str, int, etc) or key-value map
|
||||
of dictionaries
|
||||
"""
|
||||
|
||||
def _flatten(batch, flat_batch, structure, field_idx):
|
||||
if isinstance(batch, Sequence):
|
||||
for field in batch:
|
||||
if isinstance(field, np.ndarray):
|
||||
structure.append('{}{}'.format(FIELD_PREFIX, field_idx))
|
||||
flat_batch.append(field)
|
||||
field_idx += 1
|
||||
elif isinstance(field, paddle.Tensor):
|
||||
structure.append('{}{}'.format(FIELD_PREFIX, field_idx))
|
||||
flat_batch.append(field.numpy())
|
||||
field_idx += 1
|
||||
elif isinstance(field, (str, bytes, numbers.Number)):
|
||||
structure.append(field)
|
||||
elif isinstance(field, Sequence):
|
||||
field_struct, field_idx = _flatten(field, flat_batch, [],
|
||||
field_idx)
|
||||
structure.append(field_struct)
|
||||
elif isinstance(field, Mapping):
|
||||
field_struct, field_idx = _flatten(field, flat_batch, {},
|
||||
field_idx)
|
||||
structure.append(field_struct)
|
||||
else:
|
||||
structure.append(field)
|
||||
elif isinstance(batch, Mapping):
|
||||
for k, field in batch.items():
|
||||
if isinstance(field, np.ndarray):
|
||||
structure[k] = '{}{}'.format(FIELD_PREFIX, field_idx)
|
||||
flat_batch.append(field)
|
||||
field_idx += 1
|
||||
elif isinstance(field, paddle.Tensor):
|
||||
structure[k] = '{}{}'.format(FIELD_PREFIX, field_idx)
|
||||
flat_batch.append(field.numpy())
|
||||
field_idx += 1
|
||||
elif isinstance(field, (str, bytes, numbers.Number)):
|
||||
structure[k] = field
|
||||
elif isinstance(field, Sequence):
|
||||
field_struct, field_idx = _flatten(field, flat_batch, [],
|
||||
field_idx)
|
||||
structure[k] = field_struct
|
||||
elif isinstance(field, Mapping):
|
||||
field_struct, field_idx = _flatten(field, flat_batch, {},
|
||||
field_idx)
|
||||
structure[k] = field_struct
|
||||
else:
|
||||
structure[k] = field
|
||||
else:
|
||||
raise TypeError("wrong flat data type: {}".format(type(batch)))
|
||||
|
||||
return structure, field_idx
|
||||
|
||||
# sample only contains single fields
|
||||
if not isinstance(batch, Sequence):
|
||||
flat_batch = []
|
||||
structure, _ = _flatten([batch], flat_batch, [], 0)
|
||||
return flat_batch, structure[0]
|
||||
flat_batch = []
|
||||
structure, _ = _flatten(batch, flat_batch, [], 0)
|
||||
return flat_batch, structure
|
||||
|
||||
|
||||
def _restore_batch(flat_batch, structure):
|
||||
"""
|
||||
After reading list of Tensor data from lod_blocking_queue outputs,
|
||||
use this function to restore the batch data structrue, replace
|
||||
:attr:`_paddle_field_x` with data from flat_batch
|
||||
"""
|
||||
|
||||
def _restore(structure, field_idx):
|
||||
if isinstance(structure, Sequence):
|
||||
for i, field in enumerate(structure):
|
||||
if isinstance(field, str) and field.startswith(FIELD_PREFIX):
|
||||
cur_field_idx = int(field.replace(FIELD_PREFIX, ''))
|
||||
field_idx = max(field_idx, cur_field_idx)
|
||||
assert flat_batch[cur_field_idx] is not None, \
|
||||
"flat_batch[{}] parsed repeatly"
|
||||
structure[i] = flat_batch[cur_field_idx]
|
||||
flat_batch[cur_field_idx] = None
|
||||
elif isinstance(field, (str, bytes, numbers.Number)):
|
||||
continue
|
||||
elif isinstance(field, (Sequence, Mapping)):
|
||||
field_idx = _restore(structure[i], field_idx)
|
||||
elif isinstance(structure, Mapping):
|
||||
for k, field in structure.items():
|
||||
if isinstance(field, str) and field.startswith(FIELD_PREFIX):
|
||||
cur_field_idx = int(field.replace(FIELD_PREFIX, ''))
|
||||
field_idx = max(field_idx, cur_field_idx)
|
||||
assert flat_batch[cur_field_idx] is not None, \
|
||||
"flat_batch[{}] parsed repeatly"
|
||||
structure[k] = flat_batch[cur_field_idx]
|
||||
flat_batch[cur_field_idx] = None
|
||||
elif isinstance(field, (str, bytes, numbers.Number)):
|
||||
continue
|
||||
elif isinstance(field, (Sequence, Mapping)):
|
||||
field_idx = _restore(structure[k], field_idx)
|
||||
else:
|
||||
raise TypeError("wrong flat data type: {}".format(type(batch)))
|
||||
|
||||
return field_idx
|
||||
|
||||
assert isinstance(flat_batch, Sequence), \
|
||||
"flat_batch is not a list or tuple"
|
||||
|
||||
# no np.array in dataset, no output tensor from blocking queue
|
||||
# simply return structure
|
||||
if len(flat_batch) == 0:
|
||||
return structure
|
||||
|
||||
# sample only contains single fields
|
||||
if isinstance(structure, (str, bytes)):
|
||||
assert structure == '{}{}'.format(FIELD_PREFIX, 0), \
|
||||
"invalid structure: {}".format(structure)
|
||||
return flat_batch[0]
|
||||
field_idx = _restore(structure, 0)
|
||||
assert field_idx + 1 == len(flat_batch), "Tensor parse incomplete"
|
||||
return structure
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue