|
|
|
@ -24,50 +24,32 @@ namespace mindspore {
|
|
|
|
|
namespace kernel {
|
|
|
|
|
void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|
|
|
|
CheckParam(kernel_node);
|
|
|
|
|
|
|
|
|
|
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
|
|
|
|
input_lens_ = 1;
|
|
|
|
|
for (auto shape : input_shape_) {
|
|
|
|
|
MS_LOG(INFO) << "input shape: " << shape;
|
|
|
|
|
input_lens_ = input_lens_ * shape;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << "input lens: " << input_lens_;
|
|
|
|
|
|
|
|
|
|
indices_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
|
|
|
|
|
indices_lens_ = 1;
|
|
|
|
|
for (auto shape : indices_shape_) {
|
|
|
|
|
MS_LOG(INFO) << "indice shape: " << shape;
|
|
|
|
|
indices_lens_ = indices_lens_ * shape;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << "indice lens: " << indices_lens_;
|
|
|
|
|
|
|
|
|
|
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
|
|
|
|
for (auto shape : output_shape_) {
|
|
|
|
|
MS_LOG(INFO) << "output shape: " << shape;
|
|
|
|
|
}
|
|
|
|
|
auto output_type = AnfAlgo::GetOutputInferDataType(kernel_node, 0);
|
|
|
|
|
MS_LOG(INFO) << "output type: " << output_type;
|
|
|
|
|
|
|
|
|
|
axis_ = 4 - input_shape_.size();
|
|
|
|
|
MS_LOG(INFO) << "axis_: " << axis_;
|
|
|
|
|
reduce_scatter_flag_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "reduce_scatter_flag");
|
|
|
|
|
MS_LOG(INFO) << "reduce_scatter_flag: " << reduce_scatter_flag_;
|
|
|
|
|
#ifdef ENABLE_MPI
|
|
|
|
|
if (reduce_scatter_flag_) {
|
|
|
|
|
size_t gatherv2_out_lens = 1;
|
|
|
|
|
for (int i = 0; i < SizeToInt(input_shape_.size()); i++) {
|
|
|
|
|
if (i == 0) {
|
|
|
|
|
for (int j = 0; j < SizeToInt(indices_shape_.size()); j++) {
|
|
|
|
|
MS_LOG(DEBUG) << "gatherv2 out shape: " << indices_shape_[j];
|
|
|
|
|
gatherv2_out_lens = gatherv2_out_lens * indices_shape_[j];
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(DEBUG) << "gatherv2 out shape: " << input_shape_[i];
|
|
|
|
|
gatherv2_out_lens = gatherv2_out_lens * input_shape_[i];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
gatherv2_out_lens_ = gatherv2_out_lens * sizeof(float);
|
|
|
|
|
MS_LOG(INFO) << "gatherv2 out lens: " << gatherv2_out_lens_;
|
|
|
|
|
gather_v2_out_ = malloc(gatherv2_out_lens_);
|
|
|
|
|
if (gather_v2_out_ == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "EmbeddingLookUpCPUKernel malloc failed, malloc lens: " << gatherv2_out_lens_;
|
|
|
|
@ -76,9 +58,7 @@ void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|
|
|
|
if (ret != 0) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "EmbeddingLookUpCPUKernel memset gatherv2 out buff failed";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
split_num_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "split_num");
|
|
|
|
|
MS_LOG(INFO) << "split_num: " << split_num_;
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
if (reduce_scatter_flag_) {
|
|
|
|
@ -86,7 +66,6 @@ void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
offset_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "offset");
|
|
|
|
|
MS_LOG(INFO) << "offset: " << offset_;
|
|
|
|
|
CPUKernelUtils::ExpandDimsTo4(&input_shape_);
|
|
|
|
|
CPUKernelUtils::ExpandDimsTo4(&output_shape_);
|
|
|
|
|
}
|
|
|
|
@ -94,21 +73,11 @@ void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|
|
|
|
bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
|
|
|
|
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
|
|
|
|
const std::vector<kernel::AddressPtr> &outputs) {
|
|
|
|
|
#if defined(_WIN32) || defined(_WIN64)
|
|
|
|
|
auto start_time = std::chrono::steady_clock::now();
|
|
|
|
|
#else
|
|
|
|
|
struct timeval start_time, end_time;
|
|
|
|
|
(void)gettimeofday(&start_time, nullptr);
|
|
|
|
|
#endif
|
|
|
|
|
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
|
|
|
|
MS_LOG(DEBUG) << "output addr: " << output_addr << "output size: " << outputs[0]->size;
|
|
|
|
|
float *gather_out_addr = reduce_scatter_flag_ ? reinterpret_cast<float *>(gather_v2_out_) : output_addr;
|
|
|
|
|
MS_LOG(DEBUG) << "gatherv2 out addr: " << gather_out_addr;
|
|
|
|
|
|
|
|
|
|
size_t dim0 = input_shape_[0];
|
|
|
|
|
size_t dim1 = input_shape_[1];
|
|
|
|
|
size_t dim2 = input_shape_[2];
|
|
|
|
|
|
|
|
|
|
if (axis_ == 3) {
|
|
|
|
|
for (size_t i = 0; i < dim0; ++i) {
|
|
|
|
|
for (size_t j = 0; j < dim1; ++j) {
|
|
|
|
@ -130,7 +99,6 @@ bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
|
|
|
|
|
} else if (axis_ == 0) {
|
|
|
|
|
LookUpTable(inputs, 0, 0, 0, &gather_out_addr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#ifdef ENABLE_MPI
|
|
|
|
|
if (reduce_scatter_flag_) {
|
|
|
|
|
size_t one_split_lens = gatherv2_out_lens_ / split_num_ / sizeof(float);
|
|
|
|
@ -143,21 +111,10 @@ bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#if defined(_WIN32) || defined(_WIN64)
|
|
|
|
|
auto end_time = std::chrono::steady_clock::now();
|
|
|
|
|
std::chrono::duration<double, std::ratio<1, 1000000>> cost = end_time - start_time;
|
|
|
|
|
MS_LOG(INFO) << "EmbeddingLookUpCPUKernel, used time: " << cost.count() << " us";
|
|
|
|
|
#else
|
|
|
|
|
(void)gettimeofday(&end_time, nullptr);
|
|
|
|
|
uint64_t time = 1000000 * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
|
|
|
|
|
time += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
|
|
|
|
|
MS_LOG(INFO) << "EmbeddingLookUpCPUKernel, used time: " << time << " us";
|
|
|
|
|
#endif
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void LookUpTable_task(float *input_addr, float *output_addr, int *indices_addr, size_t indices_lens, size_t num,
|
|
|
|
|
void LookUpTable_task(const float *input_addr, float *output_addr, int *indices_addr, size_t indices_lens, size_t num,
|
|
|
|
|
size_t dim0, size_t dim1, size_t dim2, int offset, size_t axis, std::vector<size_t> input_shape,
|
|
|
|
|
size_t input_lens) {
|
|
|
|
|
size_t lens = num * sizeof(float);
|
|
|
|
@ -182,7 +139,6 @@ void LookUpTable_task(float *input_addr, float *output_addr, int *indices_addr,
|
|
|
|
|
if (ret != EOK) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "LookUpTable task memcpy failed.";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
auto ret = memset_s(output_addr, lens, 0, lens);
|
|
|
|
|
if (ret != EOK) {
|
|
|
|
@ -204,6 +160,7 @@ void LookUpTable_task(float *input_addr, float *output_addr, int *indices_addr,
|
|
|
|
|
output_addr += num;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void EmbeddingLookUpCPUKernel::LookUpTable(const std::vector<kernel::AddressPtr> &inputs, size_t dim0, size_t dim1,
|
|
|
|
|
size_t dim2, float **output_addr) {
|
|
|
|
|
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);
|
|
|
|
|