add coo_tensor

pull/3114/head
panyifeng 5 years ago
parent 11732f0ea2
commit 5a10383cc3

@ -17,7 +17,7 @@
"""Resources for ast tree parse.""" """Resources for ast tree parse."""
import ast import ast
import math import math
from mindspore import IndexedSlices from mindspore import IndexedSlices, SparseTensor
from mindspore.ops.composite import multitype_ops from mindspore.ops.composite import multitype_ops
from mindspore.ops import functional as F, composite as C from mindspore.ops import functional as F, composite as C
from . import standard_method as M from . import standard_method as M
@ -140,4 +140,5 @@ convert_object_map = {
# user defined # user defined
IndexedSlices: F.make_indexed_slices, IndexedSlices: F.make_indexed_slices,
SparseTensor: F.make_sparse_tensor,
} }

@ -124,6 +124,8 @@ void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &s
// Do Nothing // Do Nothing
} else if (type->isa<UndeterminedType>()) { } else if (type->isa<UndeterminedType>()) {
// Do Nothing // Do Nothing
} else if (type->isa<SparseTensorType>()) {
// Do Nothing
} else if (type->isa<Tuple>()) { } else if (type->isa<Tuple>()) {
TuplePtr tuple_type = dyn_cast<Tuple>(type); TuplePtr tuple_type = dyn_cast<Tuple>(type);
type_proto->set_data_type(irpb::DT_TUPLE); type_proto->set_data_type(irpb::DT_TUPLE);

@ -803,6 +803,18 @@ FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_li
abstract::AbstractTuplePtr a_tuple = dyn_cast<AbstractTuple>(abs_a); abstract::AbstractTuplePtr a_tuple = dyn_cast<AbstractTuple>(abs_a);
abstract::AbstractTuplePtr b_tuple = dyn_cast<AbstractTuple>(abs_b); abstract::AbstractTuplePtr b_tuple = dyn_cast<AbstractTuple>(abs_b);
if (a_tuple == nullptr || b_tuple == nullptr) { if (a_tuple == nullptr || b_tuple == nullptr) {
TypePtrList types;
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types),
[](const AbstractBasePtr &arg) -> TypePtr {
MS_EXCEPTION_IF_NULL(arg);
return arg->BuildType();
});
auto stub = GenerateStubFunc(types);
if (stub != nullptr) {
MS_LOG(DEBUG) << "GenerateStubFunc for TupleAdd "
<< ", function: " << stub->ToString();
return stub;
}
MS_LOG(EXCEPTION) << "TupleAdd argument should be tuple,but " << args_spec_list[0]->ToString() << ", " MS_LOG(EXCEPTION) << "TupleAdd argument should be tuple,but " << args_spec_list[0]->ToString() << ", "
<< args_spec_list[1]->ToString(); << args_spec_list[1]->ToString();
} }

@ -119,42 +119,6 @@ const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) {
return py::none(); return py::none();
} }
FuncGraphPtr GenerateStubFunc(const TypePtrList &types) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool enable_sparse = context->enable_sparse();
if (!enable_sparse) {
return nullptr;
}
std::vector<AnfNodePtr> parameters;
ParameterPtr undetermined_param = nullptr;
auto stub = std::make_shared<FuncGraph>();
for (size_t i = 0; i < types.size(); ++i) {
auto param = stub->add_parameter();
parameters.push_back(param);
if (types[i]->type_id() == kObjectTypeUndeterminedType) {
undetermined_param = param;
}
}
if (undetermined_param != nullptr) {
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
for (size_t i = 0; i < types.size(); ++i) {
if (types[i]->type_id() == kObjectTypeFunction) {
std::vector<AnfNodePtr> call_prim{parameters[i], undetermined_param};
inputs.push_back(stub->NewCNode(call_prim));
} else {
inputs.push_back(parameters[i]);
}
}
auto stub_output = stub->NewCNode(inputs);
stub->set_output(stub_output);
stub->set_stub(true);
return stub;
}
return nullptr;
}
FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) { FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
auto py_fn = SignMatch(types); auto py_fn = SignMatch(types);
std::ostringstream buffer; std::ostringstream buffer;

@ -283,6 +283,11 @@ const PrimitivePtr kPrimMakeIndexedSlices = std::make_shared<Primitive>("MakeInd
const PrimitivePtr kPrimIndexedSlicesGetValues = std::make_shared<Primitive>("IndexedSlicesGetValues"); const PrimitivePtr kPrimIndexedSlicesGetValues = std::make_shared<Primitive>("IndexedSlicesGetValues");
const PrimitivePtr kPrimIndexedSlicesGetIndices = std::make_shared<Primitive>("IndexedSlicesGetIndices"); const PrimitivePtr kPrimIndexedSlicesGetIndices = std::make_shared<Primitive>("IndexedSlicesGetIndices");
const PrimitivePtr kPrimIndexedSlicesGetDenseShape = std::make_shared<Primitive>("IndexedSlicesGetDenseShape"); const PrimitivePtr kPrimIndexedSlicesGetDenseShape = std::make_shared<Primitive>("IndexedSlicesGetDenseShape");
const PrimitivePtr kPrimIsIndexedSlices = std::make_shared<Primitive>("IsIndexedSlices");
// SparseTensor
const PrimitivePtr kPrimMakeSparseTensor = std::make_shared<Primitive>("MakeSparseTensor");
const PrimitivePtr kPrimSparseTensorGetValues = std::make_shared<Primitive>("SparseTensorGetValues");
const PrimitivePtr kPrimSparseTensorGetIndices = std::make_shared<Primitive>("SparseTensorGetIndices");
const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared<Primitive>("SparseTensorGetDenseShape");
} // namespace prim } // namespace prim
} // namespace mindspore } // namespace mindspore

@ -292,7 +292,12 @@ extern const PrimitivePtr kPrimMakeIndexedSlices;
extern const PrimitivePtr kPrimIndexedSlicesGetValues; extern const PrimitivePtr kPrimIndexedSlicesGetValues;
extern const PrimitivePtr kPrimIndexedSlicesGetIndices; extern const PrimitivePtr kPrimIndexedSlicesGetIndices;
extern const PrimitivePtr kPrimIndexedSlicesGetDenseShape; extern const PrimitivePtr kPrimIndexedSlicesGetDenseShape;
extern const PrimitivePtr kPrimIsIndexedSlices;
// SparseTensor
extern const PrimitivePtr kPrimMakeSparseTensor;
extern const PrimitivePtr kPrimSparseTensorGetValues;
extern const PrimitivePtr kPrimSparseTensorGetIndices;
extern const PrimitivePtr kPrimSparseTensorGetDenseShape;
// attribute 'unroll_flag' of primitive 'switch', when 'unroll_flag' is '0', 'switch' will not unroll // attribute 'unroll_flag' of primitive 'switch', when 'unroll_flag' is '0', 'switch' will not unroll
const char SWITCH_UNROLL_FLAG[] = "unroll_flag"; const char SWITCH_UNROLL_FLAG[] = "unroll_flag";

