commit
317a97e6b9
@ -0,0 +1,122 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/batch_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
#include "minddata/dataset/engine/ir/datasetops/map_node.h"
|
||||
#include "minddata/dataset/engine/opt/post/auto_worker_pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
// this will become the RootNode:DatasetNode when it is turned on
|
||||
Status AutoWorkerPass::RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *modified) {
|
||||
uint8_t config = GlobalContext::config_manager()->get_auto_worker_config_();
|
||||
|
||||
OpWeightPass pass(kOpWeightConfigs[config < kOpWeightConfigs.size() ? config : 0]);
|
||||
|
||||
std::string weight_str;
|
||||
for (const auto &p : pass.weight_profile_) weight_str += ("(" + p.first + "=" + std::to_string(p.second) + ")");
|
||||
int32_t num_shards = GlobalContext::config_manager()->get_num_shards_for_auto_num_workers();
|
||||
num_shards = std::min(std::max(1, num_shards), thread_cnt_);
|
||||
|
||||
MS_LOG(INFO) << "AutoWorkerPass is enabled; this could override existing num_workers set in each parallel op."
|
||||
<< "total number of threads on this CPU: " << thread_cnt_ << ", "
|
||||
<< "min num_workers to override:" << min_num_workers_ << ", "
|
||||
<< "max num_workers to override:" << max_num_workers_ << ", "
|
||||
<< "adjusted num_shards (between 1 and total thread cnt): " << num_shards
|
||||
<< ", weight profile:" << weight_str << ".";
|
||||
|
||||
// get the maximum weight of all the ops, this value is used to ensure the ratio of num_workers between ops
|
||||
float max_weight = 0;
|
||||
for (const auto &p : pass.weight_profile_) max_weight = std::max(max_weight, p.second);
|
||||
RETURN_IF_NOT_OK(pass.Run(root_ir, modified));
|
||||
if (pass.parallel_ops_.size() > 3) {
|
||||
MS_LOG(WARNING) << "AutoWorkerPass at current stage is only optimized for simple network that has LeafNode, "
|
||||
<< "BatchNode and MapNode. User discretion is advised for usage on other complex networks.";
|
||||
}
|
||||
|
||||
for (auto &p : pass.parallel_ops_) {
|
||||
// get the num worker via the weight ratio
|
||||
int32_t num_workers = std::ceil((thread_cnt_ * p.second) / (pass.weight_sum_ * num_shards));
|
||||
// this is to ensure when thread_cnt_ is very large let's say 192, the num_worker ratio is still kept
|
||||
// e.g. the optional 2:1 ratio between minddataset and batch
|
||||
int32_t cur_node_max = std::ceil(p.second * max_num_workers_ / max_weight);
|
||||
// this will ensure that num_workers will fall with the range of [1,cur_node_max]
|
||||
int32_t cur_node_num_worker = std::max(std::min(num_workers, cur_node_max), min_num_workers_);
|
||||
// log the change via warning msg so user can see what the num_worker is being set for which op
|
||||
MS_LOG(WARNING) << "num_workers in " << p.first->Name() << " is auto-adjusted from "
|
||||
<< std::to_string(p.first->num_workers()) + " to " + std::to_string(cur_node_num_worker);
|
||||
p.first->SetNumWorkers(cur_node_num_worker);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AutoWorkerPass::OpWeightPass::Visit(std::shared_ptr<MapNode> node, bool *modified) {
|
||||
auto itr = weight_profile_.find(node->Name());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(itr != weight_profile_.end(), node->Name() + "'s weight doesn't exist.");
|
||||
int32_t weight = itr->second;
|
||||
weight_sum_ += weight;
|
||||
parallel_ops_.emplace_back(std::make_pair(std::static_pointer_cast<DatasetNode>(node), weight));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AutoWorkerPass::OpWeightPass::Visit(std::shared_ptr<BatchNode> node, bool *modified) {
|
||||
auto itr = weight_profile_.find(node->Name());
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(itr != weight_profile_.end(), node->Name() + "'s weight doesn't exist.");
|
||||
int32_t weight = itr->second;
|
||||
weight_sum_ += weight;
|
||||
parallel_ops_.emplace_back(std::make_pair(std::static_pointer_cast<DatasetNode>(node), weight));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AutoWorkerPass::OpWeightPass::Visit(std::shared_ptr<MappableSourceNode> node, bool *modified) {
|
||||
RETURN_OK_IF_TRUE(node->Name() == kGeneratorNode); // generator is pipeline op, skip this
|
||||
auto itr = weight_profile_.find("MappableSource");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(itr != weight_profile_.end(),
|
||||
"LeafSourceNode::" + node->Name() + "'s weight doesn't exist.");
|
||||
int32_t weight = itr->second;
|
||||
weight_sum_ += weight;
|
||||
parallel_ops_.emplace_back(std::make_pair(std::static_pointer_cast<DatasetNode>(node), weight));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AutoWorkerPass::OpWeightPass::Visit(std::shared_ptr<NonMappableSourceNode> node, bool *modified) {
|
||||
auto itr = weight_profile_.find("NonMappableSourceNode");
|
||||
CHECK_FAIL_RETURN_UNEXPECTED(itr != weight_profile_.end(),
|
||||
"NonLeafSource::" + node->Name() + "'s weight doesn't exist.");
|
||||
int32_t weight = itr->second;
|
||||
weight_sum_ += weight;
|
||||
parallel_ops_.emplace_back(std::make_pair(std::static_pointer_cast<DatasetNode>(node), weight));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AutoWorkerPass::OpWeightPass::Visit(std::shared_ptr<DatasetNode> node, bool *modified) {
|
||||
weight_sum_ += GetNodeWeightFromProfile(node);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
float AutoWorkerPass::OpWeightPass::GetNodeWeightFromProfile(std::shared_ptr<DatasetNode> node) {
|
||||
auto itr = weight_profile_.find(node->Name());
|
||||
// returns 0 if name doesn't exist in the weight profile
|
||||
return itr == weight_profile_.end() ? 0 : itr->second;
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
@ -0,0 +1,82 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef DATASET_ENGINE_OPT_POST_AUTO_WORKER_PASS_H_
|
||||
#define DATASET_ENGINE_OPT_POST_AUTO_WORKER_PASS_H_
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
|
||||
#include "minddata/dataset/engine/opt/pass.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace dataset {
|
||||
|
||||
class AutoWorkerPass : public IRTreePass {
|
||||
public:
|
||||
// this map will contain weight for the basic pipeline ops. Pipeline op takes up 1 thread but doesn't have workers
|
||||
const std::vector<std::map<std::string, float>> kOpWeightConfigs = {
|
||||
{{"MappableSource", 8}, {"NonMappableSource", 8}, {kBatchNode, 8}, {kMapNode, 8}}, // config1 leaf:batch:map=1:1:1
|
||||
{{"MappableSource", 8}, {"NonMappableSource", 8}, {kBatchNode, 4}, {kMapNode, 4}}, // config2 leaf:batch:map=2:1:1
|
||||
{{"MappableSource", 4}, {"NonMappableSource", 4}, {kBatchNode, 8}, {kMapNode, 4}}, // config3 leaf:batch:map=1:2:1
|
||||
{{"MappableSource", 4}, {"NonMappableSource", 4}, {kBatchNode, 4}, {kMapNode, 8}}, // config4 leaf:batch:map=1:1:2
|
||||
{{"MappableSource", 8}, {"NonMappableSource", 8}, {kBatchNode, 8}, {kMapNode, 4}}, // config5 leaf:batch:map=2:2:1
|
||||
{{"MappableSource", 8}, {"NonMappableSource", 8}, {kBatchNode, 4}, {kMapNode, 8}}, // config6 leaf:batch:map=2:1:2
|
||||
{{"MappableSource", 4}, {"NonMappableSource", 4}, {kBatchNode, 8}, {kMapNode, 8}}, // config7 leaf:batch:map=1:2:2
|
||||
};
|
||||
AutoWorkerPass()
|
||||
: min_num_workers_(1),
|
||||
max_num_workers_(8),
|
||||
thread_cnt_(GlobalContext::Instance()->config_manager()->num_cpu_threads()) {}
|
||||
|
||||
Status RunOnTree(std::shared_ptr<DatasetNode> root_ir, bool *) override;
|
||||
|
||||
private:
|
||||
class OpWeightPass : public IRNodePass {
|
||||
public:
|
||||
explicit OpWeightPass(const std::map<std::string, float> &weight_profile)
|
||||
: IRNodePass(), weight_sum_(0), weight_profile_(weight_profile) {}
|
||||
// this is the base class function which contains the logic to handle most of the pipeline ops
|
||||
// pipeline ops although can't config num_workers it still runs 1 thread they need to be factored into weight
|
||||
Status Visit(std::shared_ptr<DatasetNode> node, bool *modified) override;
|
||||
// these functions calculate the weights of more complex Nodes which may depend on its input arg. these functions
|
||||
// will also push these nodes to a vector whose num_workers will be set int the Tree Pass
|
||||
Status Visit(std::shared_ptr<BatchNode> node, bool *modified) override;
|
||||
Status Visit(std::shared_ptr<MapNode> node, bool *modified) override;
|
||||
Status Visit(std::shared_ptr<MappableSourceNode> node, bool *modified) override;
|
||||
Status Visit(std::shared_ptr<NonMappableSourceNode> node, bool *modified) override;
|
||||
|
||||
// helper function to look up weight according to the name of this Op.
|
||||
float GetNodeWeightFromProfile(std::shared_ptr<DatasetNode> node);
|
||||
|
||||
int32_t weight_sum_; // sum of all weights in the pipeline
|
||||
const std::map<std::string, float> weight_profile_; // key: name of ir node, val: weight of this node
|
||||
std::vector<std::pair<std::shared_ptr<DatasetNode>, float>> parallel_ops_; // first: node second: weight
|
||||
};
|
||||
|
||||
const int32_t min_num_workers_; // minimum number of threads allowed for each op
|
||||
const int32_t max_num_workers_; // maximum number of threads allowed for each op
|
||||
const int32_t thread_cnt_; // thread cnt of current CPU, obtained through config manager
|
||||
};
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // DATASET_ENGINE_OPT_POST_AUTO_WORKER_PASS_H_
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue