|
|
|
@ -40,7 +40,7 @@ Status TensorRedistribution::Init(const TensorLayout& from, const TensorLayout&
|
|
|
|
|
return Status::SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorList() {
|
|
|
|
|
RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorList(bool is_cost_model) {
|
|
|
|
|
// Step 1: Match device arrangement between from_ and to_
|
|
|
|
|
RedistributionLayoutTransfer layout_transfer;
|
|
|
|
|
Status status = layout_transfer.Init(from_, to_);
|
|
|
|
@ -62,7 +62,7 @@ RedistributionOpListPtr TensorRedistribution::InferTensorRedistributionOperatorL
|
|
|
|
|
MS_LOG(DEBUG) << "reshape to_ " << to_.ToString();
|
|
|
|
|
// Step 2: Infer redistribution and insert operators
|
|
|
|
|
RedistributionOperatorInfer operator_infer(construct_op_flag_);
|
|
|
|
|
if (operator_infer.Init(from_layout, to_layout.tensor_map(), dev_list_) == Status::FAILED) {
|
|
|
|
|
if (operator_infer.Init(from_layout, to_layout.tensor_map(), dev_list_, is_cost_model) == Status::FAILED) {
|
|
|
|
|
MS_LOG(ERROR) << "Init operatorInfer failed!";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
@ -138,7 +138,7 @@ Status TensorRedistribution::InferReshape(const TensorLayout& from_layout, const
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status TensorRedistribution::ComputeCost() {
|
|
|
|
|
RedistributionOpListPtr redistribution_oplist_ptr = InferTensorRedistributionOperatorList();
|
|
|
|
|
RedistributionOpListPtr redistribution_oplist_ptr = InferTensorRedistributionOperatorList(true);
|
|
|
|
|
if (redistribution_oplist_ptr == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "Failure: InferTensorRedistribution failed";
|
|
|
|
|
return Status::FAILED;
|
|
|
|
@ -151,14 +151,22 @@ Status TensorRedistribution::ComputeCost() {
|
|
|
|
|
std::accumulate(slice_shape.begin(), slice_shape.end(), static_cast<double>(1.0), std::multiplies<double>());
|
|
|
|
|
std::string str = op.first;
|
|
|
|
|
if (str == PERMUTE_BY_AXIS) {
|
|
|
|
|
// The shape does not change after PermuteByAxis operation.
|
|
|
|
|
// communication cost = all_to_all + all_to_all = 2 * slice_shape
|
|
|
|
|
// computation cost = slice_shape
|
|
|
|
|
forward_comm_cost_ += prod;
|
|
|
|
|
backward_comm_cost_ += prod;
|
|
|
|
|
comm_cost_ += 2.0 * prod;
|
|
|
|
|
computation_cost_ += prod;
|
|
|
|
|
memory_cost_ += prod;
|
|
|
|
|
// Since AlltoAll is a virtual operator, the expanded operators are used here to compute cost.
|
|
|
|
|
// communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape
|
|
|
|
|
forward_comm_cost_ += prod * ALLTOALL_SCALE_FACTOR;
|
|
|
|
|
backward_comm_cost_ += prod * ALLTOALL_SCALE_FACTOR;
|
|
|
|
|
comm_cost_ += 2.0 * prod * ALLTOALL_SCALE_FACTOR;
|
|
|
|
|
int32_t concat_dim = op.second[2];
|
|
|
|
|
if (concat_dim == 0) {
|
|
|
|
|
// memory cost = all_gather
|
|
|
|
|
computation_cost_ += prod;
|
|
|
|
|
memory_cost_ += prod;
|
|
|
|
|
} else {
|
|
|
|
|
// memory cost = all_gather + split + concat
|
|
|
|
|
int32_t dev_num = op.second[4];
|
|
|
|
|
computation_cost_ += (prod + prod * dev_num + prod * dev_num);
|
|
|
|
|
memory_cost_ += (prod * dev_num + prod * dev_num + prod);
|
|
|
|
|
}
|
|
|
|
|
} else if (str == CONCAT_BY_AXIS) {
|
|
|
|
|
// communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape
|
|
|
|
|
// computation cost = before_slice_shape
|
|
|
|
@ -168,9 +176,9 @@ Status TensorRedistribution::ComputeCost() {
|
|
|
|
|
}
|
|
|
|
|
double dev_num = op.second[2];
|
|
|
|
|
// here, communication cost = all_gather + reduce_scatter
|
|
|
|
|
forward_comm_cost_ += prod * dev_num;
|
|
|
|
|
backward_comm_cost_ += prod;
|
|
|
|
|
comm_cost_ += prod * (dev_num + 1.0);
|
|
|
|
|
forward_comm_cost_ += prod * dev_num * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
|
|
|
|
|
backward_comm_cost_ += prod * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
|
|
|
|
|
comm_cost_ += prod * (dev_num + 1.0) * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
|
|
|
|
|
int32_t concat_dim = op.second[0];
|
|
|
|
|
if (concat_dim == 0) {
|
|
|
|
|
// computation cost = all_gather
|
|
|
|
|