|
|
|
|
@ -35,7 +35,7 @@ tensor::TensorPtr CreateTensor(const AnfNodePtr &node) {
|
|
|
|
|
// 1 create tensor
|
|
|
|
|
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
|
|
|
|
|
auto last_dim = shape[shape.size() - 1];
|
|
|
|
|
std::vector<int> indices_shape = {SizeToInt(last_dim)};
|
|
|
|
|
std::vector<int> indices_shape = {SizeToInt(last_dim * 2)};
|
|
|
|
|
TensorTypePtr tensor_type = std::make_shared<TensorType>(kFloat16);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tensor_type);
|
|
|
|
|
tensor::DeviceInfo device_info{kOpFormat_DEFAULT, tensor_type};
|
|
|
|
|
@ -50,7 +50,11 @@ tensor::TensorPtr CreateTensor(const AnfNodePtr &node) {
|
|
|
|
|
for (size_t i = 0; i < last_dim; ++i) {
|
|
|
|
|
half_data.emplace_back(Eigen::half(static_cast<float>(i)));
|
|
|
|
|
}
|
|
|
|
|
auto elem_num = last_dim * kFloat16Len;
|
|
|
|
|
for (size_t i = 0; i < last_dim; ++i) {
|
|
|
|
|
auto gap = static_cast<int>(i) - static_cast<int>(Eigen::half(static_cast<float>(i)));
|
|
|
|
|
half_data.emplace_back(Eigen::half(static_cast<float>(gap)));
|
|
|
|
|
}
|
|
|
|
|
auto elem_num = last_dim * kFloat16Len * 2;
|
|
|
|
|
auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(indices_tensor->data().nbytes()), half_data.data(), elem_num);
|
|
|
|
|
if (ret_code != 0) {
|
|
|
|
|
MS_LOG(ERROR) << "Failed to copy data into Tensor.";
|
|
|
|
|
@ -108,6 +112,13 @@ const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNod
|
|
|
|
|
MS_LOG(INFO) << "The input k of topk has been converted to attr";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto shape = AnfAlgo::GetPrevNodeOutputInferShape(node, 0);
|
|
|
|
|
auto last_dim = shape[shape.size() - 1];
|
|
|
|
|
const size_t kMaxFloat16 = 65500;
|
|
|
|
|
if (last_dim > kMaxFloat16) {
|
|
|
|
|
MS_LOG(INFO) << "The last dim is more than 65500, switch to aicpu ops.";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
// Copy a new node to check supported.
|
|
|
|
|
std::vector<AnfNodePtr> new_inputs{NewValueNode(std::make_shared<Primitive>(kTopKOpName))};
|
|
|
|
|
new_inputs.insert(new_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
|
|
|
|
|
|