@ -349,6 +349,26 @@ AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const Prim
auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, 1); auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
auto dense_shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 2); auto dense_shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 2);
auto indices_dtype = indices->element()->BuildType();
if (!indices_dtype->isa<Int>()) {
MS_EXCEPTION(TypeError) << "The dtype of indices must be a Int, but got " << indices_dtype->ToString();
}
auto indices_shp = indices->shape()->shape();
if (indices_shp.size() != 1) {
MS_EXCEPTION(TypeError) << "Indices must be a 1 dimension tensor, but got a " << indices_shp.size()
<< " dimension tensor";
}
auto values_shp = values->shape()->shape();
if (indices_shp[0] != values_shp[0]) {
MS_EXCEPTION(TypeError) << "The first dimension of indices must be the same with the first dimension of values "
<< values_shp[0] << ", but got " << indices_shp[0];
}
for (auto elem_type : dense_shape->ElementsType()) {
if (!elem_type->isa<Int>()) {
MS_EXCEPTION(TypeError) << "The element type of dense_shape must be Int, but got " << elem_type->ToString();
}
}
auto dense_shape_value = dense_shape->BuildValue()->cast<ValueTuplePtr>(); auto dense_shape_value = dense_shape->BuildValue()->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(dense_shape_value); MS_EXCEPTION_IF_NULL(dense_shape_value);
auto shp = dense_shape_value->value(); auto shp = dense_shape_value->value();
@ -358,6 +378,12 @@ AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const Prim
auto elem = GetValue<int>(e); auto elem = GetValue<int>(e);
return elem; return elem;
}); });
for (auto dense_shape_elem : dense_shape_vec) {
if (dense_shape_elem < 0) {
MS_EXCEPTION(TypeError) << "The element of dense_shape must be positive, but got "
<< dense_shape_value->ToString();
}
}
auto ret = std::make_shared<AbstractIndexedSlices>(values->element()->BuildType(), dense_shape_vec); auto ret = std::make_shared<AbstractIndexedSlices>(values->element()->BuildType(), dense_shape_vec);
ret->set_indices(indices); ret->set_indices(indices);
ret->set_values(values); ret->set_values(values);
@ -395,16 +421,89 @@ AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, c
return indexed_slices->dense_shape(); return indexed_slices->dense_shape();
} }
AbstractBasePtr InferImplIsIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) { const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors and a tuple.
const std::string op_name = primitive->name(); const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1); CheckArgsSize(op_name, args_spec_list, 3);
bool ret = false; auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
if (args_spec_list[0]->isa<AbstractIndexedSlices>()) { auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
ret = true; auto dense_shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 2);
auto indices_dtype = indices->element()->BuildType();
if (!indices_dtype->isa<Int>()) {
MS_EXCEPTION(TypeError) << "The dtype of indices must be a Int, but got " << indices_dtype->ToString();
}
auto indices_shp = indices->shape()->shape();
if (indices_shp.size() != 2) {
MS_EXCEPTION(TypeError) << "Indices must be a 2 dimension tensor, but got a " << indices_shp.size()
<< " dimension tensor";
}
auto values_shp = values->shape()->shape();
if (values_shp.size() != 1) {
MS_EXCEPTION(TypeError) << "Values must be a 1 dimension tensor, but got a " << values_shp.size()
<< " dimension tensor";
}
if (indices_shp[0] != values_shp[0]) {
MS_EXCEPTION(TypeError) << "The first dimension of indices must be the same with the first dimension of values "
<< values_shp[0] << ", but got " << indices_shp[0];
}
for (auto elem_type : dense_shape->ElementsType()) {
if (!elem_type->isa<Int>()) {
MS_EXCEPTION(TypeError) << "The element type of dense_shape must be Int, but got " << elem_type->ToString();
}
} }
MS_LOG(DEBUG) << "IsIndexedSlices result: " << ret << ", input: " << args_spec_list[0]->ToString(); auto dense_shape_value = dense_shape->BuildValue()->cast<ValueTuplePtr>();
return std::make_shared<AbstractScalar>(ret); MS_EXCEPTION_IF_NULL(dense_shape_value);
auto shp = dense_shape_value->value();
std::vector<int> dense_shape_vec;
(void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec),
[](const ValuePtr &e) -> int {
auto elem = GetValue<int>(e);
return elem;
});
for (auto dense_shape_elem : dense_shape_vec) {
if (dense_shape_elem < 0) {
MS_EXCEPTION(TypeError) << "The element of dense_shape must be positive, but got "
<< dense_shape_value->ToString();
}
}
auto ret = std::make_shared<AbstractSparseTensor>(values->element()->BuildType(), dense_shape_vec);
ret->set_indices(indices);
ret->set_values(values);
ret->set_dense_shape(dense_shape);
return ret;
}
AbstractBasePtr InferImplSparseTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors and a tuple.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
auto sparse_tensor = CheckArg<AbstractSparseTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(sparse_tensor->values());
return sparse_tensor->values();
}
AbstractBasePtr InferImplSparseTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors and a tuple.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
auto sparse_tensor = CheckArg<AbstractSparseTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(sparse_tensor->indices());
return sparse_tensor->indices();
}
AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors and a tuple.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
auto sparse_tensor = CheckArg<AbstractSparseTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(sparse_tensor->dense_shape());
return sparse_tensor->dense_shape();
} }
} // namespace abstract } // namespace abstract
} // namespace mindspore } // namespace mindspore

@ -264,7 +264,7 @@ FuncGraphPtr KPrim::FakeBprop(const ValueNodePtr &value_node, const pipeline::Re
return IsPrimitiveCNode(user.first, prim); return IsPrimitiveCNode(user.first, prim);
}); });
if (cnode == users.end()) { if (cnode == users.end()) {
MS_LOG(EXCEPTION) << "Fail to find cnode."; MS_LOG(EXCEPTION) << "Fail to find user for " << prim->ToString();
} }
auto inputs_num = cnode->first->cast<CNodePtr>()->inputs().size() - 1; auto inputs_num = cnode->first->cast<CNodePtr>()->inputs().size() - 1;

@ -43,6 +43,7 @@
#include "frontend/optimizer/irpass/transpose_eliminate.h" #include "frontend/optimizer/irpass/transpose_eliminate.h"
#include "frontend/optimizer/opt.h" #include "frontend/optimizer/opt.h"
#include "frontend/optimizer/irpass/indexed_slices_eliminate.h" #include "frontend/optimizer/irpass/indexed_slices_eliminate.h"
#include "frontend/optimizer/irpass/sparse_tensor_eliminate.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
@ -159,6 +160,11 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
indexed_slices_eliminate_ = MakeSubstitution( indexed_slices_eliminate_ = MakeSubstitution(
std::make_shared<IndexedSlicesEliminater>(), "indexed_slices_eliminate", std::make_shared<IndexedSlicesEliminater>(), "indexed_slices_eliminate",
{prim::kPrimIndexedSlicesGetIndices, prim::kPrimIndexedSlicesGetValues, prim::kPrimIndexedSlicesGetDenseShape}); {prim::kPrimIndexedSlicesGetIndices, prim::kPrimIndexedSlicesGetValues, prim::kPrimIndexedSlicesGetDenseShape});
// SparseTensor Eliminate
sparse_tensor_eliminate_ = MakeSubstitution(
std::make_shared<SparseTensorEliminater>(), "sparse_tensor_eliminate",
{prim::kPrimSparseTensorGetIndices, prim::kPrimSparseTensorGetValues, prim::kPrimSparseTensorGetDenseShape});
} }
ResolveIRPassLib::ResolveIRPassLib() { ResolveIRPassLib::ResolveIRPassLib() {

@ -107,6 +107,9 @@ class OptimizeIRPassLib {
// IndexedSlices Eliminate // IndexedSlices Eliminate
SubstitutionPtr indexed_slices_eliminate_; SubstitutionPtr indexed_slices_eliminate_;
// SparseTensor Eliminate
SubstitutionPtr sparse_tensor_eliminate_;
}; };
// the collection of irpass for resolve action // the collection of irpass for resolve action

@ -0,0 +1,75 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPARSE_TENSOR_ELIMINATE_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPARSE_TENSOR_ELIMINATE_H_
#include <vector>
#include <algorithm>
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "ir/visitor.h"
#include "frontend/operator/ops.h"
namespace mindspore {
namespace opt {
namespace irpass {
// {prim::kPrimSparseTensorGetIndices, {prim::kPrimMakeSparseTensor, Xs}}
// {prim::kPrimSparseTensorGetValues, {prim::kPrimMakeSparseTensor, Xs}}
// {prim::kPrimSparseTensorGetDenseShape, {prim::kPrimMakeSparseTensor, Xs}}
class SparseTensorEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimSparseTensorGetIndices, {IsCNode})(node);
if (is_match_) {
return tuple_->input(1);
}
AnfVisitor::Match(prim::kPrimSparseTensorGetValues, {IsCNode})(node);
if (is_match_) {
return tuple_->input(2);
}
AnfVisitor::Match(prim::kPrimSparseTensorGetDenseShape, {IsCNode})(node);
if (is_match_) {
return tuple_->input(3);
}
return nullptr;
}
void Visit(const CNodePtr &cnode) override {
if (IsPrimitiveCNode(cnode, prim::kPrimMakeSparseTensor)) {
tuple_ = cnode;
is_match_ = true;
}
}
void Reset() {
tuple_ = nullptr;
is_match_ = false;
}
private:
bool is_match_{false};
CNodePtr tuple_{nullptr};
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPARSE_TENSOR_ELIMINATE_H_

@ -157,6 +157,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.make_ref_eliminate_, irpass.make_ref_eliminate_,
irpass.get_ref_param_eliminate_, irpass.get_ref_param_eliminate_,
irpass.indexed_slices_eliminate_, irpass.indexed_slices_eliminate_,
irpass.sparse_tensor_eliminate_,
}); });
OptPassGroupMap map({ OptPassGroupMap map({
{"b_1", b_1}, {"b_1", b_1},

@ -179,6 +179,12 @@ MethodMap &GetMethodMap() {
{"indices", prim::kPrimIndexedSlicesGetIndices}, // F.indexed_slices_get_indices {"indices", prim::kPrimIndexedSlicesGetIndices}, // F.indexed_slices_get_indices
{"dense_shape", prim::kPrimIndexedSlicesGetDenseShape}, // F.indexed_slices_get_dense_shape {"dense_shape", prim::kPrimIndexedSlicesGetDenseShape}, // F.indexed_slices_get_dense_shape
}}, }},
{kObjectTypeSparseTensorType,
{
{"values", prim::kPrimSparseTensorGetValues}, // F.sparse_tensor_get_values
{"indices", prim::kPrimSparseTensorGetIndices}, // F.sparse_tensor_get_indices
{"dense_shape", prim::kPrimSparseTensorGetDenseShape}, // F.sparse_tensor_get_dense_shape
}},
{kObjectTypeJTagged, {}}, {kObjectTypeJTagged, {}},
{kObjectTypeSymbolicKeyType, {}}, {kObjectTypeSymbolicKeyType, {}},
{kObjectTypeEnvType, {}}}; {kObjectTypeEnvType, {}}};

