add batch parallel info black list

pull/6409/head
Ziyan 5 years ago
parent defd74e261
commit 9e5248497b

@ -80,9 +80,17 @@ const std::set<std::string> BLACK_LIST = {TUPLE_GETITEM,
REF_TO_EMBED,
STOP_GRADIENT};
const std::set<std::string> BATCH_PARALLEL_BLACK_LIST = {PACK, TENSOR_SCATTER_UPDATE, MIN_MAX_UPDATE_PER_LAYER};
bool IsInBlackList(const PrimitivePtr &prim) {
MS_EXCEPTION_IF_NULL(prim);
return (BLACK_LIST.find(prim->name()) != BLACK_LIST.end());
}
bool IsInBatchParallelBlackList(const PrimitivePtr &prim) {
MS_EXCEPTION_IF_NULL(prim);
return (BATCH_PARALLEL_BLACK_LIST.find(prim->name()) != BATCH_PARALLEL_BLACK_LIST.end());
}
} // namespace parallel
} // namespace mindspore

@ -22,6 +22,7 @@
namespace mindspore {
namespace parallel {
bool IsInBlackList(const PrimitivePtr &prim);
bool IsInBatchParallelBlackList(const PrimitivePtr &prim);
} // namespace parallel
} // namespace mindspore

@ -232,7 +232,6 @@ constexpr char SQUARE[] = "Square";
constexpr char BATCHMATMUL[] = "BatchMatMul";
constexpr char TOPK[] = "TopK";
constexpr char IN_TOPK[] = "InTopK";
constexpr char PACK[] = "Pack";
constexpr char GATHER_ND[] = "GatherNd";
constexpr char UNSORTEF_SEGMENT_MIND[] = "UnsortedSegmentMinD";
constexpr char UNSORTEF_SEGMENT_PRODD[] = "UnsortedSegmentProdD";
@ -298,6 +297,11 @@ constexpr char ZEROSLIKE[] = "ZerosLike";
constexpr char REF_TO_EMBED[] = "RefToEmbed";
constexpr char STOP_GRADIENT[] = "stop_gradient";
// Batch parallel black list
constexpr char TENSOR_SCATTER_UPDATE[] = "TensorScatterUpdate";
constexpr char MIN_MAX_UPDATE_PER_LAYER[] = "MinMaxUpdatePerLayer";
constexpr char PACK[] = "Pack";
constexpr size_t LAST_INDEX(size_t s) { return s - 1; }
constexpr size_t SECOND_FROM_END(size_t s) { return s - 2; }
constexpr size_t THIRD_FROM_END(size_t s) { return s - 3; }

