From 0647b8b7dbe10c4b53ccb35e78a9a98f8381bdea Mon Sep 17 00:00:00 2001 From: buxue Date: Tue, 8 Dec 2020 17:53:57 +0800 Subject: [PATCH] optimize scalar to tensor function --- .../parallel/graph_util/generate_graph.cc | 2 +- .../pipeline/pynative/pynative_execute.cc | 2 +- mindspore/ccsrc/utils/convert_utils.cc | 97 ++++++++----------- mindspore/ccsrc/utils/convert_utils.h | 3 +- mindspore/core/ir/tensor.cc | 10 ++ mindspore/core/ir/tensor.h | 12 +++ mindspore/nn/layer/basic.py | 16 ++- mindspore/nn/wrap/cell_wrapper.py | 2 +- 8 files changed, 83 insertions(+), 61 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.cc b/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.cc index 48e3187c3d..7b5caf7c40 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.cc +++ b/mindspore/ccsrc/frontend/parallel/graph_util/generate_graph.cc @@ -75,7 +75,7 @@ AnfNodePtr CreateInt32Tensor(int64_t value) { if (it != int_tensor_map.end()) { return it->second; } - mindspore::tensor::TensorPtr tensor_ptr = std::make_shared(py::int_(value), kInt32); + mindspore::tensor::TensorPtr tensor_ptr = std::make_shared(value, kInt32); ValuePtr value_ptr = MakeValue(tensor_ptr); auto anf_node_ptr = ValuePtrToAnfNodePtr(value_ptr); int_tensor_map[value] = anf_node_ptr; diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 2543fc9878..9ef32d9ea2 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -382,7 +382,7 @@ void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr tensor_ptr = std::make_shared(input_value, kFloat32); *tensor_mask = kValueNodeTensorMask; } else if (py::isinstance(input_object)) { - tensor_ptr = std::make_shared(py::cast(input_object), kInt64); + tensor_ptr = std::make_shared(py::cast(input_object), kInt64); *tensor_mask = kValueNodeTensorMask; } else if (py::isinstance(input_object)) { tensor_ptr = TensorPy::MakeTensor(py::cast(input_object), nullptr); diff --git a/mindspore/ccsrc/utils/convert_utils.cc b/mindspore/ccsrc/utils/convert_utils.cc index 8b37f53148..6c885d2afa 100644 --- a/mindspore/ccsrc/utils/convert_utils.cc +++ b/mindspore/ccsrc/utils/convert_utils.cc @@ -20,16 +20,13 @@ #include #include #include -#include #include #include -#include "abstract/abstract_value.h" #include "ir/value.h" #include "ir/tensor.h" #include "ir/param_info.h" #include "utils/ms_context.h" -#include "utils/shape_utils.h" namespace mindspore { bool ValueToBool(const ValuePtr &v, bool *value) { @@ -37,13 +34,13 @@ bool ValueToBool(const ValuePtr &v, bool *value) { if (v->isa()) { *value = v->cast()->value(); } else if (v->isa()) { - *value = v->cast()->value() == 0 ? false : true; + *value = v->cast()->value() != 0; } else if (v->isa()) { - *value = v->cast()->value() == 0 ? false : true; + *value = v->cast()->value() != 0; } else if (v->isa()) { - *value = v->cast()->value() == 0 ? false : true; + *value = v->cast()->value() != 0; } else if (v->isa()) { - *value = v->cast()->value() == 0 ? false : true; + *value = v->cast()->value() != 0; } else if (v->isa()) { auto tensor = v->cast(); MS_EXCEPTION_IF_NULL(tensor); @@ -65,11 +62,11 @@ bool BaseRefToInt(const ValuePtr &v, int64_t *value) { auto tensor = v->cast(); (void)tensor->data_sync(); if (tensor->Dtype()->ToString() == "Int32") { - int32_t *tensor_data = static_cast(tensor->data_c()); + auto *tensor_data = static_cast(tensor->data_c()); auto vb = tensor_data[0]; *value = static_cast(vb); } else if (tensor->Dtype()->ToString() == "Int64") { - int64_t *tensor_data = static_cast(tensor->data_c()); + auto *tensor_data = static_cast(tensor->data_c()); auto vb = tensor_data[0]; *value = vb; } else { @@ -86,39 +83,19 @@ bool BaseRefToBool(const BaseRef &v, bool *value) { return ValueToBool(utils::cast(v), value); } else if (utils::isa(v)) { auto vb = utils::cast(v); - if (vb == true) { - *value = true; - } else { - *value = false; - } + *value = vb; } else if (utils::isa(v)) { auto vb = utils::cast(v); - if (vb == 0) { - *value = false; - } else { - *value = true; - } + *value = vb != 0; } else if (utils::isa(v)) { auto vb = utils::cast(v); - if (vb == 0) { - *value = false; - } else { - *value = true; - } + *value = vb != 0; } else if (utils::isa(v)) { auto vb = utils::cast(v); - if (vb >= -FLT_EPSILON && vb <= FLT_EPSILON) { - *value = false; - } else { - *value = true; - } + *value = !(vb >= -FLT_EPSILON && vb <= FLT_EPSILON); } else if (utils::isa(v)) { auto vb = utils::cast(v); - if (vb >= -DBL_EPSILON && vb <= DBL_EPSILON) { - *value = false; - } else { - *value = true; - } + *value = !(vb >= -DBL_EPSILON && vb <= DBL_EPSILON); } else { MS_LOG(DEBUG) << "value is not supported to cast to be bool"; return false; @@ -187,13 +164,13 @@ bool SameNode(const AnfNodePtr &node1, const AnfNodePtr &node2, FuncGraphPairMap return SameNodeShallow(node1, node2, equiv_func_graph, equiv_node); } -bool SameSubgraph(AnfNodePtr root1, AnfNodePtr root2, FuncGraphPairMapEquiv *equiv_func_graph, +bool SameSubgraph(const AnfNodePtr &root1, const AnfNodePtr &root2, FuncGraphPairMapEquiv *equiv_func_graph, NodeMapEquiv *const equiv_node) { std::unordered_set done; std::stack> todo; todo.push(std::make_pair(root1, root2)); - while (todo.size() > 0) { + while (!todo.empty()) { AnfNodePtr node1 = todo.top().first; if (done.count(node1) > 0) { todo.pop(); @@ -231,7 +208,7 @@ bool SameSubgraph(AnfNodePtr root1, AnfNodePtr root2, FuncGraphPairMapEquiv *equ } } // namespace -bool Isomorphic(FuncGraphPtr fg1, FuncGraphPtr fg2, FuncGraphPairMapEquiv *equiv_func_graph, +bool Isomorphic(const FuncGraphPtr &fg1, const FuncGraphPtr &fg2, FuncGraphPairMapEquiv *equiv_func_graph, NodeMapEquiv *const equiv_node) { auto fg1_fg2 = std::make_pair(fg1, fg2); if (equiv_func_graph == nullptr) { @@ -267,23 +244,35 @@ tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar) { if (scalar == nullptr) { MS_EXCEPTION(ArgumentError) << "Nullptr Error!"; } - tensor::TensorPtr tensor = nullptr; - if (scalar->isa()) { - tensor = std::make_shared(static_cast(GetValue(scalar)), kFloat32); - } else if (scalar->isa()) { - tensor = std::make_shared(static_cast(GetValue(scalar)), kInt32); - } else if (scalar->isa()) { - tensor = std::make_shared(GetValue(scalar), kInt64); - } else if (scalar->isa()) { - const int64_t bool_value = GetValue(scalar) ? 1 : 0; - tensor = std::make_shared(bool_value, kBool); - } else { - auto type = scalar->type(); - auto type_str = (type == nullptr) ? "nullptr" : type->ToString(); - MS_LOG(EXCEPTION) << "Invalid scalar type: " << type_str; + TypePtr data_type = scalar->type(); + MS_EXCEPTION_IF_NULL(data_type); + TypeId type_id = data_type->type_id(); + switch (type_id) { + case kNumberTypeBool: + return std::make_shared(GetValue(scalar), data_type); + case kNumberTypeInt8: + return std::make_shared(static_cast(GetValue(scalar)), data_type); + case kNumberTypeInt16: + return std::make_shared(static_cast(GetValue(scalar)), data_type); + case kNumberTypeInt32: + return std::make_shared(static_cast(GetValue(scalar)), data_type); + case kNumberTypeInt64: + return std::make_shared(GetValue(scalar), data_type); + case kNumberTypeUInt8: + return std::make_shared(static_cast(GetValue(scalar)), data_type); + case kNumberTypeUInt16: + return std::make_shared(static_cast(GetValue(scalar)), data_type); + case kNumberTypeUInt32: + return std::make_shared(static_cast(GetValue(scalar)), data_type); + case kNumberTypeUInt64: + return std::make_shared(GetValue(scalar), data_type); + case kNumberTypeFloat32: + return std::make_shared(GetValue(scalar), data_type); + case kNumberTypeFloat64: + return std::make_shared(GetValue(scalar), data_type); + default: + MS_LOG(EXCEPTION) << "When convert scalar to tensor, the scalar type: " << data_type << "is valid."; } - MS_EXCEPTION_IF_NULL(tensor); - return tensor; } void TensorValueToTensor(const ValuePtr &value, std::vector *tensors) { @@ -301,7 +290,7 @@ void TensorValueToTensor(const ValuePtr &value, std::vector * } } } else if (value->isa()) { - tensor::TensorPtr tensor = value->cast(); + auto tensor = value->cast(); MS_EXCEPTION_IF_NULL(tensor); tensors->push_back(tensor); } diff --git a/mindspore/ccsrc/utils/convert_utils.h b/mindspore/ccsrc/utils/convert_utils.h index 8c14d75787..21cc6b707e 100644 --- a/mindspore/ccsrc/utils/convert_utils.h +++ b/mindspore/ccsrc/utils/convert_utils.h @@ -57,7 +57,8 @@ enum EquivState { kNotEquiv = 0, kEquiv = 1, kPending = 2 }; using FuncGraphPairMapEquiv = std::unordered_map, EquivState, PairHasher>; using NodeMapEquiv = std::unordered_map; -bool Isomorphic(FuncGraphPtr g1, FuncGraphPtr g2, FuncGraphPairMapEquiv *equiv_func_graph, NodeMapEquiv *equiv_node); +bool Isomorphic(const FuncGraphPtr &g1, const FuncGraphPtr &g2, FuncGraphPairMapEquiv *equiv_func_graph, + NodeMapEquiv *equiv_node); tensor::TensorPtr ScalarToTensor(const ScalarPtr &scalar); diff --git a/mindspore/core/ir/tensor.cc b/mindspore/core/ir/tensor.cc index 67523d6c43..8107322b93 100644 --- a/mindspore/core/ir/tensor.cc +++ b/mindspore/core/ir/tensor.cc @@ -491,6 +491,16 @@ Tensor::Tensor(double input, const TypePtr &data_type) data_(MakeTensorData(data_type_, {}, input)), id_(MakeId()) {} +Tensor::Tensor(uint64_t input, const TypePtr &data_type) + : MetaTensor(TypeIdOf(data_type, kNumberTypeUInt64), {}), + data_(MakeTensorData(data_type_, {}, input)), + id_(MakeId()) {} + +Tensor::Tensor(bool input, const TypePtr &data_type) + : MetaTensor(TypeIdOf(data_type, kNumberTypeBool), {}), + data_(MakeTensorData(data_type_, {}, input)), + id_(MakeId()) {} + bool Tensor::operator==(const Tensor &tensor) const { return (&tensor == this || (MetaTensor::operator==(tensor) && data_ == tensor.data_)); } diff --git a/mindspore/core/ir/tensor.h b/mindspore/core/ir/tensor.h index 5d45b280af..672c0f3bb3 100644 --- a/mindspore/core/ir/tensor.h +++ b/mindspore/core/ir/tensor.h @@ -172,6 +172,18 @@ class Tensor : public MetaTensor { // param data_type [TypeId] data type explicit Tensor(double input, const TypePtr &data_type = nullptr); + // brief Create 0 dimension tensor from a uint scalar. + // + // param input [uint] the data for tensor + // param data_type [TypeId] data type + explicit Tensor(uint64_t input, const TypePtr &data_type = nullptr); + + // brief Create 0 dimension tensor from a bool scalar. + // + // param input [bool] the data for tensor + // param data_type [TypeId] data type + explicit Tensor(bool input, const TypePtr &data_type = nullptr); + ~Tensor() override = default; MS_DECLARE_PARENT(Tensor, MetaTensor); diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index 47b8df8b29..c29a7319ea 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -88,6 +88,7 @@ class L1Regularizer(Cell): l1_regularization = self.scale * self.reduce_sum(self.abs(weights)) return l1_regularization + class Dropout(Cell): r""" Dropout layer for the input. @@ -210,6 +211,7 @@ class Flatten(Cell): def construct(self, x): return F.reshape(x, (F.shape(x)[0], -1)) + @constexpr def get_broadcast_weight_bias_shape(x_shape, out_channel, in_channel): """get broadcast_weight_bias shape""" @@ -217,6 +219,7 @@ def get_broadcast_weight_bias_shape(x_shape, out_channel, in_channel): broad_bias_shape = x_shape[:-1] + (out_channel,) return broad_weight_shape, broad_bias_shape + class Dense(Cell): r""" The dense connected layer. @@ -262,6 +265,7 @@ class Dense(Cell): [[ 2.5246444 2.2738023 0.5711005 -3.9399147 ] [ 1.0739875 4.0155234 0.94188046 -5.459526 ]] """ + @cell_attr_register(attrs=['has_bias', 'activation', 'in_channels', 'out_channels']) def __init__(self, in_channels, @@ -323,7 +327,6 @@ class Dense(Cell): x = self.activation(x) return x - def extend_repr(self): s = 'input_channels={}, output_channels={}'.format(self.in_channels, self.out_channels) if self.has_bias: @@ -339,10 +342,12 @@ def _is_equal_one(x): return False return bool(x.asnumpy().mean() == 1.0) + @constexpr def _dtype_check(x_dtype): if x_dtype not in [mstype.float32, mstype.float16]: - raise TypeError("The input type must be float32 or float16.") + raise TypeError("The input type must be float32 or float16.") + @constexpr def _is_float_dtype(dtype): @@ -539,7 +544,6 @@ class OneHot(Cell): return self.onehot(indices, self.depth, F.cast(self.on_value, self.dtype), F.cast(self.off_value, self.dtype)) - class Pad(Cell): """ Pads the input tensor according to the paddings and mode. @@ -672,6 +676,7 @@ class Interpolate(Cell): >>> print(result.shape) (1, 1, 5, 5) """ + def __init__(self): super(Interpolate, self).__init__() @@ -767,6 +772,7 @@ class Tril(Cell): [[1 0] [3 4]] """ + def __init__(self): super(Tril, self).__init__() self.dtype = P.DType() @@ -809,6 +815,7 @@ class Triu(Cell): [[1 2] [0 4]] """ + def __init__(self): super(Triu, self).__init__() self.dtype = P.DType() @@ -859,6 +866,7 @@ class MatrixDiag(Cell): [[ 1. 0.] [ 0. -1.]] """ + def __init__(self): super(MatrixDiag, self).__init__() self.matrix_diag = inner.MatrixDiag() @@ -895,6 +903,7 @@ class MatrixDiagPart(Cell): [-1. 1.] [-1. 1.]] """ + def __init__(self): super(MatrixDiagPart, self).__init__() self.matrix_diag_part = inner.MatrixDiagPart() @@ -936,6 +945,7 @@ class MatrixSetDiag(Cell): [[-1. 0.] [ 0. 1.]]] """ + def __init__(self): super(MatrixSetDiag, self).__init__() self.matrix_set_diag = inner.MatrixSetDiag() diff --git a/mindspore/nn/wrap/cell_wrapper.py b/mindspore/nn/wrap/cell_wrapper.py index 544dbacd3e..7dd94304d5 100644 --- a/mindspore/nn/wrap/cell_wrapper.py +++ b/mindspore/nn/wrap/cell_wrapper.py @@ -407,7 +407,7 @@ class ParameterUpdate(Cell): >>> param = network.parameters_dict()['weight'] >>> update = nn.ParameterUpdate(param) >>> update.phase = "update_param" - >>> weight = Tensor(0.001, mindspore.float32) + >>> weight = Tensor(np.arrange(12).reshape((4, 3)), mindspore.float32) >>> update(weight) """