add extract_feature

pull/7124/head
sunsuodong 5 years ago
parent 7126e316bc
commit d2559d1111

@ -102,5 +102,148 @@ int WriteSeperatedStringsToTensor(Tensor *tensor, const std::vector<std::vector<
return RET_OK; return RET_OK;
} }
int GetStringCount(const void *data) { return *(static_cast<const int32_t *>(data)); }
int GetStringCount(Tensor *tensor) { return GetStringCount(tensor->MutableData()); }
// Some primes between 2^63 and 2^64
static uint64_t k0 = 0xc3a5c85c97cb3127ULL;
static uint64_t k1 = 0xb492b66fbe98f273ULL;
static uint64_t k2 = 0x9ae16a3b2f90404fULL;
uint64_t Fetch64Bit(const char *p) {
uint64_t result;
memcpy(&result, p, sizeof(uint64_t));
return __builtin_bswap64(result);
}
uint32_t Fetch32Bit(const char *p) {
uint32_t result;
memcpy(&result, p, sizeof(uint32_t));
return __builtin_bswap32(result);
}
uint64_t Rotate64(uint64_t value, int shift) {
return shift == 0 ? value : ((value >> shift) | (value << (64 - shift)));
}
uint64_t HashLen16(uint64_t u, uint64_t v, uint64_t multiple) {
uint64_t a = (u ^ v) * multiple;
a ^= (a >> 47);
uint64_t b = (v ^ a) * multiple;
b ^= (b >> 47);
b *= multiple;
return b;
}
uint64_t ShiftMix(uint64_t value) { return value ^ (value >> 47); }
uint64_t HashStringLen0to16(const char *s, size_t len) {
if (len >= 8) {
uint64_t mul = k2 + len * 2;
uint64_t a = Fetch64Bit(s) + k2;
uint64_t b = Fetch64Bit(s + len - 8);
uint64_t c = Rotate64(b, 37) * mul + a;
uint64_t d = (Rotate64(a, 25) + b) * mul;
return HashLen16(c, d, mul);
}
if (len >= 4) {
uint64_t mul = k2 + len * 2;
uint64_t a = Fetch32Bit(s);
return HashLen16(len + (a << 3), Fetch32Bit(s + len - 4), mul);
}
if (len > 0) {
uint8_t a = s[0];
uint8_t b = s[len >> 1];
uint8_t c = s[len - 1];
uint32_t y = static_cast<uint32_t>(a) + (static_cast<uint32_t>(b) << 8);
uint32_t z = len + (static_cast<uint32_t>(c) << 2);
return ShiftMix(y * k2 ^ z * k0) * k2;
}
return k2;
}
uint64_t HashStringLen17to32(const char *s, size_t len) {
uint64_t mul = k2 + len * 2;
uint64_t a = Fetch64Bit(s) * k1;
uint64_t b = Fetch64Bit(s + 8);
uint64_t c = Fetch64Bit(s + len - 8) * mul;
uint64_t d = Fetch64Bit(s + len - 16) * k2;
return HashLen16(Rotate64(a + b, 43) + Rotate64(c, 30) + d, a + Rotate64(b + k2, 18) + c, mul);
}
uint64_t HashStringLen33to64(const char *s, size_t len) {
uint64_t mul = k2 + len * 2;
uint64_t a = Fetch64Bit(s) * k2;
uint64_t b = Fetch64Bit(s + 8);
uint64_t c = Fetch64Bit(s + len - 8) * mul;
uint64_t d = Fetch64Bit(s + len - 16) * k2;
uint64_t y = Rotate64(a + b, 43) + Rotate64(c, 30) + d;
uint64_t z = HashLen16(y, a + Rotate64(b + k2, 18) + c, mul);
uint64_t e = Fetch64Bit(s + 16) * mul;
uint64_t f = Fetch64Bit(s + 24);
uint64_t g = (y + Fetch64Bit(s + len - 32)) * mul;
uint64_t h = (z + Fetch64Bit(s + len - 24)) * mul;
return HashLen16(Rotate64(e + f, 43) + Rotate64(g, 30) + h, e + Rotate64(f + a, 18) + g, mul);
}
std::pair<uint64_t, uint64_t> HashLen32WithSeeds(const char *s, uint64_t a, uint64_t b) {
a += Fetch64Bit(s);
b = Rotate64(b + a + Fetch64Bit(s + 24), 21);
uint64_t c = a;
a += Fetch64Bit(s + 8);
a += Fetch64Bit(s + 16);
b += Rotate64(a, 44);
return std::make_pair(a + Fetch64Bit(s + 24), b + c);
}
uint64_t StringHash64(const char *s, size_t len) {
uint64_t seed_value = 81;
if (len <= 16) {
return HashStringLen0to16(s, len);
} else if (len <= 32) {
return HashStringLen17to32(s, len);
} else if (len <= 64) {
return HashStringLen33to64(s, len);
}
uint64_t x = seed_value;
uint64_t y = seed_value * k1 + 113;
uint64_t tmp = y * k2 + 113;
uint64_t z = (tmp ^ (tmp >> 47)) * k2;
std::pair<uint64_t, uint64_t> v = std::make_pair(0, 0);
std::pair<uint64_t, uint64_t> w = std::make_pair(0, 0);
x = x * k2 + Fetch64Bit(s);
const char *end = s + ((len - 1) / 64) * 64;
const char *last64 = end + ((len - 1) & 63) - 63;
MS_ASSERT(s + len - 64 == last64);
do {
x = Rotate64(x + y + v.first + Fetch64Bit(s + 8), 37) * k1;
y = Rotate64(y + v.second + Fetch64Bit(s + 48), 42) * k1;
x ^= w.second;
y += v.first + Fetch64Bit(s + 40);
z = Rotate64(z + w.first, 33) * k1;
v = HashLen32WithSeeds(s, v.second * k1, x + w.first);
w = HashLen32WithSeeds(s + 32, z + w.second, y + Fetch64Bit(s + 16));
std::swap(z, x);
s += 64;
} while (s != end);
uint64_t mul = k1 + ((z & 0xff) << 1);
s = last64;
w.first += ((len - 1) & 63);
v.first += w.first;
w.first += v.first;
x = Rotate64(x + y + v.first + Fetch64Bit(s + 8), 37) * mul;
y = Rotate64(y + v.second + Fetch64Bit(s + 48), 42) * mul;
x ^= w.second * 9;
y += v.first * 9 + Fetch64Bit(s + 40);
z = Rotate64(z + w.first, 33) * mul;
v = HashLen32WithSeeds(s, v.second * mul, x + w.first);
w = HashLen32WithSeeds(s + 32, z + w.second, y + Fetch64Bit(s + 16));
std::swap(z, x);
return HashLen16(HashLen16(v.first, w.first, mul) + ShiftMix(y) * k0 + z, HashLen16(v.second, w.second, mul) + x,
mul);
}
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

@ -32,12 +32,23 @@ typedef struct {
const char *data; const char *data;
} StringPack; } StringPack;
// example of string tensor:
// 3, 0, 0, 0 # int32, num of strings
// 20, 0, 0, 0 # int32, offset of 0-th string
// 23, 0, 0, 0 # int32, offset of 0-th string
// 26, 0, 0, 0 # int32, offset of 0-th string
// 29, 0, 0, 0 # int32, total length of tensor data
// 'h', 'o', 'w', 'a', 'r', 'e', 'y', 'o', 'u' # char, how are you
std::vector<StringPack> ParseTensorBuffer(Tensor *tensor); std::vector<StringPack> ParseTensorBuffer(Tensor *tensor);
std::vector<StringPack> ParseStringBuffer(const void *data); std::vector<StringPack> ParseStringBuffer(const void *data);
int WriteStringsToTensor(Tensor *tensor, const std::vector<StringPack> &string_buffer); int WriteStringsToTensor(Tensor *tensor, const std::vector<StringPack> &string_buffer);
int WriteSeperatedStringsToTensor(Tensor *tensor, const std::vector<std::vector<StringPack>> &string_buffer); int WriteSeperatedStringsToTensor(Tensor *tensor, const std::vector<std::vector<StringPack>> &string_buffer);
int GetStringCount(const void *data);
int GetStringCount(Tensor *tensor);
uint64_t StringHash64(const char *s, size_t len);
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

@ -14,6 +14,7 @@
* limitations under the License. * limitations under the License.
*/ */
#include "src/ops/custom_extract_features.h" #include "src/ops/custom_extract_features.h"
#include "src/common/string_util.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
@ -30,9 +31,30 @@ int CustomExtractFeatures::UnPackToFlatBuilder(const schema::Primitive *primitiv
return RET_OK; return RET_OK;
} }
#endif #endif
int CustomExtractFeatures::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { int CustomExtractFeatures::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
PrimitiveC::InferShape(inputs_, outputs_); auto input = inputs_.at(0);
return RET_INFER_INVALID; MS_ASSERT(input != nullptr);
if (input->data_c() == nullptr) {
MS_LOG(INFO) << "Do infer shape in runtime.";
return RET_INFER_INVALID;
}
int string_num = lite::GetStringCount(input);
auto output0 = outputs_.at(0);
auto output1 = outputs_.at(1);
MS_ASSERT(output0 != nullptr);
MS_ASSERT(output1 != nullptr);
std::vector<int> shape;
shape.push_back(string_num == 0 ? 1 : string_num);
output0->set_shape(shape);
output0->set_data_type(input->data_type());
output0->SetFormat(input->GetFormat());
output1->set_shape(shape);
output1->set_data_type(input->data_type());
output1->SetFormat(input->GetFormat());
return RET_OK;
} }
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

@ -4,6 +4,7 @@ file(GLOB KERNEL_SRC
${CMAKE_CURRENT_SOURCE_DIR}/base/*.cc ${CMAKE_CURRENT_SOURCE_DIR}/base/*.cc
${CMAKE_CURRENT_SOURCE_DIR}/fp32/*.cc ${CMAKE_CURRENT_SOURCE_DIR}/fp32/*.cc
${CMAKE_CURRENT_SOURCE_DIR}/int8/*.cc ${CMAKE_CURRENT_SOURCE_DIR}/int8/*.cc
${CMAKE_CURRENT_SOURCE_DIR}/string/*.cc
) )
list(REMOVE_ITEM KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/int8/opt_op_handler.cc) list(REMOVE_ITEM KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/int8/opt_op_handler.cc)

@ -0,0 +1,97 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/string/extract_feature.h"
#include <string>
#include "src/kernel_registry.h"
using mindspore::lite::KernelRegistrar;
using mindspore::schema::PrimitiveType_CustomExtractFeatures;
namespace mindspore::kernel {
int ExtractFeatureCPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int ExtractFeatureCPUKernel::ReSize() { return RET_OK; }
bool ExtractFeatureCPUKernel::IsInBlacklist(const lite::StringPack &str) {
std::vector<std::string> kBlacklist = {"<S>", "<E>", "<S> <E>"};
for (const auto &s : kBlacklist) {
if (str.len != static_cast<int>(s.length())) {
continue;
}
if (memcmp(str.data, s.data(), str.len) == 0) {
return true;
}
}
return false;
}
int ExtractFeatureCPUKernel::Run() {
auto ret = Prepare();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail! Ret error code: " << ret;
return ret;
}
const int kMaxDimension = 1000000;
auto input_tensor = in_tensors_.at(0);
auto label_data = reinterpret_cast<int32_t *>(out_tensors_.at(0)->MutableData());
auto weight_data = out_tensors_.at(1)->MutableData();
int string_num = lite::GetStringCount(input_tensor);
std::vector<lite::StringPack> all_string_pack = ParseTensorBuffer(input_tensor);
for (int i = 0; i < string_num; i++) {
lite::StringPack str = all_string_pack[i];
if (IsInBlacklist(str)) {
label_data[i] = 0;
reinterpret_cast<int32_t *>(weight_data)[i] = 0;
continue;
}
int64_t hash_value = lite::StringHash64(str.data, str.len) % kMaxDimension;
label_data[i] = hash_value;
reinterpret_cast<float *>(weight_data)[i] = std::count(str.data, str.data + str.len, ' ') + 1;
}
if (string_num == 0) {
label_data[0] = 0;
reinterpret_cast<int32_t *>(weight_data)[0] = 0;
}
return RET_OK;
}
kernel::LiteKernel *CpuExtractFeatureKernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *parameter,
const lite::InnerContext *ctx, const kernel::KernelKey &desc,
const mindspore::lite::PrimitiveC *primitive) {
auto *kernel = new (std::nothrow) ExtractFeatureCPUKernel(parameter, inputs, outputs, ctx, primitive);
if (kernel == nullptr) {
MS_LOG(ERROR) << "new ExtractFeatureCPUKernel fail!";
return nullptr;
}
auto ret = kernel->Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init kernel failed, name: " << parameter->name_
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(parameter->type_));
delete kernel;
return nullptr;
}
return kernel;
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_CustomExtractFeatures, CpuExtractFeatureKernelCreator)
} // namespace mindspore::kernel

@ -0,0 +1,42 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_EXTRACT_FEATURE_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_EXTRACT_FEATURE_H_
#include <vector>
#include "src/lite_kernel.h"
#include "include/context.h"
#include "src/common/string_util.h"
namespace mindspore::kernel {
class ExtractFeatureCPUKernel : public LiteKernel {
public:
ExtractFeatureCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive) {}
~ExtractFeatureCPUKernel() {}
int Init() override;
int ReSize() override;
int Run() override;
private:
bool IsInBlacklist(const lite::StringPack &str);
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_STRING_EXTRACT_FEATURE_H_

@ -28,6 +28,7 @@ file(GLOB KERNEL_OP_SRC
${LITE_DIR}/src/runtime/kernel/arm/base/*.cc ${LITE_DIR}/src/runtime/kernel/arm/base/*.cc
${LITE_DIR}/src/runtime/kernel/arm/fp32/*.cc ${LITE_DIR}/src/runtime/kernel/arm/fp32/*.cc
${LITE_DIR}/src/runtime/kernel/arm/int8/*.cc ${LITE_DIR}/src/runtime/kernel/arm/int8/*.cc
${LITE_DIR}/src/runtime/kernel/arm/string/*.cc
${LITE_DIR}/nnacl/*.c ${LITE_DIR}/nnacl/*.c
${LITE_DIR}/nnacl/fp32/*.c ${LITE_DIR}/nnacl/fp32/*.c
${LITE_DIR}/nnacl/int8/*.c ${LITE_DIR}/nnacl/int8/*.c

Loading…
Cancel
Save