diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index 8696264135..ef8906952a 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -275,7 +275,6 @@ union PrimitiveType { Erf, StridedSliceGrad, IsFinite, - BatchMatMul, LinSpace, UniformReal, AbsGrad diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 64f0aba4e8..ff3c43eab4 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -1280,12 +1280,6 @@ table Erf { table IsFinite { } -table BatchMatMul { - transpose_a :bool; - transpose_b :bool; -} - - table LinSpace { } diff --git a/mindspore/lite/src/ops/batch_matmul.cc b/mindspore/lite/src/ops/batch_matmul.cc deleted file mode 100644 index 7c51d8d42e..0000000000 --- a/mindspore/lite/src/ops/batch_matmul.cc +++ /dev/null @@ -1,85 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "src/ops/batch_matmul.h" -#ifndef PRIMITIVE_WRITEABLE -#include "src/ops/ops_register.h" -#endif - -namespace mindspore { -namespace lite { -#ifdef PRIMITIVE_WRITEABLE -bool BatchMatMul::GetTransposeA() const { return this->primitive_->value.AsBatchMatMul()->transpose_a; } - -bool BatchMatMul::GetTransposeB() const { return this->primitive_->value.AsBatchMatMul()->transpose_b; } - -void BatchMatMul::SetTransposeA(bool transpose_a) { - this->primitive_->value.AsBatchMatMul()->transpose_a = transpose_a; -} - -void BatchMatMul::SetTransposeB(bool transpose_b) { - this->primitive_->value.AsBatchMatMul()->transpose_b = transpose_b; -} -int BatchMatMul::UnPackAttr(const Primitive &prim, const std::vector &inputs) { - if (this->primitive_ == nullptr) { - this->primitive_ = new (std::nothrow) schema::PrimitiveT; - if (this->primitive_ == nullptr) { - MS_LOG(ERROR) << "new primitiveT failed"; - return RET_ERROR; - } - this->primitive_->value.type = schema::PrimitiveType_BatchMatMul; - } - if (this->primitive_->value.type != schema::PrimitiveType_BatchMatMul) { - MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; - return RET_ERROR; - } - if (this->primitive_->value.value == nullptr) { - auto attr = new (std::nothrow) schema::BatchMatMulT(); - if (attr == nullptr) { - MS_LOG(ERROR) << "new FusedBatchMatMulT failed"; - delete this->primitive_; - this->primitive_ = nullptr; - return RET_ERROR; - } - attr->transpose_a = GetValue(prim.GetAttr("transpose_a")); - attr->transpose_b = GetValue(prim.GetAttr("transpose_b")); - this->primitive_->value.value = attr; - } - return RET_OK; -} -#else -bool BatchMatMul::GetTransposeA() const { return this->primitive_->value_as_BatchMatMul()->transpose_a(); } -bool BatchMatMul::GetTransposeB() const { return this->primitive_->value_as_BatchMatMul()->transpose_b(); } -int BatchMatMul::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { - MS_ASSERT(nullptr != primitive); - MS_ASSERT(nullptr != fbb); - auto attr = primitive->value_as_BatchMatMul(); - if (attr == nullptr) { - MS_LOG(ERROR) << "value_as_Add return nullptr"; - return RET_ERROR; - } - auto val_offset = schema::CreateBatchMatMul(*fbb, attr->transpose_a(), attr->transpose_b()); - auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BatchMatMul, val_offset.o); - fbb->Finish(prim_offset); - return RET_OK; -} -PrimitiveC *BatchMatMulCreator(const schema::Primitive *primitive) { - return PrimitiveC::NewPrimitiveC(primitive); -} -Registry BatchMatMulRegistry(schema::PrimitiveType_BatchMatMul, BatchMatMulCreator); -#endif -} // namespace lite -} // namespace mindspore diff --git a/mindspore/lite/src/ops/batch_matmul.h b/mindspore/lite/src/ops/batch_matmul.h deleted file mode 100644 index 49d5898663..0000000000 --- a/mindspore/lite/src/ops/batch_matmul.h +++ /dev/null @@ -1,45 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef LITE_MINDSPORE_LITE_C_OPS_BATCH_MATMUL_H_ -#define LITE_MINDSPORE_LITE_C_OPS_BATCH_MATMUL_H_ - -#include -#include -#include -#include "src/ops/primitive_c.h" - -namespace mindspore { -namespace lite { -class BatchMatMul : public PrimitiveC { - public: - BatchMatMul() = default; - ~BatchMatMul() = default; -#ifdef PRIMITIVE_WRITEABLE - MS_DECLARE_PARENT(BatchMatMul, PrimitiveC); - explicit BatchMatMul(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} - int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; - void SetTransposeA(bool transpose_a); - void SetTransposeB(bool transpose_b); -#else - int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; -#endif - bool GetTransposeA() const; - bool GetTransposeB() const; -}; -} // namespace lite -} // namespace mindspore -#endif // LITE_MINDSPORE_LITE_C_OPS_BATCH_MATMUL_H_ diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index 7125b11872..cf63ea3b6e 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -170,7 +170,6 @@ #include "src/ops/crop_and_resize.h" #include "src/ops/nonzero.h" #include "src/ops/erf.h" -#include "src/ops/batch_matmul.h" #include "src/ops/lin_space.h" #include "src/ops/uniform_real.h" #include "src/ops/rank.h" @@ -1057,8 +1056,6 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) { return new (std::nothrow) Erf(primitive); case schema::PrimitiveType_IsFinite: return new (std::nothrow) IsFinite(primitive); - case schema::PrimitiveType_BatchMatMul: - return new (std::nothrow) BatchMatMul(primitive); case schema::PrimitiveType_LinSpace: return new (std::nothrow) LinSpace(primitive); case schema::PrimitiveType_UniformReal: diff --git a/mindspore/lite/src/ops/tensorlist_getitem.cc b/mindspore/lite/src/ops/tensorlist_getitem.cc index 01cb7e1592..2a499f775f 100644 --- a/mindspore/lite/src/ops/tensorlist_getitem.cc +++ b/mindspore/lite/src/ops/tensorlist_getitem.cc @@ -125,6 +125,9 @@ int TensorListGetItem::InferShape(std::vector inputs_, std::vect MS_ASSERT(inputs_.at(1) != nullptr); MS_ASSERT(inputs_.at(2) != nullptr); auto input0 = reinterpret_cast(inputs_.at(0)); + if (input0->root_tensor() != nullptr) { + input0 = reinterpret_cast(input0->root_tensor()); + } auto get_index = inputs_.at(1); MS_ASSERT(get_index != nullptr); if (get_index->ElementsNum() != 1) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_fromtensor_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_fromtensor_fp32.cc index e018e9bc9e..228a157b1e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_fromtensor_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_fromtensor_fp32.cc @@ -102,6 +102,7 @@ int TensorListFromTensorCPUKernel::Run() { memcpy(out_data, in_data, data_offset); in_data += data_offset; } + output0->set_tensors_data_type(dtype_); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_getitem_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_getitem_fp32.cc index 7f88cb6ee4..82b94453c3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_getitem_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/tensorlist_getitem_fp32.cc @@ -45,6 +45,9 @@ int TensorListGetItemCPUKernel::Run() { MS_ASSERT(in_tensors_.at(1) != nullptr); MS_ASSERT(out_tensors_.at(0) != nullptr); auto input0 = reinterpret_cast(in_tensors_.at(0)); + if (input0->root_tensor() != nullptr) { + input0 = reinterpret_cast(input0->root_tensor()); + } if (dtype_ != input0->tensors_data_type()) { MS_LOG(ERROR) << "op dtype: " << dtype_ << " is not equal in_tensor[0] dtype: " << input0->tensors_data_type(); return RET_ERROR; diff --git a/mindspore/lite/src/sub_graph_kernel.cc b/mindspore/lite/src/sub_graph_kernel.cc index bde13ee6ad..bb1cc7efe8 100644 --- a/mindspore/lite/src/sub_graph_kernel.cc +++ b/mindspore/lite/src/sub_graph_kernel.cc @@ -16,6 +16,7 @@ #include "src/sub_graph_kernel.h" #include "src/tensor.h" +#include "src/tensorlist.h" #if defined(ENABLE_ARM64) && defined(ENABLE_FP16) #include "src/runtime/kernel/arm/fp16/fp16_op_handler.h" #endif @@ -176,7 +177,8 @@ int CpuSubGraph::Prepare() { #ifdef ENABLE_FP16 void CpuFp16SubGraph::FreeOriginInputData() { - for (auto *data_store : this->origin_input_data_) { + for (auto &iter : this->origin_input_data_) { + auto *data_store = iter.second; if (data_store == nullptr) { continue; } @@ -199,37 +201,99 @@ void CpuFp16SubGraph::FreeOriginInputData() { this->origin_input_data_.clear(); } +int CpuFp16SubGraph::Float32TensorToFloat16Tensor(lite::Tensor *tensor) { + auto float32_data = tensor->data_c(); + if (float32_data == nullptr) { + MS_LOG(ERROR) << "tensor data is null."; + return lite::RET_NULL_PTR; + } + tensor->set_data(nullptr); + tensor->set_data_type(TypeId::kNumberTypeFloat16); + auto ret = tensor->MallocData(); + if (RET_OK != ret) { + MS_LOG(ERROR) << "malloc data failed"; + this->FreeOriginInputData(); + return RET_ERROR; + } + MS_ASSERT(tensor->data_c() != nullptr); + Float32ToFloat16_fp16_handler(float32_data, tensor->data_c(), tensor->ElementsNum()); + auto *data_store = DataStore::CreateDataStore(float32_data, tensor->allocator(), this->context_->allocator.get()); + if (data_store == nullptr) { + MS_LOG(ERROR) << "Create DataStore failed"; + this->FreeOriginInputData(); + return RET_ERROR; + } + origin_input_data_[tensor] = data_store; + return RET_OK; +} + +int CpuFp16SubGraph::Float16TensorToFloat32Tensor(lite::Tensor *tensor) { + auto float16_data = tensor->data_c(); + if (float16_data == nullptr) { + MS_LOG(ERROR) << "tensor data is null."; + return lite::RET_NULL_PTR; + } + tensor->set_data(nullptr); + tensor->set_data_type(TypeId::kNumberTypeFloat32); + auto ret = tensor->MallocData(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "malloc data failed"; + if (this->context_ != nullptr && this->context_->allocator != nullptr) { + this->context_->allocator->Free(float16_data); + } else { + free(float16_data); + } + return RET_ERROR; + } + MS_ASSERT(tensor->data_c() != nullptr); + Float16ToFloat32_fp16_handler(float16_data, tensor->data_c(), tensor->ElementsNum()); + if (tensor->allocator() != nullptr) { + tensor->allocator()->Free(float16_data); + } else { + free(float16_data); + } + return RET_OK; +} + int CpuFp16SubGraph::PreProcess() { #ifdef ENABLE_ARM64 if (!mindspore::lite::IsSupportFloat16()) { - MS_LOG(ERROR) << "Unsupport fp16 in this devices"; + MS_LOG(ERROR) << "Unsupported fp16 in this devices"; return RET_ERROR; } - MS_ASSERT(origin_input_data_.empty()); + int ret; for (auto tensor : this->in_tensors_) { MS_ASSERT(tensor != nullptr); - if (tensor->data_type() == kNumberTypeFloat32) { - auto float32_data = tensor->data_c(); - MS_ASSERT(float32_data != nullptr); - tensor->set_data(nullptr); - tensor->set_data_type(TypeId::kNumberTypeFloat16); - auto ret = tensor->MallocData(); + auto real_tensor = tensor; + if (tensor->root_tensor() != nullptr) { + real_tensor = tensor->root_tensor(); + if (tensor->data_type() == kNumberTypeFloat32) { + tensor->set_data_type(kNumberTypeFloat16); + } else if (tensor->data_type() == kObjectTypeTensorType) { + auto tensorlist = reinterpret_cast(tensor); + if (tensorlist->tensors_data_type() == kNumberTypeFloat32) { + tensorlist->set_tensors_data_type(kNumberTypeFloat16); + } + } + } + if (real_tensor->data_type() == kNumberTypeFloat32) { + ret = Float32TensorToFloat16Tensor(real_tensor); if (RET_OK != ret) { - MS_LOG(ERROR) << "malloc data failed"; - this->FreeOriginInputData(); - return RET_ERROR; + MS_LOG(ERROR) << "Float32TensorToFloat16Tensor failed."; + return ret; } - MS_ASSERT(tensor->data_c() != nullptr); - Float32ToFloat16_fp16_handler(float32_data, tensor->data_c(), tensor->ElementsNum()); - auto *data_store = DataStore::CreateDataStore(float32_data, tensor->allocator(), this->context_->allocator.get()); - if (data_store == nullptr) { - MS_LOG(ERROR) << "Create DataStore failed"; - this->FreeOriginInputData(); - return RET_ERROR; + } else if (real_tensor->data_type() == kObjectTypeTensorType) { + auto tensorlist = reinterpret_cast(real_tensor); + if (tensorlist->tensors_data_type() == kNumberTypeFloat32) { + tensorlist->set_tensors_data_type(kNumberTypeFloat16); + for (auto inner_tensor : tensorlist->tensors()) { + ret = Float32TensorToFloat16Tensor(inner_tensor); + if (RET_OK != ret) { + MS_LOG(ERROR) << "Float32TensorToFloat16Tensor failed."; + return ret; + } + } } - origin_input_data_.emplace_back(data_store); - } else { - origin_input_data_.emplace_back(nullptr); } } for (auto kernel : this->nodes_) { @@ -239,6 +303,11 @@ int CpuFp16SubGraph::PreProcess() { } if (tensor->data_type() == kNumberTypeFloat32) { tensor->set_data_type(kNumberTypeFloat16); + } else if (tensor->data_type() == kObjectTypeTensorType) { + auto tensorlist = reinterpret_cast(tensor); + if (tensorlist->tensors_data_type() == kNumberTypeFloat32) { + tensorlist->set_tensors_data_type(kNumberTypeFloat16); + } } } } @@ -251,47 +320,72 @@ int CpuFp16SubGraph::PreProcess() { int CpuFp16SubGraph::PostProcess() { #ifdef ENABLE_ARM64 if (!mindspore::lite::IsSupportFloat16()) { - MS_LOG(ERROR) << "Unsupport fp16 in this devices"; + MS_LOG(ERROR) << "Unsupported fp16 in this devices"; return RET_ERROR; } + int ret; for (auto tensor : this->out_tensors_) { MS_ASSERT(tensor != nullptr); if (tensor->data_type() == kNumberTypeFloat16) { - auto float16_data = tensor->data_c(); - MS_ASSERT(float16_data != nullptr); - tensor->set_data(nullptr); - tensor->set_data_type(TypeId::kNumberTypeFloat32); - auto ret = tensor->MallocData(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "malloc data failed"; - if (this->context_ != nullptr && this->context_->allocator != nullptr) { - this->context_->allocator->Free(float16_data); - } else { - free(float16_data); - } - return RET_ERROR; + ret = Float16TensorToFloat32Tensor(tensor); + if (RET_OK != ret) { + MS_LOG(ERROR) << "Float16TensorToFloat32Tensor failed."; + return ret; } - MS_ASSERT(tensor->data_c() != nullptr); - Float16ToFloat32_fp16_handler(float16_data, tensor->data_c(), tensor->ElementsNum()); - if (tensor->allocator() != nullptr) { - tensor->allocator()->Free(float16_data); - } else { - free(float16_data); + } else if (tensor->data_type() == kObjectTypeTensorType) { + auto tensorlist = reinterpret_cast(tensor); + if (tensorlist->tensors_data_type() == kNumberTypeFloat16) { + tensorlist->set_tensors_data_type(kNumberTypeFloat32); + for (auto inner_tensor : tensorlist->tensors()) { + ret = Float16TensorToFloat32Tensor(inner_tensor); + if (RET_OK != ret) { + MS_LOG(ERROR) << "Float32TensorToFloat16Tensor failed."; + return ret; + } + } } } } - MS_ASSERT(this->origin_input_data_.size() == this->in_tensors_.size()); + + int tensor_count = 0; for (size_t i = 0; i < this->in_tensors_.size(); i++) { auto tensor = in_tensors_.at(i); MS_ASSERT(tensor != nullptr); - auto origin_tensor_data = origin_input_data_.at(i); - if (tensor->data_type() == kNumberTypeFloat16 && origin_tensor_data != nullptr) { - MS_ASSERT(tensor != nullptr); - tensor->FreeData(); + auto real_tensor = tensor; + if (tensor->root_tensor() != nullptr) { + real_tensor = tensor->root_tensor(); + if (tensor->data_type() == kNumberTypeFloat16) { + tensor->set_data_type(kNumberTypeFloat32); + } else if (tensor->data_type() == kObjectTypeTensorType) { + auto tensorlist = reinterpret_cast(tensor); + if (tensorlist->tensors_data_type() == kNumberTypeFloat16) { + tensorlist->set_tensors_data_type(kNumberTypeFloat32); + } + } + } + if (real_tensor->data_type() == kNumberTypeFloat16 && origin_input_data_.at(real_tensor) != nullptr) { + auto origin_tensor_data = origin_input_data_.at(real_tensor); + real_tensor->FreeData(); MS_ASSERT(origin_tensor_data->data_ != nullptr); - tensor->set_data(origin_tensor_data->data_); - tensor->set_data_type(kNumberTypeFloat32); + real_tensor->set_data(origin_tensor_data->data_); + real_tensor->set_data_type(kNumberTypeFloat32); origin_tensor_data->data_ = nullptr; + tensor_count++; + } else if (real_tensor->data_type() == kObjectTypeTensorType) { + auto tensorlist = reinterpret_cast(real_tensor); + if (tensorlist->tensors_data_type() == kNumberTypeFloat16) { + tensorlist->set_tensors_data_type(kNumberTypeFloat32); + for (auto inner_tensor : tensorlist->tensors()) { + MS_ASSERT(inner_tensor != nullptr); + auto origin_tensor_data = origin_input_data_.at(inner_tensor); + inner_tensor->FreeData(); + MS_ASSERT(origin_tensor_data->data_ != nullptr); + inner_tensor->set_data(origin_tensor_data->data_); + inner_tensor->set_data_type(kNumberTypeFloat32); + origin_tensor_data->data_ = nullptr; + tensor_count++; + } + } } } this->FreeOriginInputData(); diff --git a/mindspore/lite/src/sub_graph_kernel.h b/mindspore/lite/src/sub_graph_kernel.h index c3cdedeae1..2375be879c 100644 --- a/mindspore/lite/src/sub_graph_kernel.h +++ b/mindspore/lite/src/sub_graph_kernel.h @@ -20,6 +20,7 @@ #include #include #include +#include #include "src/lite_kernel.h" #include "src/executor.h" #include "src/common/log_adapter.h" @@ -179,9 +180,11 @@ class CpuFp16SubGraph : public CpuSubGraph { private: void FreeOriginInputData(); + int Float32TensorToFloat16Tensor(lite::Tensor *tensor); + int Float16TensorToFloat32Tensor(lite::Tensor *tensor); private: - std::vector origin_input_data_{}; + std::map origin_input_data_; }; #endif } // namespace mindspore::kernel diff --git a/mindspore/lite/tools/converter/parser/tf/tf_batch_matmul_parser.cc b/mindspore/lite/tools/converter/parser/tf/tf_batch_matmul_parser.cc index 8c555736a2..cf3cdb09bb 100644 --- a/mindspore/lite/tools/converter/parser/tf/tf_batch_matmul_parser.cc +++ b/mindspore/lite/tools/converter/parser/tf/tf_batch_matmul_parser.cc @@ -35,7 +35,7 @@ STATUS TFBatchMatMulParser::Parse(const tensorflow::NodeDef &tf_op, MS_LOG(ERROR) << "New PrimitiveT failed"; return RET_NULL_PTR; } - auto attr = std::make_unique(); + auto attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new attr failed"; return RET_NULL_PTR; @@ -45,13 +45,13 @@ STATUS TFBatchMatMulParser::Parse(const tensorflow::NodeDef &tf_op, MS_LOG(ERROR) << "The begin_mask attr should be specified"; return RET_ERROR; } - attr->transpose_a = attr_value.b(); + attr->transposeA = attr_value.b(); if (!TensorFlowUtils::FindAttrValue(tf_op, "adj_y", &attr_value)) { MS_LOG(ERROR) << "The begin_mask attr should be specified"; return RET_ERROR; } - attr->transpose_b = attr_value.b(); - primitive->value.type = schema::PrimitiveType_BatchMatMul; + attr->transposeB = attr_value.b(); + primitive->value.type = schema::PrimitiveType_MatMul; primitive->value.value = attr.release(); *primitiveC = PrimitiveC::Create(primitive.release()); if (*primitiveC == nullptr) {