commit
ec8f541325
@ -0,0 +1,241 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "dataset/text/kernels/to_number_op.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "dataset/core/data_type.h"
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/core/tensor_shape.h"
|
||||
#include "dataset/kernels/data/data_utils.h"
|
||||
#include "dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
ToNumberOp::ToNumberOp(const DataType &cast_to_type) : cast_to_type_(cast_to_type) {}
|
||||
|
||||
ToNumberOp::ToNumberOp(const std::string &cast_to_type) : cast_to_type_(DataType(cast_to_type)) {}
|
||||
|
||||
Status ToNumberOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "Input tenosrs should have type string.");
|
||||
|
||||
switch (cast_to_type_.value()) {
|
||||
case DataType::DE_INT8:
|
||||
RETURN_IF_NOT_OK(ToSignedIntegral<int8_t>(input, output));
|
||||
break;
|
||||
case DataType::DE_INT16:
|
||||
RETURN_IF_NOT_OK(ToSignedIntegral<int16_t>(input, output));
|
||||
break;
|
||||
case DataType::DE_INT32:
|
||||
RETURN_IF_NOT_OK(ToSignedIntegral<int32_t>(input, output));
|
||||
break;
|
||||
case DataType::DE_INT64:
|
||||
RETURN_IF_NOT_OK(ToSignedIntegral<int64_t>(input, output));
|
||||
break;
|
||||
case DataType::DE_UINT8:
|
||||
RETURN_IF_NOT_OK(ToUnsignedIntegral<uint8_t>(input, output));
|
||||
break;
|
||||
case DataType::DE_UINT16:
|
||||
RETURN_IF_NOT_OK(ToUnsignedIntegral<uint16_t>(input, output));
|
||||
break;
|
||||
case DataType::DE_UINT32:
|
||||
RETURN_IF_NOT_OK(ToUnsignedIntegral<uint32_t>(input, output));
|
||||
break;
|
||||
case DataType::DE_UINT64:
|
||||
RETURN_IF_NOT_OK(ToUnsignedIntegral<uint64_t>(input, output));
|
||||
break;
|
||||
case DataType::DE_FLOAT16:
|
||||
RETURN_IF_NOT_OK(this->ToFloat16(input, output));
|
||||
break;
|
||||
case DataType::DE_FLOAT32:
|
||||
RETURN_IF_NOT_OK(ToFloat(input, output));
|
||||
break;
|
||||
case DataType::DE_FLOAT64:
|
||||
RETURN_IF_NOT_OK(ToDouble(input, output));
|
||||
break;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void ToNumberOp::Print(std::ostream &out) const { out << "ToNumberOp: casting to " << '\n'; }
|
||||
|
||||
Status ToNumberOp::OutputShape(const std::vector<TensorShape> &input_shapes, std::vector<TensorShape> &output_shapes) {
|
||||
(void)std::copy(input_shapes.begin(), input_shapes.end(), std::back_inserter(output_shapes));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status ToNumberOp::ToSignedIntegral(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
std::vector<T> casted;
|
||||
|
||||
for (auto it = input->begin<std::string_view>(); it != input->end<std::string_view>(); ++it) {
|
||||
bool is_cast_out_of_range = false;
|
||||
int64_t result = 0;
|
||||
|
||||
try {
|
||||
result = std::stoll(std::string(*it));
|
||||
} catch (const std::out_of_range &) {
|
||||
is_cast_out_of_range = true;
|
||||
} catch (const std::invalid_argument &) {
|
||||
RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to a number.");
|
||||
}
|
||||
|
||||
if (result > std::numeric_limits<T>::max() || result < std::numeric_limits<T>::min() || is_cast_out_of_range) {
|
||||
std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " +
|
||||
cast_to_type_.ToString() + ". The valid range is: [" +
|
||||
std::to_string(std::numeric_limits<T>::min()) + ", " +
|
||||
std::to_string(std::numeric_limits<T>::max()) + "].";
|
||||
|
||||
RETURN_STATUS_UNEXPECTED(error_message);
|
||||
}
|
||||
|
||||
T casted_result = static_cast<T>(result);
|
||||
casted.push_back(casted_result);
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status ToNumberOp::ToUnsignedIntegral(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
std::vector<T> casted;
|
||||
|
||||
for (auto it = input->begin<std::string_view>(); it != input->end<std::string_view>(); ++it) {
|
||||
bool is_cast_out_of_range = false;
|
||||
uint64_t result = 0;
|
||||
|
||||
// If there is a - at the start of the string, it is considered by us to
|
||||
// be out of bounds. If the - is somewhere else in the string, it is
|
||||
// deemed invalid by std::stoull and will throw std::invalid_argument
|
||||
for (int i = 0; i < (*it).size(); i++) {
|
||||
if ((*it)[i] == '-') {
|
||||
is_cast_out_of_range = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
result = std::stoull(std::string(*it));
|
||||
} catch (const std::out_of_range &) {
|
||||
is_cast_out_of_range = true;
|
||||
} catch (const std::invalid_argument &) {
|
||||
RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to an unsigned integer.");
|
||||
}
|
||||
|
||||
if (result > std::numeric_limits<T>::max() || result < std::numeric_limits<T>::min() || is_cast_out_of_range) {
|
||||
std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " +
|
||||
cast_to_type_.ToString() + ". The valid range is: [" +
|
||||
std::to_string(std::numeric_limits<T>::min()) + ", " +
|
||||
std::to_string(std::numeric_limits<T>::max()) + "].";
|
||||
|
||||
RETURN_STATUS_UNEXPECTED(error_message);
|
||||
}
|
||||
|
||||
T casted_result = static_cast<T>(result);
|
||||
casted.push_back(casted_result);
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ToNumberOp::ToFloat16(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
// special case, float16 does not exist in c++, no native support for
|
||||
// casting, so cast to float first then use this method, which use Eigen.
|
||||
std::shared_ptr<Tensor> temp;
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(&temp, TensorImpl::kFlexible, input->shape(), DataType("float32")));
|
||||
RETURN_IF_NOT_OK(ToFloat(input, &temp));
|
||||
RETURN_IF_NOT_OK(mindspore::dataset::ToFloat16(temp, output));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ToNumberOp::ToFloat(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
std::vector<float> casted;
|
||||
|
||||
for (auto it = input->begin<std::string_view>(); it != input->end<std::string_view>(); ++it) {
|
||||
bool is_cast_out_of_range = false;
|
||||
float result = 0;
|
||||
|
||||
try {
|
||||
result = std::stof(std::string(*it));
|
||||
} catch (const std::out_of_range &) {
|
||||
is_cast_out_of_range = true;
|
||||
} catch (const std::invalid_argument &) {
|
||||
RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to an unsigned integer.");
|
||||
}
|
||||
|
||||
if (result > std::numeric_limits<float>::max() || result < std::numeric_limits<float>::lowest() ||
|
||||
is_cast_out_of_range) {
|
||||
std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " +
|
||||
cast_to_type_.ToString() + ". The valid range is: [" +
|
||||
std::to_string(std::numeric_limits<float>::lowest()) + ", " +
|
||||
std::to_string(std::numeric_limits<float>::max()) + "].";
|
||||
|
||||
RETURN_STATUS_UNEXPECTED(error_message);
|
||||
}
|
||||
|
||||
float casted_result = static_cast<float>(result);
|
||||
casted.push_back(casted_result);
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ToNumberOp::ToDouble(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) {
|
||||
std::vector<double> casted;
|
||||
|
||||
for (auto it = input->begin<std::string_view>(); it != input->end<std::string_view>(); ++it) {
|
||||
bool is_cast_out_of_range = false;
|
||||
double result = 0;
|
||||
|
||||
try {
|
||||
result = std::stod(std::string(*it));
|
||||
} catch (const std::out_of_range &) {
|
||||
is_cast_out_of_range = true;
|
||||
} catch (const std::invalid_argument &) {
|
||||
RETURN_STATUS_UNEXPECTED("It is invalid to convert " + std::string(*it) + " to an unsigned integer.");
|
||||
}
|
||||
|
||||
if (result > std::numeric_limits<double>::max() || result < std::numeric_limits<double>::lowest() ||
|
||||
is_cast_out_of_range) {
|
||||
std::string error_message = "String input " + std::string(*it) + " will be out of bounds if casted to " +
|
||||
cast_to_type_.ToString() + ". The valid range is: [" +
|
||||
std::to_string(std::numeric_limits<double>::lowest()) + ", " +
|
||||
std::to_string(std::numeric_limits<double>::max()) + "].";
|
||||
|
||||
RETURN_STATUS_UNEXPECTED(error_message);
|
||||
}
|
||||
|
||||
double casted_result = static_cast<double>(result);
|
||||
casted.push_back(casted_result);
|
||||
}
|
||||
|
||||
RETURN_IF_NOT_OK(Tensor::CreateTensor(output, casted, input->shape()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,79 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef DATASET_TEXT_KERNELS_TO_NUMBER_OP_H_
|
||||
#define DATASET_TEXT_KERNELS_TO_NUMBER_OP_H_
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "dataset/core/data_type.h"
|
||||
#include "dataset/core/tensor.h"
|
||||
#include "dataset/kernels/tensor_op.h"
|
||||
#include "dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class ToNumberOp : public TensorOp {
|
||||
public:
|
||||
// Constructor of ToNumberOp
|
||||
// @param const DataType &cast_to_type - the type to convert string inputs to.
|
||||
explicit ToNumberOp(const DataType &cast_to_type);
|
||||
|
||||
// Constructor of ToNumberOp
|
||||
// @param const std::string &cast_to_type - the type in string form to convert string inputs to.
|
||||
explicit ToNumberOp(const std::string &cast_to_type);
|
||||
|
||||
~ToNumberOp() override = default;
|
||||
|
||||
// Perform numeric conversion on each string in each tensor.
|
||||
// @param const std::shared_ptr<Tensor> &input
|
||||
// @param std::shared_ptr<Tensor> *output
|
||||
// @return error code
|
||||
Status Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output) override;
|
||||
|
||||
// For each input shape, find the output shape
|
||||
// @param std::vector<TensorShape> &inputs - shape of input tensors
|
||||
// @param std::vector<TensorShape> &outputs - shape of output tensors
|
||||
// @return error code
|
||||
Status OutputShape(const std::vector<TensorShape> &input_shapes, std::vector<TensorShape> &output_shapes) override;
|
||||
|
||||
// print arg for debugging
|
||||
// @param std::ostream &out
|
||||
void Print(std::ostream &out) const override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
Status ToSignedIntegral(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
|
||||
|
||||
template <typename T>
|
||||
Status ToUnsignedIntegral(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
|
||||
|
||||
Status ToFloat16(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
|
||||
|
||||
Status ToFloat(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
|
||||
|
||||
Status ToDouble(const std::shared_ptr<Tensor> &input, std::shared_ptr<Tensor> *output);
|
||||
|
||||
DataType cast_to_type_;
|
||||
};
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // DATASET_TEXT_KERNELS_TO_NUMBER_OP_H_
|
@ -0,0 +1,194 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.text as text
|
||||
|
||||
np_integral_types = [np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16,
|
||||
np.uint32, np.uint64]
|
||||
ms_integral_types = [mstype.int8, mstype.int16, mstype.int32, mstype.int64, mstype.uint8,
|
||||
mstype.uint16, mstype.uint32, mstype.uint64]
|
||||
|
||||
np_non_integral_types = [np.float16, np.float32, np.float64]
|
||||
ms_non_integral_types = [mstype.float16, mstype.float32, mstype.float64]
|
||||
|
||||
def string_dataset_generator(strings):
|
||||
for string in strings:
|
||||
yield (np.array(string, dtype='S'),)
|
||||
|
||||
|
||||
def test_to_number_typical_case_integral():
|
||||
input_strings = [["-121", "14"], ["-2219", "7623"], ["-8162536", "162371864"],
|
||||
["-1726483716", "98921728421"]]
|
||||
|
||||
for ms_type, inputs in zip(ms_integral_types, input_strings):
|
||||
dataset = ds.GeneratorDataset(string_dataset_generator(inputs), "strings")
|
||||
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_type))
|
||||
|
||||
expected_output = [int(string) for string in inputs]
|
||||
output = []
|
||||
for data in dataset.create_dict_iterator():
|
||||
output.append(data["strings"])
|
||||
|
||||
assert output == expected_output
|
||||
|
||||
|
||||
def test_to_number_typical_case_non_integral():
|
||||
input_strings = [["-1.1", "1.4"], ["-2219.321", "7623.453"], ["-816256.234282", "162371864.243243"]]
|
||||
epsilons = [0.001, 0.001, 0.0001, 0.0001, 0.0000001, 0.0000001]
|
||||
|
||||
for ms_type, inputs in zip(ms_non_integral_types, input_strings):
|
||||
dataset = ds.GeneratorDataset(string_dataset_generator(inputs), "strings")
|
||||
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_type))
|
||||
|
||||
expected_output = [float(string) for string in inputs]
|
||||
output = []
|
||||
for data in dataset.create_dict_iterator():
|
||||
output.append(data["strings"])
|
||||
|
||||
for expected, actual, epsilon in zip(expected_output, output, epsilons):
|
||||
assert abs(expected - actual) < epsilon
|
||||
|
||||
|
||||
def out_of_bounds_error_message_check(dataset, np_type, value_to_cast):
|
||||
type_info = np.iinfo(np_type)
|
||||
type_max = str(type_info.max)
|
||||
type_min = str(type_info.min)
|
||||
type_name = str(np.dtype(np_type))
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
for _ in dataset.create_dict_iterator():
|
||||
pass
|
||||
assert "String input " + value_to_cast + " will be out of bounds if casted to " + type_name in str(info.value)
|
||||
assert "valid range is: [" + type_min + ", " + type_max + "]" in str(info.value)
|
||||
|
||||
|
||||
def test_to_number_out_of_bounds_integral():
|
||||
for np_type, ms_type in zip(np_integral_types, ms_integral_types):
|
||||
type_info = np.iinfo(np_type)
|
||||
input_strings = [str(type_info.max + 10)]
|
||||
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_type))
|
||||
out_of_bounds_error_message_check(dataset, np_type, input_strings[0])
|
||||
|
||||
input_strings = [str(type_info.min - 10)]
|
||||
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_type))
|
||||
out_of_bounds_error_message_check(dataset, np_type, input_strings[0])
|
||||
|
||||
|
||||
def test_to_number_out_of_bounds_non_integral():
|
||||
above_range = [str(np.finfo(np.float16).max * 10), str(np.finfo(np.float32).max * 10), "1.8e+308"]
|
||||
|
||||
input_strings = [above_range[0]]
|
||||
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_non_integral_types[0]))
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
for _ in dataset.create_dict_iterator():
|
||||
pass
|
||||
assert "outside of valid float16 range" in str(info.value)
|
||||
|
||||
input_strings = [above_range[1]]
|
||||
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_non_integral_types[1]))
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
for _ in dataset.create_dict_iterator():
|
||||
pass
|
||||
assert "String input " + input_strings[0] + " will be out of bounds if casted to float32" in str(info.value)
|
||||
|
||||
input_strings = [above_range[2]]
|
||||
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_non_integral_types[2]))
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
for _ in dataset.create_dict_iterator():
|
||||
pass
|
||||
assert "String input " + input_strings[0] + " will be out of bounds if casted to float64" in str(info.value)
|
||||
|
||||
below_range = [str(np.finfo(np.float16).min * 10), str(np.finfo(np.float32).min * 10), "-1.8e+308"]
|
||||
|
||||
input_strings = [below_range[0]]
|
||||
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_non_integral_types[0]))
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
for _ in dataset.create_dict_iterator():
|
||||
pass
|
||||
assert "outside of valid float16 range" in str(info.value)
|
||||
|
||||
input_strings = [below_range[1]]
|
||||
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_non_integral_types[1]))
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
for _ in dataset.create_dict_iterator():
|
||||
pass
|
||||
assert "String input " + input_strings[0] + " will be out of bounds if casted to float32" in str(info.value)
|
||||
|
||||
input_strings = [below_range[2]]
|
||||
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_non_integral_types[2]))
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
for _ in dataset.create_dict_iterator():
|
||||
pass
|
||||
assert "String input " + input_strings[0] + " will be out of bounds if casted to float64" in str(info.value)
|
||||
|
||||
|
||||
def test_to_number_boundaries_integral():
|
||||
for np_type, ms_type in zip(np_integral_types, ms_integral_types):
|
||||
type_info = np.iinfo(np_type)
|
||||
input_strings = [str(type_info.max)]
|
||||
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_type))
|
||||
for data in dataset.create_dict_iterator():
|
||||
assert data["strings"] == int(input_strings[0])
|
||||
|
||||
input_strings = [str(type_info.min)]
|
||||
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_type))
|
||||
for data in dataset.create_dict_iterator():
|
||||
assert data["strings"] == int(input_strings[0])
|
||||
|
||||
input_strings = [str(0)]
|
||||
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(ms_type))
|
||||
for data in dataset.create_dict_iterator():
|
||||
assert data["strings"] == int(input_strings[0])
|
||||
|
||||
|
||||
def test_to_number_invalid_input():
|
||||
input_strings = ["a8fa9ds8fa"]
|
||||
dataset = ds.GeneratorDataset(string_dataset_generator(input_strings), "strings")
|
||||
dataset = dataset.map(input_columns=["strings"], operations=text.ToNumber(mstype.int32))
|
||||
|
||||
with pytest.raises(RuntimeError) as info:
|
||||
for _ in dataset.create_dict_iterator():
|
||||
pass
|
||||
assert "It is invalid to convert " + input_strings[0] + " to a number" in str(info.value)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_to_number_typical_case_integral()
|
||||
test_to_number_typical_case_non_integral()
|
||||
test_to_number_boundaries_integral()
|
||||
test_to_number_out_of_bounds_integral()
|
||||
test_to_number_out_of_bounds_non_integral()
|
||||
test_to_number_invalid_input()
|
Loading…
Reference in new issue