support decoder fp16

pull/12447/head
cjh9368 4 years ago
parent 4b805b87ee
commit 10b9f430d0

@ -275,7 +275,6 @@ union PrimitiveType {
Erf,
StridedSliceGrad,
IsFinite,
BatchMatMul,
LinSpace,
UniformReal,
AbsGrad

@ -1280,12 +1280,6 @@ table Erf {
table IsFinite {
}
table BatchMatMul {
transpose_a :bool;
transpose_b :bool;
}
table LinSpace {
}

@ -1,85 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/ops/batch_matmul.h"
#ifndef PRIMITIVE_WRITEABLE
#include "src/ops/ops_register.h"
#endif
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
bool BatchMatMul::GetTransposeA() const { return this->primitive_->value.AsBatchMatMul()->transpose_a; }
bool BatchMatMul::GetTransposeB() const { return this->primitive_->value.AsBatchMatMul()->transpose_b; }
void BatchMatMul::SetTransposeA(bool transpose_a) {
this->primitive_->value.AsBatchMatMul()->transpose_a = transpose_a;
}
void BatchMatMul::SetTransposeB(bool transpose_b) {
this->primitive_->value.AsBatchMatMul()->transpose_b = transpose_b;
}
int BatchMatMul::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_BatchMatMul;
}
if (this->primitive_->value.type != schema::PrimitiveType_BatchMatMul) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::BatchMatMulT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new FusedBatchMatMulT failed";
delete this->primitive_;
this->primitive_ = nullptr;
return RET_ERROR;
}
attr->transpose_a = GetValue<bool>(prim.GetAttr("transpose_a"));
attr->transpose_b = GetValue<bool>(prim.GetAttr("transpose_b"));
this->primitive_->value.value = attr;
}
return RET_OK;
}
#else
bool BatchMatMul::GetTransposeA() const { return this->primitive_->value_as_BatchMatMul()->transpose_a(); }
bool BatchMatMul::GetTransposeB() const { return this->primitive_->value_as_BatchMatMul()->transpose_b(); }
int BatchMatMul::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
auto attr = primitive->value_as_BatchMatMul();
if (attr == nullptr) {
MS_LOG(ERROR) << "value_as_Add return nullptr";
return RET_ERROR;
}
auto val_offset = schema::CreateBatchMatMul(*fbb, attr->transpose_a(), attr->transpose_b());
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_BatchMatMul, val_offset.o);
fbb->Finish(prim_offset);
return RET_OK;
}
PrimitiveC *BatchMatMulCreator(const schema::Primitive *primitive) {
return PrimitiveC::NewPrimitiveC<BatchMatMul>(primitive);
}
Registry BatchMatMulRegistry(schema::PrimitiveType_BatchMatMul, BatchMatMulCreator);
#endif
} // namespace lite
} // namespace mindspore

@ -1,45 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_MINDSPORE_LITE_C_OPS_BATCH_MATMUL_H_
#define LITE_MINDSPORE_LITE_C_OPS_BATCH_MATMUL_H_
#include <vector>
#include <set>
#include <cmath>
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
class BatchMatMul : public PrimitiveC {
public:
BatchMatMul() = default;
~BatchMatMul() = default;
#ifdef PRIMITIVE_WRITEABLE
MS_DECLARE_PARENT(BatchMatMul, PrimitiveC);
explicit BatchMatMul(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
void SetTransposeA(bool transpose_a);
void SetTransposeB(bool transpose_b);
#else
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
bool GetTransposeA() const;
bool GetTransposeB() const;
};
} // namespace lite
} // namespace mindspore
#endif // LITE_MINDSPORE_LITE_C_OPS_BATCH_MATMUL_H_

