|
|
|
@ -35,29 +35,11 @@ int GatherInt8CPUKernel::Init() {
|
|
|
|
|
axis_ = (reinterpret_cast<GatherParameter *>(op_parameter_))->axis_;
|
|
|
|
|
batchDims_ = (reinterpret_cast<GatherParameter *>(op_parameter_))->batchDims_;
|
|
|
|
|
auto in_quant_args = in_tensors_.at(0)->GetQuantParams();
|
|
|
|
|
auto ind_quant_args = in_tensors_.at(1)->GetQuantParams();
|
|
|
|
|
auto out_quant_args = out_tensors_.at(0)->GetQuantParams();
|
|
|
|
|
param_.alpha_ = in_quant_args.front().scale / out_quant_args.front().scale;
|
|
|
|
|
param_.zp_in_ = in_quant_args.front().zeroPoint;
|
|
|
|
|
param_.zp_out_ = out_quant_args.front().zeroPoint;
|
|
|
|
|
|
|
|
|
|
auto indices_ptr = reinterpret_cast<int8_t *>(in_tensors_.at(1)->Data());
|
|
|
|
|
if (indices_ != nullptr) {
|
|
|
|
|
free(indices_);
|
|
|
|
|
indices_ = nullptr;
|
|
|
|
|
}
|
|
|
|
|
int count = in_tensors_.at(1)->ElementsNum();
|
|
|
|
|
indices_ = reinterpret_cast<int *>(malloc(count * sizeof(int)));
|
|
|
|
|
if (indices_ == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "Gather Malloc indices_ error!";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
(void)memset(indices_, 0, count * sizeof(int));
|
|
|
|
|
for (int i = 0; i < count; ++i) {
|
|
|
|
|
indices_[i] =
|
|
|
|
|
static_cast<int>(round((indices_ptr[i] - ind_quant_args.front().zeroPoint) * ind_quant_args.front().scale));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!InferShapeDone()) {
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
@ -73,6 +55,7 @@ int GatherInt8CPUKernel::DoGather(int task_id) {
|
|
|
|
|
|
|
|
|
|
auto input_ptr = reinterpret_cast<int8_t *>(input_tensor->Data());
|
|
|
|
|
auto output_ptr = reinterpret_cast<int8_t *>(out_tensor->Data());
|
|
|
|
|
auto indices_ptr = reinterpret_cast<int32_t *>(out_tensor->Data());
|
|
|
|
|
|
|
|
|
|
auto in_shape = input_tensor->shape();
|
|
|
|
|
int in_rank = in_shape.size();
|
|
|
|
@ -80,8 +63,8 @@ int GatherInt8CPUKernel::DoGather(int task_id) {
|
|
|
|
|
|
|
|
|
|
const int limit = in_shape[axis_];
|
|
|
|
|
for (int i = 0; i < indices_element_size; ++i) {
|
|
|
|
|
if (indices_[i] >= limit) {
|
|
|
|
|
MS_LOG(ERROR) << " indice data: " << indices_[i] << " is not in [ 0, " << limit - 1 << " ]";
|
|
|
|
|
if (indices_ptr[i] >= limit) {
|
|
|
|
|
MS_LOG(ERROR) << " indice data: " << indices_ptr[i] << " is not in [ 0, " << limit - 1 << " ]";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -103,7 +86,7 @@ int GatherInt8CPUKernel::DoGather(int task_id) {
|
|
|
|
|
int error_code;
|
|
|
|
|
input_ptr += thread_stride * limit;
|
|
|
|
|
output_ptr += thread_stride * indices_element_size;
|
|
|
|
|
error_code = GatherInt8(input_ptr, output_ptr, count, inner_size, limit, indices_, indices_element_size, param_);
|
|
|
|
|
error_code = GatherInt8(input_ptr, output_ptr, count, inner_size, limit, indices_ptr, indices_element_size, param_);
|
|
|
|
|
|
|
|
|
|
if (error_code != RET_OK) {
|
|
|
|
|
return RET_ERROR;
|
|
|
|
@ -127,6 +110,7 @@ int GatherInt8CPUKernel::Run() {
|
|
|
|
|
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
|
|
|
|
|
return prepare_ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int error_code = ParallelLaunch(THREAD_POOL_DEFAULT, GatherInt8Run, this, thread_count_);
|
|
|
|
|
if (error_code != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Gather function error error_code[" << error_code << "]";
|
|
|
|
|