@ -1029,7 +1029,10 @@ OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs
const std::vector<Shapes> &shape_list) {
MS_EXCEPTION_IF_NULL(prim);
OperatorInfoPtr operator_ = OperatorInstanceByName(prim->name(), attrs, shape_list);
if ((operator_ == nullptr) && (prim->name() != MAKE_TUPLE)) {
if (operator_ == nullptr) {
if (IsInBatchParallelBlackList(prim)) {
MS_LOG(EXCEPTION) << "Operator " << prim->name() << " is not supported yet in auto parallel mode.";
}
MS_LOG(INFO) << "Creat " << prim->name() << " failed, use batch parallel";
operator_ = OperatorInstanceByName(BATCH_PARALLEL, attrs, shape_list);
MS_EXCEPTION_IF_NULL(operator_);

@ -149,60 +149,75 @@ Status TensorRedistribution::ComputeCost() {
double prod =
std::accumulate(slice_shape.begin(), slice_shape.end(), static_cast<double>(1.0), std::multiplies<double>());
std::string str = op.first;
if (str == PERMUTE_BY_AXIS) {
if (str == PERMUTE_BY_AXIS && ComputePermuteCost(prod, op.second) != Status::SUCCESS) {
return Status::FAILED;
} else if (str == CONCAT_BY_AXIS && ComputeConcatCost(prod, op.second) != Status::SUCCESS) {
return Status::FAILED;
} else {
// There is only computation cost in SplitByAxis.
// computation cost = before_slice_shape
computation_cost_ += prod;
// This addtion may be erroneous
memory_cost_ += prod;
}
}
if (reshape_flag()) {
Shape prev_slice_shape = from_.slice_shape().array();
double prev_prod = std::accumulate(prev_slice_shape.begin(), prev_slice_shape.end(), 1, std::multiplies<int>());
computation_cost_ += 2.0 * prev_prod;
memory_cost_ += 2.0 * prev_prod;
}
return Status::SUCCESS;
}
Status TensorRedistribution::ComputePermuteCost(double input_size, Shape attrs) {
// Since AlltoAll is a virtual operator, the expanded operators are used here to compute cost.
// communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape
forward_comm_cost_ += prod * ALLTOALL_SCALE_FACTOR;
backward_comm_cost_ += prod * ALLTOALL_SCALE_FACTOR;
comm_cost_ += 2.0 * prod * ALLTOALL_SCALE_FACTOR;
int32_t concat_dim = op.second[2];
if (attrs.size() < 4) {
MS_LOG(ERROR) << "attrs size should not be less than 4!";
return Status::FAILED;
}
forward_comm_cost_ += input_size * ALLTOALL_SCALE_FACTOR;
backward_comm_cost_ += input_size * ALLTOALL_SCALE_FACTOR;
comm_cost_ += 2.0 * input_size * ALLTOALL_SCALE_FACTOR;
int32_t concat_dim = attrs[2];
if (concat_dim == 0) {
// memory cost = all_gather
computation_cost_ += prod;
memory_cost_ += prod;
computation_cost_ += input_size;
memory_cost_ += input_size;
} else {
// memory cost = all_gather + split + concat
int32_t dev_num = op.second[4];
computation_cost_ += (prod + prod * dev_num + prod * dev_num);
memory_cost_ += (prod * dev_num + prod * dev_num + prod);
int32_t dev_num = attrs[4];
computation_cost_ += (input_size + input_size * dev_num + input_size * dev_num);
memory_cost_ += (input_size * dev_num + input_size * dev_num + input_size);
}
return Status::SUCCESS;
}
} else if (str == CONCAT_BY_AXIS) {
Status TensorRedistribution::ComputeConcatCost(double input_size, Shape attrs) {
// communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape
// computation cost = before_slice_shape
if (op.second.size() < 3) {
if (attrs.size() < 3) {
MS_LOG(ERROR) << "op.second size should not be less than 3!";
return Status::FAILED;
}
double dev_num = op.second[2];
double dev_num = attrs[2];
// here, communication cost = all_gather + reduce_scatter
forward_comm_cost_ += prod * dev_num * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
backward_comm_cost_ += prod * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
comm_cost_ += prod * (dev_num + 1.0) * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
int32_t concat_dim = op.second[0];
forward_comm_cost_ += input_size * dev_num * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
backward_comm_cost_ += input_size * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
comm_cost_ += input_size * (dev_num + 1.0) * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
int32_t concat_dim = attrs[0];
if (concat_dim == 0) {
// computation cost = all_gather
computation_cost_ += prod;
memory_cost_ += prod * dev_num;
computation_cost_ += input_size;
memory_cost_ += input_size * dev_num;
} else {
// computation cost = all_gather + split + concat
computation_cost_ += (prod + prod * dev_num + prod * dev_num);
memory_cost_ += (prod * dev_num + prod * dev_num + prod);
}
} else {
// There is only computation cost in SplitByAxis.
// computation cost = before_slice_shape
computation_cost_ += prod;
// This addtion may be erroneous
memory_cost_ += prod;
}
}
if (reshape_flag()) {
Shape prev_slice_shape = from_.slice_shape().array();
double prev_prod = std::accumulate(prev_slice_shape.begin(), prev_slice_shape.end(), 1, std::multiplies<int>());
computation_cost_ += 2.0 * prev_prod;
memory_cost_ += 2.0 * prev_prod;
computation_cost_ += (input_size + input_size * dev_num + input_size * dev_num);
memory_cost_ += (input_size * dev_num + input_size * dev_num + input_size);
}
return Status::SUCCESS;
}
} // namespace parallel
} // namespace mindspore

@ -61,7 +61,8 @@ class TensorRedistribution {
private:
Status InferReshape(const TensorLayout &from_layout, const TensorLayout &to_layout,
OperatorVector *const operator_vector, OutPutInfoVector *const output_info_vector);
Status ComputeConcatCost(double input_size, Shape attrs);
Status ComputePermuteCost(double input_size, Shape attrs);
TensorLayout from_origin_;
TensorLayout to_origin_;
TensorLayout from_;

Loading…
Cancel
Save