parent
2645ed3c90
commit
401d42a103
@ -0,0 +1,124 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/ops/nonzero.h"
|
||||
#include <algorithm>
|
||||
#include "include/errorcode.h"
|
||||
#include "src/common/log_adapter.h"
|
||||
#include "src/tensor.h"
|
||||
#ifndef PRIMITIVE_WRITEABLE
|
||||
#include "src/ops/ops_register.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
int NonZero::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
|
||||
if (this->primitive_ == nullptr) {
|
||||
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
|
||||
if (this->primitive_ == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
this->primitive_->value.type = schema::PrimitiveType_NonZero;
|
||||
}
|
||||
if (this->primitive_->value.type != schema::PrimitiveType_NonZero) {
|
||||
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
this->primitive_->value.value = new (std::nothrow) schema::NonZeroT();
|
||||
if (this->primitive_->value.value == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitiveT value failed";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
PopulaterQuantParam(prim, inputs);
|
||||
return RET_OK;
|
||||
}
|
||||
#else
|
||||
int NonZero::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
|
||||
MS_ASSERT(nullptr != primitive);
|
||||
MS_ASSERT(nullptr != fbb);
|
||||
auto attr = primitive->value_as_NonZero();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "value_as_NonZero return nullptr";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto val_offset = schema::CreateNonZero(*fbb);
|
||||
auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_NonZero, val_offset.o);
|
||||
fbb->Finish(prim_offset);
|
||||
return RET_OK;
|
||||
}
|
||||
PrimitiveC *NonZeroCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<NonZero>(primitive); }
|
||||
Registry NonZeroRegistry(schema::PrimitiveType_NonZero, NonZeroCreator);
|
||||
#endif
|
||||
template <typename T>
|
||||
void CalShape(const T *data, const std::vector<Tensor *> &inputs, std::vector<int> *out_shape) {
|
||||
int input_count = inputs[0]->ElementsNum();
|
||||
int input_dim_size = inputs[0]->shape().empty() ? 1 : inputs[0]->shape().size();
|
||||
(*out_shape)[0] = input_dim_size;
|
||||
int nonzero_size = 0;
|
||||
for (int i = 0; i < input_count; i++) {
|
||||
if (static_cast<int>(data[i]) != 0) {
|
||||
nonzero_size++;
|
||||
}
|
||||
}
|
||||
if (nonzero_size == 0) {
|
||||
*out_shape = {};
|
||||
} else {
|
||||
(*out_shape)[1] = nonzero_size / input_dim_size;
|
||||
}
|
||||
}
|
||||
int NonZero::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
|
||||
MS_ASSERT(this->primitive_ != nullptr);
|
||||
MS_ASSERT(inputs_.size() == 1);
|
||||
auto input = inputs_.front();
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto output = outputs_.front();
|
||||
MS_ASSERT(output != nullptr);
|
||||
output->set_data_type(input->data_type());
|
||||
output->set_format(input->format());
|
||||
if (!infer_flag()) {
|
||||
return RET_INFER_INVALID;
|
||||
}
|
||||
|
||||
std::vector<int> out_shape;
|
||||
if (inputs_.size() == kSingleNum) {
|
||||
auto input_tensor = inputs_.at(0);
|
||||
if (input_tensor->data_c() == nullptr) {
|
||||
MS_LOG(INFO) << "Do infer shape in runtime.";
|
||||
return RET_INFER_INVALID;
|
||||
}
|
||||
switch (input_tensor->data_type()) {
|
||||
case kNumberTypeFloat: {
|
||||
auto data = reinterpret_cast<float *>(input_tensor->MutableData());
|
||||
CalShape<float>(data, inputs_, &out_shape);
|
||||
} break;
|
||||
default: {
|
||||
MS_LOG(ERROR) << "NonZero weight tensor has unsupported dataType: " << input_tensor->data_type();
|
||||
return RET_INFER_ERR;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MS_LOG(ERROR) << "inputs tensor size invalid.";
|
||||
return RET_INFER_ERR;
|
||||
}
|
||||
output->set_shape(out_shape);
|
||||
return RET_OK;
|
||||
}
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,45 @@
|
||||
/**
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_OPS_NONZERO_H_
|
||||
#define MINDSPORE_LITE_SRC_OPS_NONZERO_H_
|
||||
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
#include "src/ops/primitive_c.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class NonZero : public PrimitiveC {
|
||||
public:
|
||||
NonZero() = default;
|
||||
~NonZero() = default;
|
||||
#ifdef PRIMITIVE_WRITEABLE
|
||||
MS_DECLARE_PARENT(NonZero, PrimitiveC);
|
||||
explicit NonZero(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {}
|
||||
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
|
||||
#else
|
||||
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
|
||||
#endif
|
||||
int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_OPS_NONZERO_H_
|
@ -0,0 +1,105 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/runtime/kernel/arm/fp32/nonzero_fp32.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "nnacl/op_base.h"
|
||||
#include "schema/model_generated.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/runtime/runtime_api.h"
|
||||
#include "src/tensor.h"
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_NonZero;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int NonZeroCPUKernel::Init() {
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int NonZeroCPUKernel::ReSize() { return RET_OK; }
|
||||
int NonZeroCPUKernel::Run() {
|
||||
auto in_tensor = in_tensors_.front();
|
||||
auto out_tensor = out_tensors_.front();
|
||||
auto input_data = reinterpret_cast<float *>(in_tensor->MutableData());
|
||||
auto output_data = reinterpret_cast<int *>(out_tensor->MutableData());
|
||||
auto input_dim_size = in_tensor->shape().size();
|
||||
if (out_tensor->shape().size() != 2) {
|
||||
MS_LOG(ERROR) << "out tensor shape size must be equal to 2!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto non_zero_nums = out_tensor->shape()[1];
|
||||
int non_zero_count = 0;
|
||||
std::vector coordiate_values(in_tensor->shape().size(), 0);
|
||||
for (int i = 0; i < in_tensor->ElementsNum(); i += 1) {
|
||||
if (input_data[i] != 0) {
|
||||
for (size_t j = 0; j < input_dim_size; j++) {
|
||||
output_data[non_zero_count + j * non_zero_nums] = coordiate_values[j];
|
||||
}
|
||||
non_zero_count++;
|
||||
}
|
||||
for (int idx = input_dim_size - 1; idx >= 0; --idx) {
|
||||
if (coordiate_values[idx] != in_tensor->shape()[idx] - 1) {
|
||||
coordiate_values[idx] = coordiate_values[idx] + 1;
|
||||
break;
|
||||
}
|
||||
coordiate_values[idx] = 0;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *CpuNonZeroFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, OpParameter *opParameter,
|
||||
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
|
||||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
if (opParameter == nullptr) {
|
||||
MS_LOG(ERROR) << "Input opParameter is nullptr!";
|
||||
return nullptr;
|
||||
}
|
||||
if (ctx == nullptr) {
|
||||
MS_LOG(ERROR) << "Input context is nullptr!";
|
||||
free(opParameter);
|
||||
return nullptr;
|
||||
}
|
||||
if (ctx->thread_num_ == 0) {
|
||||
MS_LOG(ERROR) << "context thread num is 0!";
|
||||
free(opParameter);
|
||||
return nullptr;
|
||||
}
|
||||
auto *kernel = new (std::nothrow) NonZeroCPUKernel(opParameter, inputs, outputs, ctx, primitive);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "new NonZeroCPUKernel fail!";
|
||||
free(opParameter);
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: "
|
||||
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_NonZero, CpuNonZeroFp32KernelCreator)
|
||||
} // namespace mindspore::kernel
|
@ -0,0 +1,41 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NONZERO_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NONZERO_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class NonZeroCPUKernel : public LiteKernel {
|
||||
public:
|
||||
NonZeroCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
|
||||
|
||||
~NonZeroCPUKernel() = default;
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
|
||||
protected:
|
||||
int thread_count_ = 1;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_NONZERO_H_
|
@ -0,0 +1,42 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_if_parser.h"
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_model_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
lite::PrimitiveC *OnnxIfParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx IfParser";
|
||||
auto attr = std::make_unique<schema::IfT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_If;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
OnnxNodeRegistrar g_onnxIfParser("If", new OnnxIfParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,34 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_IF_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_IF_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class OnnxIfParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxIfParser() : OnnxNodeParser("If") {}
|
||||
~OnnxIfParser() override = default;
|
||||
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_IF_PARSER_H
|
@ -0,0 +1,42 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_loop_parser.h"
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_model_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
lite::PrimitiveC *OnnxLoopParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx LoopParser";
|
||||
auto attr = std::make_unique<schema::WhileT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_While;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
OnnxNodeRegistrar g_onnxLoopParser("Loop", new OnnxLoopParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,34 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_LOOP_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_LOOP_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class OnnxLoopParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxLoopParser() : OnnxNodeParser("Loop") {}
|
||||
~OnnxLoopParser() override = default;
|
||||
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_LOOP_PARSER_H
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,42 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_nonzero_parser.h"
|
||||
#include <memory>
|
||||
#include "tools/converter/parser/onnx/onnx_model_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
lite::PrimitiveC *OnnxNonZeroParser::ParseLitePrimitive(const onnx::GraphProto &onnx_graph,
|
||||
const onnx::NodeProto &onnx_node) {
|
||||
MS_LOG(DEBUG) << "onnx NonZeroParser";
|
||||
auto attr = std::make_unique<schema::NonZeroT>();
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "new op failed";
|
||||
return nullptr;
|
||||
}
|
||||
auto primitive = std::make_unique<schema::PrimitiveT>();
|
||||
if (primitive == nullptr) {
|
||||
MS_LOG(ERROR) << "new primitive failed";
|
||||
return nullptr;
|
||||
}
|
||||
primitive->value.type = schema::PrimitiveType_NonZero;
|
||||
primitive->value.value = attr.release();
|
||||
return PrimitiveC::Create(primitive.release());
|
||||
}
|
||||
OnnxNodeRegistrar g_onnxNonZeroParser("NonZero", new OnnxNonZeroParser());
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
@ -0,0 +1,34 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NONZERO_PARSER_H
|
||||
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NONZERO_PARSER_H
|
||||
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser.h"
|
||||
#include "tools/converter/parser/onnx/onnx_node_parser_registry.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class OnnxNonZeroParser : public OnnxNodeParser {
|
||||
public:
|
||||
OnnxNonZeroParser() : OnnxNodeParser("NonZero") {}
|
||||
~OnnxNonZeroParser() override = default;
|
||||
|
||||
lite::PrimitiveC *ParseLitePrimitive(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node) override;
|
||||
};
|
||||
} // namespace lite
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_NONZERO_PARSER_H
|
Loading…
Reference in new issue