@ -138,7 +138,11 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimIndexedSlicesGetValues, {InferImplIndexedSlicesGetValues, true}}, {prim::kPrimIndexedSlicesGetValues, {InferImplIndexedSlicesGetValues, true}},
{prim::kPrimIndexedSlicesGetIndices, {InferImplIndexedSlicesGetIndices, true}}, {prim::kPrimIndexedSlicesGetIndices, {InferImplIndexedSlicesGetIndices, true}},
{prim::kPrimIndexedSlicesGetDenseShape, {InferImplIndexedSlicesGetDenseShape, true}}, {prim::kPrimIndexedSlicesGetDenseShape, {InferImplIndexedSlicesGetDenseShape, true}},
{prim::kPrimIsIndexedSlices, {InferImplIsIndexedSlices, true}}, // SparseTensor
{prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, true}},
{prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, true}},
{prim::kPrimSparseTensorGetIndices, {InferImplSparseTensorGetIndices, true}},
{prim::kPrimSparseTensorGetDenseShape, {InferImplSparseTensorGetDenseShape, true}},
}; };
return prim_eval_implement_map; return prim_eval_implement_map;
} }

@ -358,8 +358,14 @@ AbstractBasePtr InferImplIndexedSlicesGetIndices(const AnalysisEnginePtr &, cons
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplIsIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive, AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list); const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSparseTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSparseTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
} // namespace abstract } // namespace abstract
} // namespace mindspore } // namespace mindspore

