diff --git a/mindspore/ccsrc/operator/prim_others.cc b/mindspore/ccsrc/operator/prim_others.cc index 432b12f83b..9350e9aa3b 100644 --- a/mindspore/ccsrc/operator/prim_others.cc +++ b/mindspore/ccsrc/operator/prim_others.cc @@ -118,11 +118,9 @@ const size_t UndeterminedShapeType::fields_num = 6; std::unordered_map g_undetermined_configs; void InitUndeterminedFromEnv(const std::string &sparse_shape_types) { - if (!g_undetermined_configs.empty()) { - return; - } std::string tmp; std::stringstream input(sparse_shape_types); + g_undetermined_configs.clear(); while (std::getline(input, tmp, ';')) { auto config = UndeterminedShapeType(tmp); g_undetermined_configs.insert(std::make_pair(config.param_name(), config)); @@ -145,17 +143,19 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt if (!key->sparse_grad().empty()) { // Will be fixed once undetermined type ready - auto sparse_shape_types = common::GetEnv("UNDETERMINED_SPARSE_SHAPE_TYPES"); - if (sparse_shape_types.empty()) { - sparse_shape_types = "sparse_key_w1:2:Int32:2 1 2:Float32:3 1 2;sparse_key_w2:2:Int32:2 1 2:Float32:3 1 2"; + if (g_undetermined_configs.empty()) { + auto sparse_shape_types = common::GetEnv("UNDETERMINED_SPARSE_SHAPE_TYPES"); + MS_LOG(INFO) << "Undetermind sparse shape:" << sparse_shape_types; + if (sparse_shape_types.empty()) { + sparse_shape_types = "sparse_key_w1:2:Int32:2 1 2:Float32:3 1 2;sparse_key_w2:2:Int32:2 1 2:Float32:3 1 2"; + } + InitUndeterminedFromEnv(sparse_shape_types); } - InitUndeterminedFromEnv(sparse_shape_types); auto shape_types = g_undetermined_configs.find(key->sparse_grad()); if (shape_types == g_undetermined_configs.end()) { MS_LOG(EXCEPTION) << "Param " << key->ToString() - << " has sparse_grad, but shape/type is not configured in env UNDETERMINED_SPARSE_SHAPE_TYPES: " - << sparse_shape_types; + << " has sparse_grad, but shape/type is not configured in env UNDETERMINED_SPARSE_SHAPE_TYPES"; } MS_LOG(DEBUG) << "EnvGetItem is sparse_grad " << key->ToString(); AbstractBasePtrList sparse_list; diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index fc7b48d267..7d1200b190 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -43,6 +43,7 @@ #include "parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" #include "utils/comm_manager.h" #include "utils/symbolic.h" +#include "pipeline/static_analysis/prim.h" using mindspore::tensor::Tensor; @@ -1371,6 +1372,11 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { << cloned_index << ", but not found the be cloned parameter"; } } + std::string env = common::GetEnv("SLICE_ENV"); + if (!env.empty()) { + MS_LOG(INFO) << "Slice tensors shape will be configured from env:" << env; + abstract::InitUndeterminedFromEnv(env); + } } void SetVirtualDatasetStrategy(const CNodePtr &node) { diff --git a/mindspore/ccsrc/pipeline/static_analysis/prim.h b/mindspore/ccsrc/pipeline/static_analysis/prim.h index 5b3972088a..5954179aa5 100644 --- a/mindspore/ccsrc/pipeline/static_analysis/prim.h +++ b/mindspore/ccsrc/pipeline/static_analysis/prim.h @@ -349,6 +349,7 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv AbstractBasePtr InferImplDebug(const AnalysisEnginePtr &, const PrimitivePtr &primitive, const AbstractBasePtrList &args_spec_list); +void InitUndeterminedFromEnv(const std::string &sparse_shape_types); } // namespace abstract } // namespace mindspore