!7250 Add LshProjection lite ops

Merge pull request !7250 from liuwenhao/master
pull/7250/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 97159386ae

@ -0,0 +1,35 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_LSH_PROJECTION_PARAMETER_H_
#define MINDSPORE_LITE_NNACL_LSH_PROJECTION_PARAMETER_H_
#include "nnacl/op_base.h"
typedef struct LshProjectionParameter {
OpParameter op_parameter_;
int lsh_type_;
int hash_shape_[2];
int in_item_num_;
size_t in_item_size_;
size_t seed_size_;
size_t key_size_;
int64_t real_dst_count;
int task_id_;
int64_t count_unit_;
} LshProjectionParameter;
#endif // MINDSPORE_LITE_NNACL_LSH_PROJECTION_PARAMETER_H_

@ -14,12 +14,16 @@
* limitations under the License.
*/
#include "src/ops/lsh_projection.h"
#include "nnacl/lsh_projection_parameter.h"
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int LshProjection::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) { return RET_OK; }
int LshProjection::GetLshType() const { return this->primitive_->value.AsLshProjection()->type; }
#else
int LshProjection::GetLshType() const { return this->primitive_->value_as_LshProjection()->type(); }
int LshProjection::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);
MS_ASSERT(nullptr != fbb);
@ -29,9 +33,51 @@ int LshProjection::UnPackToFlatBuilder(const schema::Primitive *primitive, flatb
return RET_OK;
}
#endif
namespace {
constexpr int kSparseType = 1;
constexpr int kDenseType = 2;
} // namespace
int LshProjection::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
PrimitiveC::InferShape(inputs_, outputs_);
return RET_INFER_INVALID;
if (inputs_.size() != kDoubleNum || inputs_.size() != kMultiNum) {
MS_LOG(ERROR) << "inputs to LshProjection operator should be 2 or 3, but " << inputs_.size() << " is given.";
return RET_ERROR;
}
if (outputs_.size() != kSingleNum) {
MS_LOG(ERROR) << "outputs to Shape operator should be 1, but " << outputs_.size() << " is given.";
return RET_ERROR;
}
auto in_hash = inputs_.at(kSingleNum);
MS_ASSERT(in_hash->shape().size() == 2);
MS_ASSERT(in_hash->DimensionSize(1) <= 32);
MS_ASSERT(inputs_.at(kDoubleNum)->shape().size() >= 1);
if (inputs_.size() == kMultiNum) {
MS_ASSERT(inputs_.at(kMultiNum)->shape().size() == 1);
MS_ASSERT(inputs_.at(kMultiNum)->DimensionSize(0) == in_value->DimensionSize(0));
}
auto out_tensor = outputs_.front();
out_tensor->set_data_type(kNumberTypeInt32);
out_tensor->SetFormat(schema::Format::Format_NHWC);
if (!GetInferFlag()) {
return RET_OK;
}
std::vector<int> out_shape;
switch (GetLshType()) {
case kSparseType:
out_shape.push_back(in_hash->DimensionSize(0));
break;
case kDenseType:
out_shape.push_back(in_hash->DimensionSize(0) * in_hash->DimensionSize(1));
break;
default:
return RET_ERROR;
}
out_tensor->set_shape(out_shape);
return RET_OK;
}
} // namespace lite
} // namespace mindspore

@ -33,6 +33,7 @@ class LshProjection : public PrimitiveC {
int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override;
#endif
int InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) override;
int GetLshType() const;
};
} // namespace lite
} // namespace mindspore

