diff --git a/mindspore/lite/src/ops/populate/select_populate.cc b/mindspore/lite/src/ops/populate/select_populate.cc new file mode 100644 index 0000000000..c93674a9ae --- /dev/null +++ b/mindspore/lite/src/ops/populate/select_populate.cc @@ -0,0 +1,36 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "src/ops/select.h" +#include "src/ops/primitive_c.h" +#include "src/ops/populate/populate_register.h" + +namespace mindspore { +namespace lite { +OpParameter *PopulateSelectParameter(const mindspore::lite::PrimitiveC *primitive) { + OpParameter *select_parameter = reinterpret_cast(malloc(sizeof(OpParameter))); + if (select_parameter == nullptr) { + MS_LOG(ERROR) << "malloc SelectParameter failed."; + return nullptr; + } + memset(select_parameter, 0, sizeof(OpParameter)); + select_parameter->type_ = primitive->Type(); + + return reinterpret_cast(select_parameter); +} +Registry SelectParameterRegistry(schema::PrimitiveType_Select, PopulateSelectParameter); +} // namespace lite +} // namespace mindspore diff --git a/mindspore/lite/src/ops/select.cc b/mindspore/lite/src/ops/select.cc new file mode 100644 index 0000000000..c5f4bb4dda --- /dev/null +++ b/mindspore/lite/src/ops/select.cc @@ -0,0 +1,106 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "src/ops/select.h" + +#ifndef PRIMITIVE_WRITEABLE +#include "src/ops/ops_register.h" +#endif +#include "src/tensorlist.h" + +namespace mindspore { +namespace lite { +#ifdef PRIMITIVE_WRITEABLE +int Select::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_Select; + } + if (this->primitive_->value.type != schema::PrimitiveType_Select) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::SelectT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + + this->primitive_->value.value = attr; + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "primitive value is nullptr"; + return RET_ERROR; + } + } + return RET_OK; +} +#else +int Select::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { + MS_ASSERT(nullptr != primitive); + MS_ASSERT(nullptr != fbb); + auto attr = primitive->value_as_Select(); + if (attr == nullptr) { + MS_LOG(ERROR) << "value_as_Select return nullptr"; + return RET_ERROR; + } + auto val_offset = schema::CreateSelect(*fbb); + auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_Select, val_offset.o); + fbb->Finish(prim_offset); + return RET_OK; +} + +PrimitiveC *SelectCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC