|
|
|
@ -48,19 +48,14 @@ void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|
|
|
|
auto output_type = AnfAlgo::GetOutputInferDataType(kernel_node, 0);
|
|
|
|
|
MS_LOG(DEBUG) << "output type: " << output_type;
|
|
|
|
|
|
|
|
|
|
int axis = AnfAlgo::GetNodeAttr<int>(kernel_node, "axis");
|
|
|
|
|
MS_LOG(DEBUG) << "axis: " << axis;
|
|
|
|
|
if (axis_ < 0) {
|
|
|
|
|
axis = axis + SizeToInt(input_shape_.size());
|
|
|
|
|
}
|
|
|
|
|
axis_ = 4 - input_shape_.size() + axis;
|
|
|
|
|
axis_ = 4 - input_shape_.size();
|
|
|
|
|
MS_LOG(DEBUG) << "axis_: " << axis_;
|
|
|
|
|
reduce_scatter_flag_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "reduce_scatter_flag");
|
|
|
|
|
MS_LOG(DEBUG) << "reduce_scatter_flag: " << reduce_scatter_flag_;
|
|
|
|
|
if (reduce_scatter_flag_) {
|
|
|
|
|
size_t gatherv2_out_lens = 1;
|
|
|
|
|
for (int i = 0; i < SizeToInt(input_shape_.size()); i++) {
|
|
|
|
|
if (i == axis) {
|
|
|
|
|
if (i == 0) {
|
|
|
|
|
for (int j = 0; j < SizeToInt(indices_shape_.size()); j++) {
|
|
|
|
|
MS_LOG(DEBUG) << "gatherv2 out shape: " << indices_shape_[j];
|
|
|
|
|
gatherv2_out_lens = gatherv2_out_lens * indices_shape_[j];
|
|
|
|
@ -76,7 +71,10 @@ void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|
|
|
|
if (gather_v2_out_ == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "EmbeddingLookUpCPUKernel malloc failed, malloc lens: " << gatherv2_out_lens_;
|
|
|
|
|
}
|
|
|
|
|
memset_s(gather_v2_out_, gatherv2_out_lens_, 0, gatherv2_out_lens_);
|
|
|
|
|
auto ret = memset_s(gather_v2_out_, gatherv2_out_lens_, 0, gatherv2_out_lens_);
|
|
|
|
|
if (ret != 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "EmbeddingLookUpCPUKernel memset gatherv2 out buff failed";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
split_num_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "split_num");
|
|
|
|
|
MS_LOG(DEBUG) << "split_num: " << split_num_;
|
|
|
|
@ -99,6 +97,12 @@ bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
|
|
|
|
|
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
|
|
|
|
MS_LOG(DEBUG) << "output addr: " << output_addr << "output size: " << outputs[0]->size;
|
|
|
|
|
float *gather_out_addr = reduce_scatter_flag_ ? reinterpret_cast<float *>(gather_v2_out_) : output_addr;
|
|
|
|
|
if (!reduce_scatter_flag_) {
|
|
|
|
|
auto ret = memset_s(gather_out_addr, outputs[0]->size, 0, outputs[0]->size);
|
|
|
|
|
if (ret != 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "EmbeddingLookUpCPUKernel memset out buff failed";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "gatherv2 out addr: " << gather_out_addr;
|
|
|
|
|
size_t dim0 = input_shape_[0];
|
|
|
|
|
size_t dim1 = input_shape_[1];
|
|
|
|
@ -149,10 +153,10 @@ bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void memcpy_task(std::vector<float *> mem_dest_addr_list, std::vector<float *> mem_src_addr_list, size_t start,
|
|
|
|
|
void memcpy_task(std::vector<float *> *mem_dest_addr_list, std::vector<float *> *mem_src_addr_list, size_t start,
|
|
|
|
|
size_t end, size_t lens) {
|
|
|
|
|
for (size_t i = start; i < end; i++) {
|
|
|
|
|
auto ret = memcpy_s(mem_dest_addr_list[i], lens, mem_src_addr_list[i], lens);
|
|
|
|
|
auto ret = memcpy_s((*mem_dest_addr_list)[i], lens, (*mem_src_addr_list)[i], lens);
|
|
|
|
|
if (ret != EOK) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "memery copy failed.";
|
|
|
|
|
}
|
|
|
|
@ -204,7 +208,7 @@ void EmbeddingLookUpCPUKernel::LookUpTable(const std::vector<kernel::AddressPtr>
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
auto end = (start + ones_copy_lens) > memcpy_lens ? memcpy_lens : start + ones_copy_lens;
|
|
|
|
|
threads[i] = std::thread(memcpy_task, mem_dest_addr_list, mem_src_addr_list, start, end, lens);
|
|
|
|
|
threads[i] = std::thread(memcpy_task, &mem_dest_addr_list, &mem_src_addr_list, start, end, lens);
|
|
|
|
|
start = start + ones_copy_lens;
|
|
|
|
|
}
|
|
|
|
|
for (size_t j = 0; j < i; j++) {
|
|
|
|
|