From 362daeb2b6039b8c3e68678d3e32b289413c5e95 Mon Sep 17 00:00:00 2001 From: wangdongxu Date: Mon, 12 Oct 2020 19:05:10 +0800 Subject: [PATCH] fix prelu scalar weight bug --- mindspore/lite/CMakeLists.txt | 4 +- mindspore/lite/src/CMakeLists.txt | 6 +- .../src/runtime/kernel/opencl/cl/prelu.cl | 103 +++++++---- .../kernel/opencl/kernel/convolution.cc | 17 +- .../src/runtime/kernel/opencl/kernel/prelu.cc | 163 +++++++++++------- .../src/runtime/kernel/opencl/kernel/prelu.h | 11 +- .../lite/src/runtime/kernel/opencl/utils.cc | 17 +- 7 files changed, 195 insertions(+), 126 deletions(-) diff --git a/mindspore/lite/CMakeLists.txt b/mindspore/lite/CMakeLists.txt index 0a1438b538..8f4c553e01 100644 --- a/mindspore/lite/CMakeLists.txt +++ b/mindspore/lite/CMakeLists.txt @@ -190,8 +190,8 @@ if (PLATFORM_ARM64) endif () if (BUILD_MINDDATA STREQUAL "lite" OR BUILD_MINDDATA STREQUAL "full") - # TODO: add sentencepiece dependency - #include(${TOP_DIR}/cmake/external_libs/sentencepiece.cmake) + # add sentencepiece dependency + # include(${TOP_DIR}/cmake/external_libs/sentencepiece.cmake) # opencv set(OpenCV_DIR ${TOP_DIR}/third_party/opencv/build) find_package(OpenCV REQUIRED) diff --git a/mindspore/lite/src/CMakeLists.txt b/mindspore/lite/src/CMakeLists.txt index e54d3f0d93..e9e1d43f0c 100644 --- a/mindspore/lite/src/CMakeLists.txt +++ b/mindspore/lite/src/CMakeLists.txt @@ -96,7 +96,7 @@ endif () if ("${CMAKE_BUILD_TYPE}" STREQUAL "Release" AND PLATFORM_ARM) add_custom_command(TARGET mindspore-lite POST_BUILD COMMAND ${ANDROID_NDK}/toolchains/aarch64-linux-android-4.9/prebuilt/linux-x86_64/aarch64-linux-android/bin/strip - ${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.so) + ${CMAKE_BINARY_DIR}/src/libmindspore-lite.so) endif () if ("${CMAKE_BUILD_TYPE}" STREQUAL "Release") @@ -124,10 +124,10 @@ endif () if ("${CMAKE_BUILD_TYPE}" STREQUAL "Release" AND (PLATFORM_ARM64)) add_custom_command(TARGET mindspore-lite-optimize POST_BUILD COMMAND ${ANDROID_NDK}/toolchains/aarch64-linux-android-4.9/prebuilt/linux-x86_64/aarch64-linux-android/bin/strip - ${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite-optimize.so) + ${CMAKE_BINARY_DIR}/src/libmindspore-lite-optimize.so) add_custom_command(TARGET mindspore-lite-fp16 POST_BUILD COMMAND ${ANDROID_NDK}/toolchains/aarch64-linux-android-4.9/prebuilt/linux-x86_64/aarch64-linux-android/bin/strip - ${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite-fp16.so) + ${CMAKE_BINARY_DIR}/src/libmindspore-lite-fp16.so) endif () diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/prelu.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/prelu.cl index df0bcd75e1..d66e84859c 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/prelu.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/prelu.cl @@ -1,41 +1,80 @@ #pragma OPENCL EXTENSION cl_khr_fp16 : enable -#define SLICES 4 -#define UP_DIV(x, y) (((x) + (y) - (1)) / (y)) __constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; -__kernel void PRelu(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape, - __read_only image2d_t alpha, const int data_type, const int bias_dim) { - int H = input_shape.y; - int C = input_shape.w; // channel size - C = UP_DIV(C, SLICES); - if (C == 0 || H == 0) { +#define NHWC4 2 +#define NC4HW4 100 + +__kernel void PRelu_scalar(__read_only image2d_t input, __write_only image2d_t output, float weight, int4 shape, + int data_format) { + int h = get_global_id(0); + int w = get_global_id(1); + int slice = get_global_id(2); + int H = shape.y; + int W = shape.z; + int SLICES = shape.w; + if (h >= H || w >= W || slice >= SLICES) { return; } - int Y = get_global_id(0); // height id - int X = get_global_id(1); // weight id - FLT4 in_c4 = READ_IMAGE(input, smp_zero, (int2)(X, Y)); - FLT4 tmp; - int index = 0; - if (data_type == 1) { // NHWC4 - index = X % C; - } else if (data_type == 2) { // NC4HW4 - index = Y / H; + + int x, y; + if (data_format == 2) { + x = w * SLICES + slice; + y = h; } else { + x = w; + y = slice * H + h; + } + + FLT4 out = READ_IMAGE(input, smp_zero, (int2)(x, y)); + if (out.x < 0) { + out.x *= weight; + } + if (out.y < 0) { + out.y *= weight; + } + if (out.z < 0) { + out.z *= weight; + } + if (out.w < 0) { + out.w *= weight; + } + WRITE_IMAGE(output, (int2)(x, y), out); +} + +__kernel void PRelu_vector(__read_only image2d_t input, __write_only image2d_t output, __global FLT4 *weight_vector, + int4 shape, int data_format) { + int h = get_global_id(0); + int w = get_global_id(1); + int slice = get_global_id(2); + int H = shape.y; + int W = shape.z; + int SLICES = shape.w; + if (h >= H || w >= W || slice >= SLICES) { return; } - if (bias_dim == 1) { - index = 0; - } - FLT4 weight = READ_IMAGE(alpha, smp_zero, (int2)(index, 0)); - FLT4 bias = weight; - if (bias_dim == 1) { - bias.y = weight.x; - bias.z = weight.x; - bias.w = weight.x; - } - tmp.x = in_c4.x > 0.0f ? in_c4.x : in_c4.x * bias.x; - tmp.y = in_c4.y > 0.0f ? in_c4.y : in_c4.y * bias.y; - tmp.z = in_c4.z > 0.0f ? in_c4.z : in_c4.z * bias.z; - tmp.w = in_c4.w > 0.0f ? in_c4.w : in_c4.w * bias.w; - WRITE_IMAGE(output, (int2)(X, Y), tmp); + FLT4 weight = weight_vector[slice]; + + int x, y; + if (data_format == 2) { + x = w * SLICES + slice; + y = h; + } else { + x = w; + y = slice * H + h; + } + + FLT4 out = READ_IMAGE(input, smp_zero, (int2)(x, y)); + if (out.x < 0) { + out.x *= weight.x; + } + if (out.y < 0) { + out.y *= weight.y; + } + if (out.z < 0) { + out.z *= weight.z; + } + if (out.w < 0) { + out.w *= weight.w; + } + WRITE_IMAGE(output, (int2)(x, y), out); } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/convolution.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/convolution.cc index 46121dd703..84cc574606 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/convolution.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/convolution.cc @@ -359,6 +359,8 @@ std::string ConvolutionOpenCLKernel::CodeGenConvolutionNHWC4() { code += "#define padBottom " + std::to_string(padBottom) + "\n"; code += "#define padLeft " + std::to_string(padLeft) + "\n"; code += "#define padRight " + std::to_string(padRight) + "\n"; + code += "#define dilationH " + std::to_string(param->dilation_h_) + "\n"; + code += "#define dilationW " + std::to_string(param->dilation_w_) + "\n"; code += "#define CI_SLICES " + std::to_string(CI_SLICES_) + "\n"; code += "#define CO_SLICES " + std::to_string(CO_SLICES_) + "\n\n"; @@ -398,10 +400,10 @@ std::string ConvolutionOpenCLKernel::CodeGenConvolutionNHWC4() { code += " for (int kh = 0; kh < KH; ++kh)\n" " {\n" - " int ih = kh + oh * strideH - padTop;\n" + " int ih = kh * dilationH + oh * strideH - padTop;\n" " for (int kw = 0; kw < KW; ++kw)\n" " {\n" - " int iw = kw + ow * strideW - padLeft;\n" + " int iw = kw * dilationW + ow * strideW - padLeft;\n" " if (ih >= 0 && ih < IH && iw >= 0 && iw < IW)\n" " {\n" " for (int ci_slice = 0; ci_slice < CI_SLICES; ci_slice++)\n" @@ -491,7 +493,9 @@ std::string ConvolutionOpenCLKernel::CodeGenConvolutionNC4HW4() { code += " #define strideH " + std::to_string(strideH) + "\n"; code += " #define strideW " + std::to_string(strideW) + "\n"; code += " #define padTop " + std::to_string(padTop) + "\n"; - code += " #define padLeft " + std::to_string(padLeft) + "\n\n"; + code += " #define padLeft " + std::to_string(padLeft) + "\n"; + code += " #define dilationH " + std::to_string(param->dilation_h_) + "\n"; + code += " #define dilationW " + std::to_string(param->dilation_w_) + "\n"; code += " if (n_oh >= N_OH || ow >= OW || co_slice >= CO_SLICES) {\n" @@ -513,7 +517,7 @@ std::string ConvolutionOpenCLKernel::CodeGenConvolutionNC4HW4() { "\n" " for (int kh = 0; kh < KH; ++kh)\n" " {\n" - " int ih = kh + oh * strideH - padTop;\n" + " int ih = kh * dilationH + oh * strideH - padTop;\n" " for (int kw = 0; kw < KW; ++kw)\n" " {\n"; @@ -523,7 +527,7 @@ std::string ConvolutionOpenCLKernel::CodeGenConvolutionNC4HW4() { "{\n"; } - code += " int iw0 = kw + (ow + 0) * strideW - padLeft;\n"; + code += " int iw0 = kw * dilationW + (ow + 0) * strideW - padLeft;\n"; if (check_ow) { code += " if (last_is_double)\n" @@ -531,7 +535,7 @@ std::string ConvolutionOpenCLKernel::CodeGenConvolutionNC4HW4() { } code += - " int iw1 = kw + (ow + 1) * strideW - padLeft;\n" + " int iw1 = kw * dilationW + (ow + 1) * strideW - padLeft;\n" " for (int ci_slice = 0; ci_slice < CI_SLICES; ci_slice++)\n" " {\n" " FLT4 in0 = READ_IMAGE(input, smp_zero, (int2)(iw0, (n * CI_SLICES + ci_slice) * IH + ih));\n" @@ -916,4 +920,5 @@ kernel::LiteKernel *OpenCLConvolutionKernelCreator(const std::vector #include -#include #include "src/kernel_registry.h" #include "include/errorcode.h" @@ -36,85 +35,116 @@ namespace mindspore::kernel { void PReluOpenCLKernel::InitBuffer() { auto allocator = ocl_runtime_->GetAllocator(); - int elem_num = in_tensors_[0]->shape().size() == 2 ? in_tensors_[0]->shape()[1] : in_tensors_[0]->shape()[3]; - int elem_num_c4 = UP_DIV(elem_num, C4NUM); - size_t img_dtype = CL_FLOAT; - if (enable_fp16_) { - img_dtype = CL_HALF_FLOAT; - } - std::vector img_size{size_t(elem_num_c4), 1, img_dtype}; - PReluWeight_ = allocator->Malloc(elem_num_c4 * C4NUM * fp_size, img_size); - PReluWeight_ = allocator->MapBuffer(PReluWeight_, CL_MAP_WRITE, nullptr, true); - memset(PReluWeight_, 0x00, elem_num_c4 * C4NUM * fp_size); - if (enable_fp16_) { - if (in_tensors_[1]->data_type() == kNumberTypeFloat32) { - auto PReluWeight_fp16 = reinterpret_cast(PReluWeight_); - auto in_tensor_data_fp32 = reinterpret_cast(in_tensors_[1]->data_c()); - for (int i = 0; i < elem_num; i++) { - PReluWeight_fp16[i] = static_cast(in_tensor_data_fp32[i]); - } + auto weight_tensor = in_tensors_[1]; + if (weight_is_scalar) { + if (weight_tensor->data_type() == kNumberTypeFloat16) { + weight_scalar_ = static_cast(*reinterpret_cast(weight_tensor->data_c())); } else { - memcpy(PReluWeight_, in_tensors_[1]->data_c(), elem_num * fp_size); + weight_scalar_ = *reinterpret_cast(weight_tensor->data_c()); } } else { - if (in_tensors_[1]->data_type() == kNumberTypeFloat16) { - auto PReluWeight_fp32 = reinterpret_cast(PReluWeight_); - auto in_tensor_data_fp16 = reinterpret_cast(in_tensors_[1]->data_c()); - for (int i = 0; i < elem_num; i++) { - PReluWeight_fp32[i] = static_cast(in_tensor_data_fp16[i]); + auto sizeof_FLT = enable_fp16_ ? sizeof(float16_t) : sizeof(float); + size_t weight_size = UP_ROUND(C_, C4NUM) * sizeof_FLT; + weight_vector_ = allocator->Malloc(weight_size); + allocator->MapBuffer(weight_vector_, CL_MAP_WRITE, nullptr, true); + memset(weight_vector_, 0x00, weight_size); + if (weight_tensor->data_type() == kNumberTypeFloat16) { + if (enable_fp16_) { + memcpy(weight_vector_, weight_tensor->data_c(), C_ * sizeof_FLT); + } else { + auto weight_fp32 = reinterpret_cast(weight_vector_); + auto origin_bias_fp16 = reinterpret_cast(weight_tensor->data_c()); + for (int i = 0; i < C_; ++i) { + weight_fp32[i] = static_cast(origin_bias_fp16[i]); + } } } else { - memcpy(PReluWeight_, in_tensors_[1]->data_c(), elem_num * fp_size); + if (enable_fp16_) { + auto weight_fp16 = reinterpret_cast(weight_vector_); + auto origin_bias_fp32 = reinterpret_cast(weight_tensor->data_c()); + for (int i = 0; i < C_; ++i) { + weight_fp16[i] = static_cast(origin_bias_fp32[i]); + } + } else { + memcpy(weight_vector_, weight_tensor->data_c(), C_ * sizeof_FLT); + } } + allocator->UnmapBuffer(weight_vector_); } - allocator->UnmapBuffer(PReluWeight_); } int PReluOpenCLKernel::Init() { - if (in_tensors_[0]->shape().size() != 4) { - MS_LOG(ERROR) << "PRelu only support dim=4, but your dim=" << in_tensors_[0]->shape().size(); + auto input_tensor = in_tensors_[0]; + auto weight_tensor = in_tensors_[1]; + if (input_tensor->shape().size() != 4) { + MS_LOG(ERROR) << "PRelu only support dim=4, but your dim=" << input_tensor->shape().size(); + return RET_ERROR; + } + batch_size_ = input_tensor->Batch(); + C_ = input_tensor->Channel(); + H_ = input_tensor->Height(); + W_ = input_tensor->Width(); + if (input_tensor->GetFormat() != schema::Format_NC4HW4 && input_tensor->GetFormat() != schema::Format_NHWC4) { + MS_LOG(ERROR) << "PRelu only support Format_NC4HW4 and Format_NHWC4"; return RET_ERROR; } - int C_Weight = in_tensors_[1]->shape()[0]; - int C = in_tensors_[0]->shape()[3]; - if (C_Weight != 1 && UP_DIV(C_Weight, C4NUM) != UP_DIV(C, C4NUM)) { + if (batch_size_ != 1) { + MS_LOG(ERROR) << "Init PRelu kernel failed: Unsupported multi-batch."; + return RET_ERROR; + } + auto weight_channel = weight_tensor->shape()[0]; + if (weight_channel != 1 && weight_channel != C_) { MS_LOG(ERROR) << "PRelu weight channel size must be 1 or must be equal with in_teneors channel size, but your weight size is " - << C_Weight << " and your input channel size is " << C; + << weight_channel << " and your input channel size is " << C_; return RET_ERROR; } - for (int i = 0; i < in_tensors_[0]->shape().size(); ++i) { - input_shape_.s[i] = in_tensors_[0]->shape()[i]; + weight_is_scalar = weight_channel == 1; + if (weight_tensor->data_type() != kNumberTypeFloat16 && weight_tensor->data_type() != kNumberTypeFloat32) { + MS_LOG(ERROR) << "PRelu weight must be float32 or float16"; + return RET_ERROR; } + + enable_fp16_ = ocl_runtime_->GetFp16Enable(); + in_ori_format_ = input_tensor->GetFormat(); + out_ori_format_ = out_tensors_[0]->GetFormat(); + input_tensor->SetFormat(op_format_); + out_tensors_[0]->SetFormat(op_format_); + std::set build_options; std::string source = prelu_source; std::string program_name = "PRelu"; - std::string kernel_name = "PRelu"; - enable_fp16_ = ocl_runtime_->GetFp16Enable(); - fp_size = enable_fp16_ ? sizeof(uint16_t) : sizeof(float); - InitBuffer(); + std::string kernel_name = "PRelu_" + std::string(weight_is_scalar ? "scalar" : "vector"); ocl_runtime_->LoadSource(program_name, source); ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options); - in_ori_format_ = in_tensors_[0]->GetFormat(); - in_tensors_[0]->SetFormat(op_format_); - out_ori_format_ = out_tensors_[0]->GetFormat(); - out_tensors_[0]->SetFormat(op_format_); + + InitBuffer(); MS_LOG(DEBUG) << program_name << " init Done!"; return RET_OK; } int PReluOpenCLKernel::Run() { MS_LOG(DEBUG) << op_parameter_->name_ << " Running!"; - std::map data_type{{schema::Format::Format_NHWC4, 1}, {schema::Format::Format_NC4HW4, 2}}; + auto CO_SLICES_ = UP_DIV(C_, C4NUM); + cl_int4 shape = {batch_size_, H_, W_, CO_SLICES_}; + int arg_idx = 0; ocl_runtime_->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->data_c()); ocl_runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->data_c()); - ocl_runtime_->SetKernelArg(kernel_, arg_idx++, input_shape_); - ocl_runtime_->SetKernelArg(kernel_, arg_idx++, PReluWeight_); - ocl_runtime_->SetKernelArg(kernel_, arg_idx++, data_type[op_format_]); - ocl_runtime_->SetKernelArg(kernel_, arg_idx++, reinterpret_cast(in_tensors_[1]->shape()[0])); - std::vector local = {1, 1}; - std::vector global = {static_cast(global_shape_.s[1]), static_cast(global_shape_.s[2])}; + if (weight_is_scalar) { + ocl_runtime_->SetKernelArg(kernel_, arg_idx++, weight_scalar_); + } else { + ocl_runtime_->SetKernelArg(kernel_, arg_idx++, weight_vector_); + } + ocl_runtime_->SetKernelArg(kernel_, arg_idx++, shape); + if (op_format_ == schema::Format_NHWC4) { + ocl_runtime_->SetKernelArg(kernel_, arg_idx++, 2); + } else { // Format_NC4HW4 = 100 + ocl_runtime_->SetKernelArg(kernel_, arg_idx++, 100); + } + + std::vector local = {4, 4, 1}; + std::vector global = {static_cast(H_), static_cast(W_), static_cast(CO_SLICES_)}; auto ret = ocl_runtime_->RunKernel(kernel_, global, local, nullptr); if (ret != RET_OK) { MS_LOG(ERROR) << "Run kernel " << op_parameter_->name_ << " error."; @@ -124,22 +154,26 @@ int PReluOpenCLKernel::Run() { } int PReluOpenCLKernel::GetImageSize(size_t idx, std::vector *img_size) { - size_t img_dtype = CL_FLOAT; - if (enable_fp16_) { - img_dtype = CL_HALF_FLOAT; - } - global_shape_ = input_shape_; - if (op_format_ == schema::Format::Format_NC4HW4) { - global_shape_.s[1] = UP_DIV(input_shape_.s[3], C4NUM) * input_shape_.s[1]; - } else if (op_format_ == schema::Format::Format_NHWC4) { - global_shape_.s[2] = UP_DIV(input_shape_.s[3], C4NUM) * input_shape_.s[2]; + size_t im_dst_x, im_dst_y; + auto CO_SLICES_ = UP_DIV(C_, C4NUM); + if (in_tensors_[0]->GetFormat() == schema::Format_NHWC4) { + if (W_ * CO_SLICES_ <= MAX_IMAGE2D_SIZE) { + { + im_dst_y = batch_size_ * H_; + im_dst_x = W_ * CO_SLICES_; + } + } else { + im_dst_y = W_; + im_dst_x = batch_size_ * H_ * CO_SLICES_; + } } else { - MS_LOG(ERROR) << "op_format_:" << op_format_ << " is do not support!"; - return RET_ERROR; + im_dst_y = batch_size_ * CO_SLICES_ * H_; + im_dst_x = W_; } + size_t img_dtype = enable_fp16_ ? CL_HALF_FLOAT : CL_FLOAT; img_size->clear(); - img_size->push_back(global_shape_.s[2]); - img_size->push_back(global_shape_.s[1]); + img_size->push_back(im_dst_x); + img_size->push_back(im_dst_y); img_size->push_back(img_dtype); return RET_OK; } @@ -152,16 +186,11 @@ kernel::LiteKernel *OpenCLPReluKernelCreator(const std::vector & MS_LOG(ERROR) << "Input data size must be greater than 0, but your size is " << inputs.size(); return nullptr; } - if (inputs[0]->shape()[0] > 1) { - MS_LOG(ERROR) << "Init PRelu kernel failed: Unsupported multi-batch."; - return nullptr; - } auto *kernel = new (std::nothrow) PReluOpenCLKernel(reinterpret_cast(opParameter), inputs, outputs); if (kernel == nullptr) { MS_LOG(ERROR) << "kernel " << opParameter->name_ << "is nullptr."; return nullptr; } - auto ret = kernel->Init(); if (ret != RET_OK) { MS_LOG(ERROR) << "Init PRelu kernel failed!"; diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/prelu.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/prelu.h index 8f0fcc65fb..c4ea418ffb 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/prelu.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/prelu.h @@ -39,11 +39,14 @@ class PReluOpenCLKernel : public OpenCLKernel { private: cl::Kernel kernel_; - void *PReluWeight_; - cl_int4 input_shape_; - cl_int4 global_shape_; - size_t fp_size; bool enable_fp16_{false}; + int batch_size_{}; + int C_{}; + int H_{}; + int W_{}; + void *weight_vector_{nullptr}; + float weight_scalar_{0.f}; + bool weight_is_scalar{false}; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/opencl/utils.cc b/mindspore/lite/src/runtime/kernel/opencl/utils.cc index 9936c2ac42..9d0c084b88 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/utils.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/utils.cc @@ -235,14 +235,16 @@ void PrintTensor(lite::Tensor *tensor, int num, const std::string &out_file) { if (tensor->data_c() == nullptr) { return; } - auto runtime = lite::opencl::OpenCLRuntimeWrapper().GetInstance(); + auto runtime_wrapper = lite::opencl::OpenCLRuntimeWrapper(); + auto runtime = runtime_wrapper.GetInstance(); runtime->SyncCommandQueue(); auto allocator = runtime->GetAllocator(); auto origin_data = tensor->data_c(); - allocator->MapBuffer(origin_data, CL_MAP_READ, nullptr, true); + allocator->MapBuffer(origin_data, CL_MAP_READ | CL_MAP_WRITE, nullptr, true); tensor->SetData(origin_data); + auto Batch = tensor->Batch(); auto Height = tensor->shape().size() == 4 ? tensor->Height() : 1; auto Width = tensor->shape().size() == 4 ? tensor->Width() : 1; auto SLICES = UP_DIV(tensor->Channel(), C4NUM); @@ -250,17 +252,8 @@ void PrintTensor(lite::Tensor *tensor, int num, const std::string &out_file) { auto dtype_size = tensor->data_type() == kNumberTypeFloat16 ? sizeof(cl_half4) : sizeof(cl_float4); auto row_pitch = (Width * SLICES + alignment - 1) / alignment * alignment * dtype_size; auto row_size = Width * SLICES * dtype_size; - std::cout << "tensor->GetFormat() =" << tensor->GetFormat() << "\n"; - std::cout << "Height =" << Height << "\n"; - std::cout << "Width =" << Width << "\n"; - std::cout << "SLICES =" << SLICES << "\n"; - std::cout << "image_alignment =" << alignment << "\n"; - std::cout << "dtype_size =" << dtype_size << "\n"; - std::cout << "row_pitch =" << row_pitch << "\n"; - std::cout << "row_size =" << row_size << "\n"; - std::cout << "tensor->Size() =" << tensor->Size() << "\n"; std::vector data(tensor->Size()); - for (int i = 0; i < Height; ++i) { + for (int i = 0; i < Batch * Height; ++i) { memcpy(static_cast(data.data()) + i * row_size, static_cast(origin_data) + i * row_pitch, row_size); }