@ -54,6 +54,7 @@
#include "src/ops/resize.h"
#include "src/ops/tile.h"
#include "src/ops/one_hot.h"
#include "src/ops/lsh_projection.h"
#include "src/ops/space_to_depth.h"
#include "src/ops/split.h"
#include "src/ops/argmax.h"
@ -131,6 +132,7 @@
#include "nnacl/unstack.h"
#include "nnacl/depth_to_space.h"
#include "nnacl/conv_parameter.h"
#include "nnacl/lsh_projection_parameter.h"
#include "nnacl/fp32/pooling.h"
#include "nnacl/matmul_parameter.h"
#include "nnacl/fp32/roi_pooling.h"
@ -1323,6 +1325,20 @@ OpParameter *PopulateCropParameter(const mindspore::lite::PrimitiveC *primitive)
return reinterpret_cast<OpParameter *>(crop_param);
}
OpParameter *PopulateLshProjectionParameter(const mindspore::lite::PrimitiveC *primitive) {
LshProjectionParameter *lsh_project_param =
reinterpret_cast<LshProjectionParameter *>(malloc(sizeof(LshProjectionParameter)));
if (lsh_project_param == nullptr) {
MS_LOG(ERROR) << "malloc LshProjectionParameter failed.";
return nullptr;
}
memset(lsh_project_param, 0, sizeof(LshProjectionParameter));
lsh_project_param->op_parameter_.type_ = primitive->Type();
auto param = reinterpret_cast<mindspore::lite::LshProjection *>(const_cast<mindspore::lite::PrimitiveC *>(primitive));
lsh_project_param->lsh_type_ = param->GetLshType();
return reinterpret_cast<OpParameter *>(lsh_project_param);
}
OpParameter *PopulateOneHotParameter(const mindspore::lite::PrimitiveC *primitive) {
OneHotParameter *one_hot_param = reinterpret_cast<OneHotParameter *>(malloc(sizeof(OneHotParameter)));
if (one_hot_param == nullptr) {
@ -1747,6 +1763,7 @@ PopulateParameterRegistry::PopulateParameterRegistry() {
populate_parameter_funcs_[schema::PrimitiveType_CustomExtractFeatures] = PopulateCommonOpParameter;
populate_parameter_funcs_[schema::PrimitiveType_CustomPredict] = PopulateCustomPredictParameter;
populate_parameter_funcs_[schema::PrimitiveType_HashtableLookup] = PopulateCommonOpParameter;
populate_parameter_funcs_[schema::PrimitiveType_LshProjection] = PopulateLshProjectionParameter;
}
PopulateParameterRegistry *PopulateParameterRegistry::GetInstance() {

@ -0,0 +1,184 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp32/lsh_projection.h"
#include "include/errorcode.h"
#include "src/kernel_registry.h"
#include "src/runtime/runtime_api.h"
#include "src/common/string_util.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_MEMORY_FAILED;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_LshProjection;
namespace mindspore::kernel {
namespace {
constexpr int kSparseType = 1;
constexpr int kDenseType = 2;
} // namespace
int LshProjectionCPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int LshProjectionCPUKernel::ReSize() { return RET_OK; }
int LshProjectionCPUKernel::Run() {
auto ret = Prepare();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail!ret: " << ret;
return ret;
}
auto input_tensor0 = in_tensors_.at(0);
auto input_tensor1 = in_tensors_.at(1);
auto out_tensor0 = out_tensors_.at(0);
hash = reinterpret_cast<float *>(input_tensor0->MutableData());
in_data = reinterpret_cast<char *>(input_tensor1->MutableData());
weight = in_tensors_.size() == 2 ? nullptr : reinterpret_cast<float *>(in_tensors_.at(2)->MutableData());
output = reinterpret_cast<int32_t *>(out_tensor0->MutableData());
const size_t seed_size = sizeof(float);
const size_t input_item_size =
input_tensor1->ElementsNum() * sizeof(input_tensor1->data_type()) / input_tensor1->DimensionSize(0);
const size_t key_size = seed_size + input_item_size;
lsh_param_->seed_size_ = seed_size;
lsh_param_->in_item_size_ = input_item_size;
lsh_param_->key_size_ = key_size;
lsh_param_->in_item_num_ = input_tensor1->DimensionSize(0);
memcpy(lsh_param_->hash_shape_, input_tensor0->shape().data(), sizeof(int) * input_tensor0->shape().size());
elements_num_ = input_tensor0->DimensionSize(0);
count_unit_ = thread_num_ > 1 ? UP_DIV(elements_num_, thread_num_) : elements_num_;
ret = ParallelLaunch(this->context_->thread_pool_, LshProjectionRun, this, thread_num_);
return ret;
}
int LshProjectionRun(void *cdata, int task_id) {
auto lsh_projection = reinterpret_cast<LshProjectionCPUKernel *>(cdata);
lsh_projection->DoExecute(task_id);
return RET_OK;
}
int LshProjectionCPUKernel::DoExecute(int task_id) {
int64_t real_dst_count = MSMIN(elements_num_ - task_id * count_unit_, count_unit_);
lsh_param_->real_dst_count = real_dst_count;
lsh_param_->task_id_ = task_id;
lsh_param_->count_unit_ = count_unit_;
if (real_dst_count <= 0) {
return lite::RET_OK;
}
switch (lsh_param_->lsh_type_) {
case kSparseType:
LshProjectionSparse(hash, in_data, weight, output, lsh_param_);
break;
case kDenseType:
LshProjectionDense(hash, in_data, weight, output, lsh_param_);
break;
default:
return RET_ERROR;
}
return RET_OK;
}
int LshProjectionCPUKernel::GetSignBit(char *in_data, float *weight, float seed, LshProjectionParameter *para) {
double score = 0.0;
for (int i = 0; i < para->in_item_num_; i++) {
char *key = static_cast<char *>(ctx_->allocator->Malloc(lsh_param_->key_size_));
if (key == nullptr) {
MS_LOG(ERROR) << "malloc key failed.";
return RET_ERROR;
}
memcpy(key, &seed, para->seed_size_);
memcpy(key + para->seed_size_, in_data, para->in_item_size_);
in_data += para->in_item_size_;
double hash_sign = static_cast<double>(mindspore::lite::StringHash64(key, para->key_size_));
if (weight == nullptr) {
score += hash_sign;
} else {
score += weight[i] * hash_sign;
}
ctx_->allocator->Free(key);
}
return (score > 0) ? 1 : 0;
}
void LshProjectionCPUKernel::LshProjectionSparse(float *hash, char *in_data, float *weight, int32_t *output,
LshProjectionParameter *para) {
int start = para->task_id_ * para->count_unit_;
int end = start + para->real_dst_count;
for (int i = start; i < end; i++) {
int32_t hash_sign = 0;
for (int j = 0; j < para->hash_shape_[1]; j++) {
int bit = GetSignBit(in_data, weight, hash[i * para->hash_shape_[1] + j], para);
hash_sign = (hash_sign << 1) | bit;
}
output[i] = hash_sign + i * (1 << para->hash_shape_[1]);
}
}
void LshProjectionCPUKernel::LshProjectionDense(float *hash, char *in_data, float *weight, int32_t *output,
LshProjectionParameter *para) {
int start = para->task_id_ * para->count_unit_;
int end = start + para->real_dst_count;
for (int i = start; i < end; i++) {
for (int j = 0; j < para->hash_shape_[1]; j++) {
output[i * para->hash_shape_[1] + j] = GetSignBit(in_data, weight, hash[i * para->hash_shape_[1] + j], para);
}
}
}
kernel::LiteKernel *CpuLshProjectionFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs,
OpParameter *op_parameter, const lite::InnerContext *ctx,
const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
if (op_parameter == nullptr) {
MS_LOG(ERROR) << "Input op_parameter is nullptr!";
return nullptr;
}
if (ctx == nullptr) {
MS_LOG(ERROR) << "Input context is nullptr!";
return nullptr;
}
MS_ASSERT(desc.type == schema::PrimitiveType_LshProjection);
auto *kernel = new (std::nothrow) LshProjectionCPUKernel(op_parameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new LshProjectionCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed! name: " << op_parameter->name_ << ", type: "
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(op_parameter->type_));
delete kernel;
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LshProjection, CpuLshProjectionFp32KernelCreator)
} // namespace mindspore::kernel

@ -0,0 +1,61 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSH_PROJECTION_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSH_PROJECTION_H_
#include <vector>
#include "nnacl/lsh_projection_parameter.h"
#include "src/lite_kernel.h"
#include "schema/model_generated.h"
namespace mindspore::kernel {
class LshProjectionCPUKernel : public LiteKernel {
public:
LshProjectionCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_num_(ctx->thread_num_) {
lsh_param_ = reinterpret_cast<LshProjectionParameter *>(op_parameter_);
}
~LshProjectionCPUKernel() = default;
int Init() override;
int ReSize() override;
int Run() override;
int DoExecute(int task_id);
int GetSignBit(char *in_data, float *weight, float seed, LshProjectionParameter *para);
void LshProjectionSparse(float *hash, char *in_data, float *weight, int32_t *output, LshProjectionParameter *param);
void LshProjectionDense(float *hash, char *in_data, float *weight, int32_t *output, LshProjectionParameter *param);
private:
LshProjectionParameter *lsh_param_ = nullptr;
const lite::InnerContext *ctx_;
int thread_num_;
int64_t elements_num_;
int64_t count_unit_;
float *hash = nullptr;
char *in_data = nullptr;
float *weight = nullptr;
int32_t *output = nullptr;
};
int LshProjectionRun(void *cdata, int task_id);
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_LSH_PROJECTION_H_

@ -0,0 +1,164 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include "schema/inner/model_generated.h"
#include "src/common/log_adapter.h"
#include "common/common_test.h"
#include "mindspore/lite/nnacl/lsh_projection_parameter.h"
#include "mindspore/lite/src/kernel_registry.h"
#include "mindspore/lite/src/lite_kernel.h"
#include "mindspore/lite/src/tensor.h"
namespace mindspore {
namespace {
constexpr int kSparseType = 1;
constexpr int kDenseType = 2;
} // namespace
class TestLshProjectionFp32 : public mindspore::CommonTest {
public:
TestLshProjectionFp32() {}
};
TEST_F(TestLshProjectionFp32, Dense1DInputs) {
lite::Tensor in_tensor0(kNumberTypeFloat, {3, 2});
lite::Tensor in_tensor1(kNumberTypeInt32, {5});
lite::Tensor in_tensor2(kNumberTypeFloat, {5});
lite::Tensor out_tensor(kNumberTypeInt32, {6});
float input_data0[] = {0.123, 0.456, -0.321, 1.234, 5.678, -4.321};
int32_t input_data1[] = {12345, 54321, 67890, 9876, -12345678};
float input_data2[] = {1.0, 1.0, 1.0, 1.0, 1.0};
int32_t output_data[6] = {0};
in_tensor0.SetData(input_data0);
in_tensor1.SetData(input_data1);
in_tensor2.SetData(input_data2);
out_tensor.SetData(output_data);
std::vector<lite::Tensor *> inputs = {&in_tensor0, &in_tensor1, &in_tensor2};
std::vector<lite::Tensor *> outputs = {&out_tensor};
LshProjectionParameter parameter = {};
parameter.lsh_type_ = kDenseType;
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_LshProjection};
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
ASSERT_NE(creator, nullptr);
auto ctx = std::make_shared<lite::InnerContext>();
ctx->thread_num_ = 3;
ASSERT_EQ(lite::RET_OK, ctx->Init());
auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(&parameter), ctx.get(), desc, nullptr);
ASSERT_NE(kernel, nullptr);
auto ret = kernel->Run();
EXPECT_EQ(0, ret);
std::vector<int32_t> except_result = {0, 0, 0, 1, 0, 0};
PrintData("output data", output_data, 6);
CompareOutputData(output_data, except_result.data(), 6, 0.000001);
in_tensor0.SetData(nullptr);
in_tensor1.SetData(nullptr);
out_tensor.SetData(nullptr);
}
TEST_F(TestLshProjectionFp32, Sparse1DInputs) {
lite::Tensor in_tensor0(kNumberTypeFloat, {3, 2});
lite::Tensor in_tensor1(kNumberTypeInt32, {5});
lite::Tensor out_tensor(kNumberTypeInt32, {3});
float input_data0[] = {0.123, 0.456, -0.321, 1.234, 5.678, -4.321};
int32_t input_data1[] = {12345, 54321, 67890, 9876, -12345678};
int32_t output_data[3] = {0};
in_tensor0.SetData(input_data0);
in_tensor1.SetData(input_data1);
out_tensor.SetData(output_data);
std::vector<lite::Tensor *> inputs = {&in_tensor0, &in_tensor1};
std::vector<lite::Tensor *> outputs = {&out_tensor};
LshProjectionParameter parameter = {};
parameter.lsh_type_ = kSparseType;
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_LshProjection};
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
ASSERT_NE(creator, nullptr);
auto ctx = std::make_shared<lite::InnerContext>();
ctx->thread_num_ = 1;
ASSERT_EQ(lite::RET_OK, ctx->Init());
auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(&parameter), ctx.get(), desc, nullptr);
ASSERT_NE(kernel, nullptr);
auto ret = kernel->Run();
EXPECT_EQ(0, ret);
std::vector<int32_t> except_result = {0, 5, 8};
PrintData("output data", output_data, 3);
CompareOutputData(output_data, except_result.data(), 3, 0.000001);
in_tensor0.SetData(nullptr);
in_tensor1.SetData(nullptr);
out_tensor.SetData(nullptr);
}
TEST_F(TestLshProjectionFp32, Sparse3DInputs) {
lite::Tensor in_tensor0(kNumberTypeFloat, {3, 2});
lite::Tensor in_tensor1(kNumberTypeInt32, {5, 2, 2});
lite::Tensor in_tensor2(kNumberTypeFloat, {5});
lite::Tensor out_tensor(kNumberTypeInt32, {3});
float input_data0[] = {0.123, 0.456, -0.321, 1.234, 5.678, -4.321};
int32_t input_data1[] = {1234, 2345, 3456, 1234, 4567, 5678, 6789, 4567, 7891, 8912,
9123, 7890, -987, -876, -765, -987, -543, -432, -321, -543};
float input_data2[] = {0.12, 0.34, 0.56, 0.67, 0.78};
int32_t output_data[3] = {0};
in_tensor0.SetData(input_data0);
in_tensor1.SetData(input_data1);
in_tensor2.SetData(input_data2);
out_tensor.SetData(output_data);
std::vector<lite::Tensor *> inputs = {&in_tensor0, &in_tensor1, &in_tensor2};
std::vector<lite::Tensor *> outputs = {&out_tensor};
LshProjectionParameter parameter = {};
parameter.lsh_type_ = kSparseType;
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeFloat32, schema::PrimitiveType_LshProjection};
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);
ASSERT_NE(creator, nullptr);
auto ctx = std::make_shared<lite::InnerContext>();
ctx->thread_num_ = 3;
ASSERT_EQ(lite::RET_OK, ctx->Init());
auto kernel = creator(inputs, outputs, reinterpret_cast<OpParameter *>(&parameter), ctx.get(), desc, nullptr);
ASSERT_NE(kernel, nullptr);
auto ret = kernel->Run();
EXPECT_EQ(0, ret);
std::vector<int32_t> except_result = {2, 5, 9};
PrintData("output data", output_data, 3);
CompareOutputData(output_data, except_result.data(), 3, 0.000001);
in_tensor0.SetData(nullptr);
in_tensor1.SetData(nullptr);
out_tensor.SetData(nullptr);
}
} // namespace mindspore

@ -290,7 +290,7 @@ TEST_F(TestMulInt8, Mul_quant1_thread1) {
MulParameter op_param;
op_param.op_parameter_.type_ = schema::PrimitiveType_Mul;
lite::InnerContext *ctx = new lite::InnerContext;
ctx->thread_num_ = 2;
ctx->thread_num_ = 3;
ASSERT_EQ(lite::RET_OK, ctx->Init());
kernel::KernelKey desc = {kernel::KERNEL_ARCH::kCPU, kNumberTypeInt8, schema::PrimitiveType_Mul};
auto creator = lite::KernelRegistry::GetInstance()->GetCreator(desc);

Loading…
Cancel
Save