|
|
|
@ -59,7 +59,8 @@ class UndeterminedShapeType {
|
|
|
|
|
public:
|
|
|
|
|
explicit UndeterminedShapeType(const std::string &env_str) {
|
|
|
|
|
// param_name indices_shape indices_type values_shape values_type dense_shape
|
|
|
|
|
// export UNDETERMINED_SPARSE_SHAPE_TYPES="w1:2:Int32:2 1 2:Float32:3 1 2"
|
|
|
|
|
// export UNDETERMINED_SPARSE_SHAPE_TYPES="sparse_key_w1:2:Int32:2 1 2:Float32:3 1 2;sparse_key_w2:2:Int32:2 1
|
|
|
|
|
// 2:Float32:3 1 2"
|
|
|
|
|
std::vector<string> fields;
|
|
|
|
|
string tmp;
|
|
|
|
|
std::stringstream input(env_str);
|
|
|
|
@ -115,6 +116,20 @@ std::vector<int> UndeterminedShapeType::GetShape(const std::string &shape_str) {
|
|
|
|
|
}
|
|
|
|
|
const size_t UndeterminedShapeType::fields_num = 6;
|
|
|
|
|
|
|
|
|
|
std::unordered_map<std::string, UndeterminedShapeType> g_undetermined_configs;
|
|
|
|
|
void InitUndeterminedFromEnv(const std::string &sparse_shape_types) {
|
|
|
|
|
if (!g_undetermined_configs.empty()) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
std::string tmp;
|
|
|
|
|
std::stringstream input(sparse_shape_types);
|
|
|
|
|
while (std::getline(input, tmp, ';')) {
|
|
|
|
|
auto config = UndeterminedShapeType(tmp);
|
|
|
|
|
g_undetermined_configs.insert(std::make_pair(config.param_name(), config));
|
|
|
|
|
MS_LOG(DEBUG) << "Undetermined config from env: " << tmp;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
|
|
|
|
const AbstractBasePtrList &args_spec_list) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(primitive);
|
|
|
|
@ -128,27 +143,33 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt
|
|
|
|
|
MS_LOG(EXCEPTION) << "EnvGetItem evaluator args[1] should be a SymbolicKeyInstance but: " << key->ToString();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (key->sparse_grad()) {
|
|
|
|
|
if (!key->sparse_grad().empty()) {
|
|
|
|
|
// Will be fixed once undetermined type ready
|
|
|
|
|
auto sparse_shape_types = common::GetEnv("UNDETERMINED_SPARSE_SHAPE_TYPES");
|
|
|
|
|
if (sparse_shape_types.empty()) {
|
|
|
|
|
sparse_shape_types = "w1:2:Int32:2 1 2:Float32:3 1 2";
|
|
|
|
|
sparse_shape_types = "sparse_key_w1:2:Int32:2 1 2:Float32:3 1 2;sparse_key_w2:2:Int32:2 1 2:Float32:3 1 2";
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "EnvGetItem is sparse_grad " << key->ToString() << ", Undetermined shape is "
|
|
|
|
|
<< sparse_shape_types;
|
|
|
|
|
InitUndeterminedFromEnv(sparse_shape_types);
|
|
|
|
|
|
|
|
|
|
auto shape_types = UndeterminedShapeType(sparse_shape_types);
|
|
|
|
|
auto shape_types = g_undetermined_configs.find(key->sparse_grad());
|
|
|
|
|
if (shape_types == g_undetermined_configs.end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Param " << key->ToString()
|
|
|
|
|
<< " has sparse_grad, but shape/type is not configured in env UNDETERMINED_SPARSE_SHAPE_TYPES: "
|
|
|
|
|
<< sparse_shape_types;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "EnvGetItem is sparse_grad " << key->ToString();
|
|
|
|
|
AbstractBasePtrList sparse_list;
|
|
|
|
|
// indices
|
|
|
|
|
auto indices_ele = std::make_shared<AbstractScalar>(kAnyValue, shape_types.indices_type());
|
|
|
|
|
auto indices = std::make_shared<AbstractTensor>(indices_ele, std::make_shared<Shape>(shape_types.indices_shape()));
|
|
|
|
|
auto indices_ele = std::make_shared<AbstractScalar>(kAnyValue, shape_types->second.indices_type());
|
|
|
|
|
auto indices =
|
|
|
|
|
std::make_shared<AbstractTensor>(indices_ele, std::make_shared<Shape>(shape_types->second.indices_shape()));
|
|
|
|
|
sparse_list.emplace_back(indices);
|
|
|
|
|
// values
|
|
|
|
|
auto dout_ele = std::make_shared<AbstractScalar>(kAnyValue, shape_types.values_type());
|
|
|
|
|
auto dout = std::make_shared<AbstractTensor>(dout_ele, std::make_shared<Shape>(shape_types.values_shape()));
|
|
|
|
|
auto dout_ele = std::make_shared<AbstractScalar>(kAnyValue, shape_types->second.values_type());
|
|
|
|
|
auto dout = std::make_shared<AbstractTensor>(dout_ele, std::make_shared<Shape>(shape_types->second.values_shape()));
|
|
|
|
|
sparse_list.emplace_back(dout);
|
|
|
|
|
// dense_shape
|
|
|
|
|
sparse_list.emplace_back(std::make_shared<AbstractTuple>(shape_types.dense_shape()));
|
|
|
|
|
sparse_list.emplace_back(std::make_shared<AbstractTuple>(shape_types->second.dense_shape()));
|
|
|
|
|
return std::make_shared<AbstractTuple>(sparse_list);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|