!385 [Auto parallel] Adjusting backward phase communication cost of some operators

Merge pull request !385 from Xiaoda/modify-communicaiton-cost-of-operators-and-redistribution
pull/385/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 6e183fcc0f

@ -287,6 +287,31 @@ double BatchParallelCost::GetBackwardComputationCost(const std::vector<mindspore
return 0.0;
}
double BatchParallelCost::GetBackwardCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&,
int32_t stage_id) const {
double result = 0.0;
CheckGlobalDeviceManager();
MS_EXCEPTION_IF_NULL(g_device_manager);
auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
for (size_t j = 0; j < inputs.size(); ++j) {
if (!is_parameter_[j]) {
continue;
}
TensorInfo input_a_tensor_info = inputs[j];
Shape input_a_shape = input_a_tensor_info.shape();
Shape input_a_slice_shape = input_a_tensor_info.slice_shape();
int32_t used_device_num = 1;
for (size_t i = 0; i < input_a_shape.size(); ++i) {
used_device_num *= input_a_shape[i] / input_a_slice_shape[i];
}
if (total_device_num != IntToSize(used_device_num)) {
result += ListProduct(input_a_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
}
}
return result;
}
// return the per device communication cost in the forward phase.
double PReLUCost::GetForwardCommCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&, int32_t) const {
// prelu does not need communication in the forward phase
@ -432,8 +457,24 @@ double ReshapeCost::GetForwardCommCost(const std::vector<TensorInfo>& inputs, co
}
// return the per device communication cost in the backward phase.
double ReshapeCost::GetBackwardCommCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&, int32_t) const {
return 0.0;
double ReshapeCost::GetBackwardCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&,
int32_t stage_id) const {
double result = 0.0;
if (is_parameter_[0]) {
TensorInfo input1 = inputs[0];
MS_EXCEPTION_IF_NULL(g_device_manager);
auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
Shape input1_shape = input1.shape();
Shape input1_slice_shape = input1.slice_shape();
int32_t used_device_num = 1;
for (size_t i = 0; i < input1_shape.size(); ++i) {
used_device_num *= input1_shape[i] / input1_slice_shape[i];
}
if (total_device_num != IntToSize(used_device_num)) {
result = ListProduct(input1_slice_shape) * static_cast<double>(inputs_type_lengths_[1]);
}
}
return result;
}
// Return the per device computation cost in the forward phase. The cost is calculated according to the bytes
@ -654,10 +695,30 @@ double GatherV2Cost::GetForwardCommCost(const std::vector<TensorInfo>&, const st
}
// return the per device communication cost in the backward phase.
double GatherV2Cost::GetBackwardCommCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&,
int32_t) const {
// GatherV2Cost does not need communication in the backward phase
return 0.0;
double GatherV2Cost::GetBackwardCommCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&,
int32_t stage_id) const {
double result = 0.0;
CheckGlobalDeviceManager();
MS_EXCEPTION_IF_NULL(g_device_manager);
auto total_device_num = g_device_manager->GetDeviceListByStageId(stage_id).size();
for (size_t j = 0; j < inputs.size(); ++j) {
if (!is_parameter_[j]) {
continue;
}
TensorInfo input_a_tensor_info = inputs[j];
Shape input_a_shape = input_a_tensor_info.shape();
Shape input_a_slice_shape = input_a_tensor_info.slice_shape();
int32_t used_device_num = 1;
for (size_t i = 0; i < input_a_shape.size(); ++i) {
used_device_num *= input_a_shape[i] / input_a_slice_shape[i];
}
if (total_device_num != IntToSize(used_device_num)) {
result += ListProduct(input_a_slice_shape) * static_cast<double>(inputs_type_lengths_[0]);
}
}
return result;
}
double GatherV2Cost::GetForwardComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>&,

@ -226,9 +226,7 @@ class BatchParallelCost : public OperatorCost {
double GetForwardCommCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&, int32_t) const override {
return 0.0;
}
double GetBackwardCommCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&, int32_t) const override {
return 0.0;
}
double GetBackwardCommCost(const std::vector<TensorInfo>&, const std::vector<TensorInfo>&, int32_t) const override;
double GetComputationCost(const std::vector<TensorInfo>& inputs, const std::vector<TensorInfo>& outputs,
int32_t stage_id) const override {
return GetForwardComputationCost(inputs, outputs, stage_id) + GetBackwardComputationCost(inputs, outputs, stage_id);

@ -291,7 +291,7 @@ Status GatherV2Info::GenerateStrategies(int32_t stage_id) {
}
is_auto_parallel_ = true;
Shape input0_split(inputs_shape_[0].size());
Shape input0_split(inputs_shape_[0].size(), 1);
Shapes splittable_inputs = {input0_split};
std::vector<StrategyPtr> sp_vector;

Loading…
Cancel
Save