add compress in mindrecord

pull/1317/head
jonwe 5 years ago committed by liyong
parent 2e3d55ed87
commit bb51bb88d7

@ -23,6 +23,7 @@
#include <queue>
#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <vector>
@ -31,6 +32,7 @@
#include "dataset/engine/datasetops/source/io_block.h"
#include "dataset/util/queue.h"
#include "dataset/util/status.h"
#include "mindrecord/include/shard_column.h"
#include "mindrecord/include/shard_error.h"
#include "mindrecord/include/shard_reader.h"
#include "mindrecord/include/common/shard_utils.h"
@ -193,8 +195,6 @@ class MindRecordOp : public ParallelOp {
Status Init();
Status SetColumnsBlob();
// Base-class override for NodePass visitor acceptor.
// @param p - Pointer to the NodePass to be accepted.
// @param modified - Whether this node visit modified the pipeline.
@ -205,56 +205,11 @@ class MindRecordOp : public ParallelOp {
Status GetBufferFromReader(std::unique_ptr<DataBuffer> *fetched_buffer, int64_t buffer_id, int32_t worker_id);
// Parses a single cell and puts the data into a tensor
// @param tensor - the tensor to put the parsed data in
// @param i_col - the id of column to parse
// @param tensor_row - the tensor row to put the parsed data in
// @param columns_blob - the blob data received from the reader
// @param columns_json - the data for fields received from the reader
template <typename T>
Status LoadFeature(std::shared_ptr<Tensor> *tensor, int32_t i_col, const std::vector<uint8_t> &columns_blob,
const mindrecord::json &columns_json) const;
Status SwitchLoadFeature(const DataType &type, std::shared_ptr<Tensor> *tensor, int32_t i_col,
const std::vector<uint8_t> &columns_blob, const mindrecord::json &columns_json) const;
static Status LoadBlob(TensorShape *new_shape, const unsigned char **data, const std::vector<uint8_t> &columns_blob,
const int32_t pos, const ColDescriptor &column);
// Get shape and data (scalar or array) for tensor to be created (for floats and doubles)
// @param new_shape - the shape of tensor to be created.
// @param array_data - the array where data should be put in
// @param column_name - name of current column to be processed
// @param columns_json - the data for fields received from the reader
// @param column - description of current column from schema
// @param use_double - boolean to choose between float32 and float64
template <typename T>
static Status LoadFloat(TensorShape *new_shape, std::unique_ptr<T[]> *array_data, const std::string &column_name,
const mindrecord::json &columns_json, const ColDescriptor &column, bool use_double);
// Get shape and data (scalar or array) for tensor to be created (for integers)
// @param new_shape - the shape of tensor to be created.
// @param array_data - the array where data should be put in
// @param column_name - name of current column to be processed
// @param columns_json - the data for fields received from the reader
// @param column - description of current column from schema
template <typename T>
static Status LoadInt(TensorShape *new_shape, std::unique_ptr<T[]> *array_data, const std::string &column_name,
const mindrecord::json &columns_json, const ColDescriptor &column);
static Status LoadByte(TensorShape *new_shape, std::string *string_data, const std::string &column_name,
const mindrecord::json &columns_json);
// Get a single float value from the given json
// @param value - the float to put the value in
// @param arrayData - the given json containing the float
// @param use_double - boolean to choose between float32 and float64
template <typename T>
static Status GetFloat(T *value, const mindrecord::json &data, bool use_double);
// Get a single integer value from the given json
// @param value - the integer to put the value in
// @param arrayData - the given json containing the integer
template <typename T>
static Status GetInt(T *value, const mindrecord::json &data);
Status LoadTensorRow(TensorRow *tensor_row, const std::vector<uint8_t> &columns_blob,
const mindrecord::json &columns_json);
Status FetchBlockBuffer(const int32_t &buffer_id);

@ -91,8 +91,8 @@ void BindShardReader(const py::module *m) {
.def("launch", &ShardReader::Launch)
.def("get_header", &ShardReader::GetShardHeader)
.def("get_blob_fields", &ShardReader::GetBlobFields)
.def("get_next",
(std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>>(ShardReader::*)()) & ShardReader::GetNextPy)
.def("get_next", (std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>>(ShardReader::*)()) &
ShardReader::GetNextPy)
.def("finish", &ShardReader::Finish)
.def("close", &ShardReader::Close);
}

@ -65,6 +65,9 @@ const int kUnsignedInt4 = 4;
enum LabelCategory { kSchemaLabel, kStatisticsLabel, kIndexLabel };
const char kVersion[] = "3.0";
const std::vector<std::string> kSupportedVersion = {"2.0", kVersion};
enum ShardType {
kNLP = 0,
kCV = 1,

@ -0,0 +1,163 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDRECORD_INCLUDE_SHARD_COLUMN_H_
#define MINDRECORD_INCLUDE_SHARD_COLUMN_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "mindrecord/include/shard_header.h"
namespace mindspore {
namespace mindrecord {
const uint64_t kUnsignedOne = 1;
const uint64_t kBitsOfByte = 8;
const uint64_t kDataTypeBits = 2;
const uint64_t kNumDataOfByte = 4;
const uint64_t kBytesOfColumnLen = 4;
const uint64_t kDataTypeBitMask = 3;
const uint64_t kDataTypes = 6;
enum IntegerType { kInt8Type = 0, kInt16Type, kInt32Type, kInt64Type };
enum ColumnCategory { ColumnInRaw, ColumnInBlob, ColumnNotFound };
enum ColumnDataType {
ColumnBytes = 0,
ColumnString = 1,
ColumnInt32 = 2,
ColumnInt64 = 3,
ColumnFloat32 = 4,
ColumnFloat64 = 5,
ColumnNoDataType = 6
};
// mapping as {"bytes", "string", "int32", "int64", "float32", "float64"};
const uint32_t ColumnDataTypeSize[kDataTypes] = {1, 1, 4, 8, 4, 8};
const std::vector<std::string> ColumnDataTypeNameNormalized = {"uint8", "uint8", "int32",
"int64", "float32", "float64"};
const std::unordered_map<std::string, ColumnDataType> ColumnDataTypeMap = {
{"bytes", ColumnBytes}, {"string", ColumnString}, {"int32", ColumnInt32},
{"int64", ColumnInt64}, {"float32", ColumnFloat32}, {"float64", ColumnFloat64}};
class ShardColumn {
public:
explicit ShardColumn(const std::shared_ptr<ShardHeader> &shard_header, bool compress_integer = true);
~ShardColumn() = default;
/// \brief get column value by column name
MSRStatus GetColumnValueByName(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
const json &columns_json, const unsigned char **data,
std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes,
ColumnDataType *column_data_type, uint64_t *column_data_type_size,
std::vector<int64_t> *column_shape);
/// \brief compress blob
std::vector<uint8_t> CompressBlob(const std::vector<uint8_t> &blob);
/// \brief check if blob compressed
bool CheckCompressBlob() const { return has_compress_blob_; }
uint64_t GetNumBlobColumn() const { return num_blob_column_; }
std::vector<std::string> GetColumnName() { return column_name_; }
std::vector<ColumnDataType> GeColumnDataType() { return column_data_type_; }
std::vector<std::vector<int64_t>> GetColumnShape() { return column_shape_; }
/// \brief get column value from blob
MSRStatus GetColumnFromBlob(const std::string &column_name, const std::vector<uint8_t> &columns_blob,
const unsigned char **data, std::unique_ptr<unsigned char[]> *data_ptr,
uint64_t *n_bytes);
private:
/// \brief get column value from json
MSRStatus GetColumnFromJson(const std::string &column_name, const json &columns_json,
std::unique_ptr<unsigned char[]> *data_ptr, uint64_t *n_bytes);
/// \brief get float value from json
template <typename T>
MSRStatus GetFloat(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value, bool use_double);
/// \brief get integer value from json
template <typename T>
MSRStatus GetInt(std::unique_ptr<unsigned char[]> *data_ptr, const json &json_column_value);
/// \brief get column offset address and size from blob
MSRStatus GetColumnAddressInBlock(const uint64_t &column_id, const std::vector<uint8_t> &columns_blob,
uint64_t *num_bytes, uint64_t *shift_idx);
/// \brief check if column name is available
ColumnCategory CheckColumnName(const std::string &column_name);
/// \brief compress integer column
static vector<uint8_t> CompressInt(const vector<uint8_t> &src_bytes, const IntegerType &int_type);
/// \brief uncompress integer array column
template <typename T>
static MSRStatus UncompressInt(const uint64_t &column_id, std::unique_ptr<unsigned char[]> *data_ptr,
const std::vector<uint8_t> &columns_blob, uint64_t *num_bytes, uint64_t shift_idx);
/// \brief convert big-endian bytes to unsigned int
/// \param bytes_array bytes array
/// \param pos shift address in bytes array
/// \param i_type integer type
/// \return unsigned int
static uint64_t BytesBigToUInt64(const std::vector<uint8_t> &bytes_array, const uint64_t &pos,
const IntegerType &i_type);
/// \brief convert unsigned int to big-endian bytes
/// \param value integer value
/// \param i_type integer type
/// \return bytes
static std::vector<uint8_t> UIntToBytesBig(uint64_t value, const IntegerType &i_type);
/// \brief convert unsigned int to little-endian bytes
/// \param value integer value
/// \param i_type integer type
/// \return bytes
static std::vector<uint8_t> UIntToBytesLittle(uint64_t value, const IntegerType &i_type);
/// \brief convert unsigned int to little-endian bytes
/// \param bytes_array bytes array
/// \param pos shift address in bytes array
/// \param src_i_type source integer typ0e
/// \param dst_i_type (output), destination integer type
/// \return integer
static int64_t BytesLittleToMinIntType(const std::vector<uint8_t> &bytes_array, const uint64_t &pos,
const IntegerType &src_i_type, IntegerType *dst_i_type = nullptr);
private:
std::vector<std::string> column_name_; // column name list
std::vector<ColumnDataType> column_data_type_; // column data type list
std::vector<std::vector<int64_t>> column_shape_; // column shape list
std::unordered_map<string, uint64_t> column_name_id_; // column name id map
std::vector<std::string> blob_column_; // blob column list
std::unordered_map<std::string, uint64_t> blob_column_id_; // blob column name id map
bool has_compress_blob_; // if has compress blob
uint64_t num_blob_column_; // number of blob columns
};
} // namespace mindrecord
} // namespace mindspore
#endif // MINDRECORD_INCLUDE_SHARD_COLUMN_H_

@ -118,8 +118,6 @@ class ShardHeader {
void SetPageSize(const uint64_t &page_size) { page_size_ = page_size; }
const string GetVersion() { return version_; }
std::vector<std::string> SerializeHeader();
MSRStatus PagesToFile(const std::string dump_file_name);
@ -175,7 +173,6 @@ class ShardHeader {
uint32_t shard_count_;
uint64_t header_size_;
uint64_t page_size_;
string version_ = "2.0";
std::shared_ptr<Index> index_;
std::vector<std::string> shard_addresses_;

@ -43,6 +43,7 @@
#include <vector>
#include "mindrecord/include/common/shard_utils.h"
#include "mindrecord/include/shard_category.h"
#include "mindrecord/include/shard_column.h"
#include "mindrecord/include/shard_error.h"
#include "mindrecord/include/shard_index_generator.h"
#include "mindrecord/include/shard_operator.h"
@ -111,6 +112,10 @@ class ShardReader {
/// \return the metadata
std::shared_ptr<ShardHeader> GetShardHeader() const;
/// \brief aim to get columns context
/// \return the columns
std::shared_ptr<ShardColumn> get_shard_column() const;
/// \brief get the number of shards
/// \return # of shards
int GetShardCount() const;
@ -185,7 +190,7 @@ class ShardReader {
/// \brief return a batch, given that one is ready, python API
/// \return a batch of images and image data
std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>> GetNextPy();
std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>> GetNextPy();
/// \brief get blob filed list
/// \return blob field list
@ -295,16 +300,18 @@ class ShardReader {
/// \brief get number of classes
int64_t GetNumClasses(const std::string &category_field);
/// \brief get meta of header
std::pair<MSRStatus, std::vector<std::string>> GetMeta(const std::string &file_path, json &meta_data);
/// \brief get exactly blob fields data by indices
std::vector<uint8_t> ExtractBlobFieldBySelectColumns(std::vector<uint8_t> &blob_fields_bytes,
std::vector<uint32_t> &ordered_selected_columns_index);
/// \brief extract uncompressed data based on column list
std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> UnCompressBlob(const std::vector<uint8_t> &raw_blob_data);
protected:
uint64_t header_size_; // header size
uint64_t page_size_; // page size
int shard_count_; // number of shards
std::shared_ptr<ShardHeader> shard_header_; // shard header
std::shared_ptr<ShardColumn> shard_column_; // shard column
std::vector<sqlite3 *> database_paths_; // sqlite handle list
std::vector<string> file_paths_; // file paths

@ -36,6 +36,7 @@
#include <utility>
#include <vector>
#include "mindrecord/include/common/shard_utils.h"
#include "mindrecord/include/shard_column.h"
#include "mindrecord/include/shard_error.h"
#include "mindrecord/include/shard_header.h"
#include "mindrecord/include/shard_index.h"
@ -242,7 +243,8 @@ class ShardWriter {
std::vector<std::string> file_paths_; // file paths
std::vector<std::shared_ptr<std::fstream>> file_streams_; // file handles
std::shared_ptr<ShardHeader> shard_header_; // shard headers
std::shared_ptr<ShardHeader> shard_header_; // shard header
std::shared_ptr<ShardColumn> shard_column_; // shard columns
std::map<uint64_t, std::map<int, std::string>> err_mg_; // used for storing error raw_data info

@ -133,6 +133,12 @@ MSRStatus ShardReader::Init(const std::vector<std::string> &file_paths, bool loa
shard_header_ = std::make_shared<ShardHeader>(sh);
header_size_ = shard_header_->GetHeaderSize();
page_size_ = shard_header_->GetPageSize();
// version < 3.0
if (first_meta_data["version"] < kVersion) {
shard_column_ = std::make_shared<ShardColumn>(shard_header_, false);
} else {
shard_column_ = std::make_shared<ShardColumn>(shard_header_, true);
}
num_rows_ = 0;
auto row_group_summary = ReadRowGroupSummary();
for (const auto &rg : row_group_summary) {
@ -226,6 +232,8 @@ void ShardReader::Close() {
std::shared_ptr<ShardHeader> ShardReader::GetShardHeader() const { return shard_header_; }
std::shared_ptr<ShardColumn> ShardReader::get_shard_column() const { return shard_column_; }
int ShardReader::GetShardCount() const { return shard_header_->GetShardCount(); }
int ShardReader::GetNumRows() const { return num_rows_; }
@ -1059,36 +1067,6 @@ MSRStatus ShardReader::CreateTasks(const std::vector<std::tuple<int, int, int, u
return SUCCESS;
}
std::vector<uint8_t> ShardReader::ExtractBlobFieldBySelectColumns(
std::vector<uint8_t> &blob_fields_bytes, std::vector<uint32_t> &ordered_selected_columns_index) {
std::vector<uint8_t> exactly_blob_fields_bytes;
auto uint64_from_bytes = [&](int64_t pos) {
uint64_t result = 0;
for (uint64_t n = 0; n < kInt64Len; n++) {
result = (result << 8) + blob_fields_bytes[pos + n];
}
return result;
};
// get the exactly blob fields
uint32_t current_index = 0;
uint64_t current_offset = 0;
uint64_t data_len = uint64_from_bytes(current_offset);
while (current_offset < blob_fields_bytes.size()) {
if (std::any_of(ordered_selected_columns_index.begin(), ordered_selected_columns_index.end(),
[&current_index](uint32_t &index) { return index == current_index; })) {
exactly_blob_fields_bytes.insert(exactly_blob_fields_bytes.end(), blob_fields_bytes.begin() + current_offset,
blob_fields_bytes.begin() + current_offset + kInt64Len + data_len);
}
current_index++;
current_offset += kInt64Len + data_len;
data_len = uint64_from_bytes(current_offset);
}
return exactly_blob_fields_bytes;
}
TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_id) {
// All tasks are done
if (task_id >= static_cast<int>(tasks_.Size())) {
@ -1126,40 +1104,10 @@ TASK_RETURN_CONTENT ShardReader::ConsumerOneTask(int task_id, uint32_t consumer_
return std::make_pair(FAILED, std::vector<std::tuple<std::vector<uint8_t>, json>>());
}
// extract the exactly blob bytes by selected columns
std::vector<uint8_t> images_with_exact_columns;
if (selected_columns_.size() == 0) {
images_with_exact_columns = images;
} else {
auto blob_fields = GetBlobFields();
std::vector<uint32_t> ordered_selected_columns_index;
uint32_t index = 0;
for (auto &blob_field : blob_fields.second) {
for (auto &field : selected_columns_) {
if (field.compare(blob_field) == 0) {
ordered_selected_columns_index.push_back(index);
break;
}
}
index++;
}
if (ordered_selected_columns_index.size() != 0) {
// extract the images
if (blob_fields.second.size() == 1) {
if (ordered_selected_columns_index.size() == 1) {
images_with_exact_columns = images;
}
} else {
images_with_exact_columns = ExtractBlobFieldBySelectColumns(images, ordered_selected_columns_index);
}
}
}
// Deliver batch data to output map
std::vector<std::tuple<std::vector<uint8_t>, json>> batch;
batch.emplace_back(std::move(images_with_exact_columns), std::move(std::get<2>(task)));
batch.emplace_back(std::move(images), std::move(std::get<2>(task)));
return std::make_pair(SUCCESS, std::move(batch));
}
@ -1369,16 +1317,41 @@ std::vector<std::tuple<std::vector<uint8_t>, json>> ShardReader::GetNextById(con
return std::move(ret.second);
}
std::vector<std::tuple<std::vector<uint8_t>, pybind11::object>> ShardReader::GetNextPy() {
std::pair<MSRStatus, std::vector<std::vector<uint8_t>>> ShardReader::UnCompressBlob(
const std::vector<uint8_t> &raw_blob_data) {
auto loaded_columns = selected_columns_.size() == 0 ? shard_column_->GetColumnName() : selected_columns_;
auto blob_fields = GetBlobFields().second;
std::vector<std::vector<uint8_t>> blob_data;
for (uint32_t i_col = 0; i_col < loaded_columns.size(); ++i_col) {
if (std::find(blob_fields.begin(), blob_fields.end(), loaded_columns[i_col]) == blob_fields.end()) continue;
const unsigned char *data = nullptr;
std::unique_ptr<unsigned char[]> data_ptr;
uint64_t n_bytes = 0;
auto ret = shard_column_->GetColumnFromBlob(loaded_columns[i_col], raw_blob_data, &data, &data_ptr, &n_bytes);
if (ret != SUCCESS) {
MS_LOG(ERROR) << "Error when get data from blob, column name is " << loaded_columns[i_col] << ".";
return {FAILED, std::vector<std::vector<uint8_t>>(blob_fields.size(), std::vector<uint8_t>())};
}
if (data == nullptr) {
data = reinterpret_cast<const unsigned char *>(data_ptr.get());
}
std::vector<uint8_t> column(data, data + (n_bytes / sizeof(unsigned char)));
blob_data.push_back(column);
}
return {SUCCESS, blob_data};
}
std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>> ShardReader::GetNextPy() {
auto res = GetNext();
vector<std::tuple<std::vector<uint8_t>, pybind11::object>> jsonData;
std::transform(res.begin(), res.end(), std::back_inserter(jsonData),
[](const std::tuple<std::vector<uint8_t>, json> &item) {
vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>> data;
std::transform(res.begin(), res.end(), std::back_inserter(data),
[this](const std::tuple<std::vector<uint8_t>, json> &item) {
auto &j = std::get<1>(item);
pybind11::object obj = nlohmann::detail::FromJsonImpl(j);
return std::make_tuple(std::get<0>(item), std::move(obj));
auto ret = UnCompressBlob(std::get<0>(item));
return std::make_tuple(ret.second, std::move(obj));
});
return jsonData;
return data;
}
void ShardReader::Reset() {

@ -206,6 +206,7 @@ MSRStatus ShardWriter::OpenForAppend(const std::string &path) {
MS_LOG(ERROR) << "Open file failed";
return FAILED;
}
shard_column_ = std::make_shared<ShardColumn>(shard_header_);
return SUCCESS;
}
@ -271,6 +272,7 @@ MSRStatus ShardWriter::SetShardHeader(std::shared_ptr<ShardHeader> header_data)
shard_header_ = header_data;
shard_header_->SetHeaderSize(header_size_);
shard_header_->SetPageSize(page_size_);
shard_column_ = std::make_shared<ShardColumn>(shard_header_);
return SUCCESS;
}
@ -608,6 +610,14 @@ MSRStatus ShardWriter::WriteRawDataPreCheck(std::map<uint64_t, std::vector<json>
MS_LOG(ERROR) << "IO error / there is no free disk to be used";
return FAILED;
}
// compress blob
if (shard_column_->CheckCompressBlob()) {
for (auto &blob : blob_data) {
blob = shard_column_->CompressBlob(blob);
}
}
// Add 4-bytes dummy blob data if no any blob fields
if (blob_data.size() == 0 && raw_data.size() > 0) {
blob_data = std::vector<std::vector<uint8_t>>(raw_data[0].size(), std::vector<uint8_t>(kUnsignedInt4, 0));

File diff suppressed because it is too large Load Diff

@ -201,9 +201,9 @@ void ShardHeader::GetHeadersOneTask(int start, int end, std::vector<json> &heade
json header;
header = ret.second;
header["shard_addresses"] = realAddresses;
if (header["version"] != version_) {
if (std::find(kSupportedVersion.begin(), kSupportedVersion.end(), header["version"]) == kSupportedVersion.end()) {
MS_LOG(ERROR) << "Version wrong, file version is: " << header["version"].dump()
<< ", lib version is: " << version_;
<< ", lib version is: " << kVersion;
thread_status = true;
return;
}
@ -339,7 +339,7 @@ std::vector<std::string> ShardHeader::SerializeHeader() {
s += "\"shard_addresses\":" + address + ",";
s += "\"shard_id\":" + std::to_string(shardId) + ",";
s += "\"statistics\":" + stats + ",";
s += "\"version\":\"" + version_ + "\"";
s += "\"version\":\"" + std::string(kVersion) + "\"";
s += "}";
header.emplace_back(s);
}

@ -97,16 +97,13 @@ def populate_data(raw, blob, columns, blob_fields, schema):
if not blob_fields:
return raw
# Get the order preserving sequence of columns in blob
ordered_columns = []
loaded_columns = []
if columns:
for blob_field in blob_fields:
if blob_field in columns:
ordered_columns.append(blob_field)
for column in columns:
if column in blob_fields:
loaded_columns.append(column)
else:
ordered_columns = blob_fields
blob_bytes = bytes(blob)
loaded_columns = blob_fields
def _render_raw(field, blob_data):
data_type = schema[field]['type']
@ -119,24 +116,6 @@ def populate_data(raw, blob, columns, blob_fields, schema):
else:
raw[field] = blob_data
if len(blob_fields) == 1:
if len(ordered_columns) == 1:
_render_raw(blob_fields[0], blob_bytes)
return raw
return raw
def _int_from_bytes(xbytes: bytes) -> int:
return int.from_bytes(xbytes, 'big')
def _blob_at_position(pos):
start = 0
for _ in range(pos):
n_bytes = _int_from_bytes(blob_bytes[start : start + 8])
start += 8 + n_bytes
n_bytes = _int_from_bytes(blob_bytes[start : start + 8])
start += 8
return blob_bytes[start : start + n_bytes]
for i, blob_field in enumerate(ordered_columns):
_render_raw(blob_field, _blob_at_position(i))
for i, blob_field in enumerate(loaded_columns):
_render_raw(blob_field, bytes(blob[i]))
return raw

File diff suppressed because it is too large Load Diff

@ -25,8 +25,24 @@ from mindspore.mindrecord import SUCCESS
CIFAR100_DIR = "../data/mindrecord/testCifar100Data"
MINDRECORD_FILE = "./cifar100.mindrecord"
def test_cifar100_to_mindrecord_without_index_fields():
@pytest.fixture
def fixture_file():
"""add/remove file"""
def remove_file(x):
if os.path.exists("{}".format(x)):
os.remove("{}".format(x))
if os.path.exists("{}.db".format(x)):
os.remove("{}.db".format(x))
if os.path.exists("{}_test".format(x)):
os.remove("{}_test".format(x))
if os.path.exists("{}_test.db".format(x)):
os.remove("{}_test.db".format(x))
remove_file(MINDRECORD_FILE)
yield "yield_fixture_data"
remove_file(MINDRECORD_FILE)
def test_cifar100_to_mindrecord_without_index_fields(fixture_file):
"""test transform cifar100 dataset to mindrecord without index fields."""
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, MINDRECORD_FILE)
ret = cifar100_transformer.transform()
@ -34,25 +50,14 @@ def test_cifar100_to_mindrecord_without_index_fields():
assert os.path.exists(MINDRECORD_FILE)
assert os.path.exists(MINDRECORD_FILE + "_test")
read()
os.remove("{}".format(MINDRECORD_FILE))
os.remove("{}.db".format(MINDRECORD_FILE))
os.remove("{}".format(MINDRECORD_FILE + "_test"))
os.remove("{}.db".format(MINDRECORD_FILE + "_test"))
def test_cifar100_to_mindrecord():
def test_cifar100_to_mindrecord(fixture_file):
"""test transform cifar100 dataset to mindrecord."""
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, MINDRECORD_FILE)
cifar100_transformer.transform(['fine_label', 'coarse_label'])
assert os.path.exists(MINDRECORD_FILE)
assert os.path.exists(MINDRECORD_FILE + "_test")
read()
os.remove("{}".format(MINDRECORD_FILE))
os.remove("{}.db".format(MINDRECORD_FILE))
os.remove("{}".format(MINDRECORD_FILE + "_test"))
os.remove("{}.db".format(MINDRECORD_FILE + "_test"))
def read():
@ -77,8 +82,7 @@ def read():
assert count == 4
reader.close()
def test_cifar100_to_mindrecord_illegal_file_name():
def test_cifar100_to_mindrecord_illegal_file_name(fixture_file):
"""
test transform cifar100 dataset to mindrecord
when file name contains illegal character.
@ -88,8 +92,7 @@ def test_cifar100_to_mindrecord_illegal_file_name():
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, filename)
cifar100_transformer.transform()
def test_cifar100_to_mindrecord_filename_start_with_space():
def test_cifar100_to_mindrecord_filename_start_with_space(fixture_file):
"""
test transform cifar10 dataset to mindrecord
when file name starts with space.
@ -100,8 +103,7 @@ def test_cifar100_to_mindrecord_filename_start_with_space():
cifar100_transformer = Cifar100ToMR(CIFAR100_DIR, filename)
cifar100_transformer.transform()
def test_cifar100_to_mindrecord_filename_contain_space():
def test_cifar100_to_mindrecord_filename_contain_space(fixture_file):
"""
test transform cifar10 dataset to mindrecord
when file name contains space.
@ -111,14 +113,8 @@ def test_cifar100_to_mindrecord_filename_contain_space():
cifar100_transformer.transform()
assert os.path.exists(filename)
assert os.path.exists(filename + "_test")
os.remove("{}".format(filename))
os.remove("{}.db".format(filename))
os.remove("{}".format(filename + "_test"))
os.remove("{}.db".format(filename + "_test"))
def test_cifar100_to_mindrecord_directory():
def test_cifar100_to_mindrecord_directory(fixture_file):
"""
test transform cifar10 dataset to mindrecord
when destination path is directory.
@ -129,8 +125,7 @@ def test_cifar100_to_mindrecord_directory():
CIFAR100_DIR)
cifar100_transformer.transform()
def test_cifar100_to_mindrecord_filename_equals_cifar100():
def test_cifar100_to_mindrecord_filename_equals_cifar100(fixture_file):
"""
test transform cifar10 dataset to mindrecord
when destination path equals source path.

@ -24,36 +24,60 @@ from mindspore.mindrecord import MRMOpenError, SUCCESS
CIFAR10_DIR = "../data/mindrecord/testCifar10Data"
MINDRECORD_FILE = "./cifar10.mindrecord"
def test_cifar10_to_mindrecord_without_index_fields():
@pytest.fixture
def fixture_file():
"""add/remove file"""
def remove_file(x):
if os.path.exists("{}".format(x)):
os.remove("{}".format(x))
if os.path.exists("{}.db".format(x)):
os.remove("{}.db".format(x))
if os.path.exists("{}_test".format(x)):
os.remove("{}_test".format(x))
if os.path.exists("{}_test.db".format(x)):
os.remove("{}_test.db".format(x))
remove_file(MINDRECORD_FILE)
yield "yield_fixture_data"
remove_file(MINDRECORD_FILE)
@pytest.fixture
def fixture_space_file():
"""add/remove file"""
def remove_file(x):
if os.path.exists("{}".format(x)):
os.remove("{}".format(x))
if os.path.exists("{}.db".format(x)):
os.remove("{}.db".format(x))
if os.path.exists("{}_test".format(x)):
os.remove("{}_test".format(x))
if os.path.exists("{}_test.db".format(x)):
os.remove("{}_test.db".format(x))
x = "./yes ok"
remove_file(x)
yield "yield_fixture_data"
remove_file(x)
def test_cifar10_to_mindrecord_without_index_fields(fixture_file):
"""test transform cifar10 dataset to mindrecord without index fields."""
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE)
cifar10_transformer.transform()
assert os.path.exists(MINDRECORD_FILE)
assert os.path.exists(MINDRECORD_FILE + "_test")
read()
os.remove("{}".format(MINDRECORD_FILE))
os.remove("{}.db".format(MINDRECORD_FILE))
os.remove("{}".format(MINDRECORD_FILE + "_test"))
os.remove("{}.db".format(MINDRECORD_FILE + "_test"))
def test_cifar10_to_mindrecord():
def test_cifar10_to_mindrecord(fixture_file):
"""test transform cifar10 dataset to mindrecord."""
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE)
cifar10_transformer.transform(['label'])
assert os.path.exists(MINDRECORD_FILE)
assert os.path.exists(MINDRECORD_FILE + "_test")
read()
os.remove("{}".format(MINDRECORD_FILE))
os.remove("{}.db".format(MINDRECORD_FILE))
os.remove("{}".format(MINDRECORD_FILE + "_test"))
os.remove("{}.db".format(MINDRECORD_FILE + "_test"))
def test_cifar10_to_mindrecord_with_return():
def test_cifar10_to_mindrecord_with_return(fixture_file):
"""test transform cifar10 dataset to mindrecord."""
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, MINDRECORD_FILE)
ret = cifar10_transformer.transform(['label'])
@ -61,11 +85,6 @@ def test_cifar10_to_mindrecord_with_return():
assert os.path.exists(MINDRECORD_FILE)
assert os.path.exists(MINDRECORD_FILE + "_test")
read()
os.remove("{}".format(MINDRECORD_FILE))
os.remove("{}.db".format(MINDRECORD_FILE))
os.remove("{}".format(MINDRECORD_FILE + "_test"))
os.remove("{}.db".format(MINDRECORD_FILE + "_test"))
def read():
@ -90,8 +109,7 @@ def read():
assert count == 4
reader.close()
def test_cifar10_to_mindrecord_illegal_file_name():
def test_cifar10_to_mindrecord_illegal_file_name(fixture_file):
"""
test transform cifar10 dataset to mindrecord
when file name contains illegal character.
@ -101,8 +119,7 @@ def test_cifar10_to_mindrecord_illegal_file_name():
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, filename)
cifar10_transformer.transform()
def test_cifar10_to_mindrecord_filename_start_with_space():
def test_cifar10_to_mindrecord_filename_start_with_space(fixture_file):
"""
test transform cifar10 dataset to mindrecord
when file name starts with space.
@ -113,8 +130,7 @@ def test_cifar10_to_mindrecord_filename_start_with_space():
cifar10_transformer = Cifar10ToMR(CIFAR10_DIR, filename)
cifar10_transformer.transform()
def test_cifar10_to_mindrecord_filename_contain_space():
def test_cifar10_to_mindrecord_filename_contain_space(fixture_space_file):
"""
test transform cifar10 dataset to mindrecord
when file name contains space.
@ -124,14 +140,8 @@ def test_cifar10_to_mindrecord_filename_contain_space():
cifar10_transformer.transform()
assert os.path.exists(filename)
assert os.path.exists(filename + "_test")
os.remove("{}".format(filename))
os.remove("{}.db".format(filename))
os.remove("{}".format(filename + "_test"))
os.remove("{}.db".format(filename + "_test"))
def test_cifar10_to_mindrecord_directory():
def test_cifar10_to_mindrecord_directory(fixture_file):
"""
test transform cifar10 dataset to mindrecord
when destination path is directory.

@ -25,6 +25,26 @@ IMAGENET_IMAGE_DIR = "../data/mindrecord/testImageNetDataWhole/images"
MINDRECORD_FILE = "../data/mindrecord/testImageNetDataWhole/imagenet.mindrecord"
PARTITION_NUMBER = 4
@pytest.fixture
def fixture_file():
"""add/remove file"""
def remove_one_file(x):
if os.path.exists(x):
os.remove(x)
def remove_file():
x = MINDRECORD_FILE
remove_one_file(x)
x = MINDRECORD_FILE + ".db"
remove_one_file(x)
for i in range(PARTITION_NUMBER):
x = MINDRECORD_FILE + str(i)
remove_one_file(x)
x = MINDRECORD_FILE + str(i) + ".db"
remove_one_file(x)
remove_file()
yield "yield_fixture_data"
remove_file()
def read(filename):
"""test file reade"""
@ -38,8 +58,7 @@ def read(filename):
assert count == 20
reader.close()
def test_imagenet_to_mindrecord():
def test_imagenet_to_mindrecord(fixture_file):
"""test transform imagenet dataset to mindrecord."""
imagenet_transformer = ImageNetToMR(IMAGENET_MAP_FILE, IMAGENET_IMAGE_DIR,
MINDRECORD_FILE, PARTITION_NUMBER)
@ -48,12 +67,8 @@ def test_imagenet_to_mindrecord():
assert os.path.exists(MINDRECORD_FILE + str(i))
assert os.path.exists(MINDRECORD_FILE + str(i) + ".db")
read(MINDRECORD_FILE + "0")
for i in range(PARTITION_NUMBER):
os.remove(MINDRECORD_FILE + str(i))
os.remove(MINDRECORD_FILE + str(i) + ".db")
def test_imagenet_to_mindrecord_default_partition_number():
def test_imagenet_to_mindrecord_default_partition_number(fixture_file):
"""
test transform imagenet dataset to mindrecord
when partition number is default.
@ -64,11 +79,8 @@ def test_imagenet_to_mindrecord_default_partition_number():
assert os.path.exists(MINDRECORD_FILE)
assert os.path.exists(MINDRECORD_FILE + ".db")
read(MINDRECORD_FILE)
os.remove("{}".format(MINDRECORD_FILE))
os.remove("{}.db".format(MINDRECORD_FILE))
def test_imagenet_to_mindrecord_partition_number_0():
def test_imagenet_to_mindrecord_partition_number_0(fixture_file):
"""
test transform imagenet dataset to mindrecord
when partition number is 0.
@ -79,8 +91,7 @@ def test_imagenet_to_mindrecord_partition_number_0():
MINDRECORD_FILE, 0)
imagenet_transformer.transform()
def test_imagenet_to_mindrecord_partition_number_none():
def test_imagenet_to_mindrecord_partition_number_none(fixture_file):
"""
test transform imagenet dataset to mindrecord
when partition number is none.
@ -92,8 +103,7 @@ def test_imagenet_to_mindrecord_partition_number_none():
MINDRECORD_FILE, None)
imagenet_transformer.transform()
def test_imagenet_to_mindrecord_illegal_filename():
def test_imagenet_to_mindrecord_illegal_filename(fixture_file):
"""
test transform imagenet dataset to mindrecord
when file name contains illegal character.

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save