!6055 [MSLITE] Fix bug of several quantized operators inference.

Merge pull request !6055 from wangshaocong/lite_bugfix
pull/6055/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 6240189190

@ -79,7 +79,12 @@ int QuantDTypeCastCPUKernel::QuantDTypeCast(int task_id) {
return RET_OK; return RET_OK;
} }
int thread_offset = task_id * thread_n_stride_; int thread_offset = task_id * thread_n_stride_;
auto quant_arg = in_tensors_.front()->GetQuantParams().front(); if (in_tensors_.front()->GetQuantParams().empty() && out_tensors_.front()->GetQuantParams().empty()) {
MS_LOG(ERROR) << "QuantDTypeCast need quantization parameters which is not found.";
return RET_ERROR;
}
auto quant_arg = !in_tensors_.front()->GetQuantParams().empty() ? in_tensors_.front()->GetQuantParams().front() :
out_tensors_.front()->GetQuantParams().front();
int ret; int ret;
if (inverse_) { if (inverse_) {
ret = DoDequantizeInt8(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale, ret = DoDequantizeInt8(int8_ptr_ + thread_offset, float32_ptr_ + thread_offset, quant_arg.scale,

@ -92,17 +92,10 @@ int QuantizedAddCPUKernel::Run() {
input0_data_ = static_cast<int8_t *>(ctx_->allocator->Malloc(out_tensors_.at(0)->Size())); input0_data_ = static_cast<int8_t *>(ctx_->allocator->Malloc(out_tensors_.at(0)->Size()));
input1_data_ = static_cast<int8_t *>(ctx_->allocator->Malloc(out_tensors_.at(0)->Size())); input1_data_ = static_cast<int8_t *>(ctx_->allocator->Malloc(out_tensors_.at(0)->Size()));
ArithmeticParameter tile_para;
tile_para.ndim_ = out_tensors_.at(0)->shape().size();
for (size_t i = 0; i < tile_para.ndim_; i++) {
tile_para.in_shape0_[i] = in_tensors_.at(0)->DimensionSize(i);
tile_para.in_shape1_[i] = in_tensors_.at(1)->DimensionSize(i);
tile_para.out_shape_[i] = out_tensors_.at(0)->DimensionSize(i);
}
TileDimensionsUint8(static_cast<uint8_t *>(in_tensors_.at(0)->MutableData()), TileDimensionsUint8(static_cast<uint8_t *>(in_tensors_.at(0)->MutableData()),
static_cast<uint8_t *>(in_tensors_.at(1)->MutableData()), static_cast<uint8_t *>(in_tensors_.at(1)->MutableData()),
reinterpret_cast<uint8_t *>(input0_data_), reinterpret_cast<uint8_t *>(input1_data_), reinterpret_cast<uint8_t *>(input0_data_), reinterpret_cast<uint8_t *>(input1_data_),
&tile_para); arith_para_);
ret = ParallelLaunch(THREAD_POOL_DEFAULT, AddInt8Run, this, thread_count_); ret = ParallelLaunch(THREAD_POOL_DEFAULT, AddInt8Run, this, thread_count_);
ctx_->allocator->Free(input0_data_); ctx_->allocator->Free(input0_data_);
ctx_->allocator->Free(input1_data_); ctx_->allocator->Free(input1_data_);

@ -19,6 +19,7 @@
#include <vector> #include <vector>
#include "src/lite_kernel.h" #include "src/lite_kernel.h"
#include "nnacl/int8/add_int8.h" #include "nnacl/int8/add_int8.h"
#include "nnacl/arithmetic_common.h"
#include "src/runtime/runtime_api.h" #include "src/runtime/runtime_api.h"
namespace mindspore::kernel { namespace mindspore::kernel {
@ -27,7 +28,9 @@ class QuantizedAddCPUKernel : public LiteKernel {
explicit QuantizedAddCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, explicit QuantizedAddCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::Context *ctx, const std::vector<lite::Tensor *> &outputs, const lite::Context *ctx,
const mindspore::lite::PrimitiveC *primitive) const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx_->thread_num_) {} : LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx_->thread_num_) {
arith_para_ = reinterpret_cast<ArithmeticParameter *>(parameter);
}
~QuantizedAddCPUKernel() override {} ~QuantizedAddCPUKernel() override {}
int Init() override; int Init() override;
@ -38,6 +41,7 @@ class QuantizedAddCPUKernel : public LiteKernel {
private: private:
const lite::Context *ctx_; const lite::Context *ctx_;
AddQuantParameter para_; AddQuantParameter para_;
ArithmeticParameter *arith_para_;
int thread_count_; int thread_count_;
int64_t elements_num_; int64_t elements_num_;
int64_t count_unit_; int64_t count_unit_;

@ -91,7 +91,7 @@ ParameterPtr CreateNewParamter(const FuncGraphPtr &func_graph, Tensor *tensor) {
MS_LOG(ERROR) << "tensor_data is nullptr"; MS_LOG(ERROR) << "tensor_data is nullptr";
return nullptr; return nullptr;
} }
auto ret = memcpy_s(tensor_data, size * sizeof(float), tensor->MutableData(), size * sizeof(float)); auto ret = memcpy_s(tensor_data, tensor->Size(), tensor->MutableData(), tensor->Size());
if (ret != EOK) { if (ret != EOK) {
delete[] tensor_data; delete[] tensor_data;
MS_LOG(ERROR) << "memcpy error: " << ret; MS_LOG(ERROR) << "memcpy error: " << ret;
@ -234,6 +234,9 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An
return nullptr; return nullptr;
} }
lite::Context context; lite::Context context;
if (context.allocator == nullptr) {
context.allocator = lite::Allocator::Create();
}
auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, parameter, &context, lite_primitive.get()); auto lite_kernel = GetLiteKernel(input_tensors, output_tensors, parameter, &context, lite_primitive.get());
if (lite_kernel == nullptr) { if (lite_kernel == nullptr) {
MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr"; MS_LOG(ERROR) << "constant_folding schedule node lite kernel nullptr";

Loading…
Cancel
Save