You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/mindspore/mindrecord/shardwriter.py

227 lines
7.0 KiB

# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
This module is to write data into mindrecord.
"""
import numpy as np
import mindspore._c_mindrecord as ms
from mindspore import log as logger
from .common.exceptions import MRMOpenError, MRMOpenForAppendError, MRMInvalidHeaderSizeError, \
MRMInvalidPageSizeError, MRMSetHeaderError, MRMWriteDatasetError, MRMCommitError
__all__ = ['ShardWriter']
class ShardWriter:
"""
Wrapper class which is represent shardWrite class in c++ module.
The class would write MindRecord File series.
"""
def __init__(self):
self._writer = ms.ShardWriter()
self._header = None
self._is_open = False
def open(self, paths):
"""
Open a new MindRecord File and prepare to write raw data.
Args:
paths (list[str]): List of file path.
Returns:
MSRStatus, SUCCESS or FAILED.
Raises:
MRMOpenError: If failed to open MindRecord File.
"""
ret = self._writer.open(paths, False)
if ret != ms.MSRStatus.SUCCESS:
logger.error("Failed to open paths")
raise MRMOpenError
self._is_open = True
return ret
def open_for_append(self, path):
"""
Open a existed MindRecord File and prepare to append raw data.
Args:
path (str): String of file path.
Returns:
MSRStatus, SUCCESS or FAILED.
Raises:
MRMOpenForAppendError: If failed to append MindRecord File.
"""
ret = self._writer.open_for_append(path)
if ret != ms.MSRStatus.SUCCESS:
logger.error("Failed to open path to append.")
raise MRMOpenForAppendError
self._is_open = True
return ret
def set_header_size(self, header_size):
"""
Set the size of header.
Args:
header_size (int): Size of header, between 16KB and 128MB.
Returns:
MSRStatus, SUCCESS or FAILED.
Raises:
MRMInvalidHeaderSizeError: If failed to set header size.
"""
ret = self._writer.set_header_size(header_size)
if ret != ms.MSRStatus.SUCCESS:
logger.error("Failed to set header size.")
raise MRMInvalidHeaderSizeError
return ret
def set_page_size(self, page_size):
"""
Set the size of page.
Args:
page_size (int): Size of page, between 16KB and 128MB.
Returns:
MSRStatus, SUCCESS or FAILED.
Raises:
MRMInvalidPageSizeError: If failed to set page size.
"""
ret = self._writer.set_page_size(page_size)
if ret != ms.MSRStatus.SUCCESS:
logger.error("Failed to set page size.")
raise MRMInvalidPageSizeError
return ret
def set_shard_header(self, shard_header):
"""
Set header which contains schema and index before write raw data.
Args:
shard_header (ShardHeader): Object of ShardHeader.
Returns:
MSRStatus, SUCCESS or FAILED.
Raises:
MRMSetHeaderError: If failed to set header.
"""
self._header = shard_header
ret = self._writer.set_shard_header(shard_header.header)
if ret != ms.MSRStatus.SUCCESS:
logger.error("Failed to set header.")
raise MRMSetHeaderError
return ret
def get_shard_header(self):
return self._header
def write_raw_data(self, data, validate=True, parallel_writer=False):
"""
Write raw data of cv dataset.
Filter data according to schema and separate blob data from raw data.
Support data verify according to schema and remove the invalid data.
Args:
data (list[dict]): List of raw data.
validate (bool, optional): verify data according schema if it equals to True.
parallel_writer (bool, optional): Load data parallel if it equals to True.
Returns:
MSRStatus, SUCCESS or FAILED.
Raises:
MRMWriteCVError: If failed to write cv type dataset.
"""
blob_data = []
raw_data = []
# slice data to blob data and raw data
for item in data:
row_blob = self._merge_blob({field: item[field] for field in self._header.blob_fields})
if row_blob:
blob_data.append(list(row_blob))
# filter raw data according to schema
row_raw = {field: self._convert_np_types(item[field])
for field in self._header.schema.keys() - self._header.blob_fields if field in item}
if row_raw:
raw_data.append(row_raw)
raw_data = {0: raw_data} if raw_data else {}
ret = self._writer.write_raw_data(raw_data, blob_data, validate, parallel_writer)
if ret != ms.MSRStatus.SUCCESS:
logger.error("Failed to write dataset.")
raise MRMWriteDatasetError
return ret
def _convert_np_types(self, val):
"""convert numpy type to python primitive type"""
if isinstance(val, (np.int32, np.int64, np.float32, np.float64)):
return val.item()
return val
def _merge_blob(self, blob_data):
"""
Merge multiple blob data whose type is bytes or ndarray
Args:
blob_data (dict): Dict of blob data
Returns:
bytes, merged blob data
"""
if len(blob_data) == 1:
values = [v for v in blob_data.values()]
return bytes(values[0])
# convert int to bytes
def int_to_bytes(x: int) -> bytes:
return x.to_bytes(8, 'big')
merged = bytes()
for field, v in blob_data.items():
# convert ndarray to bytes
if isinstance(v, np.ndarray):
v = v.astype(self._header.schema[field]["type"]).tobytes()
merged += int_to_bytes(len(v))
merged += v
return merged
def commit(self):
"""
Flush data to disk.
Returns:
Class MSRStatus, SUCCESS or FAILED.
Raises:
MRMCommitError: If failed to flush data to disk.
"""
ret = self._writer.commit()
if ret != ms.MSRStatus.SUCCESS:
logger.error("Failed to commit.")
raise MRMCommitError
return ret
@property
def is_open(self):
"""getter function"""
return self._is_open