|
|
|
@ -32,11 +32,7 @@ namespace mindspore {
|
|
|
|
|
namespace parallel {
|
|
|
|
|
Status Activation::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
|
|
|
|
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
|
|
|
|
|
if (is_auto_parallel_) {
|
|
|
|
|
MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed.";
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << name_ << " : Set cost under strategy failed.";
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(ERROR) << name_ << " : Set cost under strategy failed.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -45,11 +41,7 @@ Status Activation::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
|
|
|
|
|
|
|
|
|
Status Activation::CheckStrategy(const StrategyPtr &strategy) {
|
|
|
|
|
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
|
|
|
|
|
if (is_auto_parallel_) {
|
|
|
|
|
MS_LOG(DEBUG) << name_ << " : Invalid strategy.";
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << name_ << " : Invalid strategy.";
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(ERROR) << name_ << " : Invalid strategy.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -58,11 +50,7 @@ Status Activation::CheckStrategy(const StrategyPtr &strategy) {
|
|
|
|
|
|
|
|
|
|
Status DropoutInfo::CheckStrategy(const StrategyPtr &strategy) {
|
|
|
|
|
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
|
|
|
|
|
if (is_auto_parallel_) {
|
|
|
|
|
MS_LOG(DEBUG) << name_ << " : Invalid strategy.";
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << name_ << " : Invalid strategy.";
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(ERROR) << name_ << " : Invalid strategy.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -125,7 +113,6 @@ Status Activation::GenerateStrategies(int32_t stage_id) {
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
is_auto_parallel_ = true;
|
|
|
|
|
Shape input0_split(inputs_shape_[0].size(), 1);
|
|
|
|
|
Shapes splittable_inputs = {input0_split};
|
|
|
|
|
|
|
|
|
@ -146,7 +133,6 @@ Status Activation::GenerateStrategies(int32_t stage_id) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status DropoutInfo::GenerateStrategies(int32_t stage_id) {
|
|
|
|
|
is_auto_parallel_ = true;
|
|
|
|
|
Shape input0_split(inputs_shape_[0].size(), 1);
|
|
|
|
|
Shapes splittable_inputs = {input0_split};
|
|
|
|
|
|
|
|
|
@ -168,11 +154,7 @@ Status DropoutInfo::GenerateStrategies(int32_t stage_id) {
|
|
|
|
|
|
|
|
|
|
Status Softmax::CheckStrategy(const StrategyPtr &strategy) {
|
|
|
|
|
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
|
|
|
|
|
if (is_auto_parallel_) {
|
|
|
|
|
MS_LOG(DEBUG) << name_ << " : Invalid strategy.";
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << name_ << " : Invalid strategy.";
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(ERROR) << name_ << " : Invalid strategy.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -189,11 +171,7 @@ Status Softmax::CheckStrategy(const StrategyPtr &strategy) {
|
|
|
|
|
int32_t axis_strategy = input_strategy.at(IntToSize(axis_index));
|
|
|
|
|
// Dimension corresponding to axis is un-splittable
|
|
|
|
|
if (axis_strategy != MIN_SLICE_NUM) {
|
|
|
|
|
if (is_auto_parallel_) {
|
|
|
|
|
MS_LOG(DEBUG) << name_ << " : The strategy corresponding to axis dimension(" << axis_strategy << ") is not 1";
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << name_ << " : The strategy corresponding to axis dimension(" << axis_strategy << ") is not 1";
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(ERROR) << name_ << " : The strategy corresponding to axis dimension(" << axis_strategy << ") is not 1";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -253,11 +231,7 @@ Status Softmax::GetAttrs() {
|
|
|
|
|
|
|
|
|
|
Status Softmax::SetCostUnderStrategy(const StrategyPtr &strategy) {
|
|
|
|
|
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
|
|
|
|
|
if (is_auto_parallel_) {
|
|
|
|
|
MS_LOG(DEBUG) << name_ << " : Set cost under strategy failed.";
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << name_ << " : Set cost under strategy failed.";
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(ERROR) << name_ << " : Set cost under strategy failed.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -274,7 +248,6 @@ Status Softmax::GenerateStrategies(int32_t stage_id) {
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
is_auto_parallel_ = true;
|
|
|
|
|
Shape input0_split;
|
|
|
|
|
(void)input0_split.insert(input0_split.begin(), inputs_shape_[0].size(), 1);
|
|
|
|
|
for (auto &element : axis_) {
|
|
|
|
@ -418,11 +391,7 @@ Status ActivationBase::Init(const StrategyPtr &strategy) {
|
|
|
|
|
|
|
|
|
|
Status ActivationBase::InitForCostModel(const StrategyPtr &strategy) {
|
|
|
|
|
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
|
|
|
|
|
if (is_auto_parallel_) {
|
|
|
|
|
MS_LOG(DEBUG) << name_ << " : Init for cost model failed.";
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << name_ << " : Init for cost model failed.";
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(ERROR) << name_ << " : Init for cost model failed.";
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|