!7414 C++ api add config_manager
Merge pull request !7414 from xiaotianci/c_api_configpull/7414/MERGE
commit
947f6d96d1
@ -0,0 +1,109 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "minddata/dataset/core/config_manager.h"
|
||||
#include "minddata/dataset/core/global_context.h"
|
||||
#include "minddata/dataset/include/config.h"
|
||||
#include "minddata/dataset/util/status.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
// Config operations for setting and getting the configuration.
|
||||
namespace config {
|
||||
|
||||
std::shared_ptr<ConfigManager> _config = GlobalContext::config_manager();
|
||||
|
||||
// Function to set the seed to be used in any random generator
|
||||
bool set_seed(int32_t seed) {
|
||||
if (seed < 0 || seed > UINT32_MAX) {
|
||||
MS_LOG(ERROR) << "Seed given is not within the required range: " << seed;
|
||||
return false;
|
||||
}
|
||||
_config->set_seed((uint32_t)seed);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Function to get the seed
|
||||
uint32_t get_seed() { return _config->seed(); }
|
||||
|
||||
// Function to set the number of rows to be prefetched
|
||||
bool set_prefetch_size(int32_t prefetch_size) {
|
||||
if (prefetch_size <= 0 || prefetch_size > INT32_MAX) {
|
||||
MS_LOG(ERROR) << "Prefetch size given is not within the required range: " << prefetch_size;
|
||||
return false;
|
||||
}
|
||||
_config->set_op_connector_size(prefetch_size);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Function to get prefetch size in number of rows
|
||||
int32_t get_prefetch_size() { return _config->op_connector_size(); }
|
||||
|
||||
// Function to set the default number of parallel workers
|
||||
bool set_num_parallel_workers(int32_t num_parallel_workers) {
|
||||
if (num_parallel_workers <= 0 || num_parallel_workers > INT32_MAX) {
|
||||
MS_LOG(ERROR) << "Number of parallel workers given is not within the required range: " << num_parallel_workers;
|
||||
return false;
|
||||
}
|
||||
_config->set_num_parallel_workers(num_parallel_workers);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Function to get the default number of parallel workers
|
||||
int32_t get_num_parallel_workers() { return _config->num_parallel_workers(); }
|
||||
|
||||
// Function to set the default interval (in milliseconds) for monitor sampling
|
||||
bool set_monitor_sampling_interval(int32_t interval) {
|
||||
if (interval <= 0 || interval > UINT32_MAX) {
|
||||
MS_LOG(ERROR) << "Interval given is not within the required range: " << interval;
|
||||
return false;
|
||||
}
|
||||
_config->set_monitor_sampling_interval((uint32_t)interval);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Function to get the default interval of performance monitor sampling
|
||||
int32_t get_monitor_sampling_interval() { return _config->monitor_sampling_interval(); }
|
||||
|
||||
// Function to set the default timeout (in seconds) for DSWaitedCallback
|
||||
bool set_callback_timeback(int32_t timeout) {
|
||||
if (timeout <= 0 || timeout > UINT32_MAX) {
|
||||
MS_LOG(ERROR) << "Timeout given is not within the required range: " << timeout;
|
||||
return false;
|
||||
}
|
||||
_config->set_callback_timeout((uint32_t)timeout);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Function to get the default timeout for DSWaitedCallback
|
||||
int32_t get_callback_timeout() { return _config->callback_timeout(); }
|
||||
|
||||
// Function to load configurations from a file
|
||||
bool load(std::string file) {
|
||||
Status rc = _config->LoadFile(file);
|
||||
if (rc.IsError()) {
|
||||
MS_LOG(ERROR) << "Configuration file loads failed: " << file;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace config
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,83 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_CONFIG_H
|
||||
#define MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_CONFIG_H
|
||||
|
||||
#include <cstdint>
|
||||
#include <string>
|
||||
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
namespace api {
|
||||
|
||||
// Config operations for setting and getting the configuration.
|
||||
namespace config {
|
||||
|
||||
/// \brief Function to set the seed to be used in any random generator. This is used to produce deterministic results.
|
||||
/// \param[in] seed the default seed to use.
|
||||
bool set_seed(int32_t seed);
|
||||
|
||||
/// \brief Function to get the seed.
|
||||
/// \return the seed set in the configuration.
|
||||
uint32_t get_seed();
|
||||
|
||||
/// \brief Function to set the number of rows to be prefetched.
|
||||
/// \param[in] prefetch_size total number of rows to be prefetched.
|
||||
bool set_prefetch_size(int32_t prefetch_size);
|
||||
|
||||
/// \brief Function to get the prefetch size in number of rows.
|
||||
/// \return total number of rows to be prefetched.
|
||||
int32_t get_prefetch_size();
|
||||
|
||||
/// \brief Function to set the default number of parallel workers.
|
||||
/// \param[in] num_parallel_workers number of parallel workers to be used as a default for each operation.
|
||||
bool set_num_parallel_workers(int32_t num_parallel_workers);
|
||||
|
||||
/// \brief Function to get the default number of parallel workers.
|
||||
/// \return number of parallel workers to be used as a default for each operation.
|
||||
int32_t get_num_parallel_workers();
|
||||
|
||||
/// \brief Function to set the default interval (in milliseconds) for monitor sampling.
|
||||
/// \param[in] interval interval (in milliseconds) to be used for performance monitor sampling.
|
||||
bool set_monitor_sampling_interval(int32_t interval);
|
||||
|
||||
/// \brief Function to get the default interval of performance monitor sampling.
|
||||
/// \return interval (in milliseconds) for performance monitor sampling.
|
||||
int32_t get_monitor_sampling_interval();
|
||||
|
||||
/// \brief Function to set the default timeout (in seconds) for DSWaitedCallback. In case of a deadlock, the wait
|
||||
/// function will exit after the timeout period.
|
||||
/// \param[in] timeout timeout (in seconds) to be used to end the wait in DSWaitedCallback in case of a deadlock.
|
||||
bool set_callback_timeout(int32_t timeout);
|
||||
|
||||
/// \brief Function to get the default timeout for DSWaitedCallback. In case of a deadback, the wait function will exit
|
||||
/// after the timeout period.
|
||||
/// \return the duration in seconds.
|
||||
int32_t get_callback_timeout();
|
||||
|
||||
/// \brief Function to load configuration from a file.
|
||||
/// \param[in] file path of the configuration file to be loaded.
|
||||
bool load(std::string file);
|
||||
|
||||
} // namespace config
|
||||
} // namespace api
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_MINDDATA_DATASET_INCLUDE_CONFIG_H
|
@ -0,0 +1,230 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "common/common.h"
|
||||
#include "minddata/dataset/include/config.h"
|
||||
#include "minddata/dataset/include/datasets.h"
|
||||
|
||||
using namespace mindspore::dataset::api;
|
||||
using mindspore::dataset::ShuffleMode;
|
||||
using mindspore::dataset::Tensor;
|
||||
|
||||
class MindDataTestPipeline : public UT::DatasetOpTesting {
|
||||
protected:
|
||||
};
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestConfigSetting) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestConfigSetting.";
|
||||
// Test basic configuration setting
|
||||
|
||||
// Save original configuration values
|
||||
auto original_num_parallel_workers = config::get_num_parallel_workers();
|
||||
auto original_prefetch_size = config::get_prefetch_size();
|
||||
auto original_seed = config::get_seed();
|
||||
auto original_monitor_sampling_interval = config::get_monitor_sampling_interval();
|
||||
|
||||
// Load configuration from file
|
||||
std::string config_file_path = datasets_root_path_ + "/declient.cfg";
|
||||
auto load_status = config::load(config_file_path);
|
||||
EXPECT_EQ(load_status, true);
|
||||
|
||||
// Test configuration loaded
|
||||
EXPECT_EQ(config::get_num_parallel_workers(), 4);
|
||||
EXPECT_EQ(config::get_prefetch_size(), 16);
|
||||
EXPECT_EQ(config::get_seed(), 5489);
|
||||
EXPECT_EQ(config::get_monitor_sampling_interval(), 15);
|
||||
|
||||
// Set configuration
|
||||
auto status_set_num_parallel_workers = config::set_num_parallel_workers(2);
|
||||
auto status_set_prefetch_size = config::set_prefetch_size(4);
|
||||
auto status_set_seed = config::set_seed(5);
|
||||
auto status_set_monitor_sampling_interval = config::set_monitor_sampling_interval(45);
|
||||
EXPECT_EQ(status_set_num_parallel_workers, true);
|
||||
EXPECT_EQ(status_set_prefetch_size, true);
|
||||
EXPECT_EQ(status_set_seed, true);
|
||||
EXPECT_EQ(status_set_monitor_sampling_interval, true);
|
||||
|
||||
// Test configuration set
|
||||
EXPECT_EQ(config::get_num_parallel_workers(), 2);
|
||||
EXPECT_EQ(config::get_prefetch_size(), 4);
|
||||
EXPECT_EQ(config::get_seed(), 5);
|
||||
EXPECT_EQ(config::get_monitor_sampling_interval(), 45);
|
||||
|
||||
// Restore original configuration values
|
||||
config::set_num_parallel_workers(original_num_parallel_workers);
|
||||
config::set_prefetch_size(original_prefetch_size);
|
||||
config::set_seed(original_seed);
|
||||
config::set_monitor_sampling_interval(original_monitor_sampling_interval);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestConfigParamCheck) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestConfigParamCheck.";
|
||||
// Test configuration setting with wrong parameter
|
||||
|
||||
// Save original configuration values
|
||||
auto original_num_parallel_workers = config::get_num_parallel_workers();
|
||||
auto original_prefetch_size = config::get_prefetch_size();
|
||||
auto original_seed = config::get_seed();
|
||||
auto original_monitor_sampling_interval = config::get_monitor_sampling_interval();
|
||||
|
||||
// Load configuration from file with wrong path
|
||||
std::string config_file_path = datasets_root_path_ + "/not_exist.cfg";
|
||||
auto load_status = config::load(config_file_path);
|
||||
EXPECT_EQ(load_status, false);
|
||||
|
||||
// Set configuration with wrong parameter
|
||||
auto status_set_num_parallel_workers = config::set_num_parallel_workers(0);
|
||||
auto status_set_prefetch_size = config::set_prefetch_size(0);
|
||||
auto status_set_seed = config::set_seed(-1);
|
||||
auto status_set_monitor_sampling_interval = config::set_monitor_sampling_interval(0);
|
||||
EXPECT_EQ(status_set_num_parallel_workers, false);
|
||||
EXPECT_EQ(status_set_prefetch_size, false);
|
||||
EXPECT_EQ(status_set_seed, false);
|
||||
EXPECT_EQ(status_set_monitor_sampling_interval, false);
|
||||
|
||||
// Restore original configuration values
|
||||
config::set_num_parallel_workers(original_num_parallel_workers);
|
||||
config::set_prefetch_size(original_prefetch_size);
|
||||
config::set_seed(original_seed);
|
||||
config::set_monitor_sampling_interval(original_monitor_sampling_interval);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestShuffleWithSeed) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestShuffleWithSeed.";
|
||||
// Test deterministic shuffle with setting the seed
|
||||
|
||||
// Save and set the seed
|
||||
uint32_t original_seed = config::get_seed();
|
||||
uint32_t original_num_parallel_workers = config::get_num_parallel_workers();
|
||||
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
|
||||
config::set_seed(654);
|
||||
config::set_num_parallel_workers(1);
|
||||
|
||||
// Create a TextFile Dataset with single text file which has three samples
|
||||
std::string text_file = datasets_root_path_ + "/testTextFileDataset/1.txt";
|
||||
std::shared_ptr<Dataset> ds = TextFile({text_file}, 0, ShuffleMode::kFalse);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Shuffle the dataset with buffer_size=3
|
||||
ds = ds->Shuffle(3);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset.
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
EXPECT_NE(row.find("text"), row.end());
|
||||
|
||||
std::vector<std::string> expected_result = {"Good luck to everyone.", "Be happy every day.", "This is a text file."};
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto text = row["text"];
|
||||
std::string_view sv;
|
||||
text->GetItemAt(&sv, {0});
|
||||
std::string ss(sv);
|
||||
MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50);
|
||||
// Compare against expected result
|
||||
EXPECT_STREQ(ss.c_str(), expected_result[i].c_str());
|
||||
i++;
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
|
||||
// Expect 3 samples
|
||||
EXPECT_EQ(i, 3);
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
|
||||
// Restore configuration
|
||||
config::set_seed(original_seed);
|
||||
config::set_num_parallel_workers(original_num_parallel_workers);
|
||||
}
|
||||
|
||||
TEST_F(MindDataTestPipeline, TestCallShuffleTwice) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestCallShuffleTwice.";
|
||||
// Test shuffle and repeat with setting the seed.
|
||||
// The second copy will be different from the first one because results will be different when calling shuffle twice.
|
||||
|
||||
// Save and set the seed
|
||||
uint32_t original_seed = config::get_seed();
|
||||
uint32_t original_num_parallel_workers = config::get_num_parallel_workers();
|
||||
MS_LOG(DEBUG) << "ORIGINAL seed: " << original_seed << ", num_parallel_workers: " << original_num_parallel_workers;
|
||||
config::set_seed(654);
|
||||
config::set_num_parallel_workers(1);
|
||||
|
||||
// Create a TextFile Dataset with single text file which has three samples
|
||||
std::string text_file = datasets_root_path_ + "/testTextFileDataset/1.txt";
|
||||
std::shared_ptr<Dataset> ds = TextFile({text_file}, 0, ShuffleMode::kFalse);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Shuffle the dataset with buffer_size=3
|
||||
ds = ds->Shuffle(3);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Repeat the dataset twice
|
||||
ds = ds->Repeat(2);
|
||||
EXPECT_NE(ds, nullptr);
|
||||
|
||||
// Create an iterator over the result of the above dataset.
|
||||
// This will trigger the creation of the Execution Tree and launch it.
|
||||
std::shared_ptr<Iterator> iter = ds->CreateIterator();
|
||||
EXPECT_NE(iter, nullptr);
|
||||
|
||||
// Iterate the dataset and get each row
|
||||
std::unordered_map<std::string, std::shared_ptr<Tensor>> row;
|
||||
iter->GetNextRow(&row);
|
||||
EXPECT_NE(row.find("text"), row.end());
|
||||
|
||||
std::vector<std::string> first_copy;
|
||||
std::vector<std::string> second_copy;
|
||||
|
||||
uint64_t i = 0;
|
||||
while (row.size() != 0) {
|
||||
auto text = row["text"];
|
||||
std::string_view sv;
|
||||
text->GetItemAt(&sv, {0});
|
||||
std::string ss(sv);
|
||||
MS_LOG(INFO) << "Text length: " << ss.length() << ", Text: " << ss.substr(0, 50);
|
||||
// The first three samples are the first copy and the rest are the second
|
||||
if (i < 3) {
|
||||
first_copy.push_back(ss);
|
||||
} else {
|
||||
second_copy.push_back(ss);
|
||||
}
|
||||
i++;
|
||||
iter->GetNextRow(&row);
|
||||
}
|
||||
|
||||
// Expect 6 samples
|
||||
EXPECT_EQ(i, 6);
|
||||
|
||||
// Compare the two copies which are deterministic difference
|
||||
for (int j = 0; j < 3; j++) {
|
||||
EXPECT_STRNE(first_copy.at(j).c_str(), second_copy.at(j).c_str());
|
||||
}
|
||||
|
||||
// Manually terminate the pipeline
|
||||
iter->Stop();
|
||||
|
||||
// Restore configuration
|
||||
config::set_seed(original_seed);
|
||||
config::set_num_parallel_workers(original_num_parallel_workers);
|
||||
}
|
Loading…
Reference in new issue