support decoder fp16

pull/12447/head
cjh9368 4 years ago
parent 4b805b87ee
commit 10b9f430d0

@ -275,7 +275,6 @@ union PrimitiveType {
Erf, Erf,
StridedSliceGrad, StridedSliceGrad,
IsFinite, IsFinite,
BatchMatMul,
LinSpace, LinSpace,
UniformReal, UniformReal,
AbsGrad AbsGrad

@ -1280,12 +1280,6 @@ table Erf {
table IsFinite { table IsFinite {
} }
table BatchMatMul {
transpose_a :bool;
transpose_b :bool;
}
table LinSpace { table LinSpace {
} }

@ -1,85 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/batch_matmul.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
bool BatchMatMul::GetTransposeA() const { return this->primitive_->value.AsBatchMatMul()->transpose_a; }
bool BatchMatMul::GetTransposeB() const { return this->primitive_->value.AsBatchMatMul()->transpose_b; }
void BatchMatMul::SetTransposeA(bool transpose_a) {
this->primitive_->value.AsBatchMatMul()->transpose_a = transpose_a;
}
void BatchMatMul::SetTransposeB(bool transpose_b) {
this->primitive_->value.AsBatchMatMul()->transpose_b = transpose_b;
}
int BatchMatMul::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_BatchMatMul;
}
if (this->primitive_->value.type != schema::PrimitiveType_BatchMatMul) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::BatchMatMulT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new FusedBatchMatMulT failed";
delete this->primitive_;
this->primitive_ = nullptr;
return RET_ERROR;
}
attr->transpose_a = GetValue<bool>(prim.GetAttr("transpose_a"));
attr->transpose_b = GetValue<bool>(prim.GetAttr("transpose_b"));
this->primitive_->value.value = attr;
}
return RET_OK;
}
#else
bool BatchMatMul::GetTransposeA() const { return this->primitive_->value_as_BatchMatMul()->transpose_a(); }
bool BatchMatMul::GetTransposeB() const { return this->primitive_->value_as_BatchMatMul()->transpose_b(); }
int BatchMatMul::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_BatchMatMul();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Add return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateBatchMatMul(*fbb, attr->transpose_a(), attr->transpose_b());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BatchMatMul, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
PrimitiveC *BatchMatMulCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<BatchMatMul>(primitive);
}
Registry BatchMatMulRegistry(schema::PrimitiveType_BatchMatMul, BatchMatMulCreator);
#endif
} // namespace lite
} // namespace mindspore

