!1904 Add IndexedSlices

Merge pull request !1904 from riemann_penn/add_indexed_slices
pull/1904/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit fe82d82155

@ -17,6 +17,7 @@
"""Resources for ast tree parse."""
import ast
import math
from mindspore import IndexedSlices
from mindspore.ops.composite import multitype_ops
from mindspore.ops import functional as F, composite as C
from . import standard_method as M
@ -135,4 +136,7 @@ convert_object_map = {
math.sin: NO_IMPLEMENT,
math.cos: NO_IMPLEMENT,
math.tan: NO_IMPLEMENT,
# user defined
IndexedSlices: F.make_indexed_slices,
}

@ -120,6 +120,10 @@ void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &s
type_proto->mutable_tensor_type()->mutable_shape()->add_dim()->set_size(elem);
}
}
} else if (type->isa<IndexedSlicesType>()) {
// Do Nothing
} else if (type->isa<UndeterminedType>()) {
// Do Nothing
} else if (type->isa<Tuple>()) {
TuplePtr tuple_type = dyn_cast<Tuple>(type);
type_proto->set_data_type(irpb::DT_TUPLE);

@ -94,6 +94,48 @@ bool Slice::operator==(const Type &other) const {
std::string Slice::DumpText() const { return ToString(); }
TypePtr UndeterminedType::DeepCopy() const {
MS_EXCEPTION_IF_NULL(element_type_);
if (IsGeneric()) {
return std::make_shared<UndeterminedType>();
}
return std::make_shared<UndeterminedType>(element_type_->DeepCopy());
}
std::string UndeterminedType::ToReprString() const {
if (element_type_ == nullptr) {
return "Undetermined";
}
return "Undetermined[" + element_type_->ToReprString() + "]";
}
std::string UndeterminedType::ToString() const {
if (element_type_ == nullptr) {
return "Undetermined";
}
return "Undetermined[" + element_type_->ToString() + "]";
}
std::string UndeterminedType::DumpText() const {
if (element_type_ == nullptr) {
return "Undetermined";
}
return "Undetermined[" + element_type_->DumpText() + "]";
}
bool UndeterminedType::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
auto other_elem_type = static_cast<const UndeterminedType &>(other).element_type_;
if (element_type_ == nullptr && other_elem_type == nullptr) {
return true;
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
return false;
}
return *element_type_ == *other_elem_type;
}
TypePtr TensorType::DeepCopy() const {
MS_EXCEPTION_IF_NULL(element_type_);
if (IsGeneric()) {
@ -137,6 +179,48 @@ bool TensorType::operator==(const Type &other) const {
return *element_type_ == *other_elem_type;
}
TypePtr IndexedSlicesType::DeepCopy() const {
MS_EXCEPTION_IF_NULL(element_type_);
if (IsGeneric()) {
return std::make_shared<IndexedSlicesType>();
}
return std::make_shared<IndexedSlicesType>(element_type_->DeepCopy());
}
std::string IndexedSlicesType::ToReprString() const {
if (element_type_ == nullptr) {
return "IndexedSlices";
}
return "IndexedSlices[" + element_type_->ToReprString() + "]";
}
std::string IndexedSlicesType::ToString() const {
if (element_type_ == nullptr) {
return "IndexedSlices";
}
return "IndexedSlices[" + element_type_->ToString() + "]";
}
std::string IndexedSlicesType::DumpText() const {
if (element_type_ == nullptr) {
return "IndexedSlices";
}
return "IndexedSlices[" + element_type_->DumpText() + "]";
}
bool IndexedSlicesType::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
auto other_elem_type = static_cast<const IndexedSlicesType &>(other).element_type_;
if (element_type_ == nullptr && other_elem_type == nullptr) {
return true;
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
return false;
}
return *element_type_ == *other_elem_type;
}
Function::Function() : Object(kObjectTypeFunction) {
args_ = std::vector<TypePtr>();
retval_ = nullptr;

@ -108,10 +108,34 @@ class Slice : public Object {
};
using SlicePtr = std::shared_ptr<Slice>;
class UndeterminedType : public Object {
public:
UndeterminedType() : Object(kObjectTypeUndeterminedType) {}
explicit UndeterminedType(const TypePtr &ele)
: Object(kObjectTypeUndeterminedType, kMetaTypeObject, false), element_type_(ele) {}
~UndeterminedType() override = default;
MS_DECLARE_PARENT(UndeterminedType, Object)
TypeId generic_type_id() const override { return kObjectTypeUndeterminedType; }
const TypePtr element() const { return element_type_; }
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string ToReprString() const override;
std::string DumpText() const override;
bool operator==(const Type &other) const override;
protected:
TypePtr element_type_;
};
using MetaTensorTypePtr = std::shared_ptr<UndeterminedType>;
class TensorType : public Object {
public:
TensorType() : Object(kObjectTypeTensorType) {}
explicit TensorType(const TypePtr &ele) : Object(kObjectTypeTensorType, false), element_type_(ele) {}
TensorType() : Object(kObjectTypeTensorType, kObjectTypeUndeterminedType) {}
explicit TensorType(const TypePtr &ele)
: Object(kObjectTypeTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
~TensorType() override = default;
MS_DECLARE_PARENT(TensorType, Object)
@ -130,6 +154,29 @@ class TensorType : public Object {
};
using TensorTypePtr = std::shared_ptr<TensorType>;
class IndexedSlicesType : public Object {
public:
IndexedSlicesType() : Object(kObjectTypeIndexedSlicesType, kObjectTypeUndeterminedType) {}
explicit IndexedSlicesType(const TypePtr &ele)
: Object(kObjectTypeIndexedSlicesType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
~IndexedSlicesType() override = default;
MS_DECLARE_PARENT(IndexedSlicesType, Object)
TypeId generic_type_id() const override { return kObjectTypeIndexedSlicesType; }
const TypePtr element() const { return element_type_; }
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string ToReprString() const override;
std::string DumpText() const override;
bool operator==(const Type &other) const override;
private:
TypePtr element_type_;
};
using IndexedSlicesTypePtr = std::shared_ptr<IndexedSlicesType>;
class Function : public Object {
public:
Function();
@ -255,6 +302,8 @@ TypePtr StringToType(const std::string &type_name);
// Judge whether x is predicate or is a subclass of predicate.
bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type);
bool IsParentOrChildrenType(TypePtr const &x, TypePtr const &base_type);
// Whether t1 is identity or a subclass of t2.
bool IsSubType(TypePtr const &t1, TypePtr const &t2 = nullptr);

@ -115,6 +115,10 @@ const char *ObjectIdLabel(const TypeId &v) {
return "kObjectTypeKeyword";
case kObjectTypeTensorType:
return "kObjectTypeTensorType";
case kObjectTypeIndexedSlicesType:
return "kObjectTypeIndexedSlicesType";
case kObjectTypeUndeterminedType:
return "kObjectTypeUndeterminedType";
case kObjectTypeDictionary:
return "kObjectTypeDictionary";
case kObjectTypeClass:

@ -67,6 +67,7 @@ class Type : public Value {
virtual bool equal(const TypePtr other) const { return *this == *other; }
virtual TypeId object_type() const { return kTypeUnknown; }
virtual TypeId parent_type() const { return kTypeUnknown; }
virtual TypeId number_type() const { return kTypeUnknown; }
virtual TypePtr DeepCopy() const = 0;
virtual TypePtr Clone() const { return DeepCopy(); }
@ -97,13 +98,16 @@ using TypePtrList = std::vector<TypePtr>;
//
class Object : public Type {
public:
Object() : Type(kMetaTypeObject), object_type_(kMetaTypeObject) {}
Object() : Type(kMetaTypeObject), object_type_(kMetaTypeObject), parent_type_(kMetaTypeObject) {}
explicit Object(const TypeId object_type, bool is_generic = true)
: Type(kMetaTypeObject, is_generic), object_type_(object_type) {}
: Type(kMetaTypeObject, is_generic), object_type_(object_type), parent_type_(kMetaTypeObject) {}
explicit Object(const TypeId object_type, const TypeId parent_type, bool is_generic = true)
: Type(kMetaTypeObject, is_generic), object_type_(object_type), parent_type_(parent_type) {}
~Object() override = default;
MS_DECLARE_PARENT(Object, Type)
TypeId object_type() const override { return object_type_; }
TypeId parent_type() const override { return parent_type_; }
TypeId type_id() const override { return object_type_; }
TypeId generic_type_id() const override { return kMetaTypeObject; }
bool equal(const TypePtr other) const override;
@ -114,6 +118,7 @@ class Object : public Type {
private:
const TypeId object_type_;
const TypeId parent_type_;
};
std::ostream &operator<<(std::ostream &os, const TypePtrList &types);

@ -50,6 +50,8 @@ enum TypeId : int {
kObjectTypeSlice,
kObjectTypeKeyword,
kObjectTypeTensorType,
kObjectTypeIndexedSlicesType,
kObjectTypeUndeterminedType,
kObjectTypeClass,
kObjectTypeDictionary,
kObjectTypeFunction,

@ -192,6 +192,40 @@ TypePtr TensorStrToType(const std::string &type_name) {
return type;
}
TypePtr IndexedSlicesStrToType(const std::string &type_name) {
if (type_name == "IndexedSlices") {
return std::make_shared<IndexedSlicesType>();
}
auto start = type_name.find_first_of('[') + 1;
auto end = type_name.find_last_of(']');
if (start >= type_name.size()) {
return nullptr;
}
auto element_str = type_name.substr(start, end - start);
auto element_type = StringToType(element_str);
if (element_type == nullptr) {
return nullptr;
}
return std::make_shared<IndexedSlicesType>(element_type);
}
TypePtr UndeterminedStrToType(const std::string &type_name) {
if (type_name == "Undetermined") {
return std::make_shared<UndeterminedType>();
}
auto start = type_name.find_first_of('[') + 1;
auto end = type_name.find_last_of(']');
if (start >= type_name.size()) {
return nullptr;
}
auto element_str = type_name.substr(start, end - start);
auto element_type = StringToType(element_str);
if (element_type == nullptr) {
return nullptr;
}
return std::make_shared<UndeterminedType>(element_type);
}
TypePtr ListStrToType(const std::string &type_name) {
TypePtr type = nullptr;
if (type_name == "List") {
@ -313,6 +347,10 @@ TypePtr StringToType(const std::string &type_name) {
type = StringToNumberType<Float>(type_name, "Float");
} else if (type_name.compare(0, strlen("Tensor"), "Tensor") == 0) {
type = TensorStrToType(type_name);
} else if (type_name.compare(0, strlen("Undetermined"), "Undetermined") == 0) {
type = UndeterminedStrToType(type_name);
} else if (type_name.compare(0, strlen("IndexedSlices"), "IndexedSlices") == 0) {
type = IndexedSlicesStrToType(type_name);
} else if (type_name.compare(0, strlen("List"), "List") == 0) {
type = ListStrToType(type_name);
} else if (type_name.compare(0, strlen("Tuple"), "Tuple") == 0) {
@ -340,6 +378,20 @@ TypePtr StringToType(const std::string &type_name) {
return type;
}
bool IsParentOrChildrenType(TypePtr const &x, TypePtr const &base_type) {
if (x == nullptr || base_type == nullptr) {
MS_LOG(ERROR) << "Type is nullptr.";
return false;
}
if (base_type->type_id() == kTypeUnknown || x->type_id() == kTypeUnknown) {
return false;
}
if (base_type->type_id() == x->parent_type() || x->type_id() == base_type->parent_type()) {
return true;
}
return false;
}
bool IsIdentidityOrSubclass(TypePtr const &x, TypePtr const &base_type) {
if (x == nullptr || base_type == nullptr) {
MS_LOG(ERROR) << "Type is nullptr.";
@ -481,6 +533,10 @@ REGISTER_PYBIND_DEFINE(
TensorType data(TypeIdToType(TypeId(static_cast<int>(t[0].cast<py::int_>()))));
return data;
}));
(void)py::class_<IndexedSlicesType, Type, std::shared_ptr<IndexedSlicesType>>(m_sub, "IndexedSlicesType")
.def(py::init());
(void)py::class_<UndeterminedType, Type, std::shared_ptr<UndeterminedType>>(m_sub, "UndeterminedType")
.def(py::init());
(void)py::class_<Function, Type, std::shared_ptr<Function>>(m_sub, "Function")
.def(py::init())
.def(py::init<std::vector<TypePtr>, TypePtr>(), py::arg("args"), py::arg("retval"));
@ -501,6 +557,8 @@ const TypePtr kTypeExternal = std::make_shared<External>();
const TypePtr kTypeEnv = std::make_shared<EnvType>();
const TypePtr kTypeType = std::make_shared<TypeType>();
const TypePtr kTensorType = std::make_shared<TensorType>();
const TypePtr kIndexedSlicesType = std::make_shared<IndexedSlicesType>();
const TypePtr kUndeterminedType = std::make_shared<UndeterminedType>();
const TypePtr kString = std::make_shared<String>();
const TypePtr kList = std::make_shared<List>();
const TypePtr kTuple = std::make_shared<Tuple>();

@ -93,15 +93,17 @@ static TypePtr UnwrapRef(const TypePtr &type) {
}
return type;
}
FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
bool find_fn = false;
py::function py_fn;
// Return Exact match if exists, else return non ambiguous sub class match
// Return py::none() if matching is ambiguous
const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) {
// Exact match
for (auto &item : fn_cache_py_) {
TypePtrList sign = item.first;
if (sign.size() != types.size()) {
continue;
}
bool match = true;
auto match = true;
for (size_t i = 0; i < sign.size(); ++i) {
if (!IsIdentidityOrSubclass(UnwrapRef(types[i]), sign[i])) {
match = false;
@ -111,13 +113,45 @@ FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
if (!match) {
continue;
}
find_fn = true;
py_fn = item.second;
break;
return item.second;
}
// Try best match
py::function py_fn_subclass;
size_t subclass_match_cnt = 0;
for (auto &item : fn_cache_py_) {
TypePtrList sign = item.first;
if (sign.size() != types.size()) {
continue;
}
auto match = true;
for (size_t i = 0; i < sign.size(); ++i) {
if (!IsIdentidityOrSubclass(UnwrapRef(types[i]), sign[i]) &&
!IsParentOrChildrenType(UnwrapRef(types[i]), sign[i])) {
match = false;
break;
}
}
if (!match) {
continue;
}
py_fn_subclass = item.second;
subclass_match_cnt++;
}
if (subclass_match_cnt > 1) {
MS_LOG(EXCEPTION) << "There are more than one prototypes for overload function match by subclass";
}
if (subclass_match_cnt == 1) {
MS_LOG(DEBUG) << "Found one subclass match";
return py_fn_subclass;
}
return py::none();
}
FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
auto py_fn = SignMatch(types);
std::ostringstream buffer;
buffer << types;
if (find_fn) {
if (py_fn != py::none()) {
FuncGraphPtr func_graph = parse::ParsePythonCode(py_fn);
if (func_graph == nullptr) {
MS_LOG(EXCEPTION) << "Fail to parse overload function " << buffer.str();

@ -54,6 +54,7 @@ class MultitypeFuncGraph : public MetaFuncGraph {
}
private:
const py::function SignMatch(const TypePtrList &types);
std::unordered_map<TypePtrList, specialize_fn, TypeListHasher, TypeListEqual> fn_cache_;
std::unordered_map<TypePtrList, py::function, TypeListHasher, TypeListEqual> fn_cache_py_;
};

@ -277,5 +277,12 @@ const PrimitivePtr kPrimImageSummary = std::make_shared<Primitive>("ImageSummary
const PrimitivePtr kPrimTensorSummary = std::make_shared<Primitive>("TensorSummary");
const PrimitivePtr kPrimHistogramSummary = std::make_shared<Primitive>("HistogramSummary");
const PrimitivePtr kPrimDebug = std::make_shared<Primitive>("Debug");
// IndexedSlices
const PrimitivePtr kPrimMakeIndexedSlices = std::make_shared<Primitive>("MakeIndexedSlices");
const PrimitivePtr kPrimIndexedSlicesGetValues = std::make_shared<Primitive>("IndexedSlicesGetValues");
const PrimitivePtr kPrimIndexedSlicesGetIndices = std::make_shared<Primitive>("IndexedSlicesGetIndices");
const PrimitivePtr kPrimIndexedSlicesGetDenseShape = std::make_shared<Primitive>("IndexedSlicesGetDenseShape");
const PrimitivePtr kPrimIsIndexedSlices = std::make_shared<Primitive>("IsIndexedSlices");
} // namespace prim
} // namespace mindspore

@ -287,6 +287,13 @@ extern const PrimitivePtr kPrimMirror;
extern const PrimitivePtr kPrimVirtualDiv;
extern const PrimitivePtr kPrimVirtualDataset;
// IndexedSlices
extern const PrimitivePtr kPrimMakeIndexedSlices;
extern const PrimitivePtr kPrimIndexedSlicesGetValues;
extern const PrimitivePtr kPrimIndexedSlicesGetIndices;
extern const PrimitivePtr kPrimIndexedSlicesGetDenseShape;
extern const PrimitivePtr kPrimIsIndexedSlices;
class DoSignaturePrimitive : public Primitive {
public:
explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function)

@ -24,6 +24,7 @@
#include "pipeline/static_analysis/prim.h"
#include "pipeline/static_analysis/utils.h"
#include "utils/symbolic.h"
#include "utils/context/ms_context.h"
namespace mindspore {
namespace abstract {
@ -173,6 +174,13 @@ AbstractBasePtr InferImplEnvGetItem(const AnalysisEnginePtr &, const PrimitivePt
return std::make_shared<AbstractTuple>(sparse_list);
}
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool enable_sparse_flag = context->enable_sparse_flag();
if (enable_sparse_flag && key->has_indexed_slices_grad() && dflt->isa<AbstractTensor>()) {
auto dflt_tensor = dflt->cast<AbstractTensorPtr>();
return std::make_shared<AbstractUndetermined>(dflt_tensor->element()->Clone(), dflt_tensor->shape()->Clone());
}
if (!key->GetValueTrack()->isa<SymbolicKeyInstance>()) {
return dflt;
}
@ -236,6 +244,7 @@ AbstractBasePtr InferImplMakeRef(const AnalysisEnginePtr &, const PrimitivePtr &
}
auto ret = std::make_shared<AbstractRef>(args_spec_list[0], args_spec_list[1], args_spec_list[2]);
ret->set_sparse_grad(args_spec_list[2]->sparse_grad());
ret->set_has_indexed_slices_grad(args_spec_list[2]->has_indexed_slices_grad());
return ret;
}
@ -437,5 +446,72 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv
}
return std::make_shared<AbstractScalar>(kAnyValue, kBool);
}
AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors and a tuple.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 3);
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
auto dense_shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 2);
auto dense_shape_value = dense_shape->BuildValue()->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(dense_shape_value);
auto shp = dense_shape_value->value();
std::vector<int> dense_shape_vec;
(void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec),
[](const ValuePtr &e) -> int {
auto elem = GetValue<int>(e);
return elem;
});
auto ret = std::make_shared<AbstractIndexedSlices>(values->element()->BuildType(), dense_shape_vec);
ret->set_indices(indices);
ret->set_values(values);
ret->set_dense_shape(dense_shape);
return ret;
}
AbstractBasePtr InferImplIndexedSlicesGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors and a tuple.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
auto indexed_slices = CheckArg<AbstractIndexedSlices>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(indexed_slices->values());
return indexed_slices->values();
}
AbstractBasePtr InferImplIndexedSlicesGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors and a tuple.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
auto indexed_slices = CheckArg<AbstractIndexedSlices>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(indexed_slices->indices());
return indexed_slices->indices();
}
AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors and a tuple.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
auto indexed_slices = CheckArg<AbstractIndexedSlices>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(indexed_slices->dense_shape());
return indexed_slices->dense_shape();
}
AbstractBasePtr InferImplIsIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
bool ret = false;
if (args_spec_list[0]->isa<AbstractIndexedSlices>()) {
ret = true;
}
MS_LOG(DEBUG) << "IsIndexedSlices result: " << ret << ", input: " << args_spec_list[0]->ToString();
return std::make_shared<AbstractScalar>(ret);
}
} // namespace abstract
} // namespace mindspore

@ -36,6 +36,7 @@ using mindspore::abstract::AbstractJTagged;
using mindspore::abstract::AbstractList;
using mindspore::abstract::AbstractScalar;
using mindspore::abstract::AbstractTuple;
using mindspore::abstract::AbstractUndetermined;
static AbstractBasePtr Reabs(const AbstractBasePtr &t) {
if (t == nullptr) {
@ -78,7 +79,7 @@ AnfNodePtr ConvertGetAttrToTupleGetItem(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(cons);
auto dt = data->abstract();
if (dt == nullptr) {
if (dt == nullptr || dt->BuildType()->type_id() == kObjectTypeUndeterminedType) {
return nullptr;
}

@ -42,6 +42,7 @@
#include "optimizer/irpass/tile_eliminate.h"
#include "optimizer/irpass/transpose_eliminate.h"
#include "optimizer/opt.h"
#include "optimizer/irpass/indexed_slices_eliminate.h"
namespace mindspore {
namespace opt {
@ -153,6 +154,11 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
// Mark interface fusion
mark_interface_fusion_ =
MakeSubstitution(std::make_shared<MarkInterfaceFusion>(), "mark_interface_fusion", prim::kPrimSelect);
// IndexedSlices Eliminate
indexed_slices_eliminate_ = MakeSubstitution(
std::make_shared<IndexedSlicesEliminater>(), "indexed_slices_eliminate",
{prim::kPrimIndexedSlicesGetIndices, prim::kPrimIndexedSlicesGetValues, prim::kPrimIndexedSlicesGetDenseShape});
}
ResolveIRPassLib::ResolveIRPassLib() {

@ -104,6 +104,9 @@ class OptimizeIRPassLib {
// Fusion
SubstitutionPtr mark_interface_fusion_;
// IndexedSlices Eliminate
SubstitutionPtr indexed_slices_eliminate_;
};
// the collection of irpass for resolve action

@ -0,0 +1,75 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_
#include <vector>
#include <algorithm>
#include "optimizer/irpass.h"
#include "optimizer/optimizer.h"
#include "ir/visitor.h"
#include "operator/ops.h"
namespace mindspore {
namespace opt {
namespace irpass {
// {prim::kPrimIndexedSlicesGetIndices, {prim::kPrimMakeIndexedSlices, Xs}}
// {prim::kPrimIndexedSlicesGetValues, {prim::kPrimMakeIndexedSlices, Xs}}
// {prim::kPrimIndexedSlicesGetDenseShape, {prim::kPrimMakeIndexedSlices, Xs}}
class IndexedSlicesEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimIndexedSlicesGetIndices, {IsCNode})(node);
if (is_match_) {
return tuple_->input(1);
}
AnfVisitor::Match(prim::kPrimIndexedSlicesGetValues, {IsCNode})(node);
if (is_match_) {
return tuple_->input(2);
}
AnfVisitor::Match(prim::kPrimIndexedSlicesGetDenseShape, {IsCNode})(node);
if (is_match_) {
return tuple_->input(3);
}
return nullptr;
}
void Visit(const CNodePtr &cnode) override {
if (IsPrimitiveCNode(cnode, prim::kPrimMakeIndexedSlices)) {
tuple_ = cnode;
is_match_ = true;
}
}
void Reset() {
tuple_ = nullptr;
is_match_ = false;
}
private:
bool is_match_{false};
CNodePtr tuple_{nullptr};
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_INDEXED_SLICES_ELIMINATE_H_

@ -232,6 +232,9 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
auto sparse_grad =
py::cast<std::string>(parse::python_adapter::GetPyObjAttr(param_value->value(), "sparse_grad"));
ptr->set_sparse_grad(sparse_grad);
auto has_indexed_slices_grad =
py::cast<bool>(parse::python_adapter::GetPyObjAttr(param_value->value(), "has_indexed_slices_grad"));
ptr->set_has_indexed_slices_grad(has_indexed_slices_grad);
parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr);
args_spec.push_back(ptr);

@ -154,7 +154,9 @@ PYBIND11_MODULE(_c_expression, m) {
.def("set_print_file_path", &mindspore::MsContext::set_print_file_path, "Set path to print.")
.def("set_enable_graph_kernel", &mindspore::MsContext::set_enable_graph_kernel,
"Set the GraphKernel switch to on or off.")
.def("get_enable_graph_kernel", &mindspore::MsContext::enable_graph_kernel, "Get the value of GraphKernel switch.");
.def("get_enable_graph_kernel", &mindspore::MsContext::enable_graph_kernel, "Get the value of GraphKernel switch.")
.def("get_enable_sparse_flag", &mindspore::MsContext::enable_sparse_flag, "Get whether to enable sparse.")
.def("set_enable_sparse_flag", &mindspore::MsContext::set_enable_sparse_flag, "Set whether to enable sparse.");
(void)py::class_<mindspore::MpiConfig, std::shared_ptr<mindspore::MpiConfig>>(m, "MpiConfig")
.def_static("get_instance", &mindspore::MpiConfig::GetInstance, "Get mpi config instance.")

@ -156,6 +156,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.replace_refkey_by_param_,
irpass.make_ref_eliminate_,
irpass.get_ref_param_eliminate_,
irpass.indexed_slices_eliminate_,
});
OptPassGroupMap map({
{"b_1", b_1},

File diff suppressed because it is too large Load Diff

@ -30,6 +30,10 @@ bool AbstractBase::operator==(const AbstractBase &other) const {
if (tid() != other.tid()) {
return false;
}
if (BuildType()->type_id() == kObjectTypeUndeterminedType &&
other.BuildType()->type_id() == kObjectTypeUndeterminedType) {
return true;
}
if (value_ == nullptr || other.value_ == nullptr) {
MS_LOG(EXCEPTION) << "If value_ is nullptr, AbstractBase::operator== should not be called. this: "
<< this->ToString() << ", other: " << other.ToString();
@ -65,7 +69,7 @@ std::string AbstractBase::ToString() const {
MS_EXCEPTION_IF_NULL(shape_);
buffer << type_name() << "("
<< "Type: " << type_->ToString() << " Value: " << value << " Shape: " << shape_->ToString()
<< " sparse_grad: " << sparse_grad_ << ")";
<< " sparse_grad: " << sparse_grad_ << " has_indexed_slices_grad: " << has_indexed_slices_grad_ << ")";
return buffer.str();
}
@ -76,6 +80,7 @@ AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
if (*this == *other) {
auto ret = shared_from_base<AbstractBase>();
ret->set_sparse_grad(sparse_grad());
ret->set_has_indexed_slices_grad(has_indexed_slices_grad());
return ret;
}
auto value_self = GetValueTrack();
@ -85,10 +90,12 @@ AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) {
if (res_value == value_self) {
auto ret = shared_from_base<AbstractBase>();
ret->set_sparse_grad(sparse_grad());
ret->set_has_indexed_slices_grad(has_indexed_slices_grad());
return ret;
}
auto ret = std::make_shared<AbstractScalar>(res_value, res_type);
ret->set_sparse_grad(sparse_grad());
ret->set_has_indexed_slices_grad(has_indexed_slices_grad());
return ret;
}
@ -409,6 +416,14 @@ std::size_t AbstractSlice::hash() const {
return hash_combine({tid(), start_->hash(), stop_->hash(), step_->hash()});
}
ShapePtr AbstractUndetermined::shape() const {
auto shp = dyn_cast<Shape>(GetShapeTrack());
if (shp == nullptr) {
MS_LOG(EXCEPTION) << "Tensor should have a shape.";
}
return shp;
}
TypePtr AbstractTensor::BuildType() const {
MS_EXCEPTION_IF_NULL(element_);
TypePtr element_type = element_->BuildType();
@ -425,6 +440,13 @@ BaseShapePtr AbstractTensor::BuildShape() const {
}
AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
if (other->BuildType()->type_id() == kObjectTypeUndeterminedType) {
auto other_tensor = dyn_cast<AbstractUndetermined>(other);
auto element = element_->Join(other_tensor->element());
auto shape = ShapeJoin(this->shape(), other_tensor->shape());
auto ret = std::make_shared<AbstractUndetermined>(element, shape);
return ret;
}
auto other_tensor = dyn_cast<AbstractTensor>(other);
if (other_tensor == nullptr) {
MS_LOG(EXCEPTION) << "Join failed as type mismatch, this: " << ToString() << ", other: " << other->ToString();
@ -433,6 +455,7 @@ AbstractBasePtr AbstractTensor::Join(const AbstractBasePtr &other) {
auto shape = ShapeJoin(this->shape(), other_tensor->shape());
auto ret = std::make_shared<AbstractTensor>(element, shape);
ret->set_sparse_grad(sparse_grad());
ret->set_has_indexed_slices_grad(has_indexed_slices_grad());
return ret;
}
@ -474,6 +497,7 @@ AbstractBasePtr AbstractTensor::Clone() const {
clone->set_shape(shp->Clone());
clone->set_value(GetValueTrack());
clone->set_sparse_grad(sparse_grad());
clone->set_has_indexed_slices_grad(has_indexed_slices_grad());
return clone;
}
@ -484,6 +508,7 @@ AbstractBasePtr AbstractTensor::Broaden() const {
broaden->set_shape(shp->Clone());
broaden->set_value(kAnyValue);
broaden->set_sparse_grad(sparse_grad());
broaden->set_has_indexed_slices_grad(has_indexed_slices_grad());
return broaden;
}
@ -495,17 +520,10 @@ AbstractBasePtr AbstractTensor::BroadenWithShape() const {
broaden->set_shape(shp);
broaden->set_value(kAnyValue);
broaden->set_sparse_grad(sparse_grad());
broaden->set_has_indexed_slices_grad(has_indexed_slices_grad());
return broaden;
}
ShapePtr AbstractTensor::shape() const {
auto shp = dyn_cast<Shape>(GetShapeTrack());
if (shp == nullptr) {
MS_LOG(EXCEPTION) << "Tensor should have a shape.";
}
return shp;
}
std::string AbstractTensor::ToString() const {
std::ostringstream buffer;
BaseShapePtr shape_track = GetShapeTrack();
@ -516,7 +534,7 @@ std::string AbstractTensor::ToString() const {
buffer << type_name() << "("
<< "shape: " << shape_track->ToString() << ", element: " << element_->ToString()
<< ", value_ptr: " << value_track << ", value: " << value_track->ToString() << " sparse_grad " << sparse_grad()
<< ")";
<< " has_indexed_slices_grad " << has_indexed_slices_grad() << ")";
return buffer.str();
}
@ -1019,5 +1037,64 @@ std::size_t AbstractBasePtrListHasher::operator()(const AbstractBasePtrList &arg
bool AbstractBasePtrListEqual::operator()(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs) const {
return AbstractBasePtrListDeepEqual(lhs, rhs);
}
// IndexedSlices
TypePtr AbstractIndexedSlices::BuildType() const {
MS_EXCEPTION_IF_NULL(element());
TypePtr element_type = element()->BuildType();
return std::make_shared<IndexedSlicesType>(element_type);
}
AbstractBasePtr AbstractIndexedSlices::Clone() const {
MS_EXCEPTION_IF_NULL(element());
auto clone = std::make_shared<AbstractIndexedSlices>(element()->Clone());
ShapePtr shp = shape();
clone->set_shape(shp->Clone());
clone->set_value(GetValueTrack());
clone->set_indices(indices_->Clone()->cast<AbstractTensorPtr>());
clone->set_values(values_->Clone()->cast<AbstractTensorPtr>());
clone->set_dense_shape(dense_shape_->Clone()->cast<AbstractTuplePtr>());
return clone;
}
AbstractBasePtr AbstractIndexedSlices::Broaden() const {
MS_EXCEPTION_IF_NULL(element());
auto broaden = std::make_shared<AbstractIndexedSlices>(element()->Broaden());
auto shp = shape();
broaden->set_shape(shp->Clone());
broaden->set_value(kAnyValue);
broaden->set_indices(indices_->Clone()->cast<AbstractTensorPtr>());
broaden->set_values(values_->Clone()->cast<AbstractTensorPtr>());
broaden->set_dense_shape(dense_shape_->Clone()->cast<AbstractTuplePtr>());
return broaden;
}
AbstractBasePtr AbstractIndexedSlices::BroadenWithShape() const {
MS_EXCEPTION_IF_NULL(element());
auto broaden = std::make_shared<AbstractIndexedSlices>(element()->Broaden());
auto shp = shape()->Clone();
shp->Broaden();
broaden->set_shape(shp);
broaden->set_value(kAnyValue);
broaden->set_indices(indices_->Clone()->cast<AbstractTensorPtr>());
broaden->set_values(values_->Clone()->cast<AbstractTensorPtr>());
broaden->set_dense_shape(dense_shape_->Clone()->cast<AbstractTuplePtr>());
return broaden;
}
std::string AbstractIndexedSlices::ToString() const {
std::ostringstream buffer;
BaseShapePtr shape_track = GetShapeTrack();
MS_EXCEPTION_IF_NULL(shape_track);
MS_EXCEPTION_IF_NULL(element());
auto value_track = GetValueTrack();
MS_EXCEPTION_IF_NULL(value_track);
buffer << type_name() << "("
<< "shape: " << shape_track->ToString() << ", element: " << element()->ToString()
<< ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")"
<< ", indices: " << indices_->ToString() << ", values" << values_->ToString()
<< ", dense_shape: " << dense_shape_->ToString();
return buffer.str();
}
} // namespace abstract
} // namespace mindspore

@ -44,7 +44,7 @@ class AbstractBase : public Base {
public:
explicit AbstractBase(const ValuePtr &value = nullptr, const TypePtr &type = kAnyType,
const BaseShapePtr &shape = kNoShape)
: value_(value), type_(type), shape_(shape), sparse_grad_("") {}
: value_(value), type_(type), shape_(shape), sparse_grad_(""), has_indexed_slices_grad_(false) {}
~AbstractBase() override = default;
MS_DECLARE_PARENT(AbstractBase, Base)
@ -54,12 +54,16 @@ class AbstractBase : public Base {
virtual bool operator==(const AbstractBase &other) const;
void set_value(const ValuePtr &value) { value_ = value; }
void set_sparse_grad(const std::string &sparse_grad) { sparse_grad_ = sparse_grad; }
void set_has_indexed_slices_grad(const bool &has_indexed_slices_grad) {
has_indexed_slices_grad_ = has_indexed_slices_grad;
}
void set_type(const TypePtr &type) { type_ = type; }
void set_shape(const BaseShapePtr &shape) { shape_ = shape; }
void set_value_desc(const std::string &desc) { value_desc_ = desc; }
const std::string &value_desc() const { return value_desc_; }
ValuePtr GetValueTrack() const { return value_; }
const std::string &sparse_grad() const { return sparse_grad_; }
const bool &has_indexed_slices_grad() const { return has_indexed_slices_grad_; }
TypePtr GetTypeTrack() const { return type_; }
BaseShapePtr GetShapeTrack() const { return shape_; }
@ -88,6 +92,7 @@ class AbstractBase : public Base {
BaseShapePtr shape_;
std::string value_desc_; // store initial value description for error report
std::string sparse_grad_;
bool has_indexed_slices_grad_;
};
class AbstractScalar : public AbstractBase {
@ -231,35 +236,49 @@ class AbstractKeywordArg : public AbstractBase {
};
using AbstractKeywordArgPtr = std::shared_ptr<AbstractKeywordArg>;
class AbstractTensor : public AbstractBase {
class AbstractUndetermined : public AbstractBase {
public:
// shape and type are all unknown
AbstractUndetermined() : AbstractBase(kAnyValue) {}
// only element_ and value, shape track are valid member, type track are unknown.
explicit AbstractTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>())
explicit AbstractUndetermined(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>())
: AbstractBase(kAnyValue), element_(element) {
if (element == nullptr) {
MS_LOG(EXCEPTION) << "element is nullptr";
}
if (element->isa<AbstractTensor>()) {
if (element->isa<AbstractUndetermined>()) {
MS_LOG(EXCEPTION) << "element type error";
}
set_shape(shape);
}
AbstractTensor(const TypePtr &element_type, const std::vector<int> &shape)
AbstractUndetermined(const TypePtr &element_type, const std::vector<int> &shape)
: AbstractBase(kAnyValue), element_(std::make_shared<AbstractScalar>(kAnyValue, element_type)) {
if (element_type == nullptr) {
MS_LOG(EXCEPTION) << "element_type is nullptr";
}
set_shape(std::make_shared<Shape>(shape));
}
explicit AbstractTensor(const tensor::TensorPtr &tensor)
: AbstractBase(tensor), element_(std::make_shared<AbstractScalar>(kAnyValue, tensor->Dtype())) {
if (tensor == nullptr) {
MS_LOG(EXCEPTION) << "tensor is nullptr";
}
set_shape(std::make_shared<Shape>(tensor->shape()));
}
~AbstractUndetermined() override = default;
MS_DECLARE_PARENT(AbstractUndetermined, AbstractBase)
TypePtr BuildType() const override { return std::make_shared<UndeterminedType>(); }
AbstractBasePtr Clone() const override { return std::make_shared<AbstractUndetermined>(); }
const AbstractBasePtr element() const { return element_; }
ShapePtr shape() const;
protected:
AbstractBasePtr element_;
};
class AbstractTensor : public AbstractUndetermined {
public:
// only element_ and value, shape track are valid member, type track are unknown.
explicit AbstractTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>())
: AbstractUndetermined(element, shape) {}
AbstractTensor(const TypePtr &element_type, const std::vector<int> &shape)
: AbstractUndetermined(element_type, shape) {}
explicit AbstractTensor(const tensor::TensorPtr &tensor) : AbstractUndetermined(tensor->Dtype(), tensor->shape()) {}
~AbstractTensor() override = default;
MS_DECLARE_PARENT(AbstractTensor, AbstractBase)
MS_DECLARE_PARENT(AbstractTensor, AbstractUndetermined)
TypePtr BuildType() const override;
BaseShapePtr BuildShape() const override;
@ -271,9 +290,7 @@ class AbstractTensor : public AbstractBase {
bool operator==(const AbstractTensor &other) const;
bool operator==(const AbstractBase &other) const override;
ShapePtr shape() const;
std::string ToString() const override;
const AbstractBasePtr element() const { return element_; }
std::size_t hash() const override {
auto value = GetValueTrack();
auto hash_sum = hash_combine(tid(), element_->hash());
@ -285,9 +302,6 @@ class AbstractTensor : public AbstractBase {
}
return hash_sum;
}
private:
AbstractBasePtr element_;
};
using AbstractTensorPtr = std::shared_ptr<AbstractTensor>;
using AbstractTensorPtrList = std::vector<AbstractTensorPtr>;
@ -585,6 +599,35 @@ struct AbstractBasePtrListEqual {
std::size_t AbstractBasePtrListHash(const AbstractBasePtrList &args_spec_list);
bool AbstractBasePtrListDeepEqual(const AbstractBasePtrList &lhs, const AbstractBasePtrList &rhs);
// IndexedSlices
class AbstractIndexedSlices : public AbstractUndetermined {
public:
explicit AbstractIndexedSlices(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>())
: AbstractUndetermined(element, shape) {}
AbstractIndexedSlices(const TypePtr &element_type, const std::vector<int> &shape)
: AbstractUndetermined(element_type, shape) {}
~AbstractIndexedSlices() override = default;
MS_DECLARE_PARENT(AbstractIndexedSlices, AbstractUndetermined)
const AbstractTensorPtr indices() const { return indices_; }
const AbstractTensorPtr values() const { return values_; }
const AbstractTuplePtr dense_shape() const { return dense_shape_; }
void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; }
void set_values(const AbstractTensorPtr &values) { values_ = values; }
void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; }
TypePtr BuildType() const override;
AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden() const override;
AbstractBasePtr BroadenWithShape() const;
std::string ToString() const override;
private:
AbstractTensorPtr indices_;
AbstractTensorPtr values_;
AbstractTuplePtr dense_shape_;
};
} // namespace abstract
} // namespace mindspore
#endif // PIPELINE_STATIC_ANALYSIS_ABSTRACT_VALUE_H_

@ -58,6 +58,20 @@ class Evaluator : public Base {
return args_spec_list;
}
virtual EvalResultPtr AbstractEval(const AbstractBasePtrList &args_spec_list) {
auto is_abstract = std::any_of(args_spec_list.begin(), args_spec_list.end(), [](auto &arg) {
if (arg->BuildType()->type_id() == kObjectTypeUndeterminedType) {
return true;
}
return false;
});
if (is_abstract) {
MS_LOG(DEBUG) << "Eval " << identifier_ << " return abstract result";
return std::make_shared<EvalResult>(std::make_shared<AbstractUndetermined>(), std::make_shared<AttrValueMap>());
}
return nullptr;
}
std::string ToString() const override { return identifier_; }
virtual AnfNodePtr bound_node() const { return bound_node_.lock(); }

@ -66,6 +66,7 @@ ABSTRACT_REPORT_NAME_TRAITS(Function)
ABSTRACT_REPORT_NAME_TRAITS(Type)
ABSTRACT_REPORT_NAME_TRAITS(KeywordArg)
ABSTRACT_REPORT_NAME_TRAITS(Class)
ABSTRACT_REPORT_NAME_TRAITS(IndexedSlices)
template <typename T>
std::shared_ptr<T> CheckArg(const std::string &op, const AbstractBasePtrList &args_spec_list, size_t index) {

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save