Change mem layout of string tensor

add support for MindRecord and TFRecord
----
optimize tensorshape

optimize tensorshape and FlatIndex

TFRecord and MindRecord support for string tensor

Modify mem layout
Add new constructor
Add method Allocate

Change some GetMutableBuffer usages to AllocateBuffer
pull/1308/head
hesham 5 years ago
parent d9c74e0acd
commit df361d1d26

@ -1,6 +1,10 @@
ms_protobuf_generate(EXAMPLE_SRCS EXAMPLE_HDRS example.proto)
ms_protobuf_generate(FEATURE_SRCS FEATURE_HDRS feature.proto)
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
add_library(core OBJECT
${EXAMPLE_SRCS}
${FEATURE_SRCS}
client.cc
config_manager.cc
cv_tensor.cc
@ -9,4 +13,5 @@ add_library(core OBJECT
tensor.cc
tensor_shape.cc
)
add_dependencies(core mindspore::protobuf)
target_include_directories(core PRIVATE ${pybind11_INCLUDE_DIRS})

@ -25,14 +25,14 @@ namespace dataset {
uint8_t DataType::SizeInBytes() const {
if (type_ < DataType::NUM_OF_TYPES)
return SIZE_IN_BYTES[type_];
return kTypeInfo[type_].sizeInBytes_;
else
return 0;
}
py::dtype DataType::AsNumpyType() const {
if (type_ < DataType::NUM_OF_TYPES)
return py::dtype(PYBIND_TYPES[type_]);
return py::dtype(kTypeInfo[type_].pybindType_);
else
return py::dtype("unknown");
}
@ -40,7 +40,7 @@ py::dtype DataType::AsNumpyType() const {
uint8_t DataType::AsCVType() const {
uint8_t res = kCVInvalidType;
if (type_ < DataType::NUM_OF_TYPES) {
res = CV_TYPES[type_];
res = kTypeInfo[type_].cvType_;
}
if (res == kCVInvalidType) {
@ -108,7 +108,7 @@ DataType::DataType(const std::string &type_str) {
std::string DataType::ToString() const {
if (type_ < DataType::NUM_OF_TYPES)
return TO_STRINGS[type_];
return kTypeInfo[type_].name_;
else
return "unknown";
}
@ -149,7 +149,7 @@ DataType DataType::FromNpArray(const py::array &arr) {
std::string DataType::GetPybindFormat() const {
std::string res;
if (type_ < DataType::NUM_OF_TYPES) {
res = PYBIND_FORMAT_DESCRIPTOR[type_];
res = kTypeInfo[type_].pybindFormatDescriptor_;
}
if (res.empty()) {

@ -51,56 +51,31 @@ class DataType {
NUM_OF_TYPES
};
inline static constexpr uint8_t SIZE_IN_BYTES[] = {0, // DE_UNKNOWN
1, // DE_BOOL
1, // DE_INT8
1, // DE_UINT8
2, // DE_INT16
2, // DE_UINT16
4, // DE_INT32
4, // DE_UINT32
8, // DE_INT64
8, // DE_UINT64
2, // DE_FLOAT16
4, // DE_FLOAT32
8, // DE_FLOAT64
0}; // DE_STRING
inline static const char *TO_STRINGS[] = {"unknown", "bool", "int8", "uint8", "int16", "uint16", "int32",
"uint32", "int64", "uint64", "float16", "float32", "float64", "string"};
inline static const char *PYBIND_TYPES[] = {"object", "bool", "int8", "uint8", "int16", "uint16", "int32",
"uint32", "int64", "uint64", "float16", "float32", "double", "bytes"};
inline static const std::string PYBIND_FORMAT_DESCRIPTOR[] = {"", // DE_UNKNOWN
py::format_descriptor<bool>::format(), // DE_BOOL
py::format_descriptor<int8_t>::format(), // DE_INT8
py::format_descriptor<uint8_t>::format(), // DE_UINT8
py::format_descriptor<int16_t>::format(), // DE_INT16
py::format_descriptor<uint16_t>::format(), // DE_UINT16
py::format_descriptor<int32_t>::format(), // DE_INT32
py::format_descriptor<uint32_t>::format(), // DE_UINT32
py::format_descriptor<int64_t>::format(), // DE_INT64
py::format_descriptor<uint64_t>::format(), // DE_UINT64
"e", // DE_FLOAT16
py::format_descriptor<float>::format(), // DE_FLOAT32
py::format_descriptor<double>::format(), // DE_FLOAT64
"S"}; // DE_STRING
inline static constexpr uint8_t CV_TYPES[] = {kCVInvalidType, // DE_UNKNOWN
CV_8U, // DE_BOOL
CV_8S, // DE_INT8
CV_8U, // DE_UINT8
CV_16S, // DE_INT16
CV_16U, // DE_UINT16
CV_32S, // DE_INT32
kCVInvalidType, // DE_UINT32
kCVInvalidType, // DE_INT64
kCVInvalidType, // DE_UINT64
CV_16F, // DE_FLOAT16
CV_32F, // DE_FLOAT32
CV_64F, // DE_FLOAT64
kCVInvalidType}; // DE_STRING
struct TypeInfo {
const char *name_; // name to be represent the type while printing
const uint8_t sizeInBytes_; // number of bytes needed for this type
const char *pybindType_; // Python matching type, used in get_output_types
const std::string pybindFormatDescriptor_; // pybind format used for numpy types
const uint8_t cvType_; // OpenCv matching type
};
static inline const TypeInfo kTypeInfo[] = {
// name, sizeInBytes, pybindTypem formatDescriptor, openCV
{"unknown", 0, "object", "", kCVInvalidType}, // DE_UNKNOWN
{"bool", 1, "bool", py::format_descriptor<bool>::format(), CV_8U}, // DE_BOOL
{"int8", 1, "int8", py::format_descriptor<int8_t>::format(), CV_8S}, // DE_INT8
{"uint8", 1, "uint8", py::format_descriptor<uint8_t>::format(), CV_8U}, // DE_UINT8
{"int16", 2, "int16", py::format_descriptor<int16_t>::format(), CV_16S}, // DE_INT16
{"uint16", 2, "uint16", py::format_descriptor<uint16_t>::format(), CV_16U}, // DE_UINT16
{"int32", 4, "int32", py::format_descriptor<int32_t>::format(), CV_32S}, // DE_INT32
{"uint32", 4, "uint32", py::format_descriptor<uint32_t>::format(), kCVInvalidType}, // DE_UINT32
{"int64", 8, "int64", py::format_descriptor<int64_t>::format(), kCVInvalidType}, // DE_INT64
{"uint64", 8, "uint64", py::format_descriptor<uint64_t>::format(), kCVInvalidType}, // DE_UINT64
{"float16", 2, "float16", "e", CV_16F}, // DE_FLOAT16
{"float32", 4, "float32", py::format_descriptor<float>::format(), CV_32F}, // DE_FLOAT32
{"float64", 8, "double", py::format_descriptor<double>::format(), CV_64F}, // DE_FLOAT64
{"string", 0, "bytes", "S", kCVInvalidType} // DE_STRING
};
// No arg constructor to create an unknown shape
DataType() : type_(DE_UNKNOWN) {}

File diff suppressed because it is too large Load Diff

@ -35,6 +35,7 @@
#include "dataset/util/allocator.h"
#include "dataset/util/de_error.h"
#include "dataset/util/status.h"
#include "proto/example.pb.h"
namespace py = pybind11;
namespace mindspore {
@ -64,6 +65,8 @@ class Tensor {
// @param data unsigned char*, pointer to the data.
Tensor(const TensorShape &shape, const DataType &type, const unsigned char *data);
Tensor(const TensorShape &shape, const DataType &type, const unsigned char *data, const dsize_t &length);
Tensor(const Tensor &other) = delete;
Tensor &operator=(const Tensor &other) = delete;
@ -72,6 +75,8 @@ class Tensor {
Tensor &operator=(Tensor &&other) noexcept;
Status AllocateBuffer(const dsize_t &length);
// type of offest values to store strings information
using offset_t = uint32_t;
// const of the size of the offset variable
@ -84,15 +89,24 @@ class Tensor {
// Construct a tensor from a list of strings. Reshape the tensor with `shape` if given, otherwise assume the shape is
// the size of the vector `strings`.
// The memory layout of a Tensor of strings consists of the Offset_array followed by the strings.
// OFFSET1, OFFSET2, ... String1, String2, ...
// The value of each offset is the end index of the corresponding string
// Thr offset array will store one extra value to find the length of the last string.
// OFFSET1, OFFSET2, ..., OFFSETn+1, STRING1, STRING2, ..., STRINGn
// The value of each offset is the start index of the corresponding string
// Offsets is of type offest_t
// strings will ne null-terminated
// example: Tensor(['abc', 'de'], shape={2}, type=DE_STRING)
// 3 6 a b c \0 d e \0
// |----------------------------------------------------------------|
// | OFFSET ARRAY | STRINGS |
// | bytes 0-3 | bytes 3-6 | bytes 7-10 | bytes 11-14 | bytes 15-17 |
// | 11 | 15 | 18 | abc\0 | de\0 |
// |----------------------------------------------------------------|
explicit Tensor(const std::vector<std::string> &strings,
const TensorShape &shape = TensorShape::CreateUnknownRankShape());
// Same as Tensor(vector<string>) but the input is protobuf bytelist
explicit Tensor(const dataengine::BytesList &bytes_list,
const TensorShape &shape = TensorShape::CreateUnknownRankShape());
// A static factory method to create the given flavour of derived Tensor
// Returns the base class reference for the Tensor.
// @param ptr output argument to hold the created Tensor of given tensor_impl
@ -121,6 +135,9 @@ class Tensor {
static Status CreateTensor(std::shared_ptr<Tensor> *ptr, const std::vector<std::string> &strings,
const TensorShape &shape = TensorShape::CreateUnknownRankShape());
static Status CreateTensor(std::shared_ptr<Tensor> *ptr, const dataengine::BytesList &bytes_list,
const TensorShape &shape);
// Copy raw data of a array based on shape and strides to the destination pointer
// @param dst Pointer to the destination array where the content is to be copied
// @param src Pointer to the source of strided array to be copied
@ -166,7 +183,7 @@ class Tensor {
// @param value of type `T`
template <typename T>
Status SetItemAt(const std::vector<dsize_t> &index, const T &value) {
static_cast<void>(GetMutableBuffer());
RETURN_IF_NOT_OK(AllocateBuffer(SizeInBytes()));
T *ptr = nullptr;
RETURN_IF_NOT_OK(GetItemPtr<T>(&ptr, index));
*ptr = value;
@ -203,7 +220,7 @@ class Tensor {
template <typename T>
Status Fill(const T &value) {
CHECK_FAIL_RETURN_UNEXPECTED(type_ != DataType::DE_STRING, "Cannot use fill on tensor of strings.");
static_cast<void>(GetMutableBuffer());
RETURN_IF_NOT_OK(AllocateBuffer(SizeInBytes()));
int64_t cellSize = type_.SizeInBytes();
if ((data_ != nullptr) && type_.IsCompatible<T>()) {
for (dsize_t i = 0; i < Size(); i++) {
@ -418,32 +435,28 @@ class Tensor {
using pointer = std::string_view *;
using reference = std::string_view &;
explicit TensorIterator(uchar *offset = nullptr, const uchar *buf = nullptr, dsize_t index = 0) {
offset_ = reinterpret_cast<offset_t *>(offset);
buf_ = reinterpret_cast<const char *>(buf);
explicit TensorIterator(uchar *data = nullptr, dsize_t index = 0) {
data_ = reinterpret_cast<const char *>(data);
index_ = index;
}
TensorIterator(const TensorIterator<std::string_view, DUMMY> &raw_iterator) {
offset_ = raw_iterator.offset_;
buf_ = raw_iterator.buf_;
data_ = raw_iterator.data_;
index_ = raw_iterator.index_;
}
~TensorIterator() = default;
bool operator==(const TensorIterator<std::string_view> &rhs) {
return buf_ == rhs.buf_ && offset_ == rhs.offset_ && index_ == rhs.index_;
}
bool operator==(const TensorIterator<std::string_view> &rhs) { return data_ == rhs.data_ && index_ == rhs.index_; }
bool operator!=(const TensorIterator<std::string_view> &rhs) { return !(*this == rhs); }
operator bool() const { return offset_ != nullptr; }
operator bool() const { return data_ != nullptr; }
std::string_view operator*() const {
offset_t start = 0;
if (index_ != 0) start = offset_[index_ - 1] + 1;
return std::string_view{buf_ + start};
auto offset_ = reinterpret_cast<const offset_t *>(data_);
offset_t start = offset_[index_];
return std::string_view{data_ + start};
}
TensorIterator<std::string_view> &operator+=(const dsize_t &inc) {
@ -496,8 +509,7 @@ class Tensor {
protected:
dsize_t index_;
offset_t *offset_;
const char *buf_;
const char *data_;
};
// Return a TensorIterator that points to the start of the Tensor.
@ -518,11 +530,6 @@ class Tensor {
}
protected:
// Returns the location of the item assuming row major memory layout.
// @param index
// @return
Status ToFlatIndex(const std::vector<dsize_t> &index, dsize_t *flat_index) const;
// A function that prints Tensor recursively, first called by print
// @param out
// @param cur_dim
@ -559,7 +566,7 @@ class Tensor {
// Skip the offsets and returns the start of the buffer where the real strings is stored. Caller needs to check if the
// tensor's type is a string, otherwise undefined address would be returned.
// @return address of the first string of the tensor.
uchar *GetStringsBuffer() const { return data_ + kOffsetSize * shape_.NumOfElements(); }
uchar *GetStringsBuffer() const { return data_ + kOffsetSize * shape_.NumOfElements() + kOffsetSize; }
// all access to shape_ should be via shape
TensorShape shape_;
@ -573,14 +580,8 @@ class Tensor {
unsigned char *data_end_ = nullptr;
};
template <>
inline Tensor::TensorIterator<std::string_view> Tensor::begin<std::string_view>() {
uchar *buf = GetStringsBuffer();
return TensorIterator<std::string_view>(data_, buf);
}
template <>
inline Tensor::TensorIterator<std::string_view> Tensor::end<std::string_view>() {
uchar *buf = GetStringsBuffer();
return TensorIterator<std::string_view>(data_, buf, shape_.NumOfElements());
return TensorIterator<std::string_view>(data_, shape_.NumOfElements());
}
} // namespace dataset
} // namespace mindspore

@ -40,16 +40,7 @@ dsize_t TensorShape::NumOfElements() const {
if (!known()) {
return 0;
}
dsize_t num = 1;
for (auto i : raw_shape_) {
if (multi_ok(num, i)) {
num *= i;
} else {
// dsize_t can wrap since it is signed int, we double check here
MS_LOG(ERROR) << "Tensor shape larger than maximum allowed value!";
}
}
return num;
return strides_[0];
}
void TensorShape::Print(std::ostream &out) const {
@ -72,20 +63,23 @@ void TensorShape::Print(std::ostream &out) const {
}
TensorShape::TensorShape(const std::initializer_list<dsize_t> &list)
: raw_shape_(*GlobalContext::Instance()->int_allocator()) {
: raw_shape_(*GlobalContext::Instance()->int_allocator()), strides_(*GlobalContext::Instance()->int_allocator()) {
AddListToShape(list);
}
TensorShape::TensorShape(const std::vector<dsize_t> &list) : raw_shape_(*GlobalContext::Instance()->int_allocator()) {
TensorShape::TensorShape(const std::vector<dsize_t> &list)
: raw_shape_(*GlobalContext::Instance()->int_allocator()), strides_(*GlobalContext::Instance()->int_allocator()) {
AddListToShape(list);
}
TensorShape::TensorShape(const TensorShape &shape) : raw_shape_(*GlobalContext::Instance()->int_allocator()) {
TensorShape::TensorShape(const TensorShape &shape)
: raw_shape_(*GlobalContext::Instance()->int_allocator()), strides_(*GlobalContext::Instance()->int_allocator()) {
AddListToShape(shape.AsVector());
known_ = shape.known_; // override with the input shape in case of unknown-rank tensor shape.
}
TensorShape::TensorShape(py::list l) : raw_shape_(*GlobalContext::Instance()->int_allocator()) {
TensorShape::TensorShape(py::list l)
: raw_shape_(*GlobalContext::Instance()->int_allocator()), strides_(*GlobalContext::Instance()->int_allocator()) {
std::vector<dsize_t> list_c;
for (auto &i : l) {
if (!i.is_none()) {
@ -97,6 +91,18 @@ TensorShape::TensorShape(py::list l) : raw_shape_(*GlobalContext::Instance()->in
AddListToShape(list_c);
}
TensorShape::TensorShape(cv::MatSize cv_size, uint32_t type)
: raw_shape_(*GlobalContext::Instance()->int_allocator()), strides_(*GlobalContext::Instance()->int_allocator()) {
for (int i = 0; i < cv_size.dims(); i++) {
raw_shape_.push_back(cv_size[i]);
}
auto channels = static_cast<uint8_t>(1 + (type >> static_cast<uint8_t>(CV_CN_SHIFT)));
if (channels != 1) {
raw_shape_.push_back(channels);
}
known_ = true;
}
TensorShape TensorShape::CreateUnknownRankShape() {
TensorShape s({});
s.known_ = false;
@ -109,17 +115,6 @@ TensorShape TensorShape::InsertDim(dsize_t axis, dsize_t dim) const {
return TensorShape(tmp);
}
TensorShape::TensorShape(cv::MatSize cv_size, uint32_t type) : raw_shape_(*GlobalContext::Instance()->int_allocator()) {
for (int i = 0; i < cv_size.dims(); i++) {
raw_shape_.push_back(cv_size[i]);
}
auto channels = static_cast<uint8_t>(1 + (type >> static_cast<uint8_t>(CV_CN_SHIFT)));
if (channels != 1) {
raw_shape_.push_back(channels);
}
known_ = true;
}
std::vector<dsize_t> TensorShape::AsVector() const {
return std::vector<dsize_t>(raw_shape_.begin(), raw_shape_.end());
}
@ -139,23 +134,28 @@ bool TensorShape::IsValidIndex(const std::vector<dsize_t> &index) const {
template <typename T>
void TensorShape::AddListToShape(const T &list) {
raw_shape_.resize(list.size());
strides_.resize(list.size() + 1);
strides_[list.size()] = 1;
known_ = true;
dsize_t num = 1;
dsize_t size = 0;
for (const auto &itr : list) {
if (itr > 0) {
if (num > std::numeric_limits<int64_t>::max() / itr) {
auto itr = std::rbegin(list); // iterate over the list in reverse order
auto s = list.size() - 1; // to compute strides while adding dims
for (; itr != std::rend(list); itr++, s--) {
dsize_t dim = *itr;
if (dim > 0) {
if (strides_[s + 1] > std::numeric_limits<int64_t>::max() / dim) {
MS_LOG(ERROR) << "Invalid shape data, overflow occurred!";
known_ = false;
raw_shape_.clear();
return;
}
num *= itr;
strides_[s] = dim * strides_[s + 1];
}
if (itr < 0) {
if (dim < 0) {
known_ = false;
}
if (itr > kDeMaxDim) {
if (dim > kDeMaxDim) {
std::stringstream ss;
ss << "Invalid shape data, dim (" << size << ") is larger than the maximum dim size(" << kDeMaxDim << ")!";
MS_LOG(ERROR) << ss.str().c_str();
@ -163,7 +163,7 @@ void TensorShape::AddListToShape(const T &list) {
raw_shape_.clear();
return;
}
raw_shape_.push_back(itr);
raw_shape_[s] = dim;
size++;
}
if (size > kDeMaxRank) {
@ -215,17 +215,18 @@ TensorShape TensorShape::Squeeze() const {
}
return TensorShape(new_shape);
}
std::vector<dsize_t> TensorShape::Strides() {
std::vector<dsize_t> strides(Rank());
dsize_t count = NumOfElements();
for (dsize_t i = 0; i < Rank(); i++) {
if (raw_shape_[i] != 0)
count /= raw_shape_[i];
else
count = 0;
strides[i] = count;
std::vector<dsize_t> TensorShape::Strides() const { return std::vector<dsize_t>{strides_.begin() + 1, strides_.end()}; }
// Name: ToFlatIndex()
// Description: convert a vector style index to number, used to access memory internal use only
Status TensorShape::ToFlatIndex(const std::vector<dsize_t> &index, dsize_t *flat_index) const {
*flat_index = 0;
for (size_t k = 0; k < index.size(); k++) {
*flat_index += index[k] * strides_[k + 1]; // skip the first element of strides_ which is numOfElements
}
return strides;
CHECK_FAIL_RETURN_UNEXPECTED(*flat_index < NumOfElements(), "Not a valid index");
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

@ -156,13 +156,20 @@ class TensorShape {
TensorShape Squeeze() const;
std::vector<dsize_t> Strides();
std::vector<dsize_t> Strides() const;
// Returns the location of the item assuming row major memory layout.
// @param index
// @return
Status ToFlatIndex(const std::vector<dsize_t> &index, dsize_t *flat_index) const;
private:
// True if known and valid shape, false otherwise
bool known_;
// Vector to keep the dims of the shape.
std::vector<dsize_t, IntAlloc> raw_shape_;
// Vector to keep the strides of the shape. The size is rank+1
std::vector<dsize_t, IntAlloc> strides_;
// Internal utility function to iterate over a list, check if the dim is valid and then insert it into the shape.
// @tparam T list

@ -1,5 +1,3 @@
ms_protobuf_generate(EXAMPLE_SRCS EXAMPLE_HDRS example.proto)
ms_protobuf_generate(FEATURE_SRCS FEATURE_HDRS feature.proto)
add_subdirectory(sampler)
file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc")
set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD)
@ -15,13 +13,9 @@ add_library(engine-datasetops-source OBJECT
image_folder_op.cc
mnist_op.cc
voc_op.cc
${EXAMPLE_SRCS}
${FEATURE_SRCS}
manifest_op.cc
cifar_op.cc
random_data_op.cc
celeba_op.cc
text_file_op.cc
)
add_dependencies(engine-datasetops-source mindspore::protobuf)
)

@ -127,8 +127,10 @@ Status MindRecordOp::Init() {
std::string type_str = mindrecord::ColumnDataTypeNameNormalized[col_data_types[i]];
DataType t_dtype = DataType(type_str); // valid types: {"bytes", "string", "int32", "int64", "float32", "float64"}
if (col_data_types[i] == mindrecord::ColumnBytes || col_data_types[i] == mindrecord::ColumnString) { // rank = 1
if (col_data_types[i] == mindrecord::ColumnBytes) { // rank = 1
col_desc = ColDescriptor(colname, t_dtype, TensorImpl::kFlexible, 1);
} else if (col_data_types[i] == mindrecord::ColumnString) { // rank = 0
col_desc = ColDescriptor(colname, t_dtype, TensorImpl::kFlexible, 0);
} else if (col_shapes[i].size() > 0) {
std::vector<dsize_t> vec(col_shapes[i].size()); // temporary vector to hold shape
(void)std::copy(col_shapes[i].begin(), col_shapes[i].end(), vec.begin());
@ -310,7 +312,10 @@ Status MindRecordOp::LoadTensorRow(TensorRow *tensor_row, const std::vector<uint
// Set shape
auto num_elements = n_bytes / column_data_type_size;
if (column.hasShape()) {
if (type == DataType::DE_STRING) {
std::string s{data, data + n_bytes};
RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, {s}, TensorShape::CreateScalar()));
} else if (column.hasShape()) {
auto new_shape = TensorShape(column.shape());
RETURN_IF_NOT_OK(column.MaterializeTensorShape(static_cast<int32_t>(num_elements), &new_shape));
RETURN_IF_NOT_OK(Tensor::CreateTensor(&tensor, column.tensorImpl(), new_shape, type, data));

@ -63,7 +63,8 @@ Status Sampler::CreateSamplerTensor(std::shared_ptr<Tensor> *sample_ids, int64_t
}
TensorShape shape(std::vector<dsize_t>(1, num_elements));
RETURN_IF_NOT_OK(Tensor::CreateTensor(sample_ids, col_desc_->tensorImpl(), shape, col_desc_->type()));
(void)(*sample_ids)->GetMutableBuffer(); // allocate memory in case user forgets!
RETURN_IF_NOT_OK(
(*sample_ids)->AllocateBuffer((*sample_ids)->SizeInBytes())); // allocate memory in case user forgets!
return Status::OK();
}

@ -724,18 +724,26 @@ Status TFReaderOp::LoadBytesList(const ColDescriptor &current_col, const dataeng
// kBytesList can map to the following DE types ONLY!
// DE_UINT8, DE_INT8
// Must be single byte type for each element!
if (current_col.type() != DataType::DE_UINT8 && current_col.type() != DataType::DE_INT8) {
if (current_col.type() != DataType::DE_UINT8 && current_col.type() != DataType::DE_INT8 &&
current_col.type() != DataType::DE_STRING) {
std::string err_msg = "Invalid datatype for Tensor at column: " + current_col.name();
RETURN_STATUS_UNEXPECTED(err_msg);
}
const dataengine::BytesList &bytes_list = column_values_list.bytes_list();
*num_elements = bytes_list.value_size();
if (current_col.type() == DataType::DE_STRING) {
TensorShape shape = TensorShape::CreateScalar();
RETURN_IF_NOT_OK(current_col.MaterializeTensorShape(*num_elements, &shape));
RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, bytes_list, shape));
return Status::OK();
}
uint64_t max_size = 0;
for (uint32_t i = 0; i < bytes_list.value_size(); ++i) max_size = std::max(max_size, bytes_list.value(i).size());
*num_elements = bytes_list.value_size();
int64_t pad_size = max_size;
// if user provides a shape in the form of [-1, d1, 2d, ... , dn], we need to pad to d1 * d2 * ... * dn
@ -879,7 +887,7 @@ Status TFReaderOp::LoadIntList(const ColDescriptor &current_col, const dataengin
RETURN_IF_NOT_OK(Tensor::CreateTensor(tensor, current_col.tensorImpl(), current_shape, current_col.type()));
// Tensors are lazily allocated, this eagerly allocates memory for the tensor.
(void)(*tensor)->GetMutableBuffer();
RETURN_IF_NOT_OK((*tensor)->AllocateBuffer((*tensor)->SizeInBytes()));
int64_t i = 0;
auto it = (*tensor)->begin<T>();

@ -162,7 +162,7 @@ void CastFrom(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *out
Status TypeCast(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output, const DataType &data_type) {
RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), data_type));
static_cast<void>((*output)->GetMutableBuffer());
RETURN_IF_NOT_OK((*output)->AllocateBuffer((*output)->SizeInBytes()));
switch (input->type().value()) {
case DataType::DE_BOOL:
CastFrom<bool>(input, output);
@ -211,7 +211,7 @@ Status ToFloat16(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *
// initiate new tensor for type cast
DataType new_type = DataType("float16");
RETURN_IF_NOT_OK(Tensor::CreateTensor(output, TensorImpl::kFlexible, input->shape(), new_type));
static_cast<void>((*output)->GetMutableBuffer());
RETURN_IF_NOT_OK((*output)->AllocateBuffer((*output)->SizeInBytes()));
auto in_itr = input->begin<float>();
auto out_itr = (*output)->begin<float16>();

@ -64,7 +64,8 @@ Status Flip(std::shared_ptr<Tensor> input, std::shared_ptr<Tensor> *output, int
std::shared_ptr<CVTensor> output_cv = std::make_shared<CVTensor>(input_cv->shape(), input_cv->type());
RETURN_UNEXPECTED_IF_NULL(output_cv);
(void)output_cv->GetMutableBuffer();
RETURN_IF_NOT_OK(output_cv->AllocateBuffer(output_cv->SizeInBytes()));
if (input_cv->mat().data) {
try {
cv::flip(input_cv->mat(), output_cv->mat(), flip_code);

@ -51,7 +51,7 @@ enum ColumnDataType {
// mapping as {"bytes", "string", "int32", "int64", "float32", "float64"};
const uint32_t ColumnDataTypeSize[kDataTypes] = {1, 1, 4, 8, 4, 8};
const std::vector<std::string> ColumnDataTypeNameNormalized = {"uint8", "uint8", "int32",
const std::vector<std::string> ColumnDataTypeNameNormalized = {"uint8", "string", "int32",
"int64", "float32", "float64"};
const std::unordered_map<std::string, ColumnDataType> ColumnDataTypeMap = {

@ -48,6 +48,7 @@ def mstype_to_detype(type_):
mstype.float16: cde.DataType("float16"),
mstype.float32: cde.DataType("float32"),
mstype.float64: cde.DataType("float64"),
mstype.string: cde.DataType("string"),
}[type_]

@ -26,7 +26,7 @@ from . import datasets
INT32_MAX = 2147483647
valid_detype = [
"bool", "int8", "int16", "int32", "int64", "uint8", "uint16",
"uint32", "uint64", "float16", "float32", "float64"
"uint32", "uint64", "float16", "float32", "float64", "string"
]

@ -32,47 +32,47 @@ class MindDataTestDatatype : public UT::Common {
TEST_F(MindDataTestDatatype, TestSizes) {
uint8_t x = DataType::SIZE_IN_BYTES[DataType::DE_BOOL];
uint8_t x = DataType::kTypeInfo[DataType::DE_BOOL].sizeInBytes_;
DataType d = DataType(DataType::DE_BOOL);
ASSERT_EQ(x, 1);
ASSERT_EQ(d.SizeInBytes(), x);
x = DataType::SIZE_IN_BYTES[DataType::DE_INT8];
x = DataType::kTypeInfo[DataType::DE_INT8].sizeInBytes_;
d = DataType(DataType::DE_INT8);
ASSERT_EQ(x, 1);
ASSERT_EQ(d.SizeInBytes(), x);
x = DataType::SIZE_IN_BYTES[DataType::DE_UINT8];
x = DataType::kTypeInfo[DataType::DE_UINT8].sizeInBytes_;
d = DataType(DataType::DE_UINT8);
ASSERT_EQ(x, 1);
ASSERT_EQ(d.SizeInBytes(), x);
x = DataType::SIZE_IN_BYTES[DataType::DE_INT16];
x = DataType::kTypeInfo[DataType::DE_INT16].sizeInBytes_;
d = DataType(DataType::DE_INT16);
ASSERT_EQ(x, 2);
ASSERT_EQ(d.SizeInBytes(), x);
x = DataType::SIZE_IN_BYTES[DataType::DE_UINT16];
x = DataType::kTypeInfo[DataType::DE_UINT16].sizeInBytes_;
d = DataType(DataType::DE_UINT16);
ASSERT_EQ(x, 2);
ASSERT_EQ(d.SizeInBytes(), x);
x = DataType::SIZE_IN_BYTES[DataType::DE_INT32];
x = DataType::kTypeInfo[DataType::DE_INT32].sizeInBytes_;
d = DataType(DataType::DE_INT32);
ASSERT_EQ(x, 4);
ASSERT_EQ(d.SizeInBytes(), x);
x = DataType::SIZE_IN_BYTES[DataType::DE_UINT32];
x = DataType::kTypeInfo[DataType::DE_UINT32].sizeInBytes_;
d = DataType(DataType::DE_UINT32);
ASSERT_EQ(x, 4);
ASSERT_EQ(d.SizeInBytes(), x);
x = DataType::SIZE_IN_BYTES[DataType::DE_INT64];
x = DataType::kTypeInfo[DataType::DE_INT64].sizeInBytes_;
d = DataType(DataType::DE_INT64);
ASSERT_EQ(x, 8);
ASSERT_EQ(d.SizeInBytes(), x);
x = DataType::SIZE_IN_BYTES[DataType::DE_UINT64];
x = DataType::kTypeInfo[DataType::DE_UINT64].sizeInBytes_;
d = DataType(DataType::DE_UINT64);
ASSERT_EQ(x, 8);
ASSERT_EQ(d.SizeInBytes(), x);
x = DataType::SIZE_IN_BYTES[DataType::DE_FLOAT32];
x = DataType::kTypeInfo[DataType::DE_FLOAT32].sizeInBytes_;
d = DataType(DataType::DE_FLOAT32);
ASSERT_EQ(x, 4);
ASSERT_EQ(d.SizeInBytes(), x);
x = DataType::SIZE_IN_BYTES[DataType::DE_FLOAT64];
x = DataType::kTypeInfo[DataType::DE_FLOAT64].sizeInBytes_;
d = DataType(DataType::DE_FLOAT64);
ASSERT_EQ(x, 8);
ASSERT_EQ(d.SizeInBytes(), x);

@ -14,9 +14,7 @@
* limitations under the License.
*/
#include "common/common.h"
#include "common/cvop_common.h"
#include "dataset/kernels/data/one_hot_op.h"
#include "dataset/core/cv_tensor.h"
#include "utils/log_adapter.h"
using namespace mindspore::dataset;
@ -24,9 +22,9 @@ using mindspore::MsLogLevel::INFO;
using mindspore::ExceptionType::NoExceptionType;
using mindspore::LogStream;
class MindDataTestOneHotOp : public UT::CVOP::CVOpCommon {
class MindDataTestOneHotOp : public UT::Common {
protected:
MindDataTestOneHotOp() : CVOpCommon() {}
MindDataTestOneHotOp() {}
};
TEST_F(MindDataTestOneHotOp, TestOp) {

@ -65,14 +65,14 @@ TEST_F(MindDataTestStringTensorDE, Basics) {
TEST_F(MindDataTestStringTensorDE, Basics2) {
std::shared_ptr<Tensor> t =
std::make_shared<Tensor>(std::vector<std::string>{"abc", "defg", "hi", "klmno", "123", "789"}, TensorShape({2, 3}));
ASSERT_TRUE(t->SizeInBytes() == 6 * 5 + 20);
std::vector<uint32_t> offsets = {3, 8, 11, 17, 21, 25};
ASSERT_TRUE(t->SizeInBytes() == 6 * 5 + 20 + 4);
std::vector<uint32_t> offsets = {0, 4, 9, 12, 18, 22, 26};
uint32_t ctr = 0;
for (auto i : offsets) {
ASSERT_TRUE(*(reinterpret_cast<uint32_t *>(t->GetMutableBuffer() + ctr)) == i);
ASSERT_TRUE(*(reinterpret_cast<uint32_t *>(t->GetMutableBuffer() + ctr)) == i + 28);
ctr += 4;
}
const char *buf = reinterpret_cast<char *>(t->GetMutableBuffer()) + 6 * 4;
const char *buf = reinterpret_cast<char *>(t->GetMutableBuffer()) + 6 * 4 + 4;
std::vector<uint32_t> starts = {0, 4, 9, 12, 18, 22};
uint32_t index = 0;
@ -90,14 +90,14 @@ TEST_F(MindDataTestStringTensorDE, Empty) {
std::shared_ptr<Tensor> t = std::make_shared<Tensor>(strings, TensorShape({2, 3}));
// abc_defg___123__
// 0123456789012345
ASSERT_TRUE(t->SizeInBytes() == 6 * 5 + 10);
std::vector<uint32_t> offsets = {3, 8, 9, 10, 14, 15};
ASSERT_TRUE(t->SizeInBytes() == 6 * 5 + 10 + 4);
std::vector<uint32_t> offsets = {0, 4, 9, 10, 11, 15, 16};
uint32_t ctr = 0;
for (auto i : offsets) {
ASSERT_TRUE(*(reinterpret_cast<uint32_t *>(t->GetMutableBuffer() + ctr)) == i);
ASSERT_TRUE(*(reinterpret_cast<uint32_t *>(t->GetMutableBuffer() + ctr)) == i + 28);
ctr += 4;
}
const char *buf = reinterpret_cast<char *>(t->GetMutableBuffer()) + 6 * 4;
const char *buf = reinterpret_cast<char *>(t->GetMutableBuffer()) + 6 * 4 + 4;
std::vector<uint32_t> starts = {0, 4, 9, 10, 11, 15};
uint32_t index = 0;

@ -41,6 +41,7 @@ class MindDataTestTensorDE : public UT::Common {
TEST_F(MindDataTestTensorDE, Basics) {
std::shared_ptr<Tensor> t = std::make_shared<Tensor>(TensorShape({2, 3}), DataType(DataType::DE_UINT64));
ASSERT_TRUE((t->AllocateBuffer(t->SizeInBytes())).IsOk());
ASSERT_EQ(t->shape(), TensorShape({2, 3}));
ASSERT_EQ(t->type(), DataType::DE_UINT64);
ASSERT_EQ(t->SizeInBytes(), 2 * 3 * 8);

@ -0,0 +1,18 @@
{
"datasetType": "TF",
"numRows": 3,
"columns": {
"line": {
"type": "string",
"rank": 0
},
"words": {
"type": "string",
"rank": 1
},
"chinese": {
"type": "string",
"rank": 0
}
}
}

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save