@ -170,7 +170,6 @@
#include "src/ops/crop_and_resize.h"
#include "src/ops/nonzero.h"
#include "src/ops/erf.h"
#include "src/ops/batch_matmul.h"
#include "src/ops/lin_space.h"
#include "src/ops/uniform_real.h"
#include "src/ops/rank.h"
@ -1057,8 +1056,6 @@ PrimitiveC *PrimitiveC::Create(mindspore::schema::PrimitiveT *primitive) {
return new (std::nothrow) Erf(primitive);
case schema::PrimitiveType_IsFinite:
return new (std::nothrow) IsFinite(primitive);
case schema::PrimitiveType_BatchMatMul:
return new (std::nothrow) BatchMatMul(primitive);
case schema::PrimitiveType_LinSpace:
return new (std::nothrow) LinSpace(primitive);
case schema::PrimitiveType_UniformReal:

@ -125,6 +125,9 @@ int TensorListGetItem::InferShape(std::vector<lite::Tensor *> inputs_, std::vect
MS_ASSERT(inputs_.at(1) != nullptr);
MS_ASSERT(inputs_.at(2) != nullptr);
auto input0 = reinterpret_cast<TensorList *>(inputs_.at(0));
if (input0->root_tensor() != nullptr) {
input0 = reinterpret_cast<TensorList *>(input0->root_tensor());
}
auto get_index = inputs_.at(1);
MS_ASSERT(get_index != nullptr);
if (get_index->ElementsNum() != 1) {

@ -102,6 +102,7 @@ int TensorListFromTensorCPUKernel::Run() {
memcpy(out_data, in_data, data_offset);
in_data += data_offset;
}
output0->set_tensors_data_type(dtype_);
return RET_OK;
}

@ -45,6 +45,9 @@ int TensorListGetItemCPUKernel::Run() {
MS_ASSERT(in_tensors_.at(1) != nullptr);
MS_ASSERT(out_tensors_.at(0) != nullptr);
auto input0 = reinterpret_cast<lite::TensorList *>(in_tensors_.at(0));
if (input0->root_tensor() != nullptr) {
input0 = reinterpret_cast<lite::TensorList *>(input0->root_tensor());
}
if (dtype_ != input0->tensors_data_type()) {
MS_LOG(ERROR) << "op dtype: " << dtype_ << " is not equal in_tensor[0] dtype: " << input0->tensors_data_type();
return RET_ERROR;

@ -16,6 +16,7 @@
#include "src/sub_graph_kernel.h"
#include "src/tensor.h"
#include "src/tensorlist.h"
#if defined(ENABLE_ARM64) && defined(ENABLE_FP16)
#include "src/runtime/kernel/arm/fp16/fp16_op_handler.h"
#endif
@ -176,7 +177,8 @@ int CpuSubGraph::Prepare() {
#ifdef ENABLE_FP16
void CpuFp16SubGraph::FreeOriginInputData() {
for (auto *data_store : this->origin_input_data_) {
for (auto &iter : this->origin_input_data_) {
auto *data_store = iter.second;
if (data_store == nullptr) {
continue;
}
@ -199,18 +201,12 @@ void CpuFp16SubGraph::FreeOriginInputData() {
this->origin_input_data_.clear();
}
int CpuFp16SubGraph::PreProcess() {
#ifdef ENABLE_ARM64
if (!mindspore::lite::IsSupportFloat16()) {
MS_LOG(ERROR) << "Unsupport fp16 in this devices";
return RET_ERROR;
}
MS_ASSERT(origin_input_data_.empty());
for (auto tensor : this->in_tensors_) {
MS_ASSERT(tensor != nullptr);
if (tensor->data_type() == kNumberTypeFloat32) {
int CpuFp16SubGraph::Float32TensorToFloat16Tensor(lite::Tensor *tensor) {
auto float32_data = tensor->data_c();
MS_ASSERT(float32_data != nullptr);
if (float32_data == nullptr) {
MS_LOG(ERROR) << "tensor data is null.";
return lite::RET_NULL_PTR;
}
tensor->set_data(nullptr);
tensor->set_data_type(TypeId::kNumberTypeFloat16);
auto ret = tensor->MallocData();
@ -227,9 +223,77 @@ int CpuFp16SubGraph::PreProcess() {
this->FreeOriginInputData();
return RET_ERROR;
}
origin_input_data_.emplace_back(data_store);
origin_input_data_[tensor] = data_store;
return RET_OK;
}
int CpuFp16SubGraph::Float16TensorToFloat32Tensor(lite::Tensor *tensor) {
auto float16_data = tensor->data_c();
if (float16_data == nullptr) {
MS_LOG(ERROR) << "tensor data is null.";
return lite::RET_NULL_PTR;
}
tensor->set_data(nullptr);
tensor->set_data_type(TypeId::kNumberTypeFloat32);
auto ret = tensor->MallocData();
if (ret != RET_OK) {
MS_LOG(ERROR) << "malloc data failed";
if (this->context_ != nullptr && this->context_->allocator != nullptr) {
this->context_->allocator->Free(float16_data);
} else {
origin_input_data_.emplace_back(nullptr);
free(float16_data);
}
return RET_ERROR;
}
MS_ASSERT(tensor->data_c() != nullptr);
Float16ToFloat32_fp16_handler(float16_data, tensor->data_c(), tensor->ElementsNum());
if (tensor->allocator() != nullptr) {
tensor->allocator()->Free(float16_data);
} else {
free(float16_data);
}
return RET_OK;
}
int CpuFp16SubGraph::PreProcess() {
#ifdef ENABLE_ARM64
if (!mindspore::lite::IsSupportFloat16()) {
MS_LOG(ERROR) << "Unsupported fp16 in this devices";
return RET_ERROR;
}
int ret;
for (auto tensor : this->in_tensors_) {
MS_ASSERT(tensor != nullptr);
auto real_tensor = tensor;
if (tensor->root_tensor() != nullptr) {
real_tensor = tensor->root_tensor();
if (tensor->data_type() == kNumberTypeFloat32) {
tensor->set_data_type(kNumberTypeFloat16);
} else if (tensor->data_type() == kObjectTypeTensorType) {
auto tensorlist = reinterpret_cast<lite::TensorList *>(tensor);
if (tensorlist->tensors_data_type() == kNumberTypeFloat32) {
tensorlist->set_tensors_data_type(kNumberTypeFloat16);
}
}
}
if (real_tensor->data_type() == kNumberTypeFloat32) {
ret = Float32TensorToFloat16Tensor(real_tensor);
if (RET_OK != ret) {
MS_LOG(ERROR) << "Float32TensorToFloat16Tensor failed.";
return ret;
}
} else if (real_tensor->data_type() == kObjectTypeTensorType) {
auto tensorlist = reinterpret_cast<lite::TensorList *>(real_tensor);
if (tensorlist->tensors_data_type() == kNumberTypeFloat32) {
tensorlist->set_tensors_data_type(kNumberTypeFloat16);
for (auto inner_tensor : tensorlist->tensors()) {
ret = Float32TensorToFloat16Tensor(inner_tensor);
if (RET_OK != ret) {
MS_LOG(ERROR) << "Float32TensorToFloat16Tensor failed.";
return ret;
}
}
}
}
}
for (auto kernel : this->nodes_) {
@ -239,6 +303,11 @@ int CpuFp16SubGraph::PreProcess() {
}
if (tensor->data_type() == kNumberTypeFloat32) {
tensor->set_data_type(kNumberTypeFloat16);
} else if (tensor->data_type() == kObjectTypeTensorType) {
auto tensorlist = reinterpret_cast<lite::TensorList *>(tensor);
if (tensorlist->tensors_data_type() == kNumberTypeFloat32) {
tensorlist->set_tensors_data_type(kNumberTypeFloat16);
}
}
}
}
@ -251,47 +320,72 @@ int CpuFp16SubGraph::PreProcess() {
int CpuFp16SubGraph::PostProcess() {
#ifdef ENABLE_ARM64
if (!mindspore::lite::IsSupportFloat16()) {
MS_LOG(ERROR) << "Unsupport fp16 in this devices";
MS_LOG(ERROR) << "Unsupported fp16 in this devices";
return RET_ERROR;
}
int ret;
for (auto tensor : this->out_tensors_) {
MS_ASSERT(tensor != nullptr);
if (tensor->data_type() == kNumberTypeFloat16) {
auto float16_data = tensor->data_c();
MS_ASSERT(float16_data != nullptr);
tensor->set_data(nullptr);
tensor->set_data_type(TypeId::kNumberTypeFloat32);
auto ret = tensor->MallocData();
if (ret != RET_OK) {
MS_LOG(ERROR) << "malloc data failed";
if (this->context_ != nullptr && this->context_->allocator != nullptr) {
this->context_->allocator->Free(float16_data);
} else {
free(float16_data);
ret = Float16TensorToFloat32Tensor(tensor);
if (RET_OK != ret) {
MS_LOG(ERROR) << "Float16TensorToFloat32Tensor failed.";
return ret;
}
return RET_ERROR;
} else if (tensor->data_type() == kObjectTypeTensorType) {
auto tensorlist = reinterpret_cast<lite::TensorList *>(tensor);
if (tensorlist->tensors_data_type() == kNumberTypeFloat16) {
tensorlist->set_tensors_data_type(kNumberTypeFloat32);
for (auto inner_tensor : tensorlist->tensors()) {
ret = Float16TensorToFloat32Tensor(inner_tensor);
if (RET_OK != ret) {
MS_LOG(ERROR) << "Float32TensorToFloat16Tensor failed.";
return ret;
}
MS_ASSERT(tensor->data_c() != nullptr);
Float16ToFloat32_fp16_handler(float16_data, tensor->data_c(), tensor->ElementsNum());
if (tensor->allocator() != nullptr) {
tensor->allocator()->Free(float16_data);
} else {
free(float16_data);
}
}
}
MS_ASSERT(this->origin_input_data_.size() == this->in_tensors_.size());
}
int tensor_count = 0;
for (size_t i = 0; i < this->in_tensors_.size(); i++) {
auto tensor = in_tensors_.at(i);
MS_ASSERT(tensor != nullptr);
auto origin_tensor_data = origin_input_data_.at(i);
if (tensor->data_type() == kNumberTypeFloat16 && origin_tensor_data != nullptr) {
MS_ASSERT(tensor != nullptr);
tensor->FreeData();
MS_ASSERT(origin_tensor_data->data_ != nullptr);
tensor->set_data(origin_tensor_data->data_);
auto real_tensor = tensor;
if (tensor->root_tensor() != nullptr) {
real_tensor = tensor->root_tensor();
if (tensor->data_type() == kNumberTypeFloat16) {
tensor->set_data_type(kNumberTypeFloat32);
} else if (tensor->data_type() == kObjectTypeTensorType) {
auto tensorlist = reinterpret_cast<lite::TensorList *>(tensor);
if (tensorlist->tensors_data_type() == kNumberTypeFloat16) {
tensorlist->set_tensors_data_type(kNumberTypeFloat32);
}
}
}
if (real_tensor->data_type() == kNumberTypeFloat16 && origin_input_data_.at(real_tensor) != nullptr) {
auto origin_tensor_data = origin_input_data_.at(real_tensor);
real_tensor->FreeData();
MS_ASSERT(origin_tensor_data->data_ != nullptr);
real_tensor->set_data(origin_tensor_data->data_);
real_tensor->set_data_type(kNumberTypeFloat32);
origin_tensor_data->data_ = nullptr;
tensor_count++;
} else if (real_tensor->data_type() == kObjectTypeTensorType) {
auto tensorlist = reinterpret_cast<lite::TensorList *>(real_tensor);
if (tensorlist->tensors_data_type() == kNumberTypeFloat16) {
tensorlist->set_tensors_data_type(kNumberTypeFloat32);
for (auto inner_tensor : tensorlist->tensors()) {
MS_ASSERT(inner_tensor != nullptr);
auto origin_tensor_data = origin_input_data_.at(inner_tensor);
inner_tensor->FreeData();
MS_ASSERT(origin_tensor_data->data_ != nullptr);
inner_tensor->set_data(origin_tensor_data->data_);
inner_tensor->set_data_type(kNumberTypeFloat32);
origin_tensor_data->data_ = nullptr;
tensor_count++;
}
}
}
}
this->FreeOriginInputData();

@ -20,6 +20,7 @@
#include <utility>
#include <string>
#include <vector>
#include <map>
#include "src/lite_kernel.h"
#include "src/executor.h"
#include "src/common/log_adapter.h"
@ -179,9 +180,11 @@ class CpuFp16SubGraph : public CpuSubGraph {
private:
void FreeOriginInputData();
int Float32TensorToFloat16Tensor(lite::Tensor *tensor);
int Float16TensorToFloat32Tensor(lite::Tensor *tensor);
private:
std::vector<DataStore *> origin_input_data_{};
std::map<lite::Tensor *, DataStore *> origin_input_data_;
};
#endif
} // namespace mindspore::kernel

@ -35,7 +35,7 @@ STATUS TFBatchMatMulParser::Parse(const tensorflow::NodeDef &tf_op,
MS_LOG(ERROR) << "New PrimitiveT failed";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::BatchMatMulT>();
auto attr = std::make_unique<schema::MatMulT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
@ -45,13 +45,13 @@ STATUS TFBatchMatMulParser::Parse(const tensorflow::NodeDef &tf_op,
MS_LOG(ERROR) << "The begin_mask attr should be specified";
return RET_ERROR;
}
attr->transpose_a = attr_value.b();
attr->transposeA = attr_value.b();
if (!TensorFlowUtils::FindAttrValue(tf_op, "adj_y", &attr_value)) {
MS_LOG(ERROR) << "The begin_mask attr should be specified";
return RET_ERROR;
}
attr->transpose_b = attr_value.b();
primitive->value.type = schema::PrimitiveType_BatchMatMul;
attr->transposeB = attr_value.b();
primitive->value.type = schema::PrimitiveType_MatMul;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {

Loading…
Cancel
Save