!6307 Conversion to and from TensorRow in C++

Merge pull request !6307 from MahdiRahmaniHanzaki/c_func
pull/6307/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 70221f5261

@ -72,6 +72,99 @@ class TensorRow {
// Destructor
~TensorRow() = default;
/// Convert a vector of primitive types to a TensorRow consisting of n single data Tensors.
/// \tparam `T`
/// \param[in] o input vector
/// \param[out] output TensorRow
template <typename T>
static Status ConvertToTensorRow(const std::vector<T> &o, TensorRow *output) {
DataType data_type = DataType::FromCType<T>();
if (data_type == DataType::DE_UNKNOWN) {
RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type was not recognized.");
}
if (data_type == DataType::DE_STRING) {
RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type string is not supported.");
}
for (int i = 0; i < o.size(); i++) {
std::shared_ptr<Tensor> tensor;
Tensor::CreateEmpty(TensorShape({1}), data_type, &tensor);
std::string_view s;
tensor->SetItemAt({0}, o[i]);
output->push_back(tensor);
}
return Status::OK();
}
/// Convert a single primitive type to a TensorRow consisting of one single data Tensor.
/// \tparam `T`
/// \param[in] o input
/// \param[out] output TensorRow
template <typename T>
static Status ConvertToTensorRow(const T &o, TensorRow *output) {
DataType data_type = DataType::FromCType<T>();
if (data_type == DataType::DE_UNKNOWN) {
RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type was not recognized.");
}
if (data_type == DataType::DE_STRING) {
RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type string is not supported.");
}
std::shared_ptr<Tensor> tensor;
Tensor::CreateEmpty(TensorShape({1}), data_type, &tensor);
tensor->SetItemAt({0}, o);
output->push_back(tensor);
return Status::OK();
}
/// Return the value in a TensorRow consiting of 1 single data Tensor
/// \tparam `T`
/// \param[in] input TensorRow
/// \param[out] o the primitive variable
template <typename T>
static Status ConvertFromTensorRow(const TensorRow &input, T *o) {
DataType data_type = DataType::FromCType<T>();
if (data_type == DataType::DE_UNKNOWN) {
RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: Data type was not recognized.");
}
if (data_type == DataType::DE_STRING) {
RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type string is not supported.");
}
if (input.size() != 1) {
RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The input TensorRow is empty.");
}
if (input.at(0)->type() != data_type) {
RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The output type doesn't match the input tensor type.");
}
if (input.at(0)->shape() != TensorShape({1})) {
RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The input tensors must have a shape of {1}.");
}
return input.at(0)->GetItemAt(o, {0});
}
/// Convert a TensorRow consisting of n single data tensors to a vector of size n
/// \tparam `T`
/// \param[in] o TensorRow consisting of n single data tensors
/// \param[out] o vector of primitive variable
template <typename T>
static Status ConvertFromTensorRow(const TensorRow &input, std::vector<T> *o) {
DataType data_type = DataType::FromCType<T>();
if (data_type == DataType::DE_UNKNOWN) {
RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: Data type was not recognized.");
}
if (data_type == DataType::DE_STRING) {
RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type string is not supported.");
}
for (int i = 0; i < input.size(); i++) {
if (input.at(i)->shape() != TensorShape({1})) {
RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The input tensor must have a shape of 1.");
}
T item;
RETURN_IF_NOT_OK(input.at(i)->GetItemAt(&item, {0}));
o->push_back(item);
}
return Status::OK();
}
// Functions to fetch/set id/vector
row_id_type getId() const { return id_; }

@ -71,6 +71,7 @@ SET(DE_UT_SRCS
status_test.cc
task_manager_test.cc
tensor_test.cc
tensor_row_test.cc
tensor_string_test.cc
tensorshape_test.cc
tfReader_op_test.cc

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save