@ -36,6 +36,7 @@ using mindspore::abstract::AbstractIndexedSlices;
using mindspore::abstract::AbstractJTagged; using mindspore::abstract::AbstractJTagged;
using mindspore::abstract::AbstractList; using mindspore::abstract::AbstractList;
using mindspore::abstract::AbstractScalar; using mindspore::abstract::AbstractScalar;
using mindspore::abstract::AbstractSparseTensor;
using mindspore::abstract::AbstractTensor; using mindspore::abstract::AbstractTensor;
using mindspore::abstract::AbstractTuple; using mindspore::abstract::AbstractTuple;
using mindspore::abstract::AbstractType; using mindspore::abstract::AbstractType;
@ -95,7 +96,7 @@ void ValidateAbstract(const AnfNodePtr &node) {
if (ptrBase->isa<AbstractType>() || ptrBase->isa<AbstractFunction>() || ptrBase->isa<AbstractTuple>() || if (ptrBase->isa<AbstractType>() || ptrBase->isa<AbstractFunction>() || ptrBase->isa<AbstractTuple>() ||
ptrBase->isa<AbstractList>() || ptrBase->isa<AbstractTensor>() || ptrBase->isa<AbstractIndexedSlices>() || ptrBase->isa<AbstractList>() || ptrBase->isa<AbstractTensor>() || ptrBase->isa<AbstractIndexedSlices>() ||
ptrBase->isa<abstract::AbstractRefKey>()) { ptrBase->isa<AbstractSparseTensor>() || ptrBase->isa<abstract::AbstractRefKey>()) {
return; return;
} }

@ -17,10 +17,10 @@ from . import dtype
from .api import ms_function from .api import ms_function
from .dtype import * from .dtype import *
from .parameter import Parameter, ParameterTuple from .parameter import Parameter, ParameterTuple
from .tensor import MetaTensor, Tensor, IndexedSlices from .tensor import MetaTensor, Tensor, IndexedSlices, SparseTensor
__all__ = [ __all__ = [
"MetaTensor", "Tensor", "IndexedSlices", # tensor "MetaTensor", "Tensor", "IndexedSlices", "SparseTensor", # tensor
'ms_function', # api 'ms_function', # api
'Parameter', 'ParameterTuple', # parameter 'Parameter', 'ParameterTuple', # parameter
"dtype" "dtype"

@ -21,7 +21,7 @@ from .._checkparam import check_type, check_typename
from . import dtype as mstype from . import dtype as mstype
from ._register_for_tensor import tensor_operator_registry from ._register_for_tensor import tensor_operator_registry
__all__ = ['Tensor', 'MetaTensor', 'IndexedSlices'] __all__ = ['Tensor', 'MetaTensor', 'IndexedSlices', 'SparseTensor']
np_types = (np.int8, np.int16, np.int32, np.int64, np_types = (np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64, np.float16, np.uint8, np.uint16, np.uint32, np.uint64, np.float16,
np.float32, np.float64, np.bool_) np.float32, np.float64, np.bool_)
@ -211,3 +211,7 @@ class Tensor(Tensor_):
class IndexedSlices: class IndexedSlices:
def __init__(self, indices, values, dense_shape): def __init__(self, indices, values, dense_shape):
raise NotImplementedError raise NotImplementedError
class SparseTensor:
def __init__(self, indices, values, dense_shape):
raise NotImplementedError

@ -1093,5 +1093,64 @@ std::string AbstractIndexedSlices::ToString() const {
<< ", dense_shape: " << dense_shape_->ToString(); << ", dense_shape: " << dense_shape_->ToString();
return buffer.str(); return buffer.str();
} }
// SparseTensor
TypePtr AbstractSparseTensor::BuildType() const {
MS_EXCEPTION_IF_NULL(element());
TypePtr element_type = element()->BuildType();
return std::make_shared<SparseTensorType>(element_type);
}
AbstractBasePtr AbstractSparseTensor::Clone() const {
MS_EXCEPTION_IF_NULL(element());
auto clone = std::make_shared<AbstractSparseTensor>(element()->Clone());
ShapePtr shp = shape();
clone->set_shape(shp->Clone());
clone->set_value(GetValueTrack());
clone->set_indices(indices_->Clone()->cast<AbstractTensorPtr>());
clone->set_values(values_->Clone()->cast<AbstractTensorPtr>());
clone->set_dense_shape(dense_shape_->Clone()->cast<AbstractTuplePtr>());
return clone;
}
AbstractBasePtr AbstractSparseTensor::Broaden() const {
MS_EXCEPTION_IF_NULL(element());
auto broaden = std::make_shared<AbstractSparseTensor>(element()->Broaden());
auto shp = shape();
broaden->set_shape(shp->Clone());
broaden->set_value(kAnyValue);
broaden->set_indices(indices_->Clone()->cast<AbstractTensorPtr>());
broaden->set_values(values_->Clone()->cast<AbstractTensorPtr>());
broaden->set_dense_shape(dense_shape_->Clone()->cast<AbstractTuplePtr>());
return broaden;
}
AbstractBasePtr AbstractSparseTensor::BroadenWithShape() const {
MS_EXCEPTION_IF_NULL(element());
auto broaden = std::make_shared<AbstractSparseTensor>(element()->Broaden());
auto shp = shape()->Clone();
shp->Broaden();
broaden->set_shape(shp);
broaden->set_value(kAnyValue);
broaden->set_indices(indices_->Clone()->cast<AbstractTensorPtr>());
broaden->set_values(values_->Clone()->cast<AbstractTensorPtr>());
broaden->set_dense_shape(dense_shape_->Clone()->cast<AbstractTuplePtr>());
return broaden;
}
std::string AbstractSparseTensor::ToString() const {
std::ostringstream buffer;
BaseShapePtr shape_track = GetShapeTrack();
MS_EXCEPTION_IF_NULL(shape_track);
MS_EXCEPTION_IF_NULL(element());
auto value_track = GetValueTrack();
MS_EXCEPTION_IF_NULL(value_track);
buffer << type_name() << "("
<< "shape: " << shape_track->ToString() << ", element: " << element()->ToString()
<< ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")"
<< ", indices: " << indices_->ToString() << ", values" << values_->ToString()
<< ", dense_shape: " << dense_shape_->ToString();
return buffer.str();
}
} // namespace abstract } // namespace abstract
} // namespace mindspore } // namespace mindspore

