commit
ec8f541325
@ -0,0 +1,241 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "dataset/text/kernels/to_number_op.h"
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <limits>
|
||||||
|
#include <memory>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "dataset/core/data_type.h"
|
||||||
|
#include "dataset/core/tensor.h"
|
||||||
|
#include "dataset/core/tensor_shape.h"
|
||||||
|
#include "dataset/kernels/data/data_utils.h"
|
||||||
|
#include "dataset/util/status.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
ToNumberOp::ToNumberOp(const DataType &cast_to_type) : cast_to_type_(cast_to_type) {}
|
||||||
|
|
||||||
|
ToNumberOp::ToNumberOp(const std::string &cast_to_type) : cast_to_type_(DataType(cast_to_type)) {}
|
||||||
|
|
||||||
|
Status ToNumberOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||||
|
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tenosrs should have type string.");
|
||||||
|
|
||||||
|
switch (cast_to_type_.value()) {
|
||||||
|
case DataType::DE_INT8:
|
||||||
|
RETURN_IF_NOT_OK(ToSignedIntegral<int8_t>(input, output));
|
||||||
|
break;
|
||||||
|
case DataType::DE_INT16:
|
||||||
|
RETURN_IF_NOT_OK(ToSignedIntegral<int16_t>(input, output));
|
||||||
|
break;
|
||||||
|
case DataType::DE_INT32:
|
||||||
|
RETURN_IF_NOT_OK(ToSignedIntegral<int32_t>(input, output));
|
||||||
|
break;
|
||||||
|
case DataType::DE_INT64:
|
||||||
|
RETURN_IF_NOT_OK(ToSignedIntegral<int64_t>(input, output));
|
||||||
|
break;
|
||||||
|
case DataType::DE_UINT8:
|
||||||
|
RETURN_IF_NOT_OK(ToUnsignedIntegral<uint8_t>(input, output));
|
||||||
|
break;
|
||||||
|
case DataType::DE_UINT16:
|
||||||
|
RETURN_IF_NOT_OK(ToUnsignedIntegral<uint16_t>(input, output));
|
||||||
|
break;
|
||||||
|
case DataType::DE_UINT32:
|
||||||
|
RETURN_IF_NOT_OK(ToUnsignedIntegral<uint32_t>(input, output));
|
||||||
|
break;
|
||||||
|
case DataType::DE_UINT64:
|
||||||
|
RETURN_IF_NOT_OK(ToUnsignedIntegral<uint64_t>(input, output));
|
||||||
|
break;
|
||||||
|
case DataType::DE_FLOAT16:
|
||||||
|
RETURN_IF_NOT_OK(this->ToFloat16(input, output));
|
||||||
|
break;
|
||||||
|
case DataType::DE_FLOAT32:
|
||||||
|
RETURN_IF_NOT_OK(ToFloat(input, output));
|
||||||
|
break;
|
||||||
|
case DataType::DE_FLOAT64:
|
||||||
|
RETURN_IF_NOT_OK(ToDouble(input, output));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
void ToNumberOp::Print(std::ostream &out) const { out << "ToNumberOp: casting to " << '\n'; }
|
||||||
|
|
||||||
|
Status ToNumberOp::OutputShape(const std::vector<TensorShape> &input_shapes, std::vector<TensorShape> &output_shapes) {
|
||||||
|
(void)std::copy(input_shapes.begin(), input_shapes.end(), std::back_inserter(output_shapes));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Status ToNumberOp::ToSignedIntegral(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||||
|
std::vector<T> casted;
|
||||||
|
|
||||||
|
for (auto it = input->begin<std::string_view>(); it != input->end<std::string_view>(); ++it) {
|
||||||
|
bool is_cast_out_of_range = false;
|
||||||
|
int64_t result = 0;
|
||||||
|
|
||||||
|
try {
|
||||||
|
result = std::stoll(std::string(*it));
|
||||||
|
} catch (const std::out_of_range &) {
|
||||||
|
is_cast_out_of_range = true;
|
||||||
|
} catch (const std::invalid_argument &) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to a number.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (result > std::numeric_limits<T>::max() || result < std::numeric_limits<T>::min() || is_cast_out_of_range) {
|
||||||
|
std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " +
|
||||||
|
cast_to_type_.ToString() + ". The valid range is: [" +
|
||||||
|
std::to_string(std::numeric_limits<T>::min()) + ", " +
|
||||||
|
std::to_string(std::numeric_limits<T>::max()) + "].";
|
||||||
|
|
||||||
|
RETURN_STATUS_UNEXPECTED(error_message);
|
||||||
|
}
|
||||||
|
|
||||||
|
T casted_result = static_cast<T>(result);
|
||||||
|
casted.push_back(casted_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape()));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Status ToNumberOp::ToUnsignedIntegral(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||||
|
std::vector<T> casted;
|
||||||
|
|
||||||
|
for (auto it = input->begin<std::string_view>(); it != input->end<std::string_view>(); ++it) {
|
||||||
|
bool is_cast_out_of_range = false;
|
||||||
|
uint64_t result = 0;
|
||||||
|
|
||||||
|
// If there is a - at the start of the string, it is considered by us to
|
||||||
|
// be out of bounds. If the - is somewhere else in the string, it is
|
||||||
|
// deemed invalid by std::stoull and will throw std::invalid_argument
|
||||||
|
for (int i = 0; i < (*it).size(); i++) {
|
||||||
|
if ((*it)[i] == '-') {
|
||||||
|
is_cast_out_of_range = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
result = std::stoull(std::string(*it));
|
||||||
|
} catch (const std::out_of_range &) {
|
||||||
|
is_cast_out_of_range = true;
|
||||||
|
} catch (const std::invalid_argument &) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to an unsigned integer.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (result > std::numeric_limits<T>::max() || result < std::numeric_limits<T>::min() || is_cast_out_of_range) {
|
||||||
|
std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " +
|
||||||
|
cast_to_type_.ToString() + ". The valid range is: [" +
|
||||||
|
std::to_string(std::numeric_limits<T>::min()) + ", " +
|
||||||
|
std::to_string(std::numeric_limits<T>::max()) + "].";
|
||||||
|
|
||||||
|
RETURN_STATUS_UNEXPECTED(error_message);
|
||||||
|
}
|
||||||
|
|
||||||
|
T casted_result = static_cast<T>(result);
|
||||||
|
casted.push_back(casted_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape()));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ToNumberOp::ToFloat16(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||||
|
// special case, float16 does not exist in c++, no native support for
|
||||||
|
// casting, so cast to float first then use this method, which use Eigen.
|
||||||
|
std::shared_ptr<Tensor> temp;
|
||||||
|
RETURN_IF_NOT_OK(Tensor::CreateTensor(&temp, TensorImpl::kFlexible, input->shape(), DataType("float32")));
|
||||||
|
RETURN_IF_NOT_OK(ToFloat(input, &temp));
|
||||||
|
RETURN_IF_NOT_OK(mindspore::dataset::ToFloat16(temp, output));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ToNumberOp::ToFloat(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||||
|
std::vector<float> casted;
|
||||||
|
|
||||||
|
for (auto it = input->begin<std::string_view>(); it != input->end<std::string_view>(); ++it) {
|
||||||
|
bool is_cast_out_of_range = false;
|
||||||
|
float result = 0;
|
||||||
|
|
||||||
|
try {
|
||||||
|
result = std::stof(std::string(*it));
|
||||||
|
} catch (const std::out_of_range &) {
|
||||||
|
is_cast_out_of_range = true;
|
||||||
|
} catch (const std::invalid_argument &) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to an unsigned integer.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (result > std::numeric_limits<float>::max() || result < std::numeric_limits<float>::lowest() ||
|
||||||
|
is_cast_out_of_range) {
|
||||||
|
std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " +
|
||||||
|
cast_to_type_.ToString() + ". The valid range is: [" +
|
||||||
|
std::to_string(std::numeric_limits<float>::lowest()) + ", " +
|
||||||
|
std::to_string(std::numeric_limits<float>::max()) + "].";
|
||||||
|
|
||||||
|
RETURN_STATUS_UNEXPECTED(error_message);
|
||||||
|
}
|
||||||
|
|
||||||
|
float casted_result = static_cast<float>(result);
|
||||||
|
casted.push_back(casted_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape()));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status ToNumberOp::ToDouble(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||||
|
std::vector<double> casted;
|
||||||
|
|
||||||
|
for (auto it = input->begin<std::string_view>(); it != input->end<std::string_view>(); ++it) {
|
||||||
|
bool is_cast_out_of_range = false;
|
||||||
|
double result = 0;
|
||||||
|
|
||||||
|
try {
|
||||||
|
result = std::stod(std::string(*it));
|
||||||
|
} catch (const std::out_of_range &) {
|
||||||
|
is_cast_out_of_range = true;
|
||||||
|
} catch (const std::invalid_argument &) {
|
||||||
|
RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to an unsigned integer.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if (result > std::numeric_limits<double>::max() || result < std::numeric_limits<double>::lowest() ||
|
||||||
|
is_cast_out_of_range) {
|
||||||
|
std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " +
|
||||||
|
cast_to_type_.ToString() + ". The valid range is: [" +
|
||||||
|
std::to_string(std::numeric_limits<double>::lowest()) + ", " +
|
||||||
|
std::to_string(std::numeric_limits<double>::max()) + "].";
|
||||||
|
|
||||||
|
RETURN_STATUS_UNEXPECTED(error_message);
|
||||||
|
}
|
||||||
|
|
||||||
|
double casted_result = static_cast<double>(result);
|
||||||
|
casted.push_back(casted_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape()));
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
@ -0,0 +1,79 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef DATASET_TEXT_KERNELS_TO_NUMBER_OP_H_
|
||||||
|
#define DATASET_TEXT_KERNELS_TO_NUMBER_OP_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "dataset/core/data_type.h"
|
||||||
|
#include "dataset/core/tensor.h"
|
||||||
|
#include "dataset/kernels/tensor_op.h"
|
||||||
|
#include "dataset/util/status.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace dataset {
|
||||||
|
|
||||||
|
class ToNumberOp : public TensorOp {
|
||||||
|
public:
|
||||||
|
// Constructor of ToNumberOp
|
||||||
|
// @param const DataType &cast_to_type - the type to convert string inputs to.
|
||||||
|
explicit ToNumberOp(const DataType &cast_to_type);
|
||||||
|
|
||||||
|
// Constructor of ToNumberOp
|
||||||
|
// @param const std::string &cast_to_type - the type in string form to convert string inputs to.
|
||||||
|
explicit ToNumberOp(const std::string &cast_to_type);
|
||||||
|
|
||||||
|
~ToNumberOp() override = default;
|
||||||
|
|
||||||
|
// Perform numeric conversion on each string in each tensor.
|
||||||
|
// @param const std::shared_ptr<Tensor> &input
|
||||||
|
// @param std::shared_ptr<Tensor> *output
|
||||||
|
// @return error code
|
||||||
|
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||||
|
|
||||||
|
// For each input shape, find the output shape
|
||||||
|
// @param std::vector<TensorShape> &inputs - shape of input tensors
|
||||||
|
// @param std::vector<TensorShape> &outputs - shape of output tensors
|
||||||
|
// @return error code
|
||||||
|
Status OutputShape(const std::vector<TensorShape> &input_shapes, std::vector<TensorShape> &output_shapes) override;
|
||||||
|
|
||||||
|
// print arg for debugging
|
||||||
|
// @param std::ostream &out
|
||||||
|
void Print(std::ostream &out) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
template <typename T>
|
||||||
|
Status ToSignedIntegral(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
Status ToUnsignedIntegral(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
|
||||||
|
|
||||||
|
Status ToFloat16(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
|
||||||
|
|
||||||
|
Status ToFloat(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
|
||||||
|
|
||||||
|
Status ToDouble(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
|
||||||
|
|
||||||
|
DataType cast_to_type_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace dataset
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // DATASET_TEXT_KERNELS_TO_NUMBER_OP_H_
|
@ -0,0 +1,194 @@
|
|||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
import mindspore.dataset.text as text
|
||||||
|
|
||||||
|
np_integral_types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16,
|
||||||
|
np.uint32, np.uint64]
|
||||||
|
ms_integral_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8,
|
||||||
|
mstype.uint16, mstype.uint32, mstype.uint64]
|
||||||
|
|
||||||
|
np_non_integral_types = [np.float16, np.float32, np.float64]
|
||||||
|
ms_non_integral_types = [mstype.float16, mstype.float32, mstype.float64]
|
||||||
|
|
||||||
|
def string_dataset_generator(strings):
|
||||||
|
for string in strings:
|
||||||
|
yield (np.array(string, dtype='S'),)
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_number_typical_case_integral():
|
||||||
|
input_strings = [["-121", "14"], ["-2219", "7623"], ["-8162536", "162371864"],
|
||||||
|
["-1726483716", "98921728421"]]
|
||||||
|
|
||||||
|
for ms_type, inputs in zip(ms_integral_types, input_strings):
|
||||||
|
dataset = ds.GeneratorDataset(string_dataset_generator(inputs), "strings")
|
||||||
|
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_type))
|
||||||
|
|
||||||
|
expected_output = [int(string) for string in inputs]
|
||||||
|
output = []
|
||||||
|
for data in dataset.create_dict_iterator():
|
||||||
|
output.append(data["strings"])
|
||||||
|
|
||||||
|
assert output == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_number_typical_case_non_integral():
|
||||||
|
input_strings = [["-1.1", "1.4"], ["-2219.321", "7623.453"], ["-816256.234282", "162371864.243243"]]
|
||||||
|
epsilons = [0.001, 0.001, 0.0001, 0.0001, 0.0000001, 0.0000001]
|
||||||
|
|
||||||
|
for ms_type, inputs in zip(ms_non_integral_types, input_strings):
|
||||||
|
dataset = ds.GeneratorDataset(string_dataset_generator(inputs), "strings")
|
||||||
|
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_type))
|
||||||
|
|
||||||
|
expected_output = [float(string) for string in inputs]
|
||||||
|
output = []
|
||||||
|
for data in dataset.create_dict_iterator():
|
||||||
|
output.append(data["strings"])
|
||||||
|
|
||||||
|
for expected, actual, epsilon in zip(expected_output, output, epsilons):
|
||||||
|
assert abs(expected - actual) < epsilon
|
||||||
|
|
||||||
|
|
||||||
|
def out_of_bounds_error_message_check(dataset, np_type, value_to_cast):
|
||||||
|
type_info = np.iinfo(np_type)
|
||||||
|
type_max = str(type_info.max)
|
||||||
|
type_min = str(type_info.min)
|
||||||
|
type_name = str(np.dtype(np_type))
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as info:
|
||||||
|
for _ in dataset.create_dict_iterator():
|
||||||
|
pass
|
||||||
|
assert "String input " + value_to_cast + " will be out of bounds if casted to " + type_name in str(info.value)
|
||||||
|
assert "valid range is: [" + type_min + ", " + type_max + "]" in str(info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_number_out_of_bounds_integral():
|
||||||
|
for np_type, ms_type in zip(np_integral_types, ms_integral_types):
|
||||||
|
type_info = np.iinfo(np_type)
|
||||||
|
input_strings = [str(type_info.max + 10)]
|
||||||
|
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||||
|
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_type))
|
||||||
|
out_of_bounds_error_message_check(dataset, np_type, input_strings[0])
|
||||||
|
|
||||||
|
input_strings = [str(type_info.min - 10)]
|
||||||
|
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||||
|
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_type))
|
||||||
|
out_of_bounds_error_message_check(dataset, np_type, input_strings[0])
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_number_out_of_bounds_non_integral():
|
||||||
|
above_range = [str(np.finfo(np.float16).max * 10), str(np.finfo(np.float32).max * 10), "1.8e+308"]
|
||||||
|
|
||||||
|
input_strings = [above_range[0]]
|
||||||
|
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||||
|
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_non_integral_types[0]))
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as info:
|
||||||
|
for _ in dataset.create_dict_iterator():
|
||||||
|
pass
|
||||||
|
assert "outside of valid float16 range" in str(info.value)
|
||||||
|
|
||||||
|
input_strings = [above_range[1]]
|
||||||
|
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||||
|
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_non_integral_types[1]))
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as info:
|
||||||
|
for _ in dataset.create_dict_iterator():
|
||||||
|
pass
|
||||||
|
assert "String input " + input_strings[0] + " will be out of bounds if casted to float32" in str(info.value)
|
||||||
|
|
||||||
|
input_strings = [above_range[2]]
|
||||||
|
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||||
|
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_non_integral_types[2]))
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as info:
|
||||||
|
for _ in dataset.create_dict_iterator():
|
||||||
|
pass
|
||||||
|
assert "String input " + input_strings[0] + " will be out of bounds if casted to float64" in str(info.value)
|
||||||
|
|
||||||
|
below_range = [str(np.finfo(np.float16).min * 10), str(np.finfo(np.float32).min * 10), "-1.8e+308"]
|
||||||
|
|
||||||
|
input_strings = [below_range[0]]
|
||||||
|
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||||
|
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_non_integral_types[0]))
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as info:
|
||||||
|
for _ in dataset.create_dict_iterator():
|
||||||
|
pass
|
||||||
|
assert "outside of valid float16 range" in str(info.value)
|
||||||
|
|
||||||
|
input_strings = [below_range[1]]
|
||||||
|
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||||
|
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_non_integral_types[1]))
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as info:
|
||||||
|
for _ in dataset.create_dict_iterator():
|
||||||
|
pass
|
||||||
|
assert "String input " + input_strings[0] + " will be out of bounds if casted to float32" in str(info.value)
|
||||||
|
|
||||||
|
input_strings = [below_range[2]]
|
||||||
|
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||||
|
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_non_integral_types[2]))
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as info:
|
||||||
|
for _ in dataset.create_dict_iterator():
|
||||||
|
pass
|
||||||
|
assert "String input " + input_strings[0] + " will be out of bounds if casted to float64" in str(info.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_number_boundaries_integral():
|
||||||
|
for np_type, ms_type in zip(np_integral_types, ms_integral_types):
|
||||||
|
type_info = np.iinfo(np_type)
|
||||||
|
input_strings = [str(type_info.max)]
|
||||||
|
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||||
|
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_type))
|
||||||
|
for data in dataset.create_dict_iterator():
|
||||||
|
assert data["strings"] == int(input_strings[0])
|
||||||
|
|
||||||
|
input_strings = [str(type_info.min)]
|
||||||
|
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||||
|
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_type))
|
||||||
|
for data in dataset.create_dict_iterator():
|
||||||
|
assert data["strings"] == int(input_strings[0])
|
||||||
|
|
||||||
|
input_strings = [str(0)]
|
||||||
|
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||||
|
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_type))
|
||||||
|
for data in dataset.create_dict_iterator():
|
||||||
|
assert data["strings"] == int(input_strings[0])
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_number_invalid_input():
|
||||||
|
input_strings = ["a8fa9ds8fa"]
|
||||||
|
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||||
|
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(mstype.int32))
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as info:
|
||||||
|
for _ in dataset.create_dict_iterator():
|
||||||
|
pass
|
||||||
|
assert "It is invalid to convert " + input_strings[0] + " to a number" in str(info.value)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_to_number_typical_case_integral()
|
||||||
|
test_to_number_typical_case_non_integral()
|
||||||
|
test_to_number_boundaries_integral()
|
||||||
|
test_to_number_out_of_bounds_integral()
|
||||||
|
test_to_number_out_of_bounds_non_integral()
|
||||||
|
test_to_number_invalid_input()
|
Loading…
Reference in new issue