From d871702bf33dcc47082560ab327e37dc1f833ef2 Mon Sep 17 00:00:00 2001 From: yefeng Date: Mon, 8 Mar 2021 10:51:30 +0800 Subject: [PATCH] fix_gpu_create_1 --- .../examples/train_lenet/src/net_runner.cc | 1 + .../src/runtime/kernel/arm/base/crop_base.cc | 36 +++---------------- .../src/runtime/kernel/arm/base/crop_base.h | 3 +- .../src/runtime/kernel/arm/fp16/crop_fp16.cc | 1 - .../src/runtime/kernel/arm/fp32/crop_fp32.cc | 1 - .../src/runtime/kernel/arm/int8/crop_int8.cc | 12 ------- .../src/runtime/kernel/arm/int8/crop_int8.h | 2 +- .../runtime/kernel/opencl/kernel/conv2d.cc | 7 +--- .../runtime/kernel/opencl/kernel/matmul.cc | 3 +- .../runtime/kernel/opencl/opencl_kernel.cc | 7 ++++ mindspore/lite/src/tensor.cc | 2 +- mindspore/lite/tools/common/protobuf_utils.cc | 2 +- .../converter/quantizer/quantize_util.cc | 2 +- .../tools/converter/quantizer/quantize_util.h | 2 +- .../optimizer/fusion/conv_transform_fusion.cc | 2 +- 15 files changed, 23 insertions(+), 60 deletions(-) diff --git a/mindspore/lite/examples/train_lenet/src/net_runner.cc b/mindspore/lite/examples/train_lenet/src/net_runner.cc index edeeb1ff49..a8690c490a 100644 --- a/mindspore/lite/examples/train_lenet/src/net_runner.cc +++ b/mindspore/lite/examples/train_lenet/src/net_runner.cc @@ -45,6 +45,7 @@ class Rescaler : public mindspore::session::TrainLoopCallBack { explicit Rescaler(float scale) : scale_(scale) { if (scale_ == 0) scale_ = 1.0; } + ~Rescaler() override = default; void StepBegin(const mindspore::session::TrainLoopCallBackData &cb_data) override { auto inputs = cb_data.session_->GetInputs(); auto *input_data = reinterpret_cast(inputs.at(0)->MutableData()); diff --git a/mindspore/lite/src/runtime/kernel/arm/base/crop_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/crop_base.cc index ee74e32460..3dbd2a6c12 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/crop_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/crop_base.cc @@ -32,26 +32,11 @@ int CropBaseCPUKernel::Init() { return RET_OK; } int CropBaseCPUKernel::ReSize() { auto *input_tensor = in_tensors_.at(kInputIndex); auto *out_tensor = out_tensors_.at(kOutputIndex); - auto input_shape = input_tensor->shape(); - auto output_shape = out_tensor->shape(); - size_t input_dim = input_shape.size(); - size_t output_dim = output_shape.size(); - FreeTmpBuffer(); - - crop_para_->in_shape_ = reinterpret_cast(malloc(input_dim * sizeof(int))); - if (crop_para_->in_shape_ == nullptr) { - MS_LOG(ERROR) << "in_shape_ is nullptr"; - return RET_ERROR; - } - memcpy(crop_para_->in_shape_, input_shape.data(), sizeof(int) * input_dim); - - crop_para_->out_shape_ = reinterpret_cast(malloc(output_dim * sizeof(int))); - if (crop_para_->out_shape_ == nullptr) { - MS_LOG(ERROR) << "out_shape_ is nullptr"; - return RET_ERROR; - } - memcpy(crop_para_->out_shape_, output_shape.data(), sizeof(int) * output_dim); - + input_shape_ = input_tensor->shape(); + output_shape_ = out_tensor->shape(); + size_t input_dim = input_shape_.size(); + crop_para_->in_shape_ = input_shape_.data(); + crop_para_->out_shape_ = output_shape_.data(); MS_ASSERT(input_dim <= CROP_OFFSET_MAX_SIZE); crop_para_->input_dim_ = input_dim; PadOffset(input_dim, crop_para_); @@ -77,15 +62,4 @@ void CropBaseCPUKernel::PadOffset(int input_dim, CropParameter *crop_para) { crop_para->in_offset_[i] = crop_offset; } } - -void CropBaseCPUKernel::FreeTmpBuffer() { - if (crop_para_->in_shape_ != nullptr) { - free(crop_para_->in_shape_); - crop_para_->in_shape_ = nullptr; - } - if (crop_para_->out_shape_ != nullptr) { - free(crop_para_->out_shape_); - crop_para_->out_shape_ = nullptr; - } -} } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/crop_base.h b/mindspore/lite/src/runtime/kernel/arm/base/crop_base.h index 88f170ad74..af208296c8 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/crop_base.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/crop_base.h @@ -35,9 +35,10 @@ class CropBaseCPUKernel : public LiteKernel { int Init() override; int ReSize() override; int Run() override { return 0; } - void FreeTmpBuffer(); protected: + std::vector input_shape_; + std::vector output_shape_; CropParameter *crop_para_; void PadOffset(int input_dim, CropParameter *crop_para); }; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/crop_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/crop_fp16.cc index 3d17806249..02622c55b7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/crop_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/crop_fp16.cc @@ -57,7 +57,6 @@ int CropFp16CPUKernel::Run() { if (ret != RET_OK) { MS_LOG(ERROR) << "ParallelLaunch failed: " << ret; } - FreeTmpBuffer(); return ret; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/crop_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/crop_fp32.cc index 01cb0a570b..ab2e1d46a0 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/crop_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/crop_fp32.cc @@ -67,7 +67,6 @@ int CropCPUKernel::Run() { MS_LOG(ERROR) << "Crop launch fail!ret: " << ret; return RET_ERROR; } - FreeTmpBuffer(); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/crop_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/crop_int8.cc index 663560937d..49bd21b62d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/crop_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/crop_int8.cc @@ -49,18 +49,6 @@ int CropInt8CPUKernel::Init() { return ReSize(); } -CropInt8CPUKernel::~CropInt8CPUKernel() { - if (crop_para_->in_shape_ != nullptr) { - free(const_cast(crop_para_->in_shape_)); - crop_para_->in_shape_ = nullptr; - } - - if (crop_para_->out_shape_ != nullptr) { - free(const_cast(crop_para_->out_shape_)); - crop_para_->out_shape_ = nullptr; - } -} - int CropInt8CPUKernel::ReSize() { return CropBaseCPUKernel::ReSize(); } int CropInt8CPUKernel::Run() { diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/crop_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/crop_int8.h index bee5aa5ac3..5768ea4b83 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/crop_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/crop_int8.h @@ -32,7 +32,7 @@ class CropInt8CPUKernel : public CropBaseCPUKernel { CropInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const mindspore::lite::InnerContext *ctx) : CropBaseCPUKernel(parameter, inputs, outputs, ctx) {} - ~CropInt8CPUKernel(); + ~CropInt8CPUKernel() = default; int Init() override; int ReSize() override; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc index 95aa5a87ae..8eef95289a 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/conv2d.cc @@ -506,12 +506,10 @@ kernel::LiteKernel *OpenCLConv2DCreator(const std::vector &input // case 3: common conv2d kernel::OpenCLKernel *kernel; - OpParameter *real_param; bool infer_shape_done = opParameter->infer_flag_; if (infer_shape_done && UseFcReplaceConv(inputs, outputs, conv_param)) { auto *fc_param = CreateFcParam(conv_param, inputs); kernel = new (std::nothrow) FullConnectionOpenCLKernel(fc_param, inputs, outputs, ctx); - real_param = fc_param; if (kernel == nullptr) { MS_LOG(ERROR) << "Create FullConnection kernel failed."; free(fc_param); @@ -529,7 +527,6 @@ kernel::LiteKernel *OpenCLConv2DCreator(const std::vector &input } else { kernel = new (std::nothrow) Conv2DOpenCLKernel(reinterpret_cast(conv_param), inputs, outputs, ctx); } - real_param = reinterpret_cast(conv_param); if (kernel == nullptr) { MS_LOG(ERROR) << "Create Convolution kernel failed."; free(conv_param); @@ -540,11 +537,9 @@ kernel::LiteKernel *OpenCLConv2DCreator(const std::vector &input MS_LOG(WARNING) << "kernel don't infer shape yet!"; return kernel; } - int ret = kernel->CheckSpecs(); - if (ret != mindspore::lite::RET_OK) { + if (kernel->CheckSpecs() != RET_OK || kernel->OpenCLKernel::CheckSpecs() != RET_OK) { MS_LOG(ERROR) << "Init Convolution kernel failed."; delete kernel; - free(real_param); return nullptr; } return kernel; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc index 1aaf62fff2..adf7684efb 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/matmul.cc @@ -243,8 +243,7 @@ kernel::LiteKernel *OpenCLMatMulKernelCreator(const std::vector MS_LOG(WARNING) << "kernel don't infer shape yet!"; return kernel; } - auto ret = kernel->CheckSpecs(); - if (ret != RET_OK) { + if (kernel->CheckSpecs() != RET_OK || kernel->OpenCLKernel::CheckSpecs() != RET_OK) { MS_LOG(ERROR) << "Check " << opParameter->name_ << " specification failed!"; delete kernel; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc index ff26eb29c6..5df0f21765 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/opencl_kernel.cc @@ -406,6 +406,13 @@ int OpenCLKernel::CheckSpecs() { return RET_ERROR; } } + if (in_tensors_.size() > 0) { + if (in_tensors_[0]->data_type() != kNumberTypeFloat32 && in_tensors_[0]->data_type() != kNumberTypeFloat16 && + in_tensors_[0]->data_type() != kNumberTypeInt32) { + MS_LOG(WARNING) << "Unsupported data type: " << in_tensors_[0]->data_type(); + return RET_ERROR; + } + } return RET_OK; } } // namespace mindspore::kernel diff --git a/mindspore/lite/src/tensor.cc b/mindspore/lite/src/tensor.cc index a4a23c503c..f71fb7c94f 100644 --- a/mindspore/lite/src/tensor.cc +++ b/mindspore/lite/src/tensor.cc @@ -30,7 +30,7 @@ Tensor::Tensor(const TypeId data_type, std::vector shape, const schema::For : data_type_(data_type), shape_(std::move(shape)), format_(format), category_(category) {} Tensor::Tensor(const std::string &name, enum TypeId type, const std::vector &shape, const void *data) - : tensor_name_(name), data_type_(type), shape_(std::move(shape)) { + : tensor_name_(name), data_type_(type), shape_(std::move(shape)), category_(VAR) { data_ = const_cast(data); } diff --git a/mindspore/lite/tools/common/protobuf_utils.cc b/mindspore/lite/tools/common/protobuf_utils.cc index 3bfd73b3a7..2afe0ae022 100644 --- a/mindspore/lite/tools/common/protobuf_utils.cc +++ b/mindspore/lite/tools/common/protobuf_utils.cc @@ -91,7 +91,7 @@ STATUS ReadProtoFromBinaryFile(const char *file, google::protobuf::Message *mess fs.close(); if (!success) { - MS_LOG(ERROR) << "Parse " << file << " failed."; + MS_LOG(DEBUG) << "Parse " << file << " failed."; return RET_ERROR; } diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc index 752b8605e4..32474b8f38 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc @@ -940,7 +940,7 @@ STATUS UpdateTensorDataAndSize(ParamValueLitePtr weight, void *quant_datas, int return RET_OK; } -void GetMaxMinPerchannel(int channels, int one_filter_size, int i, int elem_count, float *raw_datas, +void GetMaxMinPerchannel(int channels, int one_filter_size, int i, int elem_count, const float *raw_datas, bool channel_at_first, float *desired_max, float *desired_min) { float min = FLT_MAX; float max = -FLT_MAX; diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.h b/mindspore/lite/tools/converter/quantizer/quantize_util.h index 52e1adb556..c6ba5c15f4 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.h +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.h @@ -110,7 +110,7 @@ std::vector KMeans(float *data, size_t elem_count, size_t k, size_t epoc STATUS UpdateTensorDataAndSize(ParamValueLitePtr weight, void *quant_datas, int new_size); -void GetMaxMinPerchannel(int channels, int one_filter_size, int i, int elem_count, float *raw_datas, +void GetMaxMinPerchannel(int channels, int one_filter_size, int i, int elem_count, const float *raw_datas, bool channel_at_first, float *desired_max, float *desired_min); template diff --git a/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc index 872052d0d9..eddbd60bcf 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_transform_fusion.cc @@ -226,7 +226,7 @@ void ConvTransformFusion::GenNewConvTensor(const FuncGraphPtr &func_graph, const MS_LOG(ERROR) << "memcpy_s error:" << ret; return; } - new_weight_tensor->set_tensor_addr(temp_weight_data); + new_weight_tensor->SetTensorData(temp_weight_data, new_weight_tensor->tensor_size()); CalNewWeightTensor(conv_node, new_weight_tensor, kernel_num, trans_scale); float *bias_data = nullptr; // conv has bias,bias_flag true