@ -604,10 +604,39 @@ class AbstractIndexedSlices : public AbstractUndetermined {
MS_DECLARE_PARENT(AbstractIndexedSlices, AbstractUndetermined) MS_DECLARE_PARENT(AbstractIndexedSlices, AbstractUndetermined)
const AbstractTensorPtr indices() const { return indices_; } const AbstractTensorPtr indices() const { return indices_; }
void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; }
const AbstractTensorPtr values() const { return values_; } const AbstractTensorPtr values() const { return values_; }
void set_values(const AbstractTensorPtr &values) { values_ = values; }
const AbstractTuplePtr dense_shape() const { return dense_shape_; } const AbstractTuplePtr dense_shape() const { return dense_shape_; }
void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; }
TypePtr BuildType() const override;
AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden() const override;
AbstractBasePtr BroadenWithShape() const;
std::string ToString() const override;
private:
AbstractTensorPtr indices_;
AbstractTensorPtr values_;
AbstractTuplePtr dense_shape_;
};
// SparseTensor
class AbstractSparseTensor : public AbstractUndetermined {
public:
explicit AbstractSparseTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>())
: AbstractUndetermined(element, shape) {}
AbstractSparseTensor(const TypePtr &element_type, const std::vector<int> &shape)
: AbstractUndetermined(element_type, shape) {}
~AbstractSparseTensor() override = default;
MS_DECLARE_PARENT(AbstractSparseTensor, AbstractUndetermined)
const AbstractTensorPtr indices() const { return indices_; }
void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; } void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; }
const AbstractTensorPtr values() const { return values_; }
void set_values(const AbstractTensorPtr &values) { values_ = values; } void set_values(const AbstractTensorPtr &values) { values_ = values; }
const AbstractTuplePtr dense_shape() const { return dense_shape_; }
void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; } void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; }
TypePtr BuildType() const override; TypePtr BuildType() const override;
AbstractBasePtr Clone() const override; AbstractBasePtr Clone() const override;

@ -67,6 +67,7 @@ ABSTRACT_REPORT_NAME_TRAITS(Type)
ABSTRACT_REPORT_NAME_TRAITS(KeywordArg) ABSTRACT_REPORT_NAME_TRAITS(KeywordArg)
ABSTRACT_REPORT_NAME_TRAITS(Class) ABSTRACT_REPORT_NAME_TRAITS(Class)
ABSTRACT_REPORT_NAME_TRAITS(IndexedSlices) ABSTRACT_REPORT_NAME_TRAITS(IndexedSlices)
ABSTRACT_REPORT_NAME_TRAITS(SparseTensor)
ABSTRACT_REPORT_NAME_TRAITS(Sequeue) ABSTRACT_REPORT_NAME_TRAITS(Sequeue)
template <typename T> template <typename T>

@ -221,6 +221,48 @@ bool IndexedSlicesType::operator==(const Type &other) const {
return *element_type_ == *other_elem_type; return *element_type_ == *other_elem_type;
} }
TypePtr SparseTensorType::DeepCopy() const {
MS_EXCEPTION_IF_NULL(element_type_);
if (IsGeneric()) {
return std::make_shared<SparseTensorType>();
}
return std::make_shared<SparseTensorType>(element_type_->DeepCopy());
}
std::string SparseTensorType::ToReprString() const {
if (element_type_ == nullptr) {
return "SparseTensor";
}
return "SparseTensor[" + element_type_->ToReprString() + "]";
}
std::string SparseTensorType::ToString() const {
if (element_type_ == nullptr) {
return "SparseTensor";
}
return "SparseTensor[" + element_type_->ToString() + "]";
}
std::string SparseTensorType::DumpText() const {
if (element_type_ == nullptr) {
return "SparseTensor";
}
return "SparseTensor[" + element_type_->DumpText() + "]";
}
bool SparseTensorType::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
auto other_elem_type = static_cast<const SparseTensorType &>(other).element_type_;
if (element_type_ == nullptr && other_elem_type == nullptr) {
return true;
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
return false;
}
return *element_type_ == *other_elem_type;
}
Function::Function() : Object(kObjectTypeFunction) { Function::Function() : Object(kObjectTypeFunction) {
args_ = std::vector<TypePtr>(); args_ = std::vector<TypePtr>();
retval_ = nullptr; retval_ = nullptr;

@ -177,6 +177,29 @@ class IndexedSlicesType : public Object {
}; };
using IndexedSlicesTypePtr = std::shared_ptr<IndexedSlicesType>; using IndexedSlicesTypePtr = std::shared_ptr<IndexedSlicesType>;
class SparseTensorType : public Object {
public:
SparseTensorType() : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType) {}
explicit SparseTensorType(const TypePtr &ele)
: Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
~SparseTensorType() override = default;
MS_DECLARE_PARENT(SparseTensorType, Object)
TypeId generic_type_id() const override { return kObjectTypeSparseTensorType; }
const TypePtr element() const { return element_type_; }
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string ToReprString() const override;
std::string DumpText() const override;
bool operator==(const Type &other) const override;
private:
TypePtr element_type_;
};
using SparseTensorTypePtr = std::shared_ptr<SparseTensorType>;
class Function : public Object { class Function : public Object {
public: public:
Function(); Function();

@ -117,6 +117,8 @@ const char *ObjectIdLabel(const TypeId &v) {
return "kObjectTypeTensorType"; return "kObjectTypeTensorType";
case kObjectTypeIndexedSlicesType: case kObjectTypeIndexedSlicesType:
return "kObjectTypeIndexedSlicesType"; return "kObjectTypeIndexedSlicesType";
case kObjectTypeSparseTensorType:
return "kObjectTypeSparseTensorType";
case kObjectTypeUndeterminedType: case kObjectTypeUndeterminedType:
return "kObjectTypeUndeterminedType"; return "kObjectTypeUndeterminedType";
case kObjectTypeDictionary: case kObjectTypeDictionary:

@ -51,6 +51,7 @@ enum TypeId : int {
kObjectTypeKeyword, kObjectTypeKeyword,
kObjectTypeTensorType, kObjectTypeTensorType,
kObjectTypeIndexedSlicesType, kObjectTypeIndexedSlicesType,
kObjectTypeSparseTensorType,
kObjectTypeUndeterminedType, kObjectTypeUndeterminedType,
kObjectTypeClass, kObjectTypeClass,
kObjectTypeDictionary, kObjectTypeDictionary,

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save