|
|
|
@ -26,12 +26,16 @@ using mindspore::schema::PrimitiveType_EmbeddingLookup;
|
|
|
|
|
|
|
|
|
|
namespace mindspore::kernel {
|
|
|
|
|
int EmbeddingLookupCPUKernel::Init() {
|
|
|
|
|
if (context_->infer_shape_interrupt_ && !context_->running_) {
|
|
|
|
|
SetNeedReInit();
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
embedding_lookup_parameter_ = reinterpret_cast<EmbeddingLookupParameter *>(opParameter);
|
|
|
|
|
embedding_lookup_parameter_->thread_num = thread_count_;
|
|
|
|
|
|
|
|
|
|
if (!InferShapeDone()) {
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
return ReSize();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int EmbeddingLookupCPUKernel::ReSize() {
|
|
|
|
|
embedding_lookup_parameter_->ids_size_ = inputs_.back()->ElementsNum();
|
|
|
|
|
|
|
|
|
|
embedding_lookup_parameter_->layer_size_ = 1;
|
|
|
|
@ -45,18 +49,34 @@ int EmbeddingLookupCPUKernel::Init() {
|
|
|
|
|
embedding_lookup_parameter_->layer_num_ += inputs_[i]->shape()[0];
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
input_addr_ = reinterpret_cast<float *>(
|
|
|
|
|
std::malloc(sizeof(float) * embedding_lookup_parameter_->layer_size_ * embedding_lookup_parameter_->layer_num_));
|
|
|
|
|
if (input_addr_ != nullptr) {
|
|
|
|
|
free(input_addr_);
|
|
|
|
|
}
|
|
|
|
|
if (context_ != nullptr && context_->allocator != nullptr) {
|
|
|
|
|
input_addr_ = reinterpret_cast<float *>(context_->allocator->Malloc(
|
|
|
|
|
sizeof(float) * embedding_lookup_parameter_->layer_size_ * embedding_lookup_parameter_->layer_num_));
|
|
|
|
|
} else {
|
|
|
|
|
input_addr_ = reinterpret_cast<float *>(
|
|
|
|
|
malloc(sizeof(float) * embedding_lookup_parameter_->layer_size_ * embedding_lookup_parameter_->layer_num_));
|
|
|
|
|
}
|
|
|
|
|
if (input_addr_ == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "Create memory failed";
|
|
|
|
|
return mindspore::lite::RET_MEMORY_FAILED;
|
|
|
|
|
MS_LOG(ERROR) << "Malloc buffer failed";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
embedding_lookup_parameter_->is_regulated_ =
|
|
|
|
|
reinterpret_cast<bool *>(std::malloc(sizeof(bool) * embedding_lookup_parameter_->layer_num_));
|
|
|
|
|
if (embedding_lookup_parameter_->is_regulated_ != nullptr) {
|
|
|
|
|
free(embedding_lookup_parameter_->is_regulated_);
|
|
|
|
|
}
|
|
|
|
|
if (context_ != nullptr && context_->allocator != nullptr) {
|
|
|
|
|
embedding_lookup_parameter_->is_regulated_ =
|
|
|
|
|
reinterpret_cast<bool *>(context_->allocator->Malloc(sizeof(bool) * embedding_lookup_parameter_->layer_num_));
|
|
|
|
|
} else {
|
|
|
|
|
embedding_lookup_parameter_->is_regulated_ =
|
|
|
|
|
reinterpret_cast<bool *>(malloc(sizeof(bool) * embedding_lookup_parameter_->layer_num_));
|
|
|
|
|
}
|
|
|
|
|
if (embedding_lookup_parameter_->is_regulated_ == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "Create memory failed";
|
|
|
|
|
return mindspore::lite::RET_MEMORY_FAILED;
|
|
|
|
|
MS_LOG(ERROR) << "Malloc buffer failed";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < embedding_lookup_parameter_->layer_num_; ++i) {
|
|
|
|
@ -66,8 +86,6 @@ int EmbeddingLookupCPUKernel::Init() {
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int EmbeddingLookupCPUKernel::ReSize() { return RET_OK; }
|
|
|
|
|
|
|
|
|
|
int EmbeddingLookupCPUKernel::DoExcute(int task_id) {
|
|
|
|
|
int error_code = EmbeddingLookup(input_addr_, ids_addr_, output_addr_, embedding_lookup_parameter_, task_id);
|
|
|
|
|
if (error_code != RET_OK) {
|
|
|
|
|