!10023 address testing tickets on AutoNumWorker

From: @ziruiwu
Reviewed-by: 
Signed-off-by:
pull/10023/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 74b03da452

@ -444,7 +444,7 @@ Status SaveToDisk::TransformTensor(const unsigned char *src, const TensorShape &
#endif #endif
TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false), first_row_obtained_(false) { TreeGetters::TreeGetters() : dataset_size_(-1), init_flag_(false), first_row_obtained_(false) {
tree_adapter_ = std::make_unique<TreeAdapter>(); tree_adapter_ = std::make_unique<TreeAdapter>(TreeAdapter::UsageFlag::kDeGetter);
} }
Status TreeGetters::Init(std::shared_ptr<DatasetNode> d) { Status TreeGetters::Init(std::shared_ptr<DatasetNode> d) {
@ -570,7 +570,7 @@ Status DatasetSizeGetter::Init(std::shared_ptr<DatasetNode> d) {
return Status::OK(); return Status::OK();
} }
Status DatasetSizeGetter::DryRun(std::shared_ptr<DatasetNode> ir_node, int64_t *dataset_size) { Status DatasetSizeGetter::DryRun(std::shared_ptr<DatasetNode> ir_node, int64_t *dataset_size) {
std::shared_ptr<TreeAdapter> tree_adapter = std::make_shared<TreeAdapter>(); std::shared_ptr<TreeAdapter> tree_adapter = std::make_shared<TreeAdapter>(TreeAdapter::UsageFlag::kDeGetter);
tree_adapters_.push_back(tree_adapter); tree_adapters_.push_back(tree_adapter);
tree_adapter->SetPrePassOverride([](OptPass pre) { tree_adapter->SetPrePassOverride([](OptPass pre) {
pre.push_back( pre.push_back(

@ -228,7 +228,7 @@ class DatasetSizeGetter : public TreeConsumer, public std::enable_shared_from_th
private: private:
std::shared_ptr<DatasetNode> root_; std::shared_ptr<DatasetNode> root_;
std::vector<std::shared_ptr<TreeAdapter>> tree_adapters_; std::vector<std::shared_ptr<TreeAdapter>> tree_adapters_; // this is vector to handle different branch of zip
int64_t dataset_size_; int64_t dataset_size_;
}; };

@ -48,8 +48,8 @@ Status AutoWorkerPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *mod
for (const auto &p : pass.weight_profile_) max_weight = std::max(max_weight, p.second); for (const auto &p : pass.weight_profile_) max_weight = std::max(max_weight, p.second);
RETURN_IF_NOT_OK(pass.Run(root_ir, modified)); RETURN_IF_NOT_OK(pass.Run(root_ir, modified));
if (pass.parallel_ops_.size() > 3) { if (pass.parallel_ops_.size() > 3) {
MS_LOG(WARNING) << "AutoWorkerPass at current stage is only optimized for simple network that has LeafNode, " MS_LOG(WARNING) << "AutoNumWorker right now is only suitable for simple dataset pipelines that has at most, 1 leaf "
<< "BatchNode and MapNode. User discretion is advised for usage on other complex networks."; << "1 batch and 1 map. AutoNumWorker may not be optimal for usage on complex pipelines.";
} }
for (auto &p : pass.parallel_ops_) { for (auto &p : pass.parallel_ops_) {
@ -60,8 +60,11 @@ Status AutoWorkerPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *mod
int32_t cur_node_max = std::ceil(p.second * max_num_workers_ / max_weight); int32_t cur_node_max = std::ceil(p.second * max_num_workers_ / max_weight);
// this will ensure that num_workers will fall with the range of [1,cur_node_max] // this will ensure that num_workers will fall with the range of [1,cur_node_max]
int32_t cur_node_num_worker = std::max(std::min(num_workers, cur_node_max), min_num_workers_); int32_t cur_node_num_worker = std::max(std::min(num_workers, cur_node_max), min_num_workers_);
// if the num_worker to set is same as original, skip setting and printing the logs
if (cur_node_num_worker == p.first->num_workers()) continue;
// log the change via warning msg so user can see what the num_worker is being set for which op // log the change via warning msg so user can see what the num_worker is being set for which op
MS_LOG(WARNING) << "num_workers in " << p.first->Name() << " is auto-adjusted from " MS_LOG(WARNING) << "AutoNumWorker enabled, num_workers in " << p.first->Name() << " is auto-adjusted from "
<< std::to_string(p.first->num_workers()) + " to " + std::to_string(cur_node_num_worker); << std::to_string(p.first->num_workers()) + " to " + std::to_string(cur_node_num_worker);
p.first->SetNumWorkers(cur_node_num_worker); p.first->SetNumWorkers(cur_node_num_worker);
} }

@ -13,10 +13,8 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include <vector>
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/engine/perf/monitor.h" #include "minddata/dataset/engine/perf/monitor.h"
#include "minddata/dataset/core/config_manager.h"
#include "minddata/dataset/engine/execution_tree.h" #include "minddata/dataset/engine/execution_tree.h"
namespace mindspore { namespace mindspore {

@ -29,7 +29,6 @@ class ExecutionTree;
class Monitor { class Monitor {
public: public:
// Monitor object constructor // Monitor object constructor
explicit Monitor(ExecutionTree *tree); explicit Monitor(ExecutionTree *tree);
Monitor() = default; Monitor() = default;

@ -29,8 +29,7 @@
namespace mindspore { namespace mindspore {
namespace dataset { namespace dataset {
TreeAdapter::TreeAdapter() { TreeAdapter::TreeAdapter(UsageFlag usage) : usage_(usage), tree_state_(kCompileStateInit) {
tree_state_ = kCompileStateInit;
optimize_ = common::GetEnv("OPTIMIZE") == "true"; optimize_ = common::GetEnv("OPTIMIZE") == "true";
} }
@ -81,7 +80,8 @@ Status TreeAdapter::PostPass(std::shared_ptr<DatasetNode> ir) {
MS_LOG(INFO) << "Running post pass loops."; MS_LOG(INFO) << "Running post pass loops.";
// AutoWorkerPass should ideally precede CacheTransForm Pass to avoid complications of the setting // AutoWorkerPass should ideally precede CacheTransForm Pass to avoid complications of the setting
if (GlobalContext::config_manager()->auto_num_workers()) { if (GlobalContext::config_manager()->auto_num_workers() && usage_ == kDeIterator) {
// skip this for getter pass
actions.emplace_back(std::make_unique<AutoWorkerPass>()); actions.emplace_back(std::make_unique<AutoWorkerPass>());
} }

@ -33,7 +33,12 @@ class DatasetNode;
class TreeAdapter { class TreeAdapter {
public: public:
TreeAdapter(); // this flag is used to indicate the purpose of the creation of this tree adapter (type of the tree_consumer).
// Currently there are 3 types of consumer, Iterator, Getter and TDT/Vocab/Save ...
// To avoid premature optimization, the last type (TDT/Vocab/Save) is regarded as Iterator for now.
enum UsageFlag { kDeIterator = 0, kDeGetter = 1 };
explicit TreeAdapter(UsageFlag flag = kDeIterator);
~TreeAdapter() = default; ~TreeAdapter() = default;
@ -92,7 +97,7 @@ class TreeAdapter {
int32_t cur_connector_size_; // current connector size of root op, used for profiling int32_t cur_connector_size_; // current connector size of root op, used for profiling
int32_t cur_connector_capacity_; // current connector capacity of root op, used for profiling int32_t cur_connector_capacity_; // current connector capacity of root op, used for profiling
std::function<OptPass(OptPass)> pre_pass_override_; // function ptr that overrides pre pass, called in PrePrepare() std::function<OptPass(OptPass)> pre_pass_override_; // function ptr that overrides pre pass, called in PrePrepare()
UsageFlag usage_; // usage of this tree adapter (type of consumer)
// State flags for the lifecycle of the tree // State flags for the lifecycle of the tree
enum CompileState { enum CompileState {
kCompileStateInit = 0, // The freshly initialized state kCompileStateInit = 0, // The freshly initialized state

@ -30,6 +30,7 @@ UINT32_MAX = 4294967295
_config = cde.GlobalContext.config_manager() _config = cde.GlobalContext.config_manager()
def _init_device_info(): def _init_device_info():
""" """
INTERNAL USE ONLY! INTERNAL USE ONLY!
@ -52,6 +53,7 @@ def _init_device_info():
rank_id = cuda_id rank_id = cuda_id
_config.set_rank_id(rank_id) _config.set_rank_id(rank_id)
def set_seed(seed): def set_seed(seed):
""" """
Set the seed to be used in any random generator. This is used to produce deterministic results. Set the seed to be used in any random generator. This is used to produce deterministic results.
@ -149,6 +151,7 @@ def set_num_parallel_workers(num):
def get_num_parallel_workers(): def get_num_parallel_workers():
""" """
Get the default number of parallel workers. Get the default number of parallel workers.
This is the DEFAULT num_parallel_workers value used for each op, it is not related to AutoNumWorker feature.
Returns: Returns:
Int, number of parallel workers to be used as a default for each operation Int, number of parallel workers to be used as a default for each operation

Loading…
Cancel
Save