|
|
|
@ -13,12 +13,12 @@
|
|
|
|
|
* See the License for the specific language governing permissions and
|
|
|
|
|
* limitations under the License.
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
#include "src/runtime/kernel/arm/fp32/lsh_projection.h"
|
|
|
|
|
|
|
|
|
|
#include "include/errorcode.h"
|
|
|
|
|
#include "src/common/string_util.h"
|
|
|
|
|
#include "src/kernel_registry.h"
|
|
|
|
|
#include "src/runtime/runtime_api.h"
|
|
|
|
|
#include "src/common/string_util.h"
|
|
|
|
|
|
|
|
|
|
using mindspore::kernel::KERNEL_ARCH::kCPU;
|
|
|
|
|
using mindspore::lite::KernelRegistrar;
|
|
|
|
@ -28,12 +28,6 @@ using mindspore::lite::RET_OK;
|
|
|
|
|
using mindspore::schema::PrimitiveType_LshProjection;
|
|
|
|
|
|
|
|
|
|
namespace mindspore::kernel {
|
|
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
|
constexpr int kSparseType = 1;
|
|
|
|
|
constexpr int kDenseType = 2;
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
int LshProjectionCPUKernel::Init() {
|
|
|
|
|
if (!InferShapeDone()) {
|
|
|
|
|
return RET_OK;
|
|
|
|
@ -91,10 +85,10 @@ int LshProjectionCPUKernel::DoExecute(int task_id) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
switch (lsh_param_->lsh_type_) {
|
|
|
|
|
case kSparseType:
|
|
|
|
|
case schema::LshProjectionType_SPARSE:
|
|
|
|
|
LshProjectionSparse(hash, in_data, weight, output, lsh_param_);
|
|
|
|
|
break;
|
|
|
|
|
case kDenseType:
|
|
|
|
|
case schema::LshProjectionType_DENSE:
|
|
|
|
|
LshProjectionDense(hash, in_data, weight, output, lsh_param_);
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
@ -106,7 +100,7 @@ int LshProjectionCPUKernel::DoExecute(int task_id) {
|
|
|
|
|
int LshProjectionCPUKernel::GetSignBit(char *in_data, float *weight, float seed, LshProjectionParameter *para) {
|
|
|
|
|
double score = 0.0;
|
|
|
|
|
for (int i = 0; i < para->in_item_num_; i++) {
|
|
|
|
|
char *key = static_cast<char *>(ctx_->allocator->Malloc(lsh_param_->key_size_));
|
|
|
|
|
char *key = static_cast<char *>(context_->allocator->Malloc(lsh_param_->key_size_));
|
|
|
|
|
if (key == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "malloc key failed.";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
@ -114,13 +108,14 @@ int LshProjectionCPUKernel::GetSignBit(char *in_data, float *weight, float seed,
|
|
|
|
|
memcpy(key, &seed, para->seed_size_);
|
|
|
|
|
memcpy(key + para->seed_size_, in_data, para->in_item_size_);
|
|
|
|
|
in_data += para->in_item_size_;
|
|
|
|
|
double hash_sign = static_cast<double>(mindspore::lite::StringHash64(key, para->key_size_));
|
|
|
|
|
int64_t hash_i = static_cast<int64_t>(mindspore::lite::StringHash64(key, para->key_size_));
|
|
|
|
|
double hash_d = static_cast<double>(hash_i);
|
|
|
|
|
if (weight == nullptr) {
|
|
|
|
|
score += hash_sign;
|
|
|
|
|
score += hash_d;
|
|
|
|
|
} else {
|
|
|
|
|
score += weight[i] * hash_sign;
|
|
|
|
|
score += weight[i] * hash_d;
|
|
|
|
|
}
|
|
|
|
|
ctx_->allocator->Free(key);
|
|
|
|
|
context_->allocator->Free(key);
|
|
|
|
|
}
|
|
|
|
|
return (score > 0) ? 1 : 0;
|
|
|
|
|
}
|
|
|
|
|