@ -1,45 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_BATCH_MATMUL_H_
#define LITE_MINDSPORE_LITE_C_OPS_BATCH_MATMUL_H_
#include <vector>
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class BatchMatMul : public PrimitiveC {
public:
BatchMatMul() = default;
~BatchMatMul() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(BatchMatMul, PrimitiveC);
explicit BatchMatMul(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
void SetTransposeA(bool transpose_a);
void SetTransposeB(bool transpose_b);
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
bool GetTransposeA() const;
bool GetTransposeB() const;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_BATCH_MATMUL_H_

@ -170,7 +170,6 @@
#include "src/ops/crop_and_resize.h" #include "src/ops/crop_and_resize.h"
#include "src/ops/nonzero.h" #include "src/ops/nonzero.h"
#include "src/ops/erf.h" #include "src/ops/erf.h"
#include "src/ops/batch_matmul.h"
#include "src/ops/lin_space.h" #include "src/ops/lin_space.h"
#include "src/ops/uniform_real.h" #include "src/ops/uniform_real.h"
#include "src/ops/rank.h" #include "src/ops/rank.h"
@ -1057,8 +1056,6 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new (std::nothrow) Erf(primitive); return new (std::nothrow) Erf(primitive);
case schema::PrimitiveType_IsFinite: case schema::PrimitiveType_IsFinite:
return new (std::nothrow) IsFinite(primitive); return new (std::nothrow) IsFinite(primitive);
case schema::PrimitiveType_BatchMatMul:
return new (std::nothrow) BatchMatMul(primitive);
case schema::PrimitiveType_LinSpace: case schema::PrimitiveType_LinSpace:
return new (std::nothrow) LinSpace(primitive); return new (std::nothrow) LinSpace(primitive);
case schema::PrimitiveType_UniformReal: case schema::PrimitiveType_UniformReal:

@ -125,6 +125,9 @@ int TensorListGetItem::InferShape(std::vector<lite::Tensor *> inputs_, std::vect
MS_ASSERT(inputs_.at(1) != nullptr); MS_ASSERT(inputs_.at(1) != nullptr);
MS_ASSERT(inputs_.at(2) != nullptr); MS_ASSERT(inputs_.at(2) != nullptr);
auto input0 = reinterpret_cast<TensorList *>(inputs_.at(0)); auto input0 = reinterpret_cast<TensorList *>(inputs_.at(0));
if (input0->root_tensor() != nullptr) {
input0 = reinterpret_cast<TensorList *>(input0->root_tensor());
}
auto get_index = inputs_.at(1); auto get_index = inputs_.at(1);
MS_ASSERT(get_index != nullptr); MS_ASSERT(get_index != nullptr);
if (get_index->ElementsNum() != 1) { if (get_index->ElementsNum() != 1) {

@ -102,6 +102,7 @@ int TensorListFromTensorCPUKernel::Run() {
memcpy(out_data, in_data, data_offset); memcpy(out_data, in_data, data_offset);
in_data += data_offset; in_data += data_offset;
} }
output0->set_tensors_data_type(dtype_);
return RET_OK; return RET_OK;
} }

@ -45,6 +45,9 @@ int TensorListGetItemCPUKernel::Run() {
MS_ASSERT(in_tensors_.at(1) != nullptr); MS_ASSERT(in_tensors_.at(1) != nullptr);
MS_ASSERT(out_tensors_.at(0) != nullptr); MS_ASSERT(out_tensors_.at(0) != nullptr);
auto input0 = reinterpret_cast<lite::TensorList *>(in_tensors_.at(0)); auto input0 = reinterpret_cast<lite::TensorList *>(in_tensors_.at(0));
if (input0->root_tensor() != nullptr) {
input0 = reinterpret_cast<lite::TensorList *>(input0->root_tensor());
}
if (dtype_ != input0->tensors_data_type()) { if (dtype_ != input0->tensors_data_type()) {
MS_LOG(ERROR) << "op dtype: " << dtype_ << " is not equal in_tensor[0] dtype: " << input0->tensors_data_type(); MS_LOG(ERROR) << "op dtype: " << dtype_ << " is not equal in_tensor[0] dtype: " << input0->tensors_data_type();
return RET_ERROR; return RET_ERROR;

@ -16,6 +16,7 @@
#include "src/sub_graph_kernel.h" #include "src/sub_graph_kernel.h"
#include "src/tensor.h" #include "src/tensor.h"
#include "src/tensorlist.h"
#if defined(ENABLE_ARM64) && defined(ENABLE_FP16) #if defined(ENABLE_ARM64) && defined(ENABLE_FP16)
#include "src/runtime/kernel/arm/fp16/fp16_op_handler.h" #include "src/runtime/kernel/arm/fp16/fp16_op_handler.h"
#endif #endif
@ -176,7 +177,8 @@ int CpuSubGraph::Prepare() {
#ifdef ENABLE_FP16 #ifdef ENABLE_FP16
void CpuFp16SubGraph::FreeOriginInputData() { void CpuFp16SubGraph::FreeOriginInputData() {
for (auto *data_store : this->origin_input_data_) { for (auto &iter : this->origin_input_data_) {
auto *data_store = iter.second;
if (data_store == nullptr) { if (data_store == nullptr) {
continue; continue;
} }
@ -199,37 +201,99 @@ void CpuFp16SubGraph::FreeOriginInputData() {
this->origin_input_data_.clear(); this->origin_input_data_.clear();
} }
int CpuFp16SubGraph::Float32TensorToFloat16Tensor(lite::Tensor *tensor) {
auto float32_data = tensor->data_c();
if (float32_data == nullptr) {
MS_LOG(ERROR) << "tensor data is null.";
return lite::RET_NULL_PTR;
}
tensor->set_data(nullptr);
tensor->set_data_type(TypeId::kNumberTypeFloat16);
auto ret = tensor->MallocData();
if (RET_OK != ret) {
MS_LOG(ERROR) << "malloc data failed";
this->FreeOriginInputData();
return RET_ERROR;
}
MS_ASSERT(tensor->data_c() != nullptr);
Float32ToFloat16_fp16_handler(float32_data, tensor->data_c(), tensor->ElementsNum());
auto *data_store = DataStore::CreateDataStore(float32_data, tensor->allocator(), this->context_->allocator.get());
if (data_store == nullptr) {
MS_LOG(ERROR) << "Create DataStore failed";
this->FreeOriginInputData();
return RET_ERROR;
}
origin_input_data_[tensor] = data_store;
return RET_OK;
}
int CpuFp16SubGraph::Float16TensorToFloat32Tensor(lite::Tensor *tensor) {
auto float16_data = tensor->data_c();
if (float16_data == nullptr) {
MS_LOG(ERROR) << "tensor data is null.";
return lite::RET_NULL_PTR;
}
tensor->set_data(nullptr);
tensor->set_data_type(TypeId::kNumberTypeFloat32);
auto ret = tensor->MallocData();
if (ret != RET_OK) {
MS_LOG(ERROR) << "malloc data failed";
if (this->context_ != nullptr && this->context_->allocator != nullptr) {
this->context_->allocator->Free(float16_data);
} else {
free(float16_data);
}
return RET_ERROR;
}
MS_ASSERT(tensor->data_c() != nullptr);
Float16ToFloat32_fp16_handler(float16_data, tensor->data_c(), tensor->ElementsNum());
if (tensor->allocator() != nullptr) {
tensor->allocator()->Free(float16_data);
} else {
free(float16_data);
}
return RET_OK;
}
int CpuFp16SubGraph::PreProcess() { int CpuFp16SubGraph::PreProcess() {
#ifdef ENABLE_ARM64 #ifdef ENABLE_ARM64
if (!mindspore::lite::IsSupportFloat16()) { if (!mindspore::lite::IsSupportFloat16()) {
MS_LOG(ERROR) << "Unsupport fp16 in this devices"; MS_LOG(ERROR) << "Unsupported fp16 in this devices";
return RET_ERROR; return RET_ERROR;
} }
MS_ASSERT(origin_input_data_.empty()); int ret;
for (auto tensor : this->in_tensors_) { for (auto tensor : this->in_tensors_) {
MS_ASSERT(tensor != nullptr); MS_ASSERT(tensor != nullptr);
if (tensor->data_type() == kNumberTypeFloat32) { auto real_tensor = tensor;
auto float32_data = tensor->data_c(); if (tensor->root_tensor() != nullptr) {
MS_ASSERT(float32_data != nullptr); real_tensor = tensor->root_tensor();
tensor->set_data(nullptr); if (tensor->data_type() == kNumberTypeFloat32) {
tensor->set_data_type(TypeId::kNumberTypeFloat16); tensor->set_data_type(kNumberTypeFloat16);
auto ret = tensor->MallocData(); } else if (tensor->data_type() == kObjectTypeTensorType) {
auto tensorlist = reinterpret_cast<lite::TensorList *>(tensor);
if (tensorlist->tensors_data_type() == kNumberTypeFloat32) {
tensorlist->set_tensors_data_type(kNumberTypeFloat16);
}
}
}
if (real_tensor->data_type() == kNumberTypeFloat32) {
ret = Float32TensorToFloat16Tensor(real_tensor);
if (RET_OK != ret) { if (RET_OK != ret) {
MS_LOG(ERROR) << "malloc data failed"; MS_LOG(ERROR) << "Float32TensorToFloat16Tensor failed.";
this->FreeOriginInputData(); return ret;
return RET_ERROR;
} }
MS_ASSERT(tensor->data_c() != nullptr); } else if (real_tensor->data_type() == kObjectTypeTensorType) {
Float32ToFloat16_fp16_handler(float32_data, tensor->data_c(), tensor->ElementsNum()); auto tensorlist = reinterpret_cast<lite::TensorList *>(real_tensor);
auto *data_store = DataStore::CreateDataStore(float32_data, tensor->allocator(), this->context_->allocator.get()); if (tensorlist->tensors_data_type() == kNumberTypeFloat32) {
if (data_store == nullptr) { tensorlist->set_tensors_data_type(kNumberTypeFloat16);
MS_LOG(ERROR) << "Create DataStore failed"; for (auto inner_tensor : tensorlist->tensors()) {
this->FreeOriginInputData(); ret = Float32TensorToFloat16Tensor(inner_tensor);
return RET_ERROR; if (RET_OK != ret) {
MS_LOG(ERROR) << "Float32TensorToFloat16Tensor failed.";
return ret;
}
}
} }
origin_input_data_.emplace_back(data_store);
} else {
origin_input_data_.emplace_back(nullptr);
} }
} }
for (auto kernel : this->nodes_) { for (auto kernel : this->nodes_) {
@ -239,6 +303,11 @@ int CpuFp16SubGraph::PreProcess() {
} }
if (tensor->data_type() == kNumberTypeFloat32) { if (tensor->data_type() == kNumberTypeFloat32) {
tensor->set_data_type(kNumberTypeFloat16); tensor->set_data_type(kNumberTypeFloat16);
} else if (tensor->data_type() == kObjectTypeTensorType) {
auto tensorlist = reinterpret_cast<lite::TensorList *>(tensor);
if (tensorlist->tensors_data_type() == kNumberTypeFloat32) {
tensorlist->set_tensors_data_type(kNumberTypeFloat16);
}
} }
} }
} }
@ -251,47 +320,72 @@ int CpuFp16SubGraph::PreProcess() {
int CpuFp16SubGraph::PostProcess() { int CpuFp16SubGraph::PostProcess() {
#ifdef ENABLE_ARM64 #ifdef ENABLE_ARM64
if (!mindspore::lite::IsSupportFloat16()) { if (!mindspore::lite::IsSupportFloat16()) {
MS_LOG(ERROR) << "Unsupport fp16 in this devices"; MS_LOG(ERROR) << "Unsupported fp16 in this devices";
return RET_ERROR; return RET_ERROR;
} }
int ret;
for (auto tensor : this->out_tensors_) { for (auto tensor : this->out_tensors_) {
MS_ASSERT(tensor != nullptr); MS_ASSERT(tensor != nullptr);
if (tensor->data_type() == kNumberTypeFloat16) { if (tensor->data_type() == kNumberTypeFloat16) {
auto float16_data = tensor->data_c(); ret = Float16TensorToFloat32Tensor(tensor);
MS_ASSERT(float16_data != nullptr); if (RET_OK != ret) {
tensor->set_data(nullptr); MS_LOG(ERROR) << "Float16TensorToFloat32Tensor failed.";
tensor->set_data_type(TypeId::kNumberTypeFloat32); return ret;
auto ret = tensor->MallocData();
if (ret != RET_OK) {
MS_LOG(ERROR) << "malloc data failed";
if (this->context_ != nullptr && this->context_->allocator != nullptr) {
this->context_->allocator->Free(float16_data);
} else {
free(float16_data);
}
return RET_ERROR;
} }
MS_ASSERT(tensor->data_c() != nullptr); } else if (tensor->data_type() == kObjectTypeTensorType) {
Float16ToFloat32_fp16_handler(float16_data, tensor->data_c(), tensor->ElementsNum()); auto tensorlist = reinterpret_cast<lite::TensorList *>(tensor);
if (tensor->allocator() != nullptr) { if (tensorlist->tensors_data_type() == kNumberTypeFloat16) {
tensor->allocator()->Free(float16_data); tensorlist->set_tensors_data_type(kNumberTypeFloat32);
} else { for (auto inner_tensor : tensorlist->tensors()) {
free(float16_data); ret = Float16TensorToFloat32Tensor(inner_tensor);
if (RET_OK != ret) {
MS_LOG(ERROR) << "Float32TensorToFloat16Tensor failed.";
return ret;
}
}
} }
} }
} }
MS_ASSERT(this->origin_input_data_.size() == this->in_tensors_.size());
int tensor_count = 0;
for (size_t i = 0; i < this->in_tensors_.size(); i++) { for (size_t i = 0; i < this->in_tensors_.size(); i++) {
auto tensor = in_tensors_.at(i); auto tensor = in_tensors_.at(i);
MS_ASSERT(tensor != nullptr); MS_ASSERT(tensor != nullptr);
auto origin_tensor_data = origin_input_data_.at(i); auto real_tensor = tensor;
if (tensor->data_type() == kNumberTypeFloat16 && origin_tensor_data != nullptr) { if (tensor->root_tensor() != nullptr) {
MS_ASSERT(tensor != nullptr); real_tensor = tensor->root_tensor();
tensor->FreeData(); if (tensor->data_type() == kNumberTypeFloat16) {
tensor->set_data_type(kNumberTypeFloat32);
} else if (tensor->data_type() == kObjectTypeTensorType) {
auto tensorlist = reinterpret_cast<lite::TensorList *>(tensor);
if (tensorlist->tensors_data_type() == kNumberTypeFloat16) {
tensorlist->set_tensors_data_type(kNumberTypeFloat32);
}
}
}
if (real_tensor->data_type() == kNumberTypeFloat16 && origin_input_data_.at(real_tensor) != nullptr) {
auto origin_tensor_data = origin_input_data_.at(real_tensor);
real_tensor->FreeData();
MS_ASSERT(origin_tensor_data->data_ != nullptr); MS_ASSERT(origin_tensor_data->data_ != nullptr);
tensor->set_data(origin_tensor_data->data_); real_tensor->set_data(origin_tensor_data->data_);
tensor->set_data_type(kNumberTypeFloat32); real_tensor->set_data_type(kNumberTypeFloat32);
origin_tensor_data->data_ = nullptr; origin_tensor_data->data_ = nullptr;
tensor_count++;
} else if (real_tensor->data_type() == kObjectTypeTensorType) {
auto tensorlist = reinterpret_cast<lite::TensorList *>(real_tensor);
if (tensorlist->tensors_data_type() == kNumberTypeFloat16) {
tensorlist->set_tensors_data_type(kNumberTypeFloat32);
for (auto inner_tensor : tensorlist->tensors()) {
MS_ASSERT(inner_tensor != nullptr);
auto origin_tensor_data = origin_input_data_.at(inner_tensor);
inner_tensor->FreeData();
MS_ASSERT(origin_tensor_data->data_ != nullptr);
inner_tensor->set_data(origin_tensor_data->data_);
inner_tensor->set_data_type(kNumberTypeFloat32);
origin_tensor_data->data_ = nullptr;
tensor_count++;
}
}
} }
} }
this->FreeOriginInputData(); this->FreeOriginInputData();

@ -20,6 +20,7 @@
#include <utility> #include <utility>
#include <string> #include <string>
#include <vector> #include <vector>
#include <map>
#include "src/lite_kernel.h" #include "src/lite_kernel.h"
#include "src/executor.h" #include "src/executor.h"
#include "src/common/log_adapter.h" #include "src/common/log_adapter.h"
@ -179,9 +180,11 @@ class CpuFp16SubGraph : public CpuSubGraph {
private: private:
void FreeOriginInputData(); void FreeOriginInputData();
int Float32TensorToFloat16Tensor(lite::Tensor *tensor);
int Float16TensorToFloat32Tensor(lite::Tensor *tensor);
private: private:
std::vector<DataStore *> origin_input_data_{}; std::map<lite::Tensor *, DataStore *> origin_input_data_;
}; };
#endif #endif
} // namespace mindspore::kernel } // namespace mindspore::kernel

@ -35,7 +35,7 @@ STATUS TFBatchMatMulParser::Parse(const tensorflow::NodeDef &tf_op,
MS_LOG(ERROR) << "New PrimitiveT failed"; MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR; return RET_NULL_PTR;
} }
auto attr = std::make_unique<schema::BatchMatMulT>(); auto attr = std::make_unique<schema::MatMulT>();
if (attr == nullptr) { if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed"; MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR; return RET_NULL_PTR;
@ -45,13 +45,13 @@ STATUS TFBatchMatMulParser::Parse(const tensorflow::NodeDef &tf_op,
MS_LOG(ERROR) << "The begin_mask attr should be specified"; MS_LOG(ERROR) << "The begin_mask attr should be specified";
return RET_ERROR; return RET_ERROR;
} }
attr->transpose_a = attr_value.b(); attr->transposeA = attr_value.b();
if (!TensorFlowUtils::FindAttrValue(tf_op, "adj_y", &attr_value)) { if (!TensorFlowUtils::FindAttrValue(tf_op, "adj_y", &attr_value)) {
MS_LOG(ERROR) << "The begin_mask attr should be specified"; MS_LOG(ERROR) << "The begin_mask attr should be specified";
return RET_ERROR; return RET_ERROR;
} }
attr->transpose_b = attr_value.b(); attr->transposeB = attr_value.b();
primitive->value.type = schema::PrimitiveType_BatchMatMul; primitive->value.type = schema::PrimitiveType_MatMul;
primitive->value.value = attr.release(); primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release()); *primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) { if (*primitiveC == nullptr) {

Loading…
Cancel
Save