You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
227 lines
7.0 KiB
227 lines
7.0 KiB
# Copyright 2019 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""
|
|
This module is to write data into mindrecord.
|
|
"""
|
|
import numpy as np
|
|
import mindspore._c_mindrecord as ms
|
|
from mindspore import log as logger
|
|
from .common.exceptions import MRMOpenError, MRMOpenForAppendError, MRMInvalidHeaderSizeError, \
|
|
MRMInvalidPageSizeError, MRMSetHeaderError, MRMWriteDatasetError, MRMCommitError
|
|
|
|
__all__ = ['ShardWriter']
|
|
|
|
class ShardWriter:
|
|
"""
|
|
Wrapper class which is represent shardWrite class in c++ module.
|
|
|
|
The class would write MindRecord File series.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._writer = ms.ShardWriter()
|
|
self._header = None
|
|
self._is_open = False
|
|
|
|
def open(self, paths):
|
|
"""
|
|
Open a new MindRecord File and prepare to write raw data.
|
|
|
|
Args:
|
|
paths (list[str]): List of file path.
|
|
|
|
Returns:
|
|
MSRStatus, SUCCESS or FAILED.
|
|
|
|
Raises:
|
|
MRMOpenError: If failed to open MindRecord File.
|
|
"""
|
|
ret = self._writer.open(paths, False)
|
|
if ret != ms.MSRStatus.SUCCESS:
|
|
logger.error("Failed to open paths")
|
|
raise MRMOpenError
|
|
self._is_open = True
|
|
return ret
|
|
|
|
def open_for_append(self, path):
|
|
"""
|
|
Open a existed MindRecord File and prepare to append raw data.
|
|
|
|
Args:
|
|
path (str): String of file path.
|
|
|
|
Returns:
|
|
MSRStatus, SUCCESS or FAILED.
|
|
|
|
Raises:
|
|
MRMOpenForAppendError: If failed to append MindRecord File.
|
|
"""
|
|
ret = self._writer.open_for_append(path)
|
|
if ret != ms.MSRStatus.SUCCESS:
|
|
logger.error("Failed to open path to append.")
|
|
raise MRMOpenForAppendError
|
|
self._is_open = True
|
|
return ret
|
|
|
|
def set_header_size(self, header_size):
|
|
"""
|
|
Set the size of header.
|
|
|
|
Args:
|
|
header_size (int): Size of header, between 16KB and 128MB.
|
|
|
|
Returns:
|
|
MSRStatus, SUCCESS or FAILED.
|
|
|
|
Raises:
|
|
MRMInvalidHeaderSizeError: If failed to set header size.
|
|
"""
|
|
ret = self._writer.set_header_size(header_size)
|
|
if ret != ms.MSRStatus.SUCCESS:
|
|
logger.error("Failed to set header size.")
|
|
raise MRMInvalidHeaderSizeError
|
|
return ret
|
|
|
|
def set_page_size(self, page_size):
|
|
"""
|
|
Set the size of page.
|
|
|
|
Args:
|
|
page_size (int): Size of page, between 16KB and 128MB.
|
|
|
|
Returns:
|
|
MSRStatus, SUCCESS or FAILED.
|
|
|
|
Raises:
|
|
MRMInvalidPageSizeError: If failed to set page size.
|
|
"""
|
|
ret = self._writer.set_page_size(page_size)
|
|
if ret != ms.MSRStatus.SUCCESS:
|
|
logger.error("Failed to set page size.")
|
|
raise MRMInvalidPageSizeError
|
|
return ret
|
|
|
|
def set_shard_header(self, shard_header):
|
|
"""
|
|
Set header which contains schema and index before write raw data.
|
|
|
|
Args:
|
|
shard_header (ShardHeader): Object of ShardHeader.
|
|
|
|
Returns:
|
|
MSRStatus, SUCCESS or FAILED.
|
|
|
|
Raises:
|
|
MRMSetHeaderError: If failed to set header.
|
|
"""
|
|
self._header = shard_header
|
|
ret = self._writer.set_shard_header(shard_header.header)
|
|
if ret != ms.MSRStatus.SUCCESS:
|
|
logger.error("Failed to set header.")
|
|
raise MRMSetHeaderError
|
|
return ret
|
|
|
|
def get_shard_header(self):
|
|
return self._header
|
|
|
|
def write_raw_data(self, data, validate=True, parallel_writer=False):
|
|
"""
|
|
Write raw data of cv dataset.
|
|
|
|
Filter data according to schema and separate blob data from raw data.
|
|
Support data verify according to schema and remove the invalid data.
|
|
|
|
Args:
|
|
data (list[dict]): List of raw data.
|
|
validate (bool, optional): verify data according schema if it equals to True.
|
|
parallel_writer (bool, optional): Load data parallel if it equals to True.
|
|
|
|
Returns:
|
|
MSRStatus, SUCCESS or FAILED.
|
|
|
|
Raises:
|
|
MRMWriteCVError: If failed to write cv type dataset.
|
|
"""
|
|
blob_data = []
|
|
raw_data = []
|
|
# slice data to blob data and raw data
|
|
for item in data:
|
|
row_blob = self._merge_blob({field: item[field] for field in self._header.blob_fields})
|
|
if row_blob:
|
|
blob_data.append(list(row_blob))
|
|
# filter raw data according to schema
|
|
row_raw = {field: self._convert_np_types(item[field])
|
|
for field in self._header.schema.keys() - self._header.blob_fields if field in item}
|
|
if row_raw:
|
|
raw_data.append(row_raw)
|
|
raw_data = {0: raw_data} if raw_data else {}
|
|
ret = self._writer.write_raw_data(raw_data, blob_data, validate, parallel_writer)
|
|
if ret != ms.MSRStatus.SUCCESS:
|
|
logger.error("Failed to write dataset.")
|
|
raise MRMWriteDatasetError
|
|
return ret
|
|
|
|
def _convert_np_types(self, val):
|
|
"""convert numpy type to python primitive type"""
|
|
if isinstance(val, (np.int32, np.int64, np.float32, np.float64)):
|
|
return val.item()
|
|
return val
|
|
|
|
def _merge_blob(self, blob_data):
|
|
"""
|
|
Merge multiple blob data whose type is bytes or ndarray
|
|
|
|
Args:
|
|
blob_data (dict): Dict of blob data
|
|
|
|
Returns:
|
|
bytes, merged blob data
|
|
"""
|
|
if len(blob_data) == 1:
|
|
values = [v for v in blob_data.values()]
|
|
return bytes(values[0])
|
|
# convert int to bytes
|
|
def int_to_bytes(x: int) -> bytes:
|
|
return x.to_bytes(8, 'big')
|
|
merged = bytes()
|
|
for field, v in blob_data.items():
|
|
# convert ndarray to bytes
|
|
if isinstance(v, np.ndarray):
|
|
v = v.astype(self._header.schema[field]["type"]).tobytes()
|
|
merged += int_to_bytes(len(v))
|
|
merged += v
|
|
return merged
|
|
|
|
def commit(self):
|
|
"""
|
|
Flush data to disk.
|
|
|
|
Returns:
|
|
Class MSRStatus, SUCCESS or FAILED.
|
|
|
|
Raises:
|
|
MRMCommitError: If failed to flush data to disk.
|
|
"""
|
|
ret = self._writer.commit()
|
|
if ret != ms.MSRStatus.SUCCESS:
|
|
logger.error("Failed to commit.")
|
|
raise MRMCommitError
|
|
return ret
|
|
|
|
@property
|
|
def is_open(self):
|
|
"""getter function"""
|
|
return self._is_open
|