|
|
|
@ -72,6 +72,99 @@ class TensorRow {
|
|
|
|
|
// Destructor
|
|
|
|
|
~TensorRow() = default;
|
|
|
|
|
|
|
|
|
|
/// Convert a vector of primitive types to a TensorRow consisting of n single data Tensors.
|
|
|
|
|
/// \tparam `T`
|
|
|
|
|
/// \param[in] o input vector
|
|
|
|
|
/// \param[out] output TensorRow
|
|
|
|
|
template <typename T>
|
|
|
|
|
static Status ConvertToTensorRow(const std::vector<T> &o, TensorRow *output) {
|
|
|
|
|
DataType data_type = DataType::FromCType<T>();
|
|
|
|
|
if (data_type == DataType::DE_UNKNOWN) {
|
|
|
|
|
RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type was not recognized.");
|
|
|
|
|
}
|
|
|
|
|
if (data_type == DataType::DE_STRING) {
|
|
|
|
|
RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type string is not supported.");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < o.size(); i++) {
|
|
|
|
|
std::shared_ptr<Tensor> tensor;
|
|
|
|
|
Tensor::CreateEmpty(TensorShape({1}), data_type, &tensor);
|
|
|
|
|
std::string_view s;
|
|
|
|
|
tensor->SetItemAt({0}, o[i]);
|
|
|
|
|
output->push_back(tensor);
|
|
|
|
|
}
|
|
|
|
|
return Status::OK();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Convert a single primitive type to a TensorRow consisting of one single data Tensor.
|
|
|
|
|
/// \tparam `T`
|
|
|
|
|
/// \param[in] o input
|
|
|
|
|
/// \param[out] output TensorRow
|
|
|
|
|
template <typename T>
|
|
|
|
|
static Status ConvertToTensorRow(const T &o, TensorRow *output) {
|
|
|
|
|
DataType data_type = DataType::FromCType<T>();
|
|
|
|
|
if (data_type == DataType::DE_UNKNOWN) {
|
|
|
|
|
RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type was not recognized.");
|
|
|
|
|
}
|
|
|
|
|
if (data_type == DataType::DE_STRING) {
|
|
|
|
|
RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type string is not supported.");
|
|
|
|
|
}
|
|
|
|
|
std::shared_ptr<Tensor> tensor;
|
|
|
|
|
Tensor::CreateEmpty(TensorShape({1}), data_type, &tensor);
|
|
|
|
|
tensor->SetItemAt({0}, o);
|
|
|
|
|
output->push_back(tensor);
|
|
|
|
|
return Status::OK();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Return the value in a TensorRow consiting of 1 single data Tensor
|
|
|
|
|
/// \tparam `T`
|
|
|
|
|
/// \param[in] input TensorRow
|
|
|
|
|
/// \param[out] o the primitive variable
|
|
|
|
|
template <typename T>
|
|
|
|
|
static Status ConvertFromTensorRow(const TensorRow &input, T *o) {
|
|
|
|
|
DataType data_type = DataType::FromCType<T>();
|
|
|
|
|
if (data_type == DataType::DE_UNKNOWN) {
|
|
|
|
|
RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: Data type was not recognized.");
|
|
|
|
|
}
|
|
|
|
|
if (data_type == DataType::DE_STRING) {
|
|
|
|
|
RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type string is not supported.");
|
|
|
|
|
}
|
|
|
|
|
if (input.size() != 1) {
|
|
|
|
|
RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The input TensorRow is empty.");
|
|
|
|
|
}
|
|
|
|
|
if (input.at(0)->type() != data_type) {
|
|
|
|
|
RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The output type doesn't match the input tensor type.");
|
|
|
|
|
}
|
|
|
|
|
if (input.at(0)->shape() != TensorShape({1})) {
|
|
|
|
|
RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The input tensors must have a shape of {1}.");
|
|
|
|
|
}
|
|
|
|
|
return input.at(0)->GetItemAt(o, {0});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Convert a TensorRow consisting of n single data tensors to a vector of size n
|
|
|
|
|
/// \tparam `T`
|
|
|
|
|
/// \param[in] o TensorRow consisting of n single data tensors
|
|
|
|
|
/// \param[out] o vector of primitive variable
|
|
|
|
|
template <typename T>
|
|
|
|
|
static Status ConvertFromTensorRow(const TensorRow &input, std::vector<T> *o) {
|
|
|
|
|
DataType data_type = DataType::FromCType<T>();
|
|
|
|
|
if (data_type == DataType::DE_UNKNOWN) {
|
|
|
|
|
RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: Data type was not recognized.");
|
|
|
|
|
}
|
|
|
|
|
if (data_type == DataType::DE_STRING) {
|
|
|
|
|
RETURN_STATUS_UNEXPECTED("ConvertToTensorRow: Data type string is not supported.");
|
|
|
|
|
}
|
|
|
|
|
for (int i = 0; i < input.size(); i++) {
|
|
|
|
|
if (input.at(i)->shape() != TensorShape({1})) {
|
|
|
|
|
RETURN_STATUS_UNEXPECTED("ConvertFromTensorRow: The input tensor must have a shape of 1.");
|
|
|
|
|
}
|
|
|
|
|
T item;
|
|
|
|
|
RETURN_IF_NOT_OK(input.at(i)->GetItemAt(&item, {0}));
|
|
|
|
|
o->push_back(item);
|
|
|
|
|
}
|
|
|
|
|
return Status::OK();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Functions to fetch/set id/vector
|
|
|
|
|
row_id_type getId() const { return id_; }
|
|
|
|
|
|
|
|
|
|