!9425 [MS][list][x86] add new tensorlist ops
From: @lzkcode Reviewed-by: Signed-off-by:pull/9425/MERGE
commit
6550769104
@ -0,0 +1,41 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/tensorlistsetitem.h"
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "src/ops/populate/populate_register.h"
|
||||
#include "nnacl/tensorlist_parameter.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
OpParameter *PopulateTensorListSetItemParameter(const mindspore::lite::PrimitiveC *primitive) {
|
||||
TensorListParameter *setItem_param = reinterpret_cast<TensorListParameter *>(malloc(sizeof(TensorListParameter)));
|
||||
if (setItem_param == nullptr) {
|
||||
MS_LOG(ERROR) << "malloc TensorListParameter failed.";
|
||||
return nullptr;
|
||||
}
|
||||
memset(setItem_param, 0, sizeof(TensorListParameter));
|
||||
setItem_param->op_parameter_.type_ = primitive->Type();
|
||||
auto setItem =
|
||||
reinterpret_cast<mindspore::lite::TensorListSetItem *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
|
||||
setItem_param->element_dtype_ = setItem->GetElementDType();
|
||||
return reinterpret_cast<OpParameter *>(setItem_param);
|
||||
}
|
||||
Registry TensorListSetItemParameterRegistry(schema::PrimitiveType_TensorListSetItem,
|
||||
PopulateTensorListSetItemParameter);
|
||||
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,141 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <vector>
|
||||
#include "src/ops/tensorlistfromtensor.h"
|
||||
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
#include "src/ops/ops_register.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int TensorListFromTensor::GetElementDType() const {
|
||||
return this->primitive_->value.AsTensorListFromTensor()->elementDType;
|
||||
}
|
||||
|
||||
int TensorListFromTensor::GetShapeType() const { return this->primitive_->value.AsTensorListFromTensor()->shapeType; }
|
||||
|
||||
void TensorListFromTensor::SetElementDType(int type) {
|
||||
this->primitive_->value.AsTensorListFromTensor()->elementDType = type;
|
||||
}
|
||||
|
||||
void TensorListFromTensor::SetShapeType(int type) {
|
||||
this->primitive_->value.AsTensorListFromTensor()->shapeType = type;
|
||||
}
|
||||
|
||||
int TensorListFromTensor::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_TensorListFromTensor;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_TensorListFromTensor) {
|
||||
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
|
||||
delete this->primitive_;
|
||||
this->primitive_ = nullptr;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
auto attr = new (std::nothrow) schema::TensorListFromTensorT();
|
||||
if (attr == nullptr) {
|
||||
delete this->primitive_;
|
||||
this->primitive_ = nullptr;
|
||||
MS_LOG(ERROR) << "new TensorListFromTensorT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (prim.GetAttr("elementDType") == nullptr) {
|
||||
delete this->primitive_;
|
||||
this->primitive_ = nullptr;
|
||||
delete attr;
|
||||
MS_LOG(ERROR) << "TensorListFromTensorT's attr elementDType is not set";
|
||||
return RET_ERROR;
|
||||
} else {
|
||||
attr->elementDType = CastToInt(prim.GetAttr("elementDType")).front();
|
||||
}
|
||||
if (prim.GetAttr("shapeType") == nullptr) {
|
||||
delete this->primitive_;
|
||||
this->primitive_ = nullptr;
|
||||
delete attr;
|
||||
MS_LOG(ERROR) << "TensorListFromTensorT's attr shapeType is not set";
|
||||
return RET_ERROR;
|
||||
} else {
|
||||
attr->shapeType = CastToInt(prim.GetAttr("shapeType")).front();
|
||||
}
|
||||
this->primitive_->value.value = attr;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
#else
|
||||
int TensorListFromTensor::GetElementDType() const {
|
||||
return this->primitive_->value_as_TensorListFromTensor()->elementDType();
|
||||
}
|
||||
|
||||
int TensorListFromTensor::GetShapeType() const {
|
||||
return this->primitive_->value_as_TensorListFromTensor()->shapeType();
|
||||
}
|
||||
|
||||
int TensorListFromTensor::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_TensorListFromTensor();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_TensorListFromTensor return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateTensorListFromTensor(*fbb, attr->elementDType(), attr->shapeType());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_TensorListFromTensor, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *TensorListFromTensorCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<TensorListFromTensor>(primitive);
|
||||
}
|
||||
Registry TensorListFromTensorRegistry(schema::PrimitiveType_TensorListFromTensor, TensorListFromTensorCreator);
|
||||
#endif
|
||||
|
||||
int TensorListFromTensor::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) {
|
||||
auto input0 = inputs_[0];
|
||||
MS_ASSERT(input0 != nullptr);
|
||||
std::vector<int> input0_shape = input0->shape();
|
||||
if (input0_shape.size() < 1) {
|
||||
MS_LOG(ERROR) << "input0_shape.size():" << input0_shape.size() << " must be greater than 0!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int dim0 = input0_shape[0];
|
||||
if (dim0 < 0) {
|
||||
MS_LOG(ERROR) << "inputs_[0] dim0:" << dim0 << " must greater than or equal to 0";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto input1 = inputs_[1];
|
||||
MS_ASSERT(input1 != nullptr);
|
||||
auto ele_shape_ptr = reinterpret_cast<int *>(input1->data_c());
|
||||
auto output = reinterpret_cast<TensorList *>(outputs_[0]);
|
||||
MS_ASSERT(output != nullptr);
|
||||
// output->set_tensors_data_type(input0->data_type());
|
||||
std::vector<std::vector<int> > tensor_shape(dim0, std::vector<int>(input0_shape.begin() + 1, input0_shape.end()));
|
||||
output->set_element_shape(std::vector<int>(ele_shape_ptr, ele_shape_ptr + input1->ElementsNum()));
|
||||
output->set_shape(std::vector<int>(1, dim0));
|
||||
output->set_data_type(kObjectTypeTensorType);
|
||||
output->MallocTensorListData(input0->data_type(), tensor_shape);
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,44 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <vector>
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "src/tensorlist.h"
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_C_OPS_TENSORLISTFROMTENSOR_H_
|
||||
#define LITE_MINDSPORE_LITE_C_OPS_TENSORLISTFROMTENSOR_H_
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TensorListFromTensor : public PrimitiveC {
|
||||
public:
|
||||
TensorListFromTensor() = default;
|
||||
~TensorListFromTensor() = default;
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(TensorListFromTensor, PrimitiveC);
|
||||
void SetElementDType(int type);
|
||||
void SetShapeType(int type);
|
||||
explicit TensorListFromTensor(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
#else
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int GetElementDType() const;
|
||||
int GetShapeType() const;
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_TENSORLISTFROMTENSOR_H_
|
@ -0,0 +1,171 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <vector>
|
||||
#include "src/ops/tensorlistgetitem.h"
|
||||
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
#include "src/ops/ops_register.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
TypeId TensorListGetItem::GetElementDType() const {
|
||||
return (TypeId)(this->primitive_->value.AsTensorListGetItem()->elementDType);
|
||||
}
|
||||
|
||||
void TensorListGetItem::SetElementDType(int type) {
|
||||
this->primitive_->value.AsTensorListGetItem()->elementDType = type;
|
||||
}
|
||||
|
||||
int TensorListGetItem::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_TensorListGetItem;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_TensorListGetItem) {
|
||||
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
|
||||
delete this->primitive_;
|
||||
this->primitive_ = nullptr;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
auto attr = new (std::nothrow) schema::TensorListGetItemT();
|
||||
if (attr == nullptr) {
|
||||
delete this->primitive_;
|
||||
this->primitive_ = nullptr;
|
||||
MS_LOG(ERROR) << "new TensorListGetItemT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (prim.GetAttr("elementDType") == nullptr) {
|
||||
delete this->primitive_;
|
||||
this->primitive_ = nullptr;
|
||||
delete attr;
|
||||
MS_LOG(ERROR) << "TensorListGetItem's attr elementDType is not set";
|
||||
return RET_ERROR;
|
||||
} else {
|
||||
attr->elementDType = CastToInt(prim.GetAttr("elementDType")).front();
|
||||
}
|
||||
this->primitive_->value.value = attr;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
#else
|
||||
TypeId TensorListGetItem::GetElementDType() const {
|
||||
return (TypeId)(this->primitive_->value_as_TensorListGetItem()->elementDType());
|
||||
}
|
||||
|
||||
int TensorListGetItem::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_TensorListGetItem();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_TensorListGetItem return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateTensorListGetItem(*fbb, attr->elementDType());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_TensorListGetItem, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *TensorListGetItemCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<TensorListGetItem>(primitive);
|
||||
}
|
||||
Registry TensorListGetItemRegistry(schema::PrimitiveType_TensorListGetItem, TensorListGetItemCreator);
|
||||
#endif
|
||||
bool TensorListGetItem::IsFullyDefined(const std::vector<int> &shape) const {
|
||||
for (size_t i = 0; i < shape.size(); ++i) {
|
||||
if (shape[i] < 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int TensorListGetItem::MergeShape(const std::vector<int> &tmp) {
|
||||
if (element_shape_.size() != tmp.size()) {
|
||||
MS_LOG(ERROR) << "element_shape_.size():" << element_shape_.size() << " must be equal to tmp.size():" << tmp.size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (size_t j = 0; j < tmp.size(); ++j) {
|
||||
if (element_shape_[j] >= 0 && tmp[j] >= 0 && element_shape_[j] != tmp[j]) {
|
||||
MS_LOG(ERROR) << "element_shape_[" << j << "]:" << element_shape_[j] << " must be equal to tmp[" << j
|
||||
<< "]:" << tmp[j];
|
||||
return RET_ERROR;
|
||||
}
|
||||
element_shape_[j] = element_shape_[j] >= 0 ? element_shape_[j] : tmp[j];
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int TensorListGetItem::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) {
|
||||
auto input0 = reinterpret_cast<TensorList *>(inputs_[0]);
|
||||
auto get_index = inputs_[1];
|
||||
MS_ASSERT(get_index != nullptr);
|
||||
if (get_index->ElementsNum() != 1) {
|
||||
MS_LOG(ERROR) << "get_index->ElementsNum():" << get_index->ElementsNum() << " must be equal to 1!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
index_ = reinterpret_cast<int *>(get_index->data_c())[0];
|
||||
if (index_ < 0 || index_ > (input0->ElementsNum() - 1)) {
|
||||
MS_LOG(ERROR) << "index_:" << index_ << "must in [0, " << input0->ElementsNum() - 1 << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto tensor_index = input0->GetTensorIndex(index_);
|
||||
MS_ASSERT(tensor_index != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
if (tensor_index->data_type() != kTypeUnknown) {
|
||||
output->set_data_type(tensor_index->data_type());
|
||||
output->set_shape(tensor_index->shape());
|
||||
} else {
|
||||
auto input2 = inputs_[2];
|
||||
auto ele_shape_data = reinterpret_cast<int *>(input2->data_c());
|
||||
for (int i = 0; i < input2->ElementsNum(); ++i) {
|
||||
element_shape_.push_back(ele_shape_data[i]);
|
||||
}
|
||||
auto status = MergeShape(input0->element_shape());
|
||||
if (status != RET_OK) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!IsFullyDefined(element_shape_)) {
|
||||
for (int i = 0; i < input0->ElementsNum(); ++i) {
|
||||
auto input = input0->GetTensorIndex(i);
|
||||
MS_ASSERT(input != nullptr);
|
||||
if (input->data_type() != kTypeUnknown) {
|
||||
status = MergeShape(input->shape());
|
||||
if (status != RET_OK) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!IsFullyDefined(element_shape_)) {
|
||||
MS_LOG(ERROR) << "element_shape_ is not fullyDefined!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
output->set_data_type(GetElementDType());
|
||||
output->set_shape(element_shape_);
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,49 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <vector>
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "src/tensorlist.h"
|
||||
#include "ir/dtype/type_id.h"
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_C_OPS_TENSORLISTGETITEM_H_
|
||||
#define LITE_MINDSPORE_LITE_C_OPS_TENSORLISTGETITEM_H_
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TensorListGetItem : public PrimitiveC {
|
||||
public:
|
||||
TensorListGetItem() = default;
|
||||
~TensorListGetItem() = default;
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(TensorListGetItem, PrimitiveC);
|
||||
void SetElementDType(int type);
|
||||
explicit TensorListGetItem(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
#else
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
TypeId GetElementDType() const;
|
||||
int MergeShape(const std::vector<int> &tmp);
|
||||
bool IsFullyDefined(const std::vector<int> &shape) const;
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
|
||||
private:
|
||||
int index_ = -1;
|
||||
std::vector<int> element_shape_;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_TENSORLISTGETITEM_H_
|
@ -0,0 +1,131 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <vector>
|
||||
#include "src/ops/tensorlistreserve.h"
|
||||
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
#include "src/ops/ops_register.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
TypeId TensorListReserve::GetElementDType() const {
|
||||
return (TypeId)(this->primitive_->value.AsTensorListReserve()->elementDType);
|
||||
}
|
||||
|
||||
void TensorListReserve::SetElementDType(int type) {
|
||||
this->primitive_->value.AsTensorListReserve()->elementDType = type;
|
||||
}
|
||||
|
||||
int TensorListReserve::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_TensorListReserve;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_TensorListReserve) {
|
||||
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
|
||||
delete this->primitive_;
|
||||
this->primitive_ = nullptr;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
auto attr = new (std::nothrow) schema::TensorListReserveT();
|
||||
if (attr == nullptr) {
|
||||
delete this->primitive_;
|
||||
this->primitive_ = nullptr;
|
||||
MS_LOG(ERROR) << "new TensorListReserveT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (prim.GetAttr("elementDType") == nullptr) {
|
||||
delete this->primitive_;
|
||||
this->primitive_ = nullptr;
|
||||
delete attr;
|
||||
MS_LOG(ERROR) << "TensorListReserve's attr elementDType is not set";
|
||||
return RET_ERROR;
|
||||
} else {
|
||||
attr->elementDType = CastToInt(prim.GetAttr("elementDType")).front();
|
||||
}
|
||||
this->primitive_->value.value = attr;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
#else
|
||||
TypeId TensorListReserve::GetElementDType() const {
|
||||
return (TypeId)(this->primitive_->value_as_TensorListReserve()->elementDType());
|
||||
}
|
||||
|
||||
int TensorListReserve::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(primitive != nullptr);
|
||||
MS_ASSERT(fbb != nullptr);
|
||||
auto attr = primitive->value_as_TensorListReserve();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_TensorListReserve return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateTensorListReserve(*fbb, attr->elementDType());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_TensorListReserve, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *TensorListReserveCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<TensorListReserve>(primitive);
|
||||
}
|
||||
Registry TensorListReserveRegistry(schema::PrimitiveType_TensorListReserve, TensorListReserveCreator);
|
||||
#endif
|
||||
|
||||
int TensorListReserve::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) {
|
||||
// input0: element_shape_tensor
|
||||
// input1: num_elements
|
||||
auto input0 = inputs_.front();
|
||||
MS_ASSERT(input0 != nullptr);
|
||||
auto ele_shape_type = input0->data_type();
|
||||
if (ele_shape_type != kNumberTypeInt) {
|
||||
MS_LOG(ERROR) << "ele_shape_tensor.data_type():" << ele_shape_type
|
||||
<< " must be \"kNumberTypeInt\":" << kNumberTypeInt;
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto ele_shape_ptr = reinterpret_cast<int *>(input0->data_c());
|
||||
|
||||
auto input1 = inputs_[1];
|
||||
MS_ASSERT(input1 != nullptr);
|
||||
auto num_ele_type = input1->data_type();
|
||||
if (num_ele_type != kNumberTypeInt) {
|
||||
MS_LOG(ERROR) << "num_ele_tensor.data_type():" << num_ele_type << " must be \"kNumberTypeInt\":" << kNumberTypeInt;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (input1->ElementsNum() != 1) {
|
||||
MS_LOG(ERROR) << "input1->ElementsNum() must be equal to 1";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int num_elements = reinterpret_cast<int *>(input1->data_c())[0];
|
||||
|
||||
auto output = reinterpret_cast<TensorList *>(outputs_[0]);
|
||||
output->set_data_type(kObjectTypeTensorType);
|
||||
std::vector<std::vector<int> > tmp_shape(num_elements, std::vector<int>());
|
||||
output->set_element_shape(std::vector<int>(ele_shape_ptr, ele_shape_ptr + input0->ElementsNum()));
|
||||
output->set_shape(std::vector<int>(1, num_elements));
|
||||
output->MallocTensorListData(kTypeUnknown, tmp_shape);
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,43 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <vector>
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "src/tensorlist.h"
|
||||
#include "ir/dtype/type_id.h"
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_C_OPS_TENSORLISTRESERVE_H_
|
||||
#define LITE_MINDSPORE_LITE_C_OPS_TENSORLISTRESERVE_H_
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TensorListReserve : public PrimitiveC {
|
||||
public:
|
||||
TensorListReserve() = default;
|
||||
~TensorListReserve() = default;
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(TensorListReserve, PrimitiveC);
|
||||
void SetElementDType(int type);
|
||||
explicit TensorListReserve(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
#else
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
TypeId GetElementDType() const;
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_TENSORLISTRESERVE_H_
|
@ -0,0 +1,140 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <vector>
|
||||
#include "src/ops/tensorlistsetitem.h"
|
||||
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
#include "src/ops/ops_register.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
TypeId TensorListSetItem::GetElementDType() const {
|
||||
return (TypeId)(this->primitive_->value.AsTensorListSetItem()->elementDType);
|
||||
}
|
||||
|
||||
void TensorListSetItem::SetElementDType(int type) {
|
||||
this->primitive_->value.AsTensorListSetItem()->elementDType = type;
|
||||
}
|
||||
|
||||
int TensorListSetItem::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_TensorListSetItem;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_TensorListSetItem) {
|
||||
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
|
||||
delete this->primitive_;
|
||||
this->primitive_ = nullptr;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
auto attr = new (std::nothrow) schema::TensorListSetItemT();
|
||||
if (attr == nullptr) {
|
||||
delete this->primitive_;
|
||||
this->primitive_ = nullptr;
|
||||
MS_LOG(ERROR) << "new TensorListSetItemT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (prim.GetAttr("elementDType") == nullptr) {
|
||||
delete this->primitive_;
|
||||
this->primitive_ = nullptr;
|
||||
delete attr;
|
||||
MS_LOG(ERROR) << "TensorListSetItem's attr elementDType is not set";
|
||||
return RET_ERROR;
|
||||
} else {
|
||||
attr->elementDType = CastToInt(prim.GetAttr("elementDType")).front();
|
||||
}
|
||||
this->primitive_->value.value = attr;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
#else
|
||||
TypeId TensorListSetItem::GetElementDType() const {
|
||||
return (TypeId)(this->primitive_->value_as_TensorListSetItem()->elementDType());
|
||||
}
|
||||
|
||||
int TensorListSetItem::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_TensorListSetItem();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_TensorListSetItem return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateTensorListSetItem(*fbb, attr->elementDType());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_TensorListSetItem, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *TensorListSetItemCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<TensorListSetItem>(primitive);
|
||||
}
|
||||
Registry TensorListSetItemRegistry(schema::PrimitiveType_TensorListSetItem, TensorListSetItemCreator);
|
||||
#endif
|
||||
|
||||
int TensorListSetItem::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) {
|
||||
auto input0 = reinterpret_cast<TensorList *>(inputs_[0]);
|
||||
MS_ASSERT(input0 != nullptr);
|
||||
auto get_index = inputs_[1];
|
||||
MS_ASSERT(get_index != nullptr);
|
||||
if (get_index->data_type() != kNumberTypeInt) {
|
||||
MS_LOG(ERROR) << "inputs_[1]->data_type():" << get_index->data_type()
|
||||
<< " must be equal to \"kNumberTypeInt\":" << kNumberTypeInt;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (get_index->ElementsNum() != 1) {
|
||||
MS_LOG(ERROR) << "inputs_[1].ElementsNum():" << get_index->ElementsNum() << " must be equal to 1!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
int index = reinterpret_cast<int *>(get_index->data_c())[0];
|
||||
if (index < 0 || index > (input0->ElementsNum() - 1)) {
|
||||
MS_LOG(ERROR) << "index_:" << index << "must in [0, " << input0->ElementsNum() - 1 << "]";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto value_tensor = inputs_[2];
|
||||
MS_ASSERT(value_tensor != nullptr);
|
||||
auto output0 = reinterpret_cast<TensorList *>(outputs_[0]);
|
||||
MS_ASSERT(output0 != nullptr);
|
||||
output0->set_element_shape(input0->element_shape());
|
||||
output0->set_max_elements_num(input0->max_elements_num());
|
||||
output0->set_shape(input0->shape());
|
||||
output0->set_data_type(input0->data_type());
|
||||
std::vector<std::vector<int> > out_shape;
|
||||
for (int i = 0; i < input0->ElementsNum(); ++i) {
|
||||
auto src_ptr = input0->GetTensorIndex(i);
|
||||
if (src_ptr == nullptr) {
|
||||
MS_LOG(ERROR) << "input0->tensors_[" << i << "] is nullptr!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (src_ptr->data_type() != kTypeUnknown) {
|
||||
out_shape.push_back(src_ptr->shape());
|
||||
} else {
|
||||
out_shape.push_back(std::vector<int>());
|
||||
}
|
||||
}
|
||||
out_shape[index] = value_tensor->shape();
|
||||
output0->MallocTensorListData(input0->tensors_data_type(), out_shape);
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,43 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <vector>
|
||||
#include "src/ops/primitive_c.h"
|
||||
#include "src/tensorlist.h"
|
||||
#include "ir/dtype/type_id.h"
|
||||
|
||||
#ifndef LITE_MINDSPORE_LITE_C_OPS_TENSORLISTSETITEM_H_
|
||||
#define LITE_MINDSPORE_LITE_C_OPS_TENSORLISTSETITEM_H_
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class TensorListSetItem : public PrimitiveC {
|
||||
public:
|
||||
TensorListSetItem() = default;
|
||||
~TensorListSetItem() = default;
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(TensorListSetItem, PrimitiveC);
|
||||
void SetElementDType(int type);
|
||||
explicit TensorListSetItem(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
#else
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
TypeId GetElementDType() const;
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // LITE_MINDSPORE_LITE_C_OPS_TENSORLISTSETITEM_H_
|
@ -0,0 +1,188 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <vector>
|
||||
#include "src/ops/tensorliststack.h"
|
||||
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
#include "src/ops/ops_register.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
TypeId TensorListStack::GetElementDType() const {
|
||||
return (TypeId)(this->primitive_->value.AsTensorListStack()->elementDType);
|
||||
}
|
||||
|
||||
int TensorListStack::GetNumElements() const { return this->primitive_->value.AsTensorListStack()->numElements; }
|
||||
|
||||
void TensorListStack::SetElementDType(int type) { this->primitive_->value.AsTensorListStack()->elementDType = type; }
|
||||
|
||||
void TensorListStack::SetNumElements(int num_elements) {
|
||||
this->primitive_->value.AsTensorListStack()->numElements = num_elements;
|
||||
}
|
||||
|
||||
int TensorListStack::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_TensorListStack;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_TensorListStack) {
|
||||
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
|
||||
delete this->primitive_;
|
||||
this->primitive_ = nullptr;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
auto attr = new (std::nothrow) schema::TensorListStackT();
|
||||
if (attr == nullptr) {
|
||||
delete this->primitive_;
|
||||
this->primitive_ = nullptr;
|
||||
MS_LOG(ERROR) << "new TensorListStackT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (prim.GetAttr("elementDType") == nullptr) {
|
||||
delete this->primitive_;
|
||||
this->primitive_ = nullptr;
|
||||
delete attr;
|
||||
MS_LOG(ERROR) << "TensorListStack's attr elementDType is not set";
|
||||
return RET_ERROR;
|
||||
} else {
|
||||
attr->elementDType = CastToInt(prim.GetAttr("elementDType")).front();
|
||||
}
|
||||
if (prim.GetAttr("numElements") == nullptr) {
|
||||
delete this->primitive_;
|
||||
this->primitive_ = nullptr;
|
||||
delete attr;
|
||||
MS_LOG(ERROR) << "TensorListStack's attr numElements is not set";
|
||||
return RET_ERROR;
|
||||
} else {
|
||||
attr->numElements = CastToInt(prim.GetAttr("numElements")).front();
|
||||
}
|
||||
this->primitive_->value.value = attr;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
#else
|
||||
TypeId TensorListStack::GetElementDType() const {
|
||||
return (TypeId)(this->primitive_->value_as_TensorListStack()->elementDType());
|
||||
}
|
||||
|
||||
int TensorListStack::GetNumElements() const { return this->primitive_->value_as_TensorListStack()->numElements(); }
|
||||
|
||||
int TensorListStack::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_TensorListStack();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_TensorListStack return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateTensorListStack(*fbb, attr->elementDType(), attr->numElements());
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_TensorListStack, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
PrimitiveC *TensorListStackCreator(const schema::Primitive *primitive) {
|
||||
return PrimitiveC::NewPrimitiveC<TensorListStack>(primitive);
|
||||
}
|
||||
Registry TensorListStackRegistry(schema::PrimitiveType_TensorListStack, TensorListStackCreator);
|
||||
#endif
|
||||
|
||||
bool TensorListStack::IsFullyDefined(const std::vector<int> &shape) const {
|
||||
for (size_t i = 0; i < shape.size(); ++i) {
|
||||
if (shape[i] < 0) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int TensorListStack::InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) {
|
||||
auto input0 = reinterpret_cast<TensorList *>(inputs_.front());
|
||||
MS_ASSERT(input0 != nullptr);
|
||||
if (input0->ElementsNum() == 0) {
|
||||
MS_LOG(ERROR) << "Try to stack a empty tensorlist!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto ele_shape = inputs_[1]; // element shape
|
||||
MS_ASSERT(ele_shape != nullptr);
|
||||
auto ele_shape_ptr = reinterpret_cast<int *>(ele_shape->data_c());
|
||||
for (int i = 0; ele_shape->ElementsNum(); ++i) {
|
||||
output_shape_.push_back(ele_shape_ptr[i]);
|
||||
}
|
||||
|
||||
auto status = MergeShape(input0->element_shape());
|
||||
if (status == RET_ERROR) {
|
||||
MS_LOG(ERROR) << "Merge element_shape is error!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!IsFullyDefined(output_shape_)) {
|
||||
MS_LOG(ERROR) << "output_shape_ Is Not FullyDefined!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (!IsFullyDefined(input0->element_shape())) {
|
||||
for (int i = 0; i < input0->ElementsNum(); ++i) {
|
||||
auto tensor_ele = input0->GetTensorIndex(i);
|
||||
MS_ASSERT(tensor_ele != nullptr);
|
||||
if (tensor_ele->data_type() != kTypeUnknown) {
|
||||
status = MergeShape(tensor_ele->shape());
|
||||
if (status == RET_ERROR) {
|
||||
MS_LOG(ERROR) << "Merge input0->tensors_[" << i << "] is error!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
output->set_data_type(input0->tensors_data_type());
|
||||
output->set_shape(std::vector<int>(
|
||||
1,
|
||||
input0->ElementsNum() * std::accumulate(output_shape_.begin(), output_shape_.end(), 1LL, std::multiplies<int>())));
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int TensorListStack::MergeShape(const std::vector<int> &shape) {
|
||||
size_t dim0 = shape.size();
|
||||
size_t dim1 = output_shape_.size();
|
||||
if (dim1 >= unKnownRank_) {
|
||||
output_shape_ = shape;
|
||||
return RET_OK;
|
||||
}
|
||||
if (dim1 != dim0) {
|
||||
MS_LOG(ERROR) << "shape.size():" << dim1 << " must be equal output_shape_.size():" << dim0;
|
||||
return RET_ERROR;
|
||||
}
|
||||
for (size_t i = 0; i < dim0; ++i) {
|
||||
int dim0_size = shape[i];
|
||||
int dim1_size = output_shape_[i];
|
||||
if (dim0_size >= 0 && dim1_size >= 0 && dim0_size != dim1_size) {
|
||||
MS_LOG(ERROR) << "shape[" << i << "]:" << dim0_size << " is incompatible with output_shape_[" << i
|
||||
<< "]:" << dim1_size;
|
||||
return RET_ERROR;
|
||||
}
|
||||
output_shape_[i] = dim1_size >= 0 ? dim1_size : dim0_size;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,122 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "include/errorcode.h"
|
||||
#include "include/ms_tensor.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/runtime/kernel/arm/fp32/TensorListSetItem.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_NULL_PTR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_TensorListSetItem;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
|
||||
int TensorListSetItemCPUKernel::Init() {
|
||||
input0_ = reinterpret_cast<lite::TensorList *>(in_tensors_[0]);
|
||||
if (dtype_ != input0_->data_type()) {
|
||||
MS_LOG(ERROR) << "op dtype:" << dtype_ << " is not equal in_tensors[0] dtype:" << input0_->data_type();
|
||||
return RET_ERROR;
|
||||
}
|
||||
int dim0 = input0_->ElementsNum() - 1;
|
||||
if (in_tensors_[1]->data_type() != kNumberTypeInt) {
|
||||
MS_LOG(ERROR) << "in_tensors_[1]->data_type():" << in_tensors_[1]->data_type()
|
||||
<< " must be equal to \"kNumberTypeInt\":" << kNumberTypeInt;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (in_tensors_[1]->ElementsNum() != 1) {
|
||||
MS_LOG(ERROR) << "in_tensors_[1]->ElementsNum():" << in_tensors_[1]->ElementsNum() << " must be equal to 1!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
index_ = reinterpret_cast<int *>(in_tensors_[1]->data_c())[0];
|
||||
if (index_ < 0 || index_ > dim0) {
|
||||
MS_LOG(ERROR) << "index tensor:[" << index_ << "] must be in [0, " << dim0 << "]!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
input2_ = in_tensors_[2];
|
||||
MS_ASSERT(input2_ != nullptr);
|
||||
if (!input0_->IsCompatibleShape(input2_->shape())) {
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int TensorListSetItemCPUKernel::Run() {
|
||||
output0_ = reinterpret_cast<lite::TensorList *>(out_tensors_[0]);
|
||||
MS_ASSERT(output0_ != nullptr);
|
||||
// copy each tensor in tensors_
|
||||
for (int i = 0; i < output0_->ElementsNum(); ++i) {
|
||||
auto dst = output0_->GetTensorIndex(i);
|
||||
MS_ASSERT(dst != nullptr);
|
||||
auto src = input0_->GetTensorIndex(i);
|
||||
if (i == index_) {
|
||||
// copy input2_ data buff
|
||||
src = input2_;
|
||||
}
|
||||
MS_ASSERT(src != nullptr);
|
||||
if (src->data_type() != kTypeUnknown) {
|
||||
if (src->Size() != dst->Size()) {
|
||||
MS_LOG(ERROR) << "src->Size():" << src->Size() << " must be equal to dst->Size():" << dst->Size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto ret = dst->CopyTensorData(*src);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "CopyTensorData[" << i << "] is failed!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int TensorListSetItemCPUKernel::ReSize() { return RET_OK; }
|
||||
|
||||
kernel::LiteKernel *CpuTensorListSetItemFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs,
|
||||
OpParameter *op_parameter, const lite::InnerContext *ctx,
|
||||
const kernel::KernelKey &desc,
|
||||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
if (op_parameter == nullptr) {
|
||||
MS_LOG(ERROR) << "Input op_parameter is nullptr!";
|
||||
return nullptr;
|
||||
}
|
||||
if (ctx == nullptr) {
|
||||
MS_LOG(ERROR) << "Input context is nullptr!";
|
||||
free(op_parameter);
|
||||
return nullptr;
|
||||
}
|
||||
MS_ASSERT(desc.type == schema::PrimitiveType_TensorListSetItem);
|
||||
auto *kernel = new (std::nothrow) TensorListSetItemCPUKernel(op_parameter, inputs, outputs, ctx, primitive);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "new TensorListSetItemCPUKernel fail!";
|
||||
free(op_parameter);
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init kernel failed! name: " << op_parameter->name_ << ", type: "
|
||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_));
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_TensorListSetItem, CpuTensorListSetItemFp32KernelCreator)
|
||||
} // namespace mindspore::kernel
|
@ -0,0 +1,49 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_TENSORLISTSETITEM_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_TENSORLISTSETITEM_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "src/tensorlist.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "nnacl/tensorlist_parameter.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class TensorListSetItemCPUKernel : public LiteKernel {
|
||||
public:
|
||||
TensorListSetItemCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive),
|
||||
dtype_(reinterpret_cast<TensorListParameter *>(parameter)->element_dtype_) {}
|
||||
~TensorListSetItemCPUKernel() = default;
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
|
||||
private:
|
||||
lite::TensorList *input0_ = nullptr;
|
||||
lite::Tensor *input2_ = nullptr;
|
||||
lite::TensorList *output0_ = nullptr;
|
||||
int index_ = 0;
|
||||
TypeId dtype_ = kTypeUnknown;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_TENSORLISTSETITEM_H_
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue