|
|
|
@ -48,8 +48,8 @@ Status AutoWorkerPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *mod
|
|
|
|
|
for (const auto &p : pass.weight_profile_) max_weight = std::max(max_weight, p.second);
|
|
|
|
|
RETURN_IF_NOT_OK(pass.Run(root_ir, modified));
|
|
|
|
|
if (pass.parallel_ops_.size() > 3) {
|
|
|
|
|
MS_LOG(WARNING) << "AutoWorkerPass at current stage is only optimized for simple network that has LeafNode, "
|
|
|
|
|
<< "BatchNode and MapNode. User discretion is advised for usage on other complex networks.";
|
|
|
|
|
MS_LOG(WARNING) << "AutoNumWorker right now is only suitable for simple dataset pipelines that has at most, 1 leaf "
|
|
|
|
|
<< "1 batch and 1 map. AutoNumWorker may not be optimal for usage on complex pipelines.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto &p : pass.parallel_ops_) {
|
|
|
|
@ -60,8 +60,11 @@ Status AutoWorkerPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *mod
|
|
|
|
|
int32_t cur_node_max = std::ceil(p.second * max_num_workers_ / max_weight);
|
|
|
|
|
// this will ensure that num_workers will fall with the range of [1,cur_node_max]
|
|
|
|
|
int32_t cur_node_num_worker = std::max(std::min(num_workers, cur_node_max), min_num_workers_);
|
|
|
|
|
|
|
|
|
|
// if the num_worker to set is same as original, skip setting and printing the logs
|
|
|
|
|
if (cur_node_num_worker == p.first->num_workers()) continue;
|
|
|
|
|
// log the change via warning msg so user can see what the num_worker is being set for which op
|
|
|
|
|
MS_LOG(WARNING) << "num_workers in " << p.first->Name() << " is auto-adjusted from "
|
|
|
|
|
MS_LOG(WARNING) << "AutoNumWorker enabled, num_workers in " << p.first->Name() << " is auto-adjusted from "
|
|
|
|
|
<< std::to_string(p.first->num_workers()) + " to " + std::to_string(cur_node_num_worker);
|
|
|
|
|
p.first->SetNumWorkers(cur_node_num_worker);
|
|
|
|
|
}
|
|
|
|
|