!7279 [MSLITE][Develop] add HashtableLookup kernel
Merge pull request !7279 from sunsuodong/hashtable_lookuppull/7279/MERGE
commit
438cd08016
@ -0,0 +1,96 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "src/runtime/kernel/arm/string/hashtable_lookup.h"
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/common/string_util.h"
|
||||
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::schema::PrimitiveType_HashtableLookup;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int HashtableLookupCPUKernel::Init() {
|
||||
if (!InferShapeDone()) {
|
||||
return RET_OK;
|
||||
}
|
||||
return ReSize();
|
||||
}
|
||||
|
||||
int HashtableLookupCPUKernel::ReSize() { return RET_OK; }
|
||||
|
||||
static int CmpKeyFunc(const void *lhs, const void *rhs) {
|
||||
return *static_cast<const int *>(lhs) - *static_cast<const int *>(rhs);
|
||||
}
|
||||
|
||||
int HashtableLookupCPUKernel::Run() {
|
||||
auto ret = Prepare();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Prepare fail! Ret error code: " << ret;
|
||||
return ret;
|
||||
}
|
||||
auto input_tensor = in_tensors_.at(0);
|
||||
auto keys_tensor = in_tensors_.at(1);
|
||||
auto values_tensor = in_tensors_.at(2);
|
||||
auto output_tensor = out_tensors_.at(0);
|
||||
auto hits_tensor = out_tensors_.at(1);
|
||||
|
||||
int rows = values_tensor->DimensionSize(0);
|
||||
int32_t *input_data = reinterpret_cast<int32_t *>(input_tensor->MutableData());
|
||||
uint8_t *hits_data = reinterpret_cast<uint8_t *>(hits_tensor->MutableData());
|
||||
std::vector<lite::StringPack> output_string_pack;
|
||||
std::vector<lite::StringPack> all_string_pack = ParseTensorBuffer(input_tensor);
|
||||
|
||||
for (int i = 0; i < input_tensor->ElementsNum(); i++) {
|
||||
int index = -1;
|
||||
void *p = bsearch(&(input_data[i]), keys_tensor->MutableData(), rows, sizeof(int32_t), CmpKeyFunc);
|
||||
if (p != nullptr) {
|
||||
index = reinterpret_cast<int32_t *>(p) - reinterpret_cast<int32_t *>(keys_tensor->MutableData());
|
||||
}
|
||||
if (index >= rows || index < 0) {
|
||||
lite::StringPack tmp = {0, nullptr};
|
||||
output_string_pack.push_back(tmp);
|
||||
hits_data[i] = 0;
|
||||
} else {
|
||||
output_string_pack.push_back(all_string_pack[i]);
|
||||
hits_data[i] = 1;
|
||||
}
|
||||
}
|
||||
WriteStringsToTensor(output_tensor, output_string_pack);
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *CpuHashtableLookupKernelCreator(const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, OpParameter *parameter,
|
||||
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
|
||||
const mindspore::lite::PrimitiveC *primitive) {
|
||||
auto *kernel = new (std::nothrow) HashtableLookupCPUKernel(parameter, inputs, outputs, ctx, primitive);
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(ERROR) << "new HashtableLookupCPUKernel fail!";
|
||||
return nullptr;
|
||||
}
|
||||
auto ret = kernel->Init();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_
|
||||
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_));
|
||||
delete kernel;
|
||||
return nullptr;
|
||||
}
|
||||
return kernel;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_HashtableLookup, CpuHashtableLookupKernelCreator)
|
||||
} // namespace mindspore::kernel
|
@ -0,0 +1,38 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_HASHTABLE_LOOKUP_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_HASHTABLE_LOOKUP_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "include/context.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class HashtableLookupCPUKernel : public LiteKernel {
|
||||
public:
|
||||
HashtableLookupCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
|
||||
const mindspore::lite::PrimitiveC *primitive)
|
||||
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
|
||||
~HashtableLookupCPUKernel() {}
|
||||
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_HASHTABLE_LOOKUP_H_
|
Loading…
Reference in new issue