|
|
@ -149,60 +149,75 @@ Status TensorRedistribution::ComputeCost() {
|
|
|
|
double prod =
|
|
|
|
double prod =
|
|
|
|
std::accumulate(slice_shape.begin(), slice_shape.end(), static_cast<double>(1.0), std::multiplies<double>());
|
|
|
|
std::accumulate(slice_shape.begin(), slice_shape.end(), static_cast<double>(1.0), std::multiplies<double>());
|
|
|
|
std::string str = op.first;
|
|
|
|
std::string str = op.first;
|
|
|
|
if (str == PERMUTE_BY_AXIS) {
|
|
|
|
if (str == PERMUTE_BY_AXIS && ComputePermuteCost(prod, op.second) != Status::SUCCESS) {
|
|
|
|
|
|
|
|
return Status::FAILED;
|
|
|
|
|
|
|
|
} else if (str == CONCAT_BY_AXIS && ComputeConcatCost(prod, op.second) != Status::SUCCESS) {
|
|
|
|
|
|
|
|
return Status::FAILED;
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
// There is only computation cost in SplitByAxis.
|
|
|
|
|
|
|
|
// computation cost = before_slice_shape
|
|
|
|
|
|
|
|
computation_cost_ += prod;
|
|
|
|
|
|
|
|
// This addtion may be erroneous
|
|
|
|
|
|
|
|
memory_cost_ += prod;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (reshape_flag()) {
|
|
|
|
|
|
|
|
Shape prev_slice_shape = from_.slice_shape().array();
|
|
|
|
|
|
|
|
double prev_prod = std::accumulate(prev_slice_shape.begin(), prev_slice_shape.end(), 1, std::multiplies<int>());
|
|
|
|
|
|
|
|
computation_cost_ += 2.0 * prev_prod;
|
|
|
|
|
|
|
|
memory_cost_ += 2.0 * prev_prod;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return Status::SUCCESS;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Status TensorRedistribution::ComputePermuteCost(double input_size, Shape attrs) {
|
|
|
|
// Since AlltoAll is a virtual operator, the expanded operators are used here to compute cost.
|
|
|
|
// Since AlltoAll is a virtual operator, the expanded operators are used here to compute cost.
|
|
|
|
// communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape
|
|
|
|
// communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape
|
|
|
|
forward_comm_cost_ += prod * ALLTOALL_SCALE_FACTOR;
|
|
|
|
if (attrs.size() < 4) {
|
|
|
|
backward_comm_cost_ += prod * ALLTOALL_SCALE_FACTOR;
|
|
|
|
MS_LOG(ERROR) << "attrs size should not be less than 4!";
|
|
|
|
comm_cost_ += 2.0 * prod * ALLTOALL_SCALE_FACTOR;
|
|
|
|
return Status::FAILED;
|
|
|
|
int32_t concat_dim = op.second[2];
|
|
|
|
}
|
|
|
|
|
|
|
|
forward_comm_cost_ += input_size * ALLTOALL_SCALE_FACTOR;
|
|
|
|
|
|
|
|
backward_comm_cost_ += input_size * ALLTOALL_SCALE_FACTOR;
|
|
|
|
|
|
|
|
comm_cost_ += 2.0 * input_size * ALLTOALL_SCALE_FACTOR;
|
|
|
|
|
|
|
|
int32_t concat_dim = attrs[2];
|
|
|
|
if (concat_dim == 0) {
|
|
|
|
if (concat_dim == 0) {
|
|
|
|
// memory cost = all_gather
|
|
|
|
// memory cost = all_gather
|
|
|
|
computation_cost_ += prod;
|
|
|
|
computation_cost_ += input_size;
|
|
|
|
memory_cost_ += prod;
|
|
|
|
memory_cost_ += input_size;
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
// memory cost = all_gather + split + concat
|
|
|
|
// memory cost = all_gather + split + concat
|
|
|
|
int32_t dev_num = op.second[4];
|
|
|
|
int32_t dev_num = attrs[4];
|
|
|
|
computation_cost_ += (prod + prod * dev_num + prod * dev_num);
|
|
|
|
computation_cost_ += (input_size + input_size * dev_num + input_size * dev_num);
|
|
|
|
memory_cost_ += (prod * dev_num + prod * dev_num + prod);
|
|
|
|
memory_cost_ += (input_size * dev_num + input_size * dev_num + input_size);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
return Status::SUCCESS;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} else if (str == CONCAT_BY_AXIS) {
|
|
|
|
|
|
|
|
|
|
|
|
Status TensorRedistribution::ComputeConcatCost(double input_size, Shape attrs) {
|
|
|
|
// communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape
|
|
|
|
// communication cost = all_gather + reduce_scatter = before_slice_shape + after_slice_shape
|
|
|
|
// computation cost = before_slice_shape
|
|
|
|
// computation cost = before_slice_shape
|
|
|
|
if (op.second.size() < 3) {
|
|
|
|
if (attrs.size() < 3) {
|
|
|
|
MS_LOG(ERROR) << "op.second size should not be less than 3!";
|
|
|
|
MS_LOG(ERROR) << "op.second size should not be less than 3!";
|
|
|
|
return Status::FAILED;
|
|
|
|
return Status::FAILED;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
double dev_num = op.second[2];
|
|
|
|
double dev_num = attrs[2];
|
|
|
|
// here, communication cost = all_gather + reduce_scatter
|
|
|
|
// here, communication cost = all_gather + reduce_scatter
|
|
|
|
forward_comm_cost_ += prod * dev_num * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
|
|
|
|
forward_comm_cost_ += input_size * dev_num * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
|
|
|
|
backward_comm_cost_ += prod * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
|
|
|
|
backward_comm_cost_ += input_size * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
|
|
|
|
comm_cost_ += prod * (dev_num + 1.0) * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
|
|
|
|
comm_cost_ += input_size * (dev_num + 1.0) * ALLGATHER_REDUCESCATTER_SCALE_FACTOR;
|
|
|
|
int32_t concat_dim = op.second[0];
|
|
|
|
int32_t concat_dim = attrs[0];
|
|
|
|
if (concat_dim == 0) {
|
|
|
|
if (concat_dim == 0) {
|
|
|
|
// computation cost = all_gather
|
|
|
|
// computation cost = all_gather
|
|
|
|
computation_cost_ += prod;
|
|
|
|
computation_cost_ += input_size;
|
|
|
|
memory_cost_ += prod * dev_num;
|
|
|
|
memory_cost_ += input_size * dev_num;
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
// computation cost = all_gather + split + concat
|
|
|
|
// computation cost = all_gather + split + concat
|
|
|
|
computation_cost_ += (prod + prod * dev_num + prod * dev_num);
|
|
|
|
computation_cost_ += (input_size + input_size * dev_num + input_size * dev_num);
|
|
|
|
memory_cost_ += (prod * dev_num + prod * dev_num + prod);
|
|
|
|
memory_cost_ += (input_size * dev_num + input_size * dev_num + input_size);
|
|
|
|
}
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
// There is only computation cost in SplitByAxis.
|
|
|
|
|
|
|
|
// computation cost = before_slice_shape
|
|
|
|
|
|
|
|
computation_cost_ += prod;
|
|
|
|
|
|
|
|
// This addtion may be erroneous
|
|
|
|
|
|
|
|
memory_cost_ += prod;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (reshape_flag()) {
|
|
|
|
|
|
|
|
Shape prev_slice_shape = from_.slice_shape().array();
|
|
|
|
|
|
|
|
double prev_prod = std::accumulate(prev_slice_shape.begin(), prev_slice_shape.end(), 1, std::multiplies<int>());
|
|
|
|
|
|
|
|
computation_cost_ += 2.0 * prev_prod;
|
|
|
|
|
|
|
|
memory_cost_ += 2.0 * prev_prod;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return Status::SUCCESS;
|
|
|
|
return Status::SUCCESS;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
} // namespace parallel
|
|
|
|
} // namespace parallel
|
|
|
|
} // namespace mindspore
|
|
|
|
} // namespace mindspore
|
|
|
|