add coo_tensor

pull/3114/head
panyifeng 5 years ago
parent 11732f0ea2
commit 5a10383cc3

@ -17,7 +17,7 @@
"""Resources for ast tree parse."""
import ast
import math
from mindspore import IndexedSlices
from mindspore import IndexedSlices, SparseTensor
from mindspore.ops.composite import multitype_ops
from mindspore.ops import functional as F, composite as C
from . import standard_method as M
@ -140,4 +140,5 @@ convert_object_map = {
# user defined
IndexedSlices: F.make_indexed_slices,
SparseTensor: F.make_sparse_tensor,
}

@ -124,6 +124,8 @@ void ProtoExporter::SetNodeOutputType(const TypePtr &type, const BaseShapePtr &s
// Do Nothing
} else if (type->isa<UndeterminedType>()) {
// Do Nothing
} else if (type->isa<SparseTensorType>()) {
// Do Nothing
} else if (type->isa<Tuple>()) {
TuplePtr tuple_type = dyn_cast<Tuple>(type);
type_proto->set_data_type(irpb::DT_TUPLE);

@ -803,6 +803,18 @@ FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_li
abstract::AbstractTuplePtr a_tuple = dyn_cast<AbstractTuple>(abs_a);
abstract::AbstractTuplePtr b_tuple = dyn_cast<AbstractTuple>(abs_b);
if (a_tuple == nullptr || b_tuple == nullptr) {
TypePtrList types;
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types),
[](const AbstractBasePtr &arg) -> TypePtr {
MS_EXCEPTION_IF_NULL(arg);
return arg->BuildType();
});
auto stub = GenerateStubFunc(types);
if (stub != nullptr) {
MS_LOG(DEBUG) << "GenerateStubFunc for TupleAdd "
<< ", function: " << stub->ToString();
return stub;
}
MS_LOG(EXCEPTION) << "TupleAdd argument should be tuple,but " << args_spec_list[0]->ToString() << ", "
<< args_spec_list[1]->ToString();
}

@ -119,42 +119,6 @@ const py::function MultitypeFuncGraph::SignMatch(const TypePtrList &types) {
return py::none();
}
FuncGraphPtr GenerateStubFunc(const TypePtrList &types) {
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
bool enable_sparse = context->enable_sparse();
if (!enable_sparse) {
return nullptr;
}
std::vector<AnfNodePtr> parameters;
ParameterPtr undetermined_param = nullptr;
auto stub = std::make_shared<FuncGraph>();
for (size_t i = 0; i < types.size(); ++i) {
auto param = stub->add_parameter();
parameters.push_back(param);
if (types[i]->type_id() == kObjectTypeUndeterminedType) {
undetermined_param = param;
}
}
if (undetermined_param != nullptr) {
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
for (size_t i = 0; i < types.size(); ++i) {
if (types[i]->type_id() == kObjectTypeFunction) {
std::vector<AnfNodePtr> call_prim{parameters[i], undetermined_param};
inputs.push_back(stub->NewCNode(call_prim));
} else {
inputs.push_back(parameters[i]);
}
}
auto stub_output = stub->NewCNode(inputs);
stub->set_output(stub_output);
stub->set_stub(true);
return stub;
}
return nullptr;
}
FuncGraphPtr MultitypeFuncGraph::GenerateFromTypes(const TypePtrList &types) {
auto py_fn = SignMatch(types);
std::ostringstream buffer;

@ -283,6 +283,11 @@ const PrimitivePtr kPrimMakeIndexedSlices = std::make_shared<Primitive>("MakeInd
const PrimitivePtr kPrimIndexedSlicesGetValues = std::make_shared<Primitive>("IndexedSlicesGetValues");
const PrimitivePtr kPrimIndexedSlicesGetIndices = std::make_shared<Primitive>("IndexedSlicesGetIndices");
const PrimitivePtr kPrimIndexedSlicesGetDenseShape = std::make_shared<Primitive>("IndexedSlicesGetDenseShape");
const PrimitivePtr kPrimIsIndexedSlices = std::make_shared<Primitive>("IsIndexedSlices");
// SparseTensor
const PrimitivePtr kPrimMakeSparseTensor = std::make_shared<Primitive>("MakeSparseTensor");
const PrimitivePtr kPrimSparseTensorGetValues = std::make_shared<Primitive>("SparseTensorGetValues");
const PrimitivePtr kPrimSparseTensorGetIndices = std::make_shared<Primitive>("SparseTensorGetIndices");
const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared<Primitive>("SparseTensorGetDenseShape");
} // namespace prim
} // namespace mindspore

@ -292,7 +292,12 @@ extern const PrimitivePtr kPrimMakeIndexedSlices;
extern const PrimitivePtr kPrimIndexedSlicesGetValues;
extern const PrimitivePtr kPrimIndexedSlicesGetIndices;
extern const PrimitivePtr kPrimIndexedSlicesGetDenseShape;
extern const PrimitivePtr kPrimIsIndexedSlices;
// SparseTensor
extern const PrimitivePtr kPrimMakeSparseTensor;
extern const PrimitivePtr kPrimSparseTensorGetValues;
extern const PrimitivePtr kPrimSparseTensorGetIndices;
extern const PrimitivePtr kPrimSparseTensorGetDenseShape;
// attribute 'unroll_flag' of primitive 'switch', when 'unroll_flag' is '0', 'switch' will not unroll
const char SWITCH_UNROLL_FLAG[] = "unroll_flag";

@ -349,6 +349,26 @@ AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const Prim
auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
auto dense_shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 2);
auto indices_dtype = indices->element()->BuildType();
if (!indices_dtype->isa<Int>()) {
MS_EXCEPTION(TypeError) << "The dtype of indices must be a Int, but got " << indices_dtype->ToString();
}
auto indices_shp = indices->shape()->shape();
if (indices_shp.size() != 1) {
MS_EXCEPTION(TypeError) << "Indices must be a 1 dimension tensor, but got a " << indices_shp.size()
<< " dimension tensor";
}
auto values_shp = values->shape()->shape();
if (indices_shp[0] != values_shp[0]) {
MS_EXCEPTION(TypeError) << "The first dimension of indices must be the same with the first dimension of values "
<< values_shp[0] << ", but got " << indices_shp[0];
}
for (auto elem_type : dense_shape->ElementsType()) {
if (!elem_type->isa<Int>()) {
MS_EXCEPTION(TypeError) << "The element type of dense_shape must be Int, but got " << elem_type->ToString();
}
}
auto dense_shape_value = dense_shape->BuildValue()->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(dense_shape_value);
auto shp = dense_shape_value->value();
@ -358,6 +378,12 @@ AbstractBasePtr InferImplMakeIndexedSlices(const AnalysisEnginePtr &, const Prim
auto elem = GetValue<int>(e);
return elem;
});
for (auto dense_shape_elem : dense_shape_vec) {
if (dense_shape_elem < 0) {
MS_EXCEPTION(TypeError) << "The element of dense_shape must be positive, but got "
<< dense_shape_value->ToString();
}
}
auto ret = std::make_shared<AbstractIndexedSlices>(values->element()->BuildType(), dense_shape_vec);
ret->set_indices(indices);
ret->set_values(values);
@ -395,16 +421,89 @@ AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, c
return indexed_slices->dense_shape();
}
AbstractBasePtr InferImplIsIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors and a tuple.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
bool ret = false;
if (args_spec_list[0]->isa<AbstractIndexedSlices>()) {
ret = true;
CheckArgsSize(op_name, args_spec_list, 3);
auto indices = CheckArg<AbstractTensor>(op_name, args_spec_list, 0);
auto values = CheckArg<AbstractTensor>(op_name, args_spec_list, 1);
auto dense_shape = CheckArg<AbstractTuple>(op_name, args_spec_list, 2);
auto indices_dtype = indices->element()->BuildType();
if (!indices_dtype->isa<Int>()) {
MS_EXCEPTION(TypeError) << "The dtype of indices must be a Int, but got " << indices_dtype->ToString();
}
auto indices_shp = indices->shape()->shape();
if (indices_shp.size() != 2) {
MS_EXCEPTION(TypeError) << "Indices must be a 2 dimension tensor, but got a " << indices_shp.size()
<< " dimension tensor";
}
auto values_shp = values->shape()->shape();
if (values_shp.size() != 1) {
MS_EXCEPTION(TypeError) << "Values must be a 1 dimension tensor, but got a " << values_shp.size()
<< " dimension tensor";
}
if (indices_shp[0] != values_shp[0]) {
MS_EXCEPTION(TypeError) << "The first dimension of indices must be the same with the first dimension of values "
<< values_shp[0] << ", but got " << indices_shp[0];
}
for (auto elem_type : dense_shape->ElementsType()) {
if (!elem_type->isa<Int>()) {
MS_EXCEPTION(TypeError) << "The element type of dense_shape must be Int, but got " << elem_type->ToString();
}
}
MS_LOG(DEBUG) << "IsIndexedSlices result: " << ret << ", input: " << args_spec_list[0]->ToString();
return std::make_shared<AbstractScalar>(ret);
auto dense_shape_value = dense_shape->BuildValue()->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(dense_shape_value);
auto shp = dense_shape_value->value();
std::vector<int> dense_shape_vec;
(void)std::transform(std::begin(shp), std::end(shp), std::back_inserter(dense_shape_vec),
[](const ValuePtr &e) -> int {
auto elem = GetValue<int>(e);
return elem;
});
for (auto dense_shape_elem : dense_shape_vec) {
if (dense_shape_elem < 0) {
MS_EXCEPTION(TypeError) << "The element of dense_shape must be positive, but got "
<< dense_shape_value->ToString();
}
}
auto ret = std::make_shared<AbstractSparseTensor>(values->element()->BuildType(), dense_shape_vec);
ret->set_indices(indices);
ret->set_values(values);
ret->set_dense_shape(dense_shape);
return ret;
}
AbstractBasePtr InferImplSparseTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors and a tuple.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
auto sparse_tensor = CheckArg<AbstractSparseTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(sparse_tensor->values());
return sparse_tensor->values();
}
AbstractBasePtr InferImplSparseTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors and a tuple.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
auto sparse_tensor = CheckArg<AbstractSparseTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(sparse_tensor->indices());
return sparse_tensor->indices();
}
AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: two tensors and a tuple.
const std::string op_name = primitive->name();
CheckArgsSize(op_name, args_spec_list, 1);
auto sparse_tensor = CheckArg<AbstractSparseTensor>(op_name, args_spec_list, 0);
MS_EXCEPTION_IF_NULL(sparse_tensor->dense_shape());
return sparse_tensor->dense_shape();
}
} // namespace abstract
} // namespace mindspore

@ -264,7 +264,7 @@ FuncGraphPtr KPrim::FakeBprop(const ValueNodePtr &value_node, const pipeline::Re
return IsPrimitiveCNode(user.first, prim);
});
if (cnode == users.end()) {
MS_LOG(EXCEPTION) << "Fail to find cnode.";
MS_LOG(EXCEPTION) << "Fail to find user for " << prim->ToString();
}
auto inputs_num = cnode->first->cast<CNodePtr>()->inputs().size() - 1;

@ -43,6 +43,7 @@
#include "frontend/optimizer/irpass/transpose_eliminate.h"
#include "frontend/optimizer/opt.h"
#include "frontend/optimizer/irpass/indexed_slices_eliminate.h"
#include "frontend/optimizer/irpass/sparse_tensor_eliminate.h"
namespace mindspore {
namespace opt {
@ -159,6 +160,11 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
indexed_slices_eliminate_ = MakeSubstitution(
std::make_shared<IndexedSlicesEliminater>(), "indexed_slices_eliminate",
{prim::kPrimIndexedSlicesGetIndices, prim::kPrimIndexedSlicesGetValues, prim::kPrimIndexedSlicesGetDenseShape});
// SparseTensor Eliminate
sparse_tensor_eliminate_ = MakeSubstitution(
std::make_shared<SparseTensorEliminater>(), "sparse_tensor_eliminate",
{prim::kPrimSparseTensorGetIndices, prim::kPrimSparseTensorGetValues, prim::kPrimSparseTensorGetDenseShape});
}
ResolveIRPassLib::ResolveIRPassLib() {

@ -107,6 +107,9 @@ class OptimizeIRPassLib {
// IndexedSlices Eliminate
SubstitutionPtr indexed_slices_eliminate_;
// SparseTensor Eliminate
SubstitutionPtr sparse_tensor_eliminate_;
};
// the collection of irpass for resolve action

@ -0,0 +1,75 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPARSE_TENSOR_ELIMINATE_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPARSE_TENSOR_ELIMINATE_H_
#include <vector>
#include <algorithm>
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "ir/visitor.h"
#include "frontend/operator/ops.h"
namespace mindspore {
namespace opt {
namespace irpass {
// {prim::kPrimSparseTensorGetIndices, {prim::kPrimMakeSparseTensor, Xs}}
// {prim::kPrimSparseTensorGetValues, {prim::kPrimMakeSparseTensor, Xs}}
// {prim::kPrimSparseTensorGetDenseShape, {prim::kPrimMakeSparseTensor, Xs}}
class SparseTensorEliminater : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimSparseTensorGetIndices, {IsCNode})(node);
if (is_match_) {
return tuple_->input(1);
}
AnfVisitor::Match(prim::kPrimSparseTensorGetValues, {IsCNode})(node);
if (is_match_) {
return tuple_->input(2);
}
AnfVisitor::Match(prim::kPrimSparseTensorGetDenseShape, {IsCNode})(node);
if (is_match_) {
return tuple_->input(3);
}
return nullptr;
}
void Visit(const CNodePtr &cnode) override {
if (IsPrimitiveCNode(cnode, prim::kPrimMakeSparseTensor)) {
tuple_ = cnode;
is_match_ = true;
}
}
void Reset() {
tuple_ = nullptr;
is_match_ = false;
}
private:
bool is_match_{false};
CNodePtr tuple_{nullptr};
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_SPARSE_TENSOR_ELIMINATE_H_

@ -157,6 +157,7 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
irpass.make_ref_eliminate_,
irpass.get_ref_param_eliminate_,
irpass.indexed_slices_eliminate_,
irpass.sparse_tensor_eliminate_,
});
OptPassGroupMap map({
{"b_1", b_1},

@ -179,6 +179,12 @@ MethodMap &GetMethodMap() {
{"indices", prim::kPrimIndexedSlicesGetIndices}, // F.indexed_slices_get_indices
{"dense_shape", prim::kPrimIndexedSlicesGetDenseShape}, // F.indexed_slices_get_dense_shape
}},
{kObjectTypeSparseTensorType,
{
{"values", prim::kPrimSparseTensorGetValues}, // F.sparse_tensor_get_values
{"indices", prim::kPrimSparseTensorGetIndices}, // F.sparse_tensor_get_indices
{"dense_shape", prim::kPrimSparseTensorGetDenseShape}, // F.sparse_tensor_get_dense_shape
}},
{kObjectTypeJTagged, {}},
{kObjectTypeSymbolicKeyType, {}},
{kObjectTypeEnvType, {}}};

@ -138,7 +138,11 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimIndexedSlicesGetValues, {InferImplIndexedSlicesGetValues, true}},
{prim::kPrimIndexedSlicesGetIndices, {InferImplIndexedSlicesGetIndices, true}},
{prim::kPrimIndexedSlicesGetDenseShape, {InferImplIndexedSlicesGetDenseShape, true}},
{prim::kPrimIsIndexedSlices, {InferImplIsIndexedSlices, true}},
// SparseTensor
{prim::kPrimMakeSparseTensor, {InferImplMakeSparseTensor, true}},
{prim::kPrimSparseTensorGetValues, {InferImplSparseTensorGetValues, true}},
{prim::kPrimSparseTensorGetIndices, {InferImplSparseTensorGetIndices, true}},
{prim::kPrimSparseTensorGetDenseShape, {InferImplSparseTensorGetDenseShape, true}},
};
return prim_eval_implement_map;
}

@ -358,8 +358,14 @@ AbstractBasePtr InferImplIndexedSlicesGetIndices(const AnalysisEnginePtr &, cons
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplIndexedSlicesGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplIsIndexedSlices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplMakeSparseTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSparseTensorGetValues(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSparseTensorGetIndices(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplSparseTensorGetDenseShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
} // namespace abstract
} // namespace mindspore

@ -36,6 +36,7 @@ using mindspore::abstract::AbstractIndexedSlices;
using mindspore::abstract::AbstractJTagged;
using mindspore::abstract::AbstractList;
using mindspore::abstract::AbstractScalar;
using mindspore::abstract::AbstractSparseTensor;
using mindspore::abstract::AbstractTensor;
using mindspore::abstract::AbstractTuple;
using mindspore::abstract::AbstractType;
@ -95,7 +96,7 @@ void ValidateAbstract(const AnfNodePtr &node) {
if (ptrBase->isa<AbstractType>() || ptrBase->isa<AbstractFunction>() || ptrBase->isa<AbstractTuple>() ||
ptrBase->isa<AbstractList>() || ptrBase->isa<AbstractTensor>() || ptrBase->isa<AbstractIndexedSlices>() ||
ptrBase->isa<abstract::AbstractRefKey>()) {
ptrBase->isa<AbstractSparseTensor>() || ptrBase->isa<abstract::AbstractRefKey>()) {
return;
}

@ -17,10 +17,10 @@ from . import dtype
from .api import ms_function
from .dtype import *
from .parameter import Parameter, ParameterTuple
from .tensor import MetaTensor, Tensor, IndexedSlices
from .tensor import MetaTensor, Tensor, IndexedSlices, SparseTensor
__all__ = [
"MetaTensor", "Tensor", "IndexedSlices", # tensor
"MetaTensor", "Tensor", "IndexedSlices", "SparseTensor", # tensor
'ms_function', # api
'Parameter', 'ParameterTuple', # parameter
"dtype"

@ -21,7 +21,7 @@ from .._checkparam import check_type, check_typename
from . import dtype as mstype
from ._register_for_tensor import tensor_operator_registry
__all__ = ['Tensor', 'MetaTensor', 'IndexedSlices']
__all__ = ['Tensor', 'MetaTensor', 'IndexedSlices', 'SparseTensor']
np_types = (np.int8, np.int16, np.int32, np.int64,
np.uint8, np.uint16, np.uint32, np.uint64, np.float16,
np.float32, np.float64, np.bool_)
@ -211,3 +211,7 @@ class Tensor(Tensor_):
class IndexedSlices:
def __init__(self, indices, values, dense_shape):
raise NotImplementedError
class SparseTensor:
def __init__(self, indices, values, dense_shape):
raise NotImplementedError

@ -1093,5 +1093,64 @@ std::string AbstractIndexedSlices::ToString() const {
<< ", dense_shape: " << dense_shape_->ToString();
return buffer.str();
}
// SparseTensor
TypePtr AbstractSparseTensor::BuildType() const {
MS_EXCEPTION_IF_NULL(element());
TypePtr element_type = element()->BuildType();
return std::make_shared<SparseTensorType>(element_type);
}
AbstractBasePtr AbstractSparseTensor::Clone() const {
MS_EXCEPTION_IF_NULL(element());
auto clone = std::make_shared<AbstractSparseTensor>(element()->Clone());
ShapePtr shp = shape();
clone->set_shape(shp->Clone());
clone->set_value(GetValueTrack());
clone->set_indices(indices_->Clone()->cast<AbstractTensorPtr>());
clone->set_values(values_->Clone()->cast<AbstractTensorPtr>());
clone->set_dense_shape(dense_shape_->Clone()->cast<AbstractTuplePtr>());
return clone;
}
AbstractBasePtr AbstractSparseTensor::Broaden() const {
MS_EXCEPTION_IF_NULL(element());
auto broaden = std::make_shared<AbstractSparseTensor>(element()->Broaden());
auto shp = shape();
broaden->set_shape(shp->Clone());
broaden->set_value(kAnyValue);
broaden->set_indices(indices_->Clone()->cast<AbstractTensorPtr>());
broaden->set_values(values_->Clone()->cast<AbstractTensorPtr>());
broaden->set_dense_shape(dense_shape_->Clone()->cast<AbstractTuplePtr>());
return broaden;
}
AbstractBasePtr AbstractSparseTensor::BroadenWithShape() const {
MS_EXCEPTION_IF_NULL(element());
auto broaden = std::make_shared<AbstractSparseTensor>(element()->Broaden());
auto shp = shape()->Clone();
shp->Broaden();
broaden->set_shape(shp);
broaden->set_value(kAnyValue);
broaden->set_indices(indices_->Clone()->cast<AbstractTensorPtr>());
broaden->set_values(values_->Clone()->cast<AbstractTensorPtr>());
broaden->set_dense_shape(dense_shape_->Clone()->cast<AbstractTuplePtr>());
return broaden;
}
std::string AbstractSparseTensor::ToString() const {
std::ostringstream buffer;
BaseShapePtr shape_track = GetShapeTrack();
MS_EXCEPTION_IF_NULL(shape_track);
MS_EXCEPTION_IF_NULL(element());
auto value_track = GetValueTrack();
MS_EXCEPTION_IF_NULL(value_track);
buffer << type_name() << "("
<< "shape: " << shape_track->ToString() << ", element: " << element()->ToString()
<< ", value_ptr: " << value_track << ", value: " << value_track->ToString() << ")"
<< ", indices: " << indices_->ToString() << ", values" << values_->ToString()
<< ", dense_shape: " << dense_shape_->ToString();
return buffer.str();
}
} // namespace abstract
} // namespace mindspore

@ -604,10 +604,39 @@ class AbstractIndexedSlices : public AbstractUndetermined {
MS_DECLARE_PARENT(AbstractIndexedSlices, AbstractUndetermined)
const AbstractTensorPtr indices() const { return indices_; }
void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; }
const AbstractTensorPtr values() const { return values_; }
void set_values(const AbstractTensorPtr &values) { values_ = values; }
const AbstractTuplePtr dense_shape() const { return dense_shape_; }
void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; }
TypePtr BuildType() const override;
AbstractBasePtr Clone() const override;
AbstractBasePtr Broaden() const override;
AbstractBasePtr BroadenWithShape() const;
std::string ToString() const override;
private:
AbstractTensorPtr indices_;
AbstractTensorPtr values_;
AbstractTuplePtr dense_shape_;
};
// SparseTensor
class AbstractSparseTensor : public AbstractUndetermined {
public:
explicit AbstractSparseTensor(const AbstractBasePtr &element, const BaseShapePtr &shape = std::make_shared<Shape>())
: AbstractUndetermined(element, shape) {}
AbstractSparseTensor(const TypePtr &element_type, const std::vector<int> &shape)
: AbstractUndetermined(element_type, shape) {}
~AbstractSparseTensor() override = default;
MS_DECLARE_PARENT(AbstractSparseTensor, AbstractUndetermined)
const AbstractTensorPtr indices() const { return indices_; }
void set_indices(const AbstractTensorPtr &indices) { indices_ = indices; }
const AbstractTensorPtr values() const { return values_; }
void set_values(const AbstractTensorPtr &values) { values_ = values; }
const AbstractTuplePtr dense_shape() const { return dense_shape_; }
void set_dense_shape(const AbstractTuplePtr &dense_shape) { dense_shape_ = dense_shape; }
TypePtr BuildType() const override;
AbstractBasePtr Clone() const override;

@ -67,6 +67,7 @@ ABSTRACT_REPORT_NAME_TRAITS(Type)
ABSTRACT_REPORT_NAME_TRAITS(KeywordArg)
ABSTRACT_REPORT_NAME_TRAITS(Class)
ABSTRACT_REPORT_NAME_TRAITS(IndexedSlices)
ABSTRACT_REPORT_NAME_TRAITS(SparseTensor)
ABSTRACT_REPORT_NAME_TRAITS(Sequeue)
template <typename T>

@ -221,6 +221,48 @@ bool IndexedSlicesType::operator==(const Type &other) const {
return *element_type_ == *other_elem_type;
}
TypePtr SparseTensorType::DeepCopy() const {
MS_EXCEPTION_IF_NULL(element_type_);
if (IsGeneric()) {
return std::make_shared<SparseTensorType>();
}
return std::make_shared<SparseTensorType>(element_type_->DeepCopy());
}
std::string SparseTensorType::ToReprString() const {
if (element_type_ == nullptr) {
return "SparseTensor";
}
return "SparseTensor[" + element_type_->ToReprString() + "]";
}
std::string SparseTensorType::ToString() const {
if (element_type_ == nullptr) {
return "SparseTensor";
}
return "SparseTensor[" + element_type_->ToString() + "]";
}
std::string SparseTensorType::DumpText() const {
if (element_type_ == nullptr) {
return "SparseTensor";
}
return "SparseTensor[" + element_type_->DumpText() + "]";
}
bool SparseTensorType::operator==(const Type &other) const {
if (!IsSameObjectType(*this, other)) {
return false;
}
auto other_elem_type = static_cast<const SparseTensorType &>(other).element_type_;
if (element_type_ == nullptr && other_elem_type == nullptr) {
return true;
} else if (element_type_ == nullptr || other_elem_type == nullptr) {
return false;
}
return *element_type_ == *other_elem_type;
}
Function::Function() : Object(kObjectTypeFunction) {
args_ = std::vector<TypePtr>();
retval_ = nullptr;

@ -177,6 +177,29 @@ class IndexedSlicesType : public Object {
};
using IndexedSlicesTypePtr = std::shared_ptr<IndexedSlicesType>;
class SparseTensorType : public Object {
public:
SparseTensorType() : Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType) {}
explicit SparseTensorType(const TypePtr &ele)
: Object(kObjectTypeSparseTensorType, kObjectTypeUndeterminedType, false), element_type_(ele) {}
~SparseTensorType() override = default;
MS_DECLARE_PARENT(SparseTensorType, Object)
TypeId generic_type_id() const override { return kObjectTypeSparseTensorType; }
const TypePtr element() const { return element_type_; }
void set_element(const TypePtr &element_type) { element_type_ = element_type; }
TypePtr DeepCopy() const override;
std::string ToString() const override;
std::string ToReprString() const override;
std::string DumpText() const override;
bool operator==(const Type &other) const override;
private:
TypePtr element_type_;
};
using SparseTensorTypePtr = std::shared_ptr<SparseTensorType>;
class Function : public Object {
public:
Function();

@ -117,6 +117,8 @@ const char *ObjectIdLabel(const TypeId &v) {
return "kObjectTypeTensorType";
case kObjectTypeIndexedSlicesType:
return "kObjectTypeIndexedSlicesType";
case kObjectTypeSparseTensorType:
return "kObjectTypeSparseTensorType";
case kObjectTypeUndeterminedType:
return "kObjectTypeUndeterminedType";
case kObjectTypeDictionary:

@ -51,6 +51,7 @@ enum TypeId : int {
kObjectTypeKeyword,
kObjectTypeTensorType,
kObjectTypeIndexedSlicesType,
kObjectTypeSparseTensorType,
kObjectTypeUndeterminedType,
kObjectTypeClass,
kObjectTypeDictionary,

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save