|
|
|
@ -93,7 +93,7 @@ double GetWeights(const Graph::NodeType &node) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Sort all the nodes by their weights
|
|
|
|
|
std::vector<size_t> SortByWeight(const std::shared_ptr<Graph> graph) {
|
|
|
|
|
std::vector<size_t> SortByWeight(const std::shared_ptr<Graph> &graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
|
|
|
|
|
std::vector<std::pair<double, size_t>> weight_to_node_index;
|
|
|
|
@ -124,7 +124,7 @@ std::vector<size_t> SortByWeight(const std::shared_ptr<Graph> graph) {
|
|
|
|
|
// Get optimal strategy to partition the target node
|
|
|
|
|
StrategyRec PartitionNode(const Graph::NodeType &node,
|
|
|
|
|
const std::vector<std::pair<std::string, StrategyRec>> &node_name_to_strategy,
|
|
|
|
|
std::shared_ptr<Graph> graph) {
|
|
|
|
|
const std::shared_ptr<Graph> &graph) {
|
|
|
|
|
bool enable_conv_chw_partition = false;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
|
|
|
|
@ -191,7 +191,8 @@ StrategyRec PartitionNode(const Graph::NodeType &node,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Parttion graph into all devices.
|
|
|
|
|
Status PartitionForAllDevices(const size_t num_device, const double device_memory, std::shared_ptr<Graph> graph) {
|
|
|
|
|
Status PartitionForAllDevices(const size_t num_device, const double device_memory,
|
|
|
|
|
const std::shared_ptr<Graph> &graph) {
|
|
|
|
|
if (num_device < 1) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "ERROR: Number of devices can't be " << num_device << ".";
|
|
|
|
|
}
|
|
|
|
@ -261,7 +262,7 @@ Graph::NodeType ApplyStrToTensor(Graph::NodeType Node) {
|
|
|
|
|
return Node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status DevicesMemoryControl(const size_t num_device, const double device_memory, std::shared_ptr<Graph> graph) {
|
|
|
|
|
Status DevicesMemoryControl(const size_t num_device, const double device_memory, const std::shared_ptr<Graph> &graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
if (num_device == 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failure: device number is 0.";
|
|
|
|
|