Add checks to RandomData and LookUp

pull/9752/head
Zirui Wu 5 years ago
parent bdf03fe00d
commit 91a6b2b0ca

@ -320,6 +320,12 @@ Status LookupOperation::ValidateParams() {
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
if (!data_type_.IsNumeric()) {
std::string err_msg = "Lookup does not support a string to string mapping, data_type can only be numeric.";
MS_LOG(ERROR) << err_msg;
RETURN_STATUS_SYNTAX_ERROR(err_msg);
}
return Status::OK();
}

@ -183,7 +183,7 @@ class ConfigManager {
// E.g. 0 would corresponds to a 1:1:1 ratio of num_worker among leaf batch and map.
// please refer to AutoWorkerPass for detail on what each option is.
// @return The experimental config used by AutoNumWorker, each 1 refers to a different setup configuration
uint8_t get_auto_worker_config_() { return auto_worker_config_; }
uint8_t get_auto_worker_config() { return auto_worker_config_; }
// setter function
// E.g. set the value of 0 would corresponds to a 1:1:1 ratio of num_worker among leaf batch and map.

@ -147,6 +147,10 @@ void RandomDataOp::GenerateSchema() {
// All DatasetOps operate by launching a thread (see ExecutionTree). This class functor will
// provide the master loop that drives the logic for performing the work.
Status RandomDataOp::operator()() {
CHECK_FAIL_RETURN_UNEXPECTED(total_rows_ >= num_workers_,
"RandomDataOp expects total_rows < num_workers. total_row=" +
std::to_string(total_rows_) + ", num_workers=" + std::to_string(num_workers_) + " .");
// First, compute how many buffers we'll need to satisfy the total row count.
// The only reason we do this is for the purpose of throttling worker count if needed.
int64_t buffers_needed = total_rows_ / rows_per_buffer_;

@ -52,6 +52,11 @@ Status RandomNode::ValidateParams() {
RETURN_IF_NOT_OK(ValidateDatasetColumnParam("RandomNode", "columns_list", columns_list_));
}
// allow total_rows == 0 for now because RandomOp would generate a random row when it gets a 0
CHECK_FAIL_RETURN_UNEXPECTED(total_rows_ == 0 || total_rows_ >= num_workers_,
"RandomNode needs total_rows < num_workers. total_rows=" + std::to_string(total_rows_) +
", num_workers=" + std::to_string(num_workers_) + ".");
return Status::OK();
}

@ -27,7 +27,7 @@ namespace dataset {
// this will become the RootNode:DatasetNode when it is turned on
Status AutoWorkerPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) {
uint8_t config = GlobalContext::config_manager()->get_auto_worker_config_();
uint8_t config = GlobalContext::config_manager()->get_auto_worker_config();
OpWeightPass pass(kOpWeightConfigs[config < kOpWeightConfigs.size() ? config : 0]);

@ -28,6 +28,7 @@ Status LookupOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<T
IO_CHECK(input, output);
RETURN_UNEXPECTED_IF_NULL(vocab_);
CHECK_FAIL_RETURN_UNEXPECTED(input->type() == DataType::DE_STRING, "None string tensor received.");
std::vector<WordIdType> word_ids;
word_ids.reserve(input->Size());
for (auto itr = input->begin<std::string_view>(); itr != input->end<std::string_view>(); itr++) {
@ -41,6 +42,8 @@ Status LookupOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<T
// type cast to user's requirements if what user wants isn't int32_t
if ((*output)->type() != type_) {
CHECK_FAIL_RETURN_UNEXPECTED(type_.IsNumeric(),
"Lookup doesn't support string to string lookup. data_type needs to be numeric");
std::shared_ptr<Tensor> cast_to;
RETURN_IF_NOT_OK(TypeCast(*output, &cast_to, type_));
*output = cast_to;

@ -500,6 +500,7 @@ if platform.system().lower() != 'windows':
NormalizeForm.NFKD: cde.NormalizeForm.DE_NORMALIZE_NFKD
}
class NormalizeUTF8(cde.NormalizeUTF8Op):
"""
Apply normalize operation on UTF-8 string tensor.

@ -100,7 +100,6 @@ TEST_F(MindDataTestPipeline, TestRandomDatasetBasicWithPipeline) {
ds1 = ds1->Concat({ds2});
EXPECT_NE(ds1, nullptr);
// Create an iterator over the result of the above dataset
// This will trigger the creation of the Execution Tree and launch it.
std::shared_ptr<Iterator> iter = ds1->CreateIterator();
@ -474,3 +473,10 @@ TEST_F(MindDataTestPipeline, TestRandomDatasetDuplicateColumnName) {
// Expect failure: duplicate column names
EXPECT_EQ(ds->CreateIterator(), nullptr);
}
TEST_F(MindDataTestPipeline, TestRandomDatasetFail) {
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestRandomDatasetFail.";
// this will fail because num_workers is greater than num_rows
std::shared_ptr<Dataset> ds = RandomData(3)->SetNumWorkers(5);
EXPECT_EQ(ds->CreateIterator(), nullptr);
}

@ -166,6 +166,7 @@ def test_lookup_cast_type():
assert test_config("unk") == np.dtype("int32")
# test exception, data_type isn't the correct type
assert "tldr is not of type (<class 'mindspore._c_expression.typing.Type'>,)" in test_config("unk", "tldr")
assert "Lookup doesn't support string to string lookup" in test_config("w1", mstype.string)
if __name__ == '__main__':

Loading…
Cancel
Save