add auto parallel pipeline

pull/8406/head
lichenever 4 years ago
parent a321f402c8
commit 2e1c43483e

@ -49,10 +49,12 @@ DFunctor::DFunctor(const FuncGraphPtr &primal_graph, const pipeline::ResourceBas
}
// To keep switch_layer's inputs from being inlined
k_graph_->set_switch_layer_input(primal_graph->switch_layer_input());
k_graph_->set_stage(primal_graph->stage());
TraceManager::EndTrace();
TraceManager::DebugTrace(std::make_shared<TraceGradBprop>(primal_graph->debug_info()));
tape_ = std::make_shared<FuncGraph>();
tape_->set_stage(primal_graph->stage());
// Add "_Grad" postfix
if (primal_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
std::string grad_op_name = GetValue<std::string>(primal_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) + "_Grad";

@ -41,7 +41,7 @@ class ReplaceApplicator : public AnfVisitor {
}
auto fg = GetValueNode<FuncGraphPtr>(node);
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub() || *(fg->switch_layer_input())) {
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stage() != -1 || fg->stub() || *(fg->switch_layer_input())) {
return nullptr;
}
@ -124,7 +124,7 @@ class InlinerBase : public AnfVisitor {
// G
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stub()) {
if (fg->has_flag(FUNC_GRAPH_FLAG_DEFER_INLINE) || fg->stage() != -1 || fg->stub()) {
return nullptr;
}

@ -122,7 +122,7 @@ class ParallelContext {
std::string parallel_mode_;
std::string strategy_search_mode_;
std::vector<int64_t> stages_;
int32_t pipeline_stage_split_num_;
int64_t pipeline_stage_split_num_ = 0;
bool parameter_broadcast_;
bool device_num_is_set_;
bool global_rank_is_set_;

@ -0,0 +1,72 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_PIPELINE_TRANSFORMER_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_PIPELINE_TRANSFORMER_H_
#include <utility>
#include "ir/value.h"
#include "ir/graph_utils.h"
#include "base/base.h"
namespace mindspore {
namespace parallel {
typedef struct {
ValueListPtr shape;
TypePtr type;
AnfNodePtr depend;
} SendAttr;
class PipelineTransformer {
public:
PipelineTransformer(const FuncGraphManagerPtr &manager, const int &stage, const FuncGraphPtr &root,
const int64_t &global_rank, const int64_t &per_stage_rank_num)
: manager_(manager),
stage_(stage),
root_(root),
global_rank_(global_rank),
per_stage_rank_num_(per_stage_rank_num) {}
void Coloring();
void BroadCastColoring();
void HandleSharedParameter();
void CutGraph();
void ParameterColoring();
void CoverSensShape();
void ElimGraphStage();
void ElimParameter();
private:
void DoBroadCast(const FuncGraphPtr &func);
SendAttr InsertSend(const FuncGraphPtr &graph, const AnfNodePtr &parameter, const int &user_node_stage,
const int &node_stage);
void InsertReceive(const FuncGraphPtr &graph, const AnfNodePtr &node, const AnfNodePtr &use_node, const int &index,
const int &user_node_stage, const int &node_stage);
void CutBorder(const FuncGraphPtr &graph);
void ElimRootParameter();
bool IsStageNode(const CNodePtr &node);
std::pair<CNodePtr, FuncGraphPtr> FindSensNode();
FuncGraphManagerPtr manager_;
int64_t stage_;
FuncGraphPtr root_;
int64_t global_rank_;
int64_t per_stage_rank_num_;
TypePtr type_ptr_;
ValueListPtr shape_;
};
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_PIPELINE_TRANSFORMER_PIPELINE_TRANSFORMER_H_

@ -5,6 +5,7 @@ file(GLOB_RECURSE _PIPELINE_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"action.cc"
"validator.cc"
"remove_value_node_dup.cc"
"pipeline_split.cc"
"parse/*.cc"
"static_analysis/*.cc"
)

@ -302,6 +302,10 @@ bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes)
}
bool OptInlineAction(const ResourcePtr &res) {
if (parallel::ParallelContext::GetInstance()->parallel_mode() == "semi_auto_parallel" ||
parallel::ParallelContext::GetInstance()->parallel_mode() == "auto_parallel") {
return OptimizeAction(res, kInlinePasses);
}
if (opt::python_pass::PyPassManager::GetInstance()->GetPassGroup(opt::python_pass::Phase::PREAD)->size() != 0) {
return OptimizeAction(res, kInlinePasses);
}
@ -480,6 +484,7 @@ bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) {
return true;
}
bool PipelineSplitAction(const ResourcePtr &res) { return PipelineSplitPass(res); }
bool ValidateAction(const ResourcePtr &res) { return ValidatePass(res); }
bool ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) {
@ -559,6 +564,8 @@ static std::vector<ActionItem> CommonPipeline() {
actions.emplace_back(std::make_pair("inline", OptInlineAction));
// Add pre-ad, post-inline python pass stub
actions.emplace_back(std::make_pair("py_pre_ad", PreAdActionPyStub));
// Do PipelineSplit
actions.emplace_back(std::make_pair("pipeline_split", PipelineSplitAction));
return actions;
}

@ -246,6 +246,10 @@ bool ConvertCellObjToFuncGraph(const CellPtr &cell, ValuePtr *const data) {
func_graph->set_flag(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
}
}
if (py::hasattr(obj, STAGE_NAME)) {
auto stage = py::cast<int>(py::getattr(obj, STAGE_NAME));
func_graph->set_stage(stage);
}
*data = func_graph;
return true;
}

@ -132,6 +132,7 @@ const char PYTHON_EXTERN_MINDSPORE_FLAG[] = "_mindspore_flags";
// define the parse constant
const int64_t MAX_COMPARISON_OPS_SUPPORTED = 1;
const char CUSTOM_BPROP_NAME[] = "bprop";
const char STAGE_NAME[] = "stage";
// define the Namespace name
const char RESOLVE_NAMESPACE_NAME_AST[] = "Ast"; // for ast type namespace

@ -38,6 +38,7 @@
#include "frontend/parallel/step_auto_parallel.h"
#include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h"
#include "utils/log_adapter.h"
#include "pipeline/jit/pipeline_split.h"
namespace mindspore {
namespace pipeline {
@ -418,6 +419,8 @@ bool TransformTopGraphPass(const ResourcePtr &res) {
return true;
}
bool PipelineSplitPass(const ResourcePtr &res) { return PipelineSplit(res); }
bool ValidatePass(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res->func_graph());
FuncGraphPtr func_graph = res->func_graph();

@ -33,6 +33,7 @@ extern std::vector<PassItem> kInlinePasses;
extern std::vector<PassItem> kPynativePasses;
bool CconvPass(const ResourcePtr &res);
bool PipelineSplitPass(const ResourcePtr &res);
bool ValidatePass(const ResourcePtr &res);
bool ConvertPrepareAdapt(const ResourcePtr &res);
bool AddControlDependPass(const ResourcePtr &res);

@ -0,0 +1,99 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include <memory>
#include "pipeline/jit/pipeline_split.h"
#include "utils/ms_context.h"
#include "utils/comm_manager.h"
#include "frontend/parallel/context.h"
#include "frontend/parallel/pipeline_transformer/pipeline_transformer.h"
namespace mindspore {
namespace pipeline {
static int64_t GetRank();
static int64_t InferStage(const int64_t &rank_id, const int64_t &stage_num, const int64_t &device_num);
static int64_t GetRank() {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
std::string world_group;
std::string backend = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (backend == kAscendDevice) {
world_group = parallel::HCCL_WORLD_GROUP;
} else if (backend == kGPUDevice) {
world_group = parallel::NCCL_WORLD_GROUP;
} else {
MS_LOG(EXCEPTION) << "Invalid backend: " << backend;
}
int64_t global_rank = parallel::ParallelContext::GetInstance()->global_rank();
uint32_t rank_id;
if (!parallel::ParallelContext::GetInstance()->global_rank_is_set()) {
if (!CommManager::GetInstance().GetRankID(world_group, &rank_id)) {
MS_LOG(EXCEPTION) << "Get rank id failed.";
}
global_rank = UintToInt(rank_id);
}
return global_rank;
}
static int64_t InferStage(const int64_t &rank_id, const int64_t &stage_num, const int64_t &device_num) {
if (device_num % stage_num != 0) {
MS_LOG(EXCEPTION) << "Device_num must be divisible by the stage_num, got device_num: " << device_num
<< "stage_num: " << stage_num;
}
auto per_stage_rank_num = device_num / stage_num;
return rank_id / per_stage_rank_num;
}
// Only auto_parallel and semi_auto_parallel support PipelineSplit
bool PipelineSplit(const ResourcePtr &res) {
auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode();
if (parallel_mode != parallel::SEMI_AUTO_PARALLEL || parallel_mode != parallel::AUTO_PARALLEL) {
MS_LOG(INFO) << "Only auto_parallel and semi_auto_parallel support pipeline split.";
return true;
}
auto stage_num = parallel::ParallelContext::GetInstance()->pipeline_stage_split_num();
if (stage_num <= 1) {
MS_LOG(INFO) << "stage num is: " << stage_num << ". No need Pipeline split.";
return true;
}
auto manager = res->manager();
auto root = res->func_graph();
auto global_rank = GetRank();
auto device_num = parallel::ParallelContext::GetInstance()->device_num();
auto stage = InferStage(global_rank, stage_num, device_num);
auto per_stage_rank_num = device_num / stage_num;
auto transformer =
std::make_shared<parallel::PipelineTransformer>(manager, stage, root, global_rank, per_stage_rank_num);
// step1: Do color graph
transformer->Coloring();
// step2: Do color broadcast
transformer->BroadCastColoring();
// step3: Handle shared parameters
transformer->ParameterColoring();
transformer->HandleSharedParameter();
// step4: Cut Graph
transformer->CutGraph();
// step5: Handle Sens
transformer->CoverSensShape();
// step6: Elim Graph stages and no used parameter
transformer->ElimGraphStage();
transformer->ElimParameter();
return true;
}
} // namespace pipeline
} // namespace mindspore

@ -0,0 +1,28 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PIPELINE_JIT_PIPELINE_SPLIT_H_
#define MINDSPORE_CCSRC_PIPELINE_JIT_PIPELINE_SPLIT_H_
#include "pipeline/jit/resource.h"
namespace mindspore {
namespace pipeline {
bool PipelineSplit(const ResourcePtr &res);
} // namespace pipeline
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PIPELINE_JIT_PIPELINE_SPLIT_H_

@ -98,7 +98,8 @@ class AnfNode : public Base {
debug_info_(std::make_shared<NodeDebugInfo>()),
fullname_with_scope_(""),
hash_(std::hash<const AnfNode *>()),
kernel_info_(nullptr) {
kernel_info_(nullptr),
stage_(-1) {
scope_ = ScopeManager::GetInstance().GetCurrentScope();
}
@ -184,6 +185,9 @@ class AnfNode : public Base {
return user_data_.has(T::key);
}
int64_t stage() { return stage_; }
void set_stage(const int &stage) { stage_ = stage; }
protected:
// Hold a weak ref to Graph as Graph also hold ref to AnfNode.
// Otherwise, func_graph_ and AnfNode will make a reference cycle.
@ -198,6 +202,7 @@ class AnfNode : public Base {
ScopePtr scope_;
KernelInfoDevicePtr kernel_info_;
UserData user_data_;
int64_t stage_;
};
// CNode represents the complex node with a set of arguments.

@ -46,7 +46,8 @@ FuncGraph::FuncGraph()
is_generated_(false),
return_(nullptr),
manager_(std::weak_ptr<FuncGraphManager>()),
stub_(false) {
stub_(false),
stage_(-1) {
debug_info_ = std::make_shared<GraphDebugInfo>();
switch_layer_input_ = std::make_shared<bool>(false);
}

@ -355,6 +355,8 @@ class FuncGraph : public FuncGraphBase {
std::shared_ptr<bool> switch_layer_input() const { return switch_layer_input_; }
void set_switch_layer_input(std::shared_ptr<bool> switch_layer_input) { switch_layer_input_ = switch_layer_input; }
bool ContainMultiTarget() const;
int64_t stage() { return stage_; }
void set_stage(int64_t stage) { stage_ = stage; }
private:
// graph is manipulated by manager and others
@ -419,6 +421,7 @@ class FuncGraph : public FuncGraphBase {
// Design switch_layer_input as a ptr to
// share between derived backpropagator and cloned graphs
std::shared_ptr<bool> switch_layer_input_;
int64_t stage_;
std::unordered_map<AbstractBasePtrList, FuncGraphPtr, abstract::AbstractBasePtrListHasher,
abstract::AbstractBasePtrListEqual>
func_graph_cache_;

@ -186,6 +186,7 @@ void Cloner::CloneFuncGraphValueNodes(const FuncGraphPtr &func_graph, const Func
MS_EXCEPTION_IF_NULL(target_func_graph);
MS_EXCEPTION_IF_NULL(manager_);
target_func_graph->set_stage(func_graph->stage());
auto old_return = func_graph->get_return();
if (old_return != nullptr) {
auto return_node = repl_node_[old_return]->cast<CNodePtr>();
@ -668,6 +669,7 @@ FuncGraphPtr TransformableClone(const FuncGraphPtr &func_graph, const TraceInfoP
if (func_graph->has_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL)) {
new_func_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, func_graph->get_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL));
}
new_func_graph->set_stage(func_graph->stage());
return new_func_graph;
}

@ -20,7 +20,7 @@ from .. import operations as P
from ...common.tensor import RowTensor
from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..operations.comm_ops import (AllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast,
_GetTensorSlice, _MirrorOperator, ReduceOp,
_GetTensorSlice, _MirrorOperator, ReduceOp, Send, Receive,
ReduceScatter, _HostReduceScatter, _VirtualDiv)
from .grad_base import bprop_getters
@ -70,6 +70,32 @@ def get_bprop_all_reduce(self):
return bprop
@bprop_getters.register(Send)
def get_bprop_send(self):
"""Generate bprop for Send."""
shape = self.get_attr_dict()["shape"]
dtype = self.get_attr_dict()["dtype"]
send_grad = Receive(self.sr_tag, self.rank, shape, dtype, self.group)
def bprop(x, out, dout):
dx = send_grad()
return (dx,)
return bprop
@bprop_getters.register(Receive)
def get_bprop_receive(self):
"""Generate bprop for Receive."""
receive_grad = Send(self.tag, self.rank, self.group)
depend = P.Depend()
def bprop(out, dout):
send_out = receive_grad(dout)
dx = depend(dout, send_out)
return (dx,)
return bprop
@bprop_getters.register(Broadcast)
def get_bprop_broad_cast(self):
"""Generate bprop for Broadcast."""

@ -36,7 +36,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Unique, GatherD, Identity, RepeatElements)
from .comm_ops import (AllGather, AllReduce, _AlltoAll, ReduceScatter, Broadcast,
_MirrorOperator, ReduceOp, _VirtualDataset,
_VirtualDiv, _GetTensorSlice,
_VirtualDiv, _GetTensorSlice, Send, Receive,
_HostAllGather, _HostReduceScatter)
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
TensorSummary, HistogramSummary, Print, Assert)

@ -109,6 +109,117 @@ class AllReduce(PrimitiveWithInfer):
return x_dtype
class Send(PrimitiveWithInfer):
"""
Send tensors from src_rank to the specified dest_rank.
Note:
Send and Recveive must be used in combination and have same sr_tag.
Send must be used between servers.
Args:
sr_tag (int): A required integer identifying the send/recv message tag. The message will
will be received by the Receive op with the same "sr_tag".
dest_rank (int): A required integer identifying the destination rank.
group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group".
Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
Examples:
>>> import mindspore.ops.operations as P
>>> import mindspore.nn as nn
>>> from mindspore.communication import init
>>> from mindspore import Tensor
>>> import numpy as np
>>>
>>> init()
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.depend = P.Depend()
>>> self.send = P.Send(st_tag=0, dest_rank=8, group="hccl_world_group")
>>>
>>> def construct(self, x):
>>> out = self.depend(x, self.send(x))
>>> return out
>>>
>>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
>>> net = Net()
>>> output = net(input_)
"""
@prim_attr_register
def __init__(self, sr_tag, dest_rank, group=GlobalComm.WORLD_COMM_GROUP):
self.rank = get_rank(_get_group(group))
self.sr_tag = sr_tag
self.group = group
def infer_shape(self, x_shape):
self.add_prim_attr("shape", x_shape)
return x_shape
def infer_dtype(self, x_dtype):
self.add_prim_attr("dtype", x_dtype)
return x_dtype
class Receive(PrimitiveWithInfer):
"""
receive tensors from src_rank.
Note:
Send and Recveive must be used in combination and have same sr_tag.
Receive must be used between servers.
Args:
sr_tag (int): A required integer identifying the send/recv message tag. The message will
will be send by the Send op with the same "sr_tag".
src_rank (int): A required integer identifying the source rank.
shape (list[int]): A required list identifying the shape of the tensor to be received.
dtype (Type): A required Type indentifying the type of the tensor to be received. The supported types:
int8, int16, int32, float16, float32.
group (str): The communication group to work on. Default: "hccl_world_group/nccl_world_group".
Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
Examples:
>>> import mindspore.ops.operations as P
>>> import mindspore.nn as nn
>>> from mindspore.communication import init
>>> from mindspore import Tensor
>>> import numpy as np
>>>
>>> init()
>>> class Net(nn.Cell):
>>> def __init__(self):
>>> super(Net, self).__init__()
>>> self.send = P.Receive(st_tag=0, src_rank=0, shape=[2, 8], dtype=np.float32,
>>> group="hccl_world_group")
>>>
>>> def construct(self, x):
>>> out = self.depend(x, self.send(x))
>>> return out
>>>
>>> input_ = Tensor(np.ones([2, 8]).astype(np.float32))
>>> net = Net()
>>> output = net(input_)
"""
@prim_attr_register
def __init__(self, sr_tag, src_rank, shape, dtype, group=GlobalComm.WORLD_COMM_GROUP):
self.rank = get_rank(_get_group(group))
self.tag = sr_tag
self.shape = shape
self.dtype = dtype
self.group = group
def infer_shape(self):
return self.shape
def infer_dtype(self):
return self.dtype
class AllGather(PrimitiveWithInfer):
"""
Gathers tensors from the specified communication group.

Loading…
Cancel
Save