From 48a4f31c9b1e74f3f11108dfe3c71a4d4d85c6ad Mon Sep 17 00:00:00 2001 From: Zhenglong Li Date: Fri, 11 Dec 2020 18:01:50 +0800 Subject: [PATCH] Complete development of Dvpp feature which can be deployed on Ascend 310 Device, the function of this operation is decode + resize + center crop, the output is in format of YUV color space. --- mindspore/ccsrc/CMakeLists.txt | 7 + .../ccsrc/minddata/dataset/CMakeLists.txt | 20 + .../ccsrc/minddata/dataset/api/vision.cc | 85 + .../ccsrc/minddata/dataset/include/vision.h | 35 + .../dataset/kernels/image/CMakeLists.txt | 9 +- .../dataset/kernels/image/dvpp/CMakeLists.txt | 6 + .../dvpp/dvpp_decode_resize_crop_jpeg_op.cc | 104 ++ .../dvpp/dvpp_decode_resize_crop_jpeg_op.h | 60 + .../kernels/image/dvpp/utils/AclProcess.cc | 355 +++++ .../kernels/image/dvpp/utils/AclProcess.h | 86 + .../kernels/image/dvpp/utils/CMakeLists.txt | 11 + .../kernels/image/dvpp/utils/CommonDataType.h | 187 +++ .../kernels/image/dvpp/utils/DvppCommon.cc | 1417 +++++++++++++++++ .../kernels/image/dvpp/utils/DvppCommon.h | 220 +++ .../kernels/image/dvpp/utils/ErrorCode.cpp | 51 + .../kernels/image/dvpp/utils/ErrorCode.h | 258 +++ .../image/dvpp/utils/ResourceManager.cc | 136 ++ .../image/dvpp/utils/ResourceManager.h | 88 + .../minddata/dataset/kernels/tensor_op.h | 1 + tests/cxx_st/dataset/test_de.cc | 22 +- 20 files changed, 3154 insertions(+), 4 deletions(-) create mode 100644 mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/CMakeLists.txt create mode 100644 mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_resize_crop_jpeg_op.cc create mode 100644 mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_resize_crop_jpeg_op.h create mode 100644 mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/AclProcess.cc create mode 100644 mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/AclProcess.h create mode 100644 mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/CMakeLists.txt create mode 100644 mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/CommonDataType.h create mode 100644 mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/DvppCommon.cc create mode 100644 mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/DvppCommon.h create mode 100644 mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ErrorCode.cpp create mode 100644 mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ErrorCode.h create mode 100644 mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ResourceManager.cc create mode 100644 mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ResourceManager.h diff --git a/mindspore/ccsrc/CMakeLists.txt b/mindspore/ccsrc/CMakeLists.txt index e48e5a7e80..08602f778f 100644 --- a/mindspore/ccsrc/CMakeLists.txt +++ b/mindspore/ccsrc/CMakeLists.txt @@ -3,6 +3,13 @@ include_directories(${CMAKE_SOURCE_DIR}/mindspore/core) include_directories(${CMAKE_CURRENT_SOURCE_DIR}) include_directories(${CMAKE_BINARY_DIR}) +if (ENABLE_ACL) + set(ASCEND_PATH /usr/local/Ascend) + include_directories(${ASCEND_PATH}/acllib/include) + link_directories(${ASCEND_PATH}/acllib/lib64/) + find_library(ascendcl acl_dvpp ${ASCEND_PATH}/acllib/lib64) +endif () + if (NOT(CMAKE_SYSTEM_NAME MATCHES "Darwin")) link_directories(${CMAKE_SOURCE_DIR}/build/mindspore/graphengine) endif () diff --git a/mindspore/ccsrc/minddata/dataset/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/CMakeLists.txt index 64461cd175..2c80f15437 100644 --- a/mindspore/ccsrc/minddata/dataset/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/CMakeLists.txt @@ -14,6 +14,10 @@ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-attributes") if (${CMAKE_SYSTEM_NAME} MATCHES "Windows") add_definitions(-D _CRT_RAND_S) endif () +if (ENABLE_ACL) + add_definitions(-D ENABLE_ACL) + message(STATUS "ACL module is enabled") +endif () if (ENABLE_GPUQUE) add_definitions(-D ENABLE_GPUQUE) message(STATUS "GPU queue is enabled") @@ -91,6 +95,10 @@ if (NOT(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")) add_dependencies(text-kernels core) endif () +if (ENABLE_ACL) + add_dependencies(kernels-dvpp-image core dvpp-utils) +endif () + if (ENABLE_PYTHON) add_dependencies(APItoPython core) endif () @@ -140,6 +148,13 @@ if (NOT(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")) ) endif () +if (ENABLE_ACL) + set(submodules + ${submodules} + $ + $) +endif () + if (ENABLE_PYTHON) set(submodules ${submodules} @@ -163,6 +178,11 @@ endif () ################# Link with external libraries ######################## target_link_libraries(_c_dataengine PRIVATE mindspore mindspore_gvar) + +if (ENABLE_ACL) + target_link_libraries(_c_dataengine PRIVATE ascendcl acl_dvpp) +endif () + if (${CMAKE_SYSTEM_NAME} MATCHES "Windows") if (ENABLE_PYTHON) target_link_libraries(_c_dataengine PRIVATE mindspore::pybind11_module ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}) diff --git a/mindspore/ccsrc/minddata/dataset/api/vision.cc b/mindspore/ccsrc/minddata/dataset/api/vision.cc index 13e40143f7..edfc551157 100644 --- a/mindspore/ccsrc/minddata/dataset/api/vision.cc +++ b/mindspore/ccsrc/minddata/dataset/api/vision.cc @@ -31,6 +31,9 @@ #include "minddata/dataset/kernels/image/cut_out_op.h" #endif #include "minddata/dataset/kernels/image/decode_op.h" +#ifdef ENABLE_ACL +#include "minddata/dataset/kernels/image/dvpp/dvpp_decode_resize_crop_jpeg_op.h" +#endif #ifndef ENABLE_ANDROID #include "minddata/dataset/kernels/image/equalize_op.h" #include "minddata/dataset/kernels/image/hwc_to_chw_op.h" @@ -132,6 +135,16 @@ std::shared_ptr Decode(bool rgb) { return op->ValidateParams() ? op : nullptr; } +#ifdef ENABLE_ACL +// Function to create DvppDecodeResizeCropOperation. +std::shared_ptr DvppDecodeResizeCropJpeg(std::vector crop, + std::vector resize) { + auto op = std::make_shared(crop, resize); + // Input validation + return op->ValidateParams() ? op : nullptr; +} +#endif + // Function to create EqualizeOperation. std::shared_ptr Equalize() { auto op = std::make_shared(); @@ -647,6 +660,78 @@ Status MixUpBatchOperation::ValidateParams() { std::shared_ptr MixUpBatchOperation::Build() { return std::make_shared(alpha_); } #endif + +#ifdef ENABLE_ACL +// DvppDecodeResizeCropOperation +DvppDecodeResizeCropOperation::DvppDecodeResizeCropOperation(const std::vector &crop, + const std::vector &resize) + : crop_(crop), resize_(resize) {} + +Status DvppDecodeResizeCropOperation::ValidateParams() { + // size + if (crop_.empty() || crop_.size() > 2) { + std::string err_msg = "DvppDecodeResizeCropJpeg: crop size must be a vector of one or two elements, got: " + + std::to_string(crop_.size()); + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + if (resize_.empty() || resize_.size() > 2) { + std::string err_msg = "DvppDecodeResizeCropJpeg: resize size must be a vector of one or two elements, got: " + + std::to_string(resize_.size()); + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + if (crop_.size() < resize_.size()) { + if (crop_[0] >= MIN(resize_[0], resize_[1])) { + std::string err_msg = "crop size must be smaller than resize size"; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + } + if (crop_.size() > resize_.size()) { + if (MAX(crop_[0], crop_[1]) >= resize_[0]) { + std::string err_msg = "crop size must be smaller than resize size"; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + } + if (crop_.size() == resize_.size()) { + for (int32_t i = 0; i < crop_.size(); ++i) { + if (crop_[i] >= resize_[i]) { + std::string err_msg = "crop size must be smaller than resize size"; + MS_LOG(ERROR) << err_msg; + RETURN_STATUS_SYNTAX_ERROR(err_msg); + } + } + } + return Status::OK(); +} + +std::shared_ptr DvppDecodeResizeCropOperation::Build() { + // If size is a single value, the smaller edge of the image will be + // resized to this value with the same image aspect ratio. + uint32_t cropHeight, cropWidth, resizeHeight, resizeWidth; + if (crop_.size() == 1) { + cropHeight = crop_[0]; + cropWidth = crop_[0]; + } else { + cropHeight = crop_[0]; + cropWidth = crop_[1]; + } + // User specified the width value. + if (resize_.size() == 1) { + resizeHeight = resize_[0]; + resizeWidth = resize_[0]; + } else { + resizeHeight = resize_[0]; + resizeWidth = resize_[1]; + } + std::shared_ptr tensor_op = + std::make_shared(cropHeight, cropWidth, resizeHeight, resizeWidth); + return tensor_op; +} +#endif + // NormalizeOperation NormalizeOperation::NormalizeOperation(std::vector mean, std::vector std) : mean_(mean), std_(std) {} diff --git a/mindspore/ccsrc/minddata/dataset/include/vision.h b/mindspore/ccsrc/minddata/dataset/include/vision.h index 1e4527166c..f1805dc1c2 100644 --- a/mindspore/ccsrc/minddata/dataset/include/vision.h +++ b/mindspore/ccsrc/minddata/dataset/include/vision.h @@ -38,6 +38,7 @@ constexpr char kAutoContrastOperation[] = "AutoContrast"; constexpr char kBoundingBoxAugmentOperation[] = "BoundingBoxAugment"; constexpr char kCutMixBatchOperation[] = "CutMixBatch"; constexpr char kCutOutOperation[] = "CutOut"; +constexpr char kDvppDecodeResizeCropOperation[] = "DvppDecodeResizeCrop"; constexpr char kEqualizeOperation[] = "Equalize"; constexpr char kHwcToChwOperation[] = "HwcToChw"; constexpr char kInvertOperation[] = "Invert"; @@ -75,6 +76,7 @@ class AutoContrastOperation; class BoundingBoxAugmentOperation; class CutMixBatchOperation; class CutOutOperation; +class DvppDecodeResizeCropOperation; class EqualizeOperation; class HwcToChwOperation; class InvertOperation; @@ -140,6 +142,22 @@ std::shared_ptr CutMixBatch(ImageBatchFormat image_batch_f /// \return Shared pointer to the current TensorOp std::shared_ptr CutOut(int32_t length, int32_t num_patches = 1); +/// \brief Function to create a DvppDecodeResizeCropJpeg TensorOperation. +/// \notes Tensor operation to decode and resize JPEG image using the simulation algorithm of Ascend series +/// chip DVPP module. It is recommended to use this algorithm in the following scenarios: +/// When training, the DVPP of the Ascend chip is not used, +/// and the DVPP of the Ascend chip is used during inference, +/// and the accuracy of inference is lower than the accuracy of training; +/// and the input image size should be in range [16*16, 4096*4096]. +/// Only images with an even resolution can be output. The output of odd resolution is not supported. +/// \param[in] crop vector representing the output size of the final crop image. +/// \param[in] size A vector representing the output size of the intermediate resized image. +/// If size is a single value, smaller edge of the image will be resized to this value with +/// the same image aspect ratio. If size has 2 values, it should be (height, width). +/// \return Shared pointer to the current TensorOperation. +std::shared_ptr DvppDecodeResizeCropJpeg(std::vector crop = {224, 224}, + std::vector resize = {256, 256}); + /// \brief Function to create a Equalize TensorOperation. /// \notes Apply histogram equalization on input image. /// \return Shared pointer to the current TensorOperation. @@ -538,6 +556,23 @@ class CutOutOperation : public TensorOperation { ImageBatchFormat image_batch_format_; }; +class DvppDecodeResizeCropOperation : public TensorOperation { + public: + explicit DvppDecodeResizeCropOperation(const std::vector &crop, const std::vector &resize); + + ~DvppDecodeResizeCropOperation() = default; + + std::shared_ptr Build() override; + + Status ValidateParams() override; + + std::string Name() const override { return kDvppDecodeResizeCropOperation; } + + private: + std::vector crop_; + std::vector resize_; +}; + class EqualizeOperation : public TensorOperation { public: ~EqualizeOperation() = default; diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt index 763bb71327..b916a687ea 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/CMakeLists.txt @@ -2,6 +2,9 @@ file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc" set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) add_subdirectory(soft_dvpp) add_subdirectory(lite_cv) +if (ENABLE_ACL) + add_subdirectory(dvpp) +endif () add_library(kernels-image OBJECT affine_op.cc auto_contrast_op.cc @@ -50,4 +53,8 @@ add_library(kernels-image OBJECT random_resize_with_bbox_op.cc random_color_op.cc ) -add_dependencies(kernels-image kernels-soft-dvpp-image ) +if (ENABLE_ACL) + add_dependencies(kernels-image kernels-soft-dvpp-image kernels-dvpp-image) +else() + add_dependencies(kernels-image kernels-soft-dvpp-image) +endif () diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/CMakeLists.txt new file mode 100644 index 0000000000..e531e4d9ee --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/CMakeLists.txt @@ -0,0 +1,6 @@ +file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) +add_subdirectory(utils) +add_library(kernels-dvpp-image OBJECT + dvpp_decode_resize_crop_jpeg_op.cc) +add_dependencies(kernels-dvpp-image dvpp-utils) \ No newline at end of file diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_resize_crop_jpeg_op.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_resize_crop_jpeg_op.cc new file mode 100644 index 0000000000..a74f3d6479 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_resize_crop_jpeg_op.cc @@ -0,0 +1,104 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include "minddata/dataset/kernels/image/dvpp/utils/AclProcess.h" +#include "minddata/dataset/core/cv_tensor.h" +#include "minddata/dataset/kernels/image/image_utils.h" +#include "minddata/dataset/kernels/image/dvpp/utils/CommonDataType.h" +#include "minddata/dataset/core/data_type.h" +#include "minddata/dataset/kernels/image/dvpp/dvpp_decode_resize_crop_jpeg_op.h" + +namespace mindspore { +namespace dataset { +Status DvppDecodeResizeCropJpegOp::Compute(const std::shared_ptr &input, std::shared_ptr *output) { + IO_CHECK(input, output); + if (!IsNonEmptyJPEG(input)) { + RETURN_STATUS_UNEXPECTED("SoftDvppDecodeReiszeJpegOp only support process jpeg image."); + } + try { + CHECK_FAIL_RETURN_UNEXPECTED(input->GetBuffer() != nullptr, "The input image buffer is empty."); + unsigned char *buffer = const_cast(input->GetBuffer()); + RawData imageInfo; + uint32_t filesize = input->SizeInBytes(); + imageInfo.lenOfByte = filesize; + imageInfo.data = std::make_shared(); + imageInfo.data.reset(new uint8_t[filesize], std::default_delete()); + memcpy_s(imageInfo.data.get(), filesize, buffer, filesize); + // First part end, whose function is to transform data from a Tensor to imageinfo data structure which can be + // applied on device + ResourceInfo resource; + resource.aclConfigPath = ""; + resource.deviceIds.insert(0); // 0 is device id which should be refined later! + std::shared_ptr instance = ResourceManager::GetInstance(); + APP_ERROR ret = instance->InitResource(resource); + if (ret != APP_ERR_OK) { + instance->Release(); + std::string error = "Error in Init D-chip:" + std::to_string(ret); + RETURN_STATUS_UNEXPECTED(error); + } + int deviceId = *(resource.deviceIds.begin()); + aclrtContext context = instance->GetContext(deviceId); + // Second part end where we initialize the resource of D chip and set up all configures + AclProcess process(resized_width_, resized_height_, crop_width_, crop_height_, context); + process.set_mode(true); + ret = process.InitResource(); + if (ret != APP_ERR_OK) { + instance->Release(); + std::string error = "Error in Init resource:" + std::to_string(ret); + RETURN_STATUS_UNEXPECTED(error); + } + ret = process.Process(imageInfo); + if (ret != APP_ERR_OK) { + instance->Release(); + std::string error = "Error in dvpp processing:" + std::to_string(ret); + RETURN_STATUS_UNEXPECTED(error); + } + // Third part end where we execute the core function of dvpp + auto data = std::static_pointer_cast(process.Get_Memory_Data()); + unsigned char *ret_ptr = data.get(); + std::shared_ptr(DvppDataInfo) CropOut = process.Get_Device_Memory_Data(); + dsize_t dvpp_length = CropOut->dataSize; + const TensorShape dvpp_shape({dvpp_length, 1, 1}); + const DataType dvpp_data_type(DataType::DE_UINT8); + mindspore::dataset::Tensor::CreateFromMemory(dvpp_shape, dvpp_data_type, ret_ptr, output); + if (!((*output)->HasData())) { + std::string error = "[ERROR] Fail to get the Output result from memory!"; + RETURN_STATUS_UNEXPECTED(error); + } + process.device_memory_release(); + // Last part end where we transform the processed data into a tensor which can be applied in later units. + } catch (const cv::Exception &e) { + std::string error = "[ERROR] Fail in DvppDecodeResizeCropJpegOp:" + std::string(e.what()); + RETURN_STATUS_UNEXPECTED(error); + } + return Status::OK(); +} + +Status DvppDecodeResizeCropJpegOp::OutputShape(const std::vector &inputs, + std::vector &outputs) { + RETURN_IF_NOT_OK(TensorOp::OutputShape(inputs, outputs)); + outputs.clear(); + TensorShape out({-1, 1, 1}); // we don't know what is output image size, but we know it should be 3 channels + if (inputs[0].Rank() == 1) outputs.emplace_back(out); + if (!outputs.empty()) return Status::OK(); + return Status(StatusCode::kUnexpectedError, "Input has a wrong shape"); +} + +} // namespace dataset +} // namespace mindspore diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_resize_crop_jpeg_op.h b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_resize_crop_jpeg_op.h new file mode 100644 index 0000000000..aae9c77f6d --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/dvpp_decode_resize_crop_jpeg_op.h @@ -0,0 +1,60 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_DVPP_DECODE_RESIZE_CROP_JPEG_OP_H +#define MINDSPORE_DVPP_DECODE_RESIZE_CROP_JPEG_OP_H + +#include +#include +#include + +#include "minddata/dataset/core/tensor.h" +#include "minddata/dataset/kernels/tensor_op.h" +#include "minddata/dataset/util/status.h" +#include "minddata/dataset/core/data_type.h" +#include "mindspore/core/utils/log_adapter.h" +#include "minddata/dataset/kernels/image/dvpp/utils/ResourceManager.h" +#include "minddata/dataset/kernels/image/dvpp/utils/ErrorCode.h" +#include "acl/acl.h" + +namespace mindspore { +namespace dataset { +class DvppDecodeResizeCropJpegOp : public TensorOp { + public: + DvppDecodeResizeCropJpegOp(int32_t crop_height, int32_t crop_width, int32_t resized_height, int32_t resized_width) + : crop_height_(crop_height), + crop_width_(crop_width), + resized_height_(resized_height), + resized_width_(resized_width) {} + + /// \brief Destructor + ~DvppDecodeResizeCropJpegOp() = default; + + Status Compute(const std::shared_ptr &input, std::shared_ptr *output) override; + Status OutputShape(const std::vector &inputs, std::vector &outputs) override; + + std::string Name() const override { return kDvppDecodeResizeCropJpegOp; } + + private: + int32_t crop_height_; + int32_t crop_width_; + int32_t resized_height_; + int32_t resized_width_; +}; +} // namespace dataset +} // namespace mindspore + +#endif // MINDSPORE_DVPP_DECODE_RESIZE_CROP_JPEG_OP_H diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/AclProcess.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/AclProcess.cc new file mode 100644 index 0000000000..9edd5d0708 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/AclProcess.cc @@ -0,0 +1,355 @@ +/* + * Copyright (c) 2020.Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "AclProcess.h" +#include +#include +#include + +namespace { +const int BUFFER_SIZE = 2048; +const mode_t DEFAULT_FILE_PERMISSION = 0077; +} // namespace + +mode_t SetFileDefaultUmask() { return umask(DEFAULT_FILE_PERMISSION); } + +/* + * @description: Constructor + * @param: resizeWidth specifies the resized width + * @param: resizeHeight specifies the resized hegiht + * @param: stream is used to maintain the execution order of operations + * @param: context is used to manage the life cycle of objects + * @param: dvppCommon is a class for decoding and resizing + */ +AclProcess::AclProcess(uint32_t resizeWidth, uint32_t resizeHeight, uint32_t cropWidth, uint32_t cropHeight, + aclrtContext context, aclrtStream stream, std::shared_ptr dvppCommon) + : resizeWidth_(resizeWidth), + resizeHeight_(resizeHeight), + cropWidth_(cropWidth), + cropHeight_(cropHeight), + context_(context), + stream_(stream), + dvppCommon_(dvppCommon) { + repeat_ = true; +} + +/* + * @description: Release AclProcess resources + * @return: aclError which is error code of ACL API + */ +APP_ERROR AclProcess::Release() { + // Release objects resource + APP_ERROR ret = dvppCommon_->DeInit(); + dvppCommon_->ReleaseDvppBuffer(); + + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to deinitialize dvppCommon_, ret = " << ret; + return ret; + } + MS_LOG(INFO) << "dvppCommon_ object deinitialized successfully"; + dvppCommon_.reset(); + + // Release stream + if (stream_ != nullptr) { + ret = aclrtDestroyStream(stream_); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to destroy stream, ret = " << ret; + stream_ = nullptr; + return ret; + } + stream_ = nullptr; + } + MS_LOG(INFO) << "The stream is destroyed successfully"; + return APP_ERR_OK; +} + +/* + * @description: Initialize DvppCommon object + * @return: aclError which is error code of ACL API + */ +APP_ERROR AclProcess::InitModule() { + // Create Dvpp JpegD object + dvppCommon_ = std::make_shared(stream_); + if (dvppCommon_ == nullptr) { + MS_LOG(ERROR) << "Failed to create dvppCommon_ object"; + return APP_ERR_COMM_INIT_FAIL; + } + MS_LOG(INFO) << "DvppCommon object created successfully"; + APP_ERROR ret = dvppCommon_->Init(); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to initialize dvppCommon_ object, ret = " << ret; + return ret; + } + MS_LOG(INFO) << "DvppCommon object initialized successfully"; + return APP_ERR_OK; +} + +/* + * @description: Initialize AclProcess resources + * @return: aclError which is error code of ACL API + */ +APP_ERROR AclProcess::InitResource() { + APP_ERROR ret = aclrtSetCurrentContext(context_); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to get ACL context, ret = " << ret; + return ret; + } + MS_LOG(INFO) << "The context is created successfully"; + ret = aclrtCreateStream(&stream_); // Create stream for application + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to create ACL stream, ret = " << ret; + return ret; + } + MS_LOG(INFO) << "The stream is created successfully"; + // Initialize dvpp module + if (InitModule() != APP_ERR_OK) { + return APP_ERR_COMM_INIT_FAIL; + } + return APP_ERR_OK; +} + +/* + * @description: Read image files, and perform decoding and scaling + * @param: imageFile specifies the image path to be processed + * @return: aclError which is error code of ACL API + */ +APP_ERROR AclProcess::Preprocess(RawData &ImageInfo) { + // Decode process + APP_ERROR ret = dvppCommon_->CombineJpegdProcess(ImageInfo, PIXEL_FORMAT_YUV_SEMIPLANAR_420, true); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to process decode, ret = " << ret << "."; + return ret; + } + // Get output of decoded jpeg image + std::shared_ptr decodeOutData = dvppCommon_->GetDecodedImage(); + if (decodeOutData == nullptr) { + MS_LOG(ERROR) << "Decode output buffer is null."; + return APP_ERR_COMM_INVALID_POINTER; + } + // Define output of resize jpeg image + DvppDataInfo resizeOut; + resizeOut.width = resizeWidth_; + resizeOut.height = resizeHeight_; + resizeOut.format = PIXEL_FORMAT_YUV_SEMIPLANAR_420; + // Run resize application function + ret = dvppCommon_->CombineResizeProcess(*(decodeOutData.get()), resizeOut, true); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to process resize, ret = " << ret << "."; + return ret; + } + // Get output of resize jpeg image + std::shared_ptr resizeOutData = dvppCommon_->GetResizedImage(); + if (resizeOutData == nullptr) { + MS_LOG(ERROR) << "resize output buffer is null."; + return APP_ERR_COMM_INVALID_POINTER; + } + // Define output of crop jpeg image + DvppDataInfo cropOut; + cropOut.width = cropWidth_; + cropOut.height = cropHeight_; + cropOut.format = PIXEL_FORMAT_YUV_SEMIPLANAR_420; + // Define input of crop jpeg image + DvppCropInputInfo cropInfo; + cropInfo.dataInfo = *(resizeOutData.get()); + // Define crop parameters + CropRoiConfig cropCfg; + CropConfigFilter(cropCfg, cropInfo); + ret = dvppCommon_->CombineCropProcess(cropInfo, cropOut, true); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to process center crop, ret = " << ret << "."; + return ret; + } + return APP_ERR_OK; +} + +/* + * @description: Decode and scale the picture, and write the result to a file + * @param: imageFile specifies the image path to be processed + * @return: aclError which is error code of ACL API + */ +APP_ERROR AclProcess::Process(RawData &ImageInfo) { + struct timeval begin = {0}; + struct timeval end = {0}; + gettimeofday(&begin, nullptr); + // deal with image + APP_ERROR ret = Preprocess(ImageInfo); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to preprocess, ret = " << ret; + return ret; + } + gettimeofday(&end, nullptr); + // Calculate the time cost of preprocess + const double costMs = SEC2MS * (end.tv_sec - begin.tv_sec) + (end.tv_usec - begin.tv_usec) / SEC2MS; + const double fps = 1 * SEC2MS / costMs; + MS_LOG(INFO) << "[dvpp Delay] cost: " << costMs << "ms\tfps: " << fps; + // Get output of resize module + std::shared_ptr CropOutData = dvppCommon_->GetCropedImage(); + if (CropOutData->dataSize == 0) { + MS_LOG(ERROR) << "CropOutData return NULL"; + return APP_ERR_COMM_INVALID_POINTER; + } + // Alloc host memory for the inference output according to the size of output + void *resHostBuf = nullptr; + ret = aclrtMallocHost(&resHostBuf, CropOutData->dataSize); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to allocate memory from host ret = " << ret; + return ret; + } + std::shared_ptr outBuf(resHostBuf, aclrtFreeHost); + processedInfo_ = outBuf; + // Memcpy the output data from device to host + ret = aclrtMemcpy(outBuf.get(), CropOutData->dataSize, CropOutData->data, CropOutData->dataSize, + ACL_MEMCPY_DEVICE_TO_HOST); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to copy memory from device to host, ret = " << ret; + return ret; + } + return APP_ERR_OK; +} + +/* + * @description: Rename the image for saving + * @Param: primary name of image + * @return: aclError which is error code of ACL API + */ +APP_ERROR AclProcess::RenameFile(std::string &filename) { + std::string delimiter = "/"; + size_t pos = 0; + std::string token; + while ((pos = filename.find(delimiter)) != std::string::npos) { + token = filename.substr(0, pos); + filename.erase(0, pos + delimiter.length()); + } + delimiter = "."; + pos = filename.find(delimiter); + filename = filename.substr(0, pos); + if (filename.length() == 0) { + return APP_ERR_COMM_WRITE_FAIL; + } + return APP_ERR_OK; +} + +/* + * @description: Write result image to file + * @param: resultSize specifies the size of the result image + * @param: outBuf specifies the memory on the host to save the result image + * @return: aclError which is error code of ACL API + */ +APP_ERROR AclProcess::WriteResult(uint32_t resultSize, std::shared_ptr outBuf, std::string filename) { + std::string resultPathName = "result"; + // Create result directory when it does not exist + if (access(resultPathName.c_str(), 0) != 0) { + int ret = mkdir(resultPathName.c_str(), S_IRUSR | S_IWUSR | S_IXUSR); // for linux + if (ret != 0) { + MS_LOG(ERROR) << "Failed to create result directory: " << resultPathName << ", ret = " << ret; + return ret; + } + } + APP_ERROR ret = RenameFile(filename); + if (ret != 0) { + MS_LOG(ERROR) << "Failed to rename file: " << resultPathName << ", ret = " << ret; + return ret; + } + resultPathName = resultPathName + "/" + filename + ".bin"; + SetFileDefaultUmask(); + FILE *fp = fopen(resultPathName.c_str(), "wb"); + if (fp == nullptr) { + MS_LOG(ERROR) << "Failed to open file"; + return APP_ERR_COMM_OPEN_FAIL; + } + uint32_t result = fwrite(outBuf.get(), 1, resultSize, fp); + if (result != resultSize) { + MS_LOG(ERROR) << "Failed to write file"; + return APP_ERR_COMM_WRITE_FAIL; + } + MS_LOG(INFO) << "Write result to file successfully"; + // After write info onto desk, release the memory on device + dvppCommon_->ReleaseDvppBuffer(); + uint32_t ff = fflush(fp); + if (ff != 0) { + MS_LOG(ERROR) << "Failed to fflush file"; + return APP_ERR_COMM_DESTORY_FAIL; + } + uint32_t fc = fclose(fp); + if (fc != 0) { + MS_LOG(ERROR) << "Failed to fclose file"; + return APP_ERR_COMM_DESTORY_FAIL; + } + return APP_ERR_OK; +} + +void AclProcess::YUV420TOYUV444(unsigned char *inputBuffer, unsigned char *outputBuffer, int w, int h) { + unsigned char *srcY = nullptr, *srcU = nullptr, *srcV = nullptr; + unsigned char *desY = nullptr, *desU = nullptr, *desV = nullptr; + srcY = inputBuffer; // Y + if (srcY == nullptr) std::cout << "Failure pointer transfer!"; + srcU = srcY + w * h; // U + srcV = srcU + w * h / 4; + ; // V + + desY = outputBuffer; + desU = desY + w * h; + desV = desU + w * h; + memcpy(desY, srcY, w * h * sizeof(unsigned char)); + for (int i = 0; i < h; i += 2) { //行 + for (int j = 0; j < w; j += 2) { //列 + // U + desU[i * w + j] = srcU[i / 2 * w / 2 + j / 2]; + desU[i * w + j + 1] = srcU[i / 2 * w / 2 + j / 2]; + desU[(i + 1) * w + j] = srcU[i / 2 * w / 2 + j / 2]; + desU[(i + 1) * w + j + 1] = srcU[i / 2 * w / 2 + j / 2]; + // V + desV[i * w + j] = srcV[i / 2 * w / 2 + j / 2]; + desV[i * w + j + 1] = srcV[i / 2 * w / 2 + j / 2]; + desV[(i + 1) * w + j] = srcV[i / 2 * w / 2 + j / 2]; + desV[(i + 1) * w + j + 1] = srcV[i / 2 * w / 2 + j / 2]; + } + } +} + +void AclProcess::CropConfigFilter(CropRoiConfig &cfg, DvppCropInputInfo &cropinfo) { + cfg.up = (resizeHeight_ - cropHeight_) / 2; + if (cfg.up % 2 != 0) { + cfg.up++; + } + cfg.down = resizeHeight_ - (resizeHeight_ - cropHeight_) / 2; + if (cfg.down % 2 == 0) { + cfg.down--; + } + cfg.left = (resizeWidth_ - cropWidth_) / 2; + if (cfg.left % 2 != 0) { + cfg.left++; + } + cfg.right = resizeWidth_ - (resizeWidth_ - cropWidth_) / 2; + if (cfg.right % 2 == 0) { + cfg.right--; + } + cropinfo.roi = cfg; +} + +/* + * @description: Obtain result data of memory + * @param: processed_data is result data info pointer + * @return: Address of data in the memory + */ +std::shared_ptr AclProcess::Get_Memory_Data() { return processedInfo_; } + +std::shared_ptr AclProcess::Get_Device_Memory_Data() { return dvppCommon_->GetCropedImage(); } + +void AclProcess::set_mode(bool flag) { repeat_ = flag; } + +bool AclProcess::get_mode() { return repeat_; } + +void AclProcess::device_memory_release() { dvppCommon_->ReleaseDvppBuffer(); } \ No newline at end of file diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/AclProcess.h b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/AclProcess.h new file mode 100644 index 0000000000..e4b0e1ef3c --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/AclProcess.h @@ -0,0 +1,86 @@ +#include +/* + * Copyright (c) 2020.Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ACLMANAGER_H +#define ACLMANAGER_H + +#include +#include +#include +#include +#include +#include "acl/acl.h" +#include "CommonDataType.h" +#include "mindspore/core/utils/log_adapter.h" +#include "ErrorCode.h" +#include "DvppCommon.h" +#include +#include +#include +#include + +mode_t SetFileDefaultUmask(); + +class AclProcess { + public: + AclProcess(uint32_t resizeWidth, uint32_t resizeHeight, uint32_t cropWidth, uint32_t cropHeight, aclrtContext context, + aclrtStream stream = nullptr, std::shared_ptr dvppCommon = nullptr); + + ~AclProcess(){}; + + // Release all the resource + APP_ERROR Release(); + // Create resource for this sample + APP_ERROR InitResource(); + // Process the result + APP_ERROR Process(RawData &ImageInfo); + // API for access memory + std::shared_ptr Get_Memory_Data(); + // API for access device memory + std::shared_ptr Get_Device_Memory_Data(); + // change output method + void set_mode(bool flag); + // Get the mode of Acl process + bool get_mode(); + // Save the result + APP_ERROR WriteResult(uint32_t fileSize, std::shared_ptr outBuf, std::string filename); + // Color space reform + void YUV420TOYUV444(unsigned char *inputBuffer, unsigned char *outputBuffer, int w, int h); + // Crop definition + void CropConfigFilter(CropRoiConfig &cfg, DvppCropInputInfo &cropinfo); + // D-chip memory release + void device_memory_release(); + + private: + // Initialize the modules used by this sample + APP_ERROR InitModule(); + // Preprocess the input image + APP_ERROR Preprocess(RawData &ImageInfo); + // Filename process + APP_ERROR RenameFile(std::string &filename); + + aclrtContext context_; + aclrtStream stream_; + std::shared_ptr dvppCommon_; // dvpp object + std::shared_ptr processedInfo_; // processed data + uint32_t resizeWidth_; // dvpp resize width + uint32_t resizeHeight_; // dvpp resize height + uint32_t cropWidth_; // dvpp crop width + uint32_t cropHeight_; // dvpp crop height + bool repeat_; // Repeatly process image or not +}; + +#endif diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/CMakeLists.txt b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/CMakeLists.txt new file mode 100644 index 0000000000..2cd7ccce49 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/CMakeLists.txt @@ -0,0 +1,11 @@ +file(GLOB_RECURSE _CURRENT_SRC_FILES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "*.cc") +set_property(SOURCE ${_CURRENT_SRC_FILES} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_MD) + +add_definitions(-DENABLE_DVPP_INTERFACE) + +add_library(dvpp-utils OBJECT + AclProcess.cc + DvppCommon.cc + ErrorCode.cpp + ResourceManager.cc + ) diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/CommonDataType.h b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/CommonDataType.h new file mode 100644 index 0000000000..2aefe99f79 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/CommonDataType.h @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2020.Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMMONDATATYPE_H +#define COMMONDATATYPE_H +#ifndef ENABLE_DVPP_INTERFACE +#define ENABLE_DVPP_INTERFACE +#endif +#include +#include +#include +#include +#include "acl/acl.h" +#include "acl/ops/acl_dvpp.h" + +#define DVPP_ALIGN_UP(x, align) ((((x) + ((align)-1)) / (align)) * (align)) + +const uint32_t VIDEO_H264 = 0; +const uint32_t VIDEO_H265 = 1; + +const float SEC2MS = 1000.0; +const uint32_t VIDEO_PROCESS_THREAD = 16; +const int YUV_BGR_SIZE_CONVERT_3 = 3; +const int YUV_BGR_SIZE_CONVERT_2 = 2; +const int DVPP_JPEG_OFFSET = 8; +const int VPC_WIDTH_ALIGN = 16; +const int VPC_HEIGHT_ALIGN = 2; +const int JPEG_WIDTH_ALIGN = 128; +const int JPEG_HEIGHT_ALIGN = 16; +const int VPC_OFFSET_ALIGN = 2; + +// Tensor Descriptor +struct Tensor { + aclDataType dataType; // Tensor data type + int numDim; // Number of dimensions of Tensor + std::vector dims; // Dimension vector + aclFormat format; // Format of tensor, e.g. ND, NCHW, NC1HWC0 + std::string name; // Name of tensor +}; + +// Data type of tensor +enum OpAttrType { + BOOL = 0, + INT = 1, + FLOAT = 2, + STRING = 3, + LIST_BOOL = 4, + LIST_INT = 6, + LIST_FLOAT = 7, + LIST_STRING = 8, + LIST_LIST_INT = 9, +}; + +// operator attribution describe +// type decide whether the other attribute needed to set a value +struct OpAttr { + std::string name; + OpAttrType type; + int num; // LIST_BOOL/INT/FLOAT/STRING/LIST_LIST_INT need + uint8_t numBool; // BOOL need + int64_t numInt; // INT need + float numFloat; // FLOAT need + std::string numString; // STRING need + std::vector valuesBool; // LIST_BOOL need + std::vector valuesInt; // LIST_INT need + std::vector valuesFloat; // LIST_FLOAT need + std::vector valuesString; // LIST_STRING need + std::vector numLists; // LIST_LIST_INT need + std::vector> valuesListList; // LIST_LIST_INT need +}; + +// Description of image data +struct ImageInfo { + uint32_t width; // Image width + uint32_t height; // Image height + uint32_t lenOfByte; // Size of image data, bytes + std::shared_ptr data; // Smart pointer of image data +}; + +// Description of data in device +struct RawData { + size_t lenOfByte; // Size of memory, bytes + std::shared_ptr data; // Smart pointer of data +}; + +// Description of data in device +struct StreamData { + size_t size; // Size of memory, bytes + std::shared_ptr data; // Smart pointer of data +}; + +// Description of stream data +struct StreamInfo { + std::string format; + uint32_t height; + uint32_t width; + uint32_t channelId; + std::string streamPath; +}; + +// define the structure of an rectangle +struct Rectangle { + uint32_t leftTopX; + uint32_t leftTopY; + uint32_t rightBottomX; + uint32_t rightBottomY; +}; + +struct ObjectDetectInfo { + int32_t classId; + float confidence; + Rectangle location; +}; + +enum VpcProcessType { + VPC_PT_DEFAULT = 0, + VPC_PT_PADDING, // Resize with locked ratio and paste on upper left corner + VPC_PT_FIT, // Resize with locked ratio and paste on middle location + VPC_PT_FILL, // Resize with locked ratio and paste on whole locatin, the input image may be cropped +}; + +struct DvppDataInfo { + uint32_t width = 0; // Width of image + uint32_t height = 0; // Height of image + uint32_t widthStride = 0; // Width after align up + uint32_t heightStride = 0; // Height after align up + acldvppPixelFormat format = PIXEL_FORMAT_YUV_SEMIPLANAR_420; // Format of image + uint32_t frameId = 0; // Needed by video + uint32_t dataSize = 0; // Size of data in byte + uint8_t *data = nullptr; // Image data +}; + +struct CropRoiConfig { + uint32_t left; + uint32_t right; + uint32_t down; + uint32_t up; +}; + +struct DvppCropInputInfo { + DvppDataInfo dataInfo; + CropRoiConfig roi; +}; + +// Description of matrix info +struct MatrixInfo { + uint32_t row = 0; // row of matrix + uint32_t col = 0; // col of matrix + uint32_t dataSize = 0; // size of memory, bytes + std::shared_ptr data = nullptr; // data of matrix + aclDataType dataType = ACL_FLOAT16; // data Type of matrix +}; + +// Description of coefficient info +struct CoefficientInfo { + std::shared_ptr data = nullptr; // data of coefficient + aclDataType dataType = ACL_FLOAT16; // dataType +}; + +// define the input of BLAS operator such as producing: +// C = alpha * A * B + beta * C +struct BlasInput { + MatrixInfo A; + MatrixInfo B; + MatrixInfo C; + CoefficientInfo alpha; + CoefficientInfo beta; +}; + +extern bool g_vdecNotified[VIDEO_PROCESS_THREAD]; +extern bool g_vpcNotified[VIDEO_PROCESS_THREAD]; +extern bool g_inferNotified[VIDEO_PROCESS_THREAD]; +extern bool g_postNotified[VIDEO_PROCESS_THREAD]; + +#endif diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/DvppCommon.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/DvppCommon.cc new file mode 100644 index 0000000000..76ad45628b --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/DvppCommon.cc @@ -0,0 +1,1417 @@ +/* + * Copyright (c) 2020.Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "mindspore/core/utils/log_adapter.h" +#include "DvppCommon.h" +#include "CommonDataType.h" + +static auto g_resizeConfigDeleter = [](acldvppResizeConfig *p) { acldvppDestroyResizeConfig(p); }; +static auto g_picDescDeleter = [](acldvppPicDesc *picDesc) { acldvppDestroyPicDesc(picDesc); }; +static auto g_roiConfigDeleter = [](acldvppRoiConfig *p) { acldvppDestroyRoiConfig(p); }; +static auto g_jpegeConfigDeleter = [](acldvppJpegeConfig *p) { acldvppDestroyJpegeConfig(p); }; + +DvppCommon::DvppCommon(aclrtStream dvppStream) { dvppStream_ = dvppStream; } + +DvppCommon::DvppCommon(const VdecConfig &vdecConfig) { vdecConfig_ = vdecConfig; } + +/* + * @description: Create a channel for processing image data, + * the channel description is created by acldvppCreateChannelDesc + * @return: APP_ERR_OK if success, other values if failure + */ +APP_ERROR DvppCommon::Init(void) { + dvppChannelDesc_ = acldvppCreateChannelDesc(); + if (dvppChannelDesc_ == nullptr) { + return -1; + } + APP_ERROR ret = acldvppCreateChannel(dvppChannelDesc_); + if (ret != 0) { + MS_LOG(ERROR) << "Failed to create dvpp channel: " << GetAppErrCodeInfo(ret) << "."; + acldvppDestroyChannelDesc(dvppChannelDesc_); + dvppChannelDesc_ = nullptr; + return ret; + } + + return APP_ERR_OK; +} + +/* + * @description: Create a channel for processing video data, + * the channel description is created by aclvdecCreateChannelDesc + * @return: APP_ERR_OK if success, other values if failure + */ +APP_ERROR DvppCommon::InitVdec() { + isVdec_ = true; + // create vdec channelDesc + vdecChannelDesc_ = aclvdecCreateChannelDesc(); + if (vdecChannelDesc_ == nullptr) { + MS_LOG(ERROR) << "Failed to create vdec channel description."; + return APP_ERR_ACL_FAILURE; + } + + // channelId: 0-15 + aclError ret = aclvdecSetChannelDescChannelId(vdecChannelDesc_, vdecConfig_.channelId); + if (ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Failed to set vdec channel id, ret = " << ret << "."; + return APP_ERR_ACL_FAILURE; + } + + ret = aclvdecSetChannelDescThreadId(vdecChannelDesc_, vdecConfig_.threadId); + if (ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Failed to set thread id, ret = " << ret << "."; + return APP_ERR_ACL_FAILURE; + } + + // callback func + ret = aclvdecSetChannelDescCallback(vdecChannelDesc_, vdecConfig_.callback); + if (ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Failed to set vdec callback function, ret = " << ret << "."; + return APP_ERR_ACL_FAILURE; + } + + ret = aclvdecSetChannelDescEnType(vdecChannelDesc_, vdecConfig_.inFormat); + if (ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Failed to set encoded type of input video, ret = " << ret << "."; + return APP_ERR_ACL_FAILURE; + } + + ret = aclvdecSetChannelDescOutPicFormat(vdecChannelDesc_, vdecConfig_.outFormat); + if (ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Failed to set vdec output format, ret = " << ret << "."; + return APP_ERR_ACL_FAILURE; + } + + // create vdec channel + ret = aclvdecCreateChannel(vdecChannelDesc_); + if (ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Failed to create vdec channel, ret = " << ret << "."; + return APP_ERR_ACL_FAILURE; + } + + MS_LOG(INFO) << "Vdec init resource successfully."; + return APP_ERR_OK; +} + +/* + * @description: If isVdec_ is true, destroy the channel and the channel description used by video. + * Otherwise destroy the channel and the channel description used by image. + * @return: APP_ERR_OK if success, other values if failure + */ +APP_ERROR DvppCommon::DeInit(void) { + if (isVdec_) { + return DestroyResource(); + } + + APP_ERROR ret = aclrtSynchronizeStream(dvppStream_); // APP_ERROR ret + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to synchronize stream, ret = " << ret << "."; + return ret; + } + + ret = acldvppDestroyChannel(dvppChannelDesc_); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to destory dvpp channel, ret = " << ret << "."; + return ret; + } + + ret = acldvppDestroyChannelDesc(dvppChannelDesc_); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to destroy dvpp channel description, ret = " << ret << "."; + return ret; + } + return APP_ERR_OK; +} + +/* + * @description: Destroy the channel and the channel description used by video. + * @return: APP_ERR_OK if success, other values if failure + */ +APP_ERROR DvppCommon::DestroyResource() { + APP_ERROR ret = APP_ERR_OK; + isVdec_ = true; + if (vdecChannelDesc_ != nullptr) { + ret = aclvdecDestroyChannel(vdecChannelDesc_); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to destory dvpp channel, ret = " << ret; + } + aclvdecDestroyChannelDesc(vdecChannelDesc_); + vdecChannelDesc_ = nullptr; + } + return ret; +} + +/* + * @description: Release the memory that is allocated in the interfaces which are started with "Combine" + */ +void DvppCommon::ReleaseDvppBuffer() { + if (cropImage_ != nullptr) { + RELEASE_DVPP_DATA(cropImage_->data); + } + if (resizedImage_ != nullptr) { + RELEASE_DVPP_DATA(resizedImage_->data); + } + if (decodedImage_ != nullptr) { + RELEASE_DVPP_DATA(decodedImage_->data); + } + if (inputImage_ != nullptr) { + RELEASE_DVPP_DATA(inputImage_->data); + } + if (encodedImage_ != nullptr) { + RELEASE_DVPP_DATA(encodedImage_->data); + } +} + +/* + * @description: Get the size of buffer used to save image for VPC according to width, height and format + * @param width specifies the width of the output image + * @param height specifies the height of the output image + * @param format specifies the format of the output image + * @param: vpcSize is used to save the result size + * @return: APP_ERR_OK if success, other values if failure + */ +APP_ERROR DvppCommon::GetVpcDataSize(uint32_t width, uint32_t height, acldvppPixelFormat format, uint32_t &vpcSize) { + // Check the invalid format of VPC function and calculate the output buffer size + if (format != PIXEL_FORMAT_YUV_SEMIPLANAR_420 && format != PIXEL_FORMAT_YVU_SEMIPLANAR_420) { + MS_LOG(ERROR) << "Format[" << format << "] for VPC is not supported, just support NV12 or NV21."; + return APP_ERR_COMM_INVALID_PARAM; + } + uint32_t widthStride = DVPP_ALIGN_UP(width, VPC_WIDTH_ALIGN); + uint32_t heightStride = DVPP_ALIGN_UP(height, VPC_HEIGHT_ALIGN); + vpcSize = widthStride * heightStride * YUV_BGR_SIZE_CONVERT_3 / YUV_BGR_SIZE_CONVERT_2; + return APP_ERR_OK; +} + +/* + * @description: Get the aligned width and height of the input image according to the image format + * @param: width specifies the width before alignment + * @param: height specifies the height before alignment + * @param: format specifies the image format + * @param: widthStride is used to save the width after alignment + * @param: heightStride is used to save the height after alignment + * @return: APP_ERR_OK if success, other values if failure + */ +APP_ERROR DvppCommon::GetVpcInputStrideSize(uint32_t width, uint32_t height, acldvppPixelFormat format, + uint32_t &widthStride, uint32_t &heightStride) { + uint32_t inputWidthStride; + // Check the invalidty of input format and calculate the input width stride + if (format >= PIXEL_FORMAT_YUV_400 && format <= PIXEL_FORMAT_YVU_SEMIPLANAR_444) { + // If format is YUV SP, keep widthStride not change. + inputWidthStride = DVPP_ALIGN_UP(width, VPC_STRIDE_WIDTH); + } else if (format >= PIXEL_FORMAT_YUYV_PACKED_422 && format <= PIXEL_FORMAT_VYUY_PACKED_422) { + // If format is YUV422 packed, image size = H x W * 2; + inputWidthStride = DVPP_ALIGN_UP(width, VPC_STRIDE_WIDTH) * YUV422_WIDTH_NU; + } else if (format >= PIXEL_FORMAT_YUV_PACKED_444 && format <= PIXEL_FORMAT_BGR_888) { + // If format is YUV444 packed or RGB, image size = H x W * 3; + inputWidthStride = DVPP_ALIGN_UP(width, VPC_STRIDE_WIDTH) * YUV444_RGB_WIDTH_NU; + } else if (format >= PIXEL_FORMAT_ARGB_8888 && format <= PIXEL_FORMAT_BGRA_8888) { + // If format is XRGB8888, image size = H x W * 4 + inputWidthStride = DVPP_ALIGN_UP(width, VPC_STRIDE_WIDTH) * XRGB_WIDTH_NU; + } else { + MS_LOG(ERROR) << "Input format[" << format << "] for VPC is invalid, please check it."; + return APP_ERR_COMM_INVALID_PARAM; + } + uint32_t inputHeightStride = DVPP_ALIGN_UP(height, VPC_STRIDE_HEIGHT); + // Check the input validity width stride. + if (inputWidthStride > MAX_RESIZE_WIDTH || inputWidthStride < MIN_RESIZE_WIDTH) { + MS_LOG(ERROR) << "Input width stride " << inputWidthStride << " is invalid, not in [" << MIN_RESIZE_WIDTH << ", " + << MAX_RESIZE_WIDTH << "]."; + return APP_ERR_COMM_INVALID_PARAM; + } + // Check the input validity height stride. + if (inputHeightStride > MAX_RESIZE_HEIGHT || inputHeightStride < MIN_RESIZE_HEIGHT) { + MS_LOG(ERROR) << "Input height stride " << inputHeightStride << " is invalid, not in [" << MIN_RESIZE_HEIGHT << ", " + << MAX_RESIZE_HEIGHT << "]."; + return APP_ERR_COMM_INVALID_PARAM; + } + widthStride = inputWidthStride; + heightStride = inputHeightStride; + return APP_ERR_OK; +} + +/* + * @description: Get the aligned width and height of the output image according to the image format + * @param: width specifies the width before alignment + * @param: height specifies the height before alignment + * @param: format specifies the image format + * @param: widthStride is used to save the width after alignment + * @param: heightStride is used to save the height after alignment + * @return: APP_ERR_OK if success, other values if failure + */ +APP_ERROR DvppCommon::GetVpcOutputStrideSize(uint32_t width, uint32_t height, acldvppPixelFormat format, + uint32_t &widthStride, uint32_t &heightStride) { + // Check the invalidty of output format and calculate the output width and height + if (format != PIXEL_FORMAT_YUV_SEMIPLANAR_420 && format != PIXEL_FORMAT_YVU_SEMIPLANAR_420) { + MS_LOG(ERROR) << "Output format[" << format << "] for VPC is not supported, just support NV12 or NV21."; + return APP_ERR_COMM_INVALID_PARAM; + } + + widthStride = DVPP_ALIGN_UP(width, VPC_STRIDE_WIDTH); + heightStride = DVPP_ALIGN_UP(height, VPC_STRIDE_HEIGHT); + return APP_ERR_OK; +} + +/* + * @description: Set picture description information and execute resize function + * @param: input specifies the input image information + * @param: output specifies the output image information + * @param: withSynchronize specifies whether to execute synchronously + * @param: processType specifies whether to perform proportional scaling, default is non-proportional resize + * @return: APP_ERR_OK if success, other values if failure + * @attention: This function can be called only when the DvppCommon object is initialized with Init + */ +APP_ERROR DvppCommon::VpcResize(DvppDataInfo &input, DvppDataInfo &output, bool withSynchronize, + VpcProcessType processType) { + // Return special error code when the DvppCommon object is initialized with InitVdec + if (isVdec_) { + MS_LOG(ERROR) << "VpcResize cannot be called by the DvppCommon object which is initialized with InitVdec."; + return APP_ERR_DVPP_OBJ_FUNC_MISMATCH; + } + + acldvppPicDesc *inputDesc = acldvppCreatePicDesc(); + acldvppPicDesc *outputDesc = acldvppCreatePicDesc(); + resizeInputDesc_.reset(inputDesc, g_picDescDeleter); + resizeOutputDesc_.reset(outputDesc, g_picDescDeleter); + + // Set dvpp picture descriptin info of input image + APP_ERROR ret = SetDvppPicDescData(input, *resizeInputDesc_); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to set dvpp input picture description, ret = " << ret << "."; + return ret; + } + + // Set dvpp picture descriptin info of output image + ret = SetDvppPicDescData(output, *resizeOutputDesc_); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to set dvpp output picture description, ret = " << ret << "."; + return ret; + } + if (processType == VPC_PT_DEFAULT) { + return ResizeProcess(*resizeInputDesc_, *resizeOutputDesc_, withSynchronize); + } + + // Get crop area according to the processType + // When the processType is VPC_PT_FILL, the image will be cropped if the image size is different from the target + // resolution + CropRoiConfig cropRoi = {0}; + GetCropRoi(input, output, processType, cropRoi); + + // The width and height of the original image will be resized by the same ratio + // The cropped image will be pasted on the upper left corner or the middle location or the whole location according to + // the processType + CropRoiConfig pasteRoi = {0}; + GetPasteRoi(input, output, processType, pasteRoi); + + return ResizeWithPadding(*resizeInputDesc_, *resizeOutputDesc_, cropRoi, pasteRoi, withSynchronize); +} + +/* + * @description: Set image description information + * @param: dataInfo specifies the image information + * @param: picsDesc specifies the picture description information to be set + * @return: APP_ERR_OK if success, other values if failure + */ +APP_ERROR DvppCommon::SetDvppPicDescData(const DvppDataInfo &dataInfo, acldvppPicDesc &picDesc) { + APP_ERROR ret = acldvppSetPicDescData(&picDesc, dataInfo.data); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to set data for dvpp picture description, ret = " << ret << "."; + return ret; + } + ret = acldvppSetPicDescSize(&picDesc, dataInfo.dataSize); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to set size for dvpp picture description, ret = " << ret << "."; + return ret; + } + ret = acldvppSetPicDescFormat(&picDesc, dataInfo.format); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to set format for dvpp picture description, ret = " << ret << "."; + return ret; + } + ret = acldvppSetPicDescWidth(&picDesc, dataInfo.width); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to set width for dvpp picture description, ret = " << ret << "."; + return ret; + } + ret = acldvppSetPicDescHeight(&picDesc, dataInfo.height); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to set height for dvpp picture description, ret = " << ret << "."; + return ret; + } + if (!isVdec_) { + ret = acldvppSetPicDescWidthStride(&picDesc, dataInfo.widthStride); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to set aligned width for dvpp picture description, ret = " << ret << "."; + return ret; + } + ret = acldvppSetPicDescHeightStride(&picDesc, dataInfo.heightStride); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to set aligned height for dvpp picture description, ret = " << ret << "."; + return ret; + } + } + + return APP_ERR_OK; +} + +/* + * @description: Check whether the image format and zoom ratio meet the requirements + * @param: input specifies the input image information + * @param: output specifies the output image information + * @return: APP_ERR_OK if success, other values if failure + */ +APP_ERROR DvppCommon::CheckResizeParams(const DvppDataInfo &input, const DvppDataInfo &output) { + if (output.format != PIXEL_FORMAT_YUV_SEMIPLANAR_420 && output.format != PIXEL_FORMAT_YVU_SEMIPLANAR_420) { + MS_LOG(ERROR) << "Output format[" << output.format << "] for VPC is not supported, just support NV12 or NV21."; + return APP_ERR_COMM_INVALID_PARAM; + } + if (((float)output.height / input.height) < MIN_RESIZE_SCALE || + ((float)output.height / input.height) > MAX_RESIZE_SCALE) { + MS_LOG(ERROR) << "Resize scale should be in range [1/16, 32], which is " << (output.height / input.height) << "."; + return APP_ERR_COMM_INVALID_PARAM; + } + if (((float)output.width / input.width) < MIN_RESIZE_SCALE || + ((float)output.width / input.width) > MAX_RESIZE_SCALE) { + MS_LOG(ERROR) << "Resize scale should be in range [1/16, 32], which is " << (output.width / input.width) << "."; + return APP_ERR_COMM_INVALID_PARAM; + } + return APP_ERR_OK; +} + +/* + * @description: Scale the input image to the size specified by the output image and + * saves the result to the output image (non-proportionate scaling) + * @param: inputDesc specifies the description information of the input image + * @param: outputDesc specifies the description information of the output image + * @param: withSynchronize specifies whether to execute synchronously + * @return: APP_ERR_OK if success, other values if failure + */ +APP_ERROR DvppCommon::ResizeProcess(acldvppPicDesc &inputDesc, acldvppPicDesc &outputDesc, bool withSynchronize) { + acldvppResizeConfig *resizeConfig = acldvppCreateResizeConfig(); + if (resizeConfig == nullptr) { + MS_LOG(ERROR) << "Failed to create dvpp resize config."; + return APP_ERR_COMM_INVALID_POINTER; + } + + resizeConfig_.reset(resizeConfig, g_resizeConfigDeleter); + APP_ERROR ret = acldvppVpcResizeAsync(dvppChannelDesc_, &inputDesc, &outputDesc, resizeConfig_.get(), dvppStream_); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to resize asynchronously, ret = " << ret << "."; + return ret; + } + + if (withSynchronize) { + ret = aclrtSynchronizeStream(dvppStream_); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to synchronize stream, ret = " << ret << "."; + return ret; + } + } + return APP_ERR_OK; +} + +/* + * @description: Crop the image from the input image based on the specified area and + * paste the cropped image to the specified position of the target image + * as the output image + * @param: inputDesc specifies the description information of the input image + * @param: outputDesc specifies the description information of the output image + * @param: cropRoi specifies the cropped area + * @param: pasteRoi specifies the pasting area + * @param: withSynchronize specifies whether to execute synchronously + * @return: APP_ERR_OK if success, other values if failure + * @attention: If the width and height of the crop area are different from those of the + * paste area, the image is scaled again + */ +APP_ERROR DvppCommon::ResizeWithPadding(acldvppPicDesc &inputDesc, acldvppPicDesc &outputDesc, CropRoiConfig &cropRoi, + CropRoiConfig &pasteRoi, bool withSynchronize) { + acldvppRoiConfig *cropRoiCfg = acldvppCreateRoiConfig(cropRoi.left, cropRoi.right, cropRoi.up, cropRoi.down); + if (cropRoiCfg == nullptr) { + MS_LOG(ERROR) << "Failed to create dvpp roi config for corp area."; + return APP_ERR_COMM_FAILURE; + } + cropAreaConfig_.reset(cropRoiCfg, g_roiConfigDeleter); + + acldvppRoiConfig *pastRoiCfg = acldvppCreateRoiConfig(pasteRoi.left, pasteRoi.right, pasteRoi.up, pasteRoi.down); + if (pastRoiCfg == nullptr) { + MS_LOG(ERROR) << "Failed to create dvpp roi config for paster area."; + return APP_ERR_COMM_FAILURE; + } + pasteAreaConfig_.reset(pastRoiCfg, g_roiConfigDeleter); + + APP_ERROR ret = acldvppVpcCropAndPasteAsync(dvppChannelDesc_, &inputDesc, &outputDesc, cropAreaConfig_.get(), + pasteAreaConfig_.get(), dvppStream_); + if (ret != APP_ERR_OK) { + // release resource. + MS_LOG(ERROR) << "Failed to crop and paste asynchronously, ret = " << ret << "."; + return ret; + } + if (withSynchronize) { + ret = aclrtSynchronizeStream(dvppStream_); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed tp synchronize stream, ret = " << ret << "."; + return ret; + } + } + return APP_ERR_OK; +} + +/* + * @description: Get crop area + * @param: input specifies the input image information + * @param: output specifies the output image information + * @param: processType specifies whether to perform proportional scaling + * @param: cropRoi is used to save the info of the crop roi area + * @return: APP_ERR_OK if success, other values if failure + */ +void DvppCommon::GetCropRoi(const DvppDataInfo &input, const DvppDataInfo &output, VpcProcessType processType, + CropRoiConfig &cropRoi) { + // When processType is not VPC_PT_FILL, crop area is the whole input image + if (processType != VPC_PT_FILL) { + cropRoi.right = CONVERT_TO_ODD(input.width - ODD_NUM_1); + cropRoi.down = CONVERT_TO_ODD(input.height - ODD_NUM_1); + return; + } + + bool widthRatioSmaller = true; + // The scaling ratio is based on the smaller ratio to ensure the smallest edge to fill the targe edge + float resizeRatio = static_cast(input.width) / output.width; + if (resizeRatio > (static_cast(input.height) / output.height)) { + resizeRatio = static_cast(input.height) / output.height; + widthRatioSmaller = false; + } + + const int halfValue = 2; + // The left and up must be even, right and down must be odd which is required by acl + if (widthRatioSmaller) { + cropRoi.left = 0; + cropRoi.right = CONVERT_TO_ODD(input.width - ODD_NUM_1); + cropRoi.up = CONVERT_TO_EVEN(static_cast((input.height - output.height * resizeRatio) / halfValue)); + cropRoi.down = CONVERT_TO_ODD(input.height - cropRoi.up - ODD_NUM_1); + return; + } + + cropRoi.up = 0; + cropRoi.down = CONVERT_TO_ODD(input.height - ODD_NUM_1); + cropRoi.left = CONVERT_TO_EVEN(static_cast((input.width - output.width * resizeRatio) / halfValue)); + cropRoi.right = CONVERT_TO_ODD(input.width - cropRoi.left - ODD_NUM_1); + return; +} + +/* + * @description: Get paste area + * @param: input specifies the input image information + * @param: output specifies the output image information + * @param: processType specifies whether to perform proportional scaling + * @param: pasteRio is used to save the info of the paste area + * @return: APP_ERR_OK if success, other values if failure + */ +void DvppCommon::GetPasteRoi(const DvppDataInfo &input, const DvppDataInfo &output, VpcProcessType processType, + CropRoiConfig &pasteRoi) { + if (processType == VPC_PT_FILL) { + pasteRoi.right = CONVERT_TO_ODD(output.width - ODD_NUM_1); + pasteRoi.down = CONVERT_TO_ODD(output.height - ODD_NUM_1); + return; + } + + bool widthRatioLarger = true; + // The scaling ratio is based on the larger ratio to ensure the largest edge to fill the targe edge + float resizeRatio = static_cast(input.width) / output.width; + if (resizeRatio < (static_cast(input.height) / output.height)) { + resizeRatio = static_cast(input.height) / output.height; + widthRatioLarger = false; + } + + // Left and up is 0 when the roi paste on the upper left corner + if (processType == VPC_PT_PADDING) { + pasteRoi.right = (input.width / resizeRatio) - ODD_NUM_1; + pasteRoi.down = (input.height / resizeRatio) - ODD_NUM_1; + pasteRoi.right = CONVERT_TO_ODD(pasteRoi.right); + pasteRoi.down = CONVERT_TO_ODD(pasteRoi.down); + return; + } + + const int halfValue = 2; + // Left and up is 0 when the roi paste on the middler location + if (widthRatioLarger) { + pasteRoi.left = 0; + pasteRoi.right = output.width - ODD_NUM_1; + pasteRoi.up = (output.height - (input.height / resizeRatio)) / halfValue; + pasteRoi.down = output.height - pasteRoi.up - ODD_NUM_1; + } else { + pasteRoi.up = 0; + pasteRoi.down = output.height - ODD_NUM_1; + pasteRoi.left = (output.width - (input.width / resizeRatio)) / halfValue; + pasteRoi.right = output.width - pasteRoi.left - ODD_NUM_1; + } + + // The left must be even and align to 16, up must be even, right and down must be odd which is required by acl + pasteRoi.left = DVPP_ALIGN_UP(CONVERT_TO_EVEN(pasteRoi.left), VPC_WIDTH_ALIGN); + pasteRoi.right = CONVERT_TO_ODD(pasteRoi.right); + pasteRoi.up = CONVERT_TO_EVEN(pasteRoi.up); + pasteRoi.down = CONVERT_TO_ODD(pasteRoi.down); + return; +} + +/* + * @description: Resize the image specified by input and save the result to member variable resizedImage_ + * @param: input specifies the input image information + * @param: output specifies the output image information + * @param: withSynchronize specifies whether to execute synchronously + * @param: processType specifies whether to perform proportional scaling, default is non-proportional resize + * @return: APP_ERR_OK if success, other values if failure + * @attention: This function can be called only when the DvppCommon object is initialized with Init + */ +APP_ERROR DvppCommon::CombineResizeProcess(DvppDataInfo &input, DvppDataInfo &output, bool withSynchronize, + VpcProcessType processType) { + // Return special error code when the DvppCommon object is initialized with InitVdec + if (isVdec_) { + MS_LOG(ERROR) + << "CombineResizeProcess cannot be called by the DvppCommon object which is initialized with InitVdec."; + return APP_ERR_DVPP_OBJ_FUNC_MISMATCH; + } + + APP_ERROR ret = CheckResizeParams(input, output); + if (ret != APP_ERR_OK) { + return ret; + } + // Get widthStride and heightStride for input and output image according to the format + ret = + GetVpcInputStrideSize(input.widthStride, input.heightStride, input.format, input.widthStride, input.heightStride); + if (ret != APP_ERR_OK) { + return ret; + } + + resizedImage_ = std::make_shared(); + resizedImage_->width = output.width; + resizedImage_->height = output.height; + resizedImage_->format = output.format; + ret = GetVpcOutputStrideSize(output.width, output.height, output.format, resizedImage_->widthStride, + resizedImage_->heightStride); + if (ret != APP_ERR_OK) { + return ret; + } + // Get output buffer size for resize output + ret = GetVpcDataSize(output.width, output.height, output.format, resizedImage_->dataSize); + if (ret != APP_ERR_OK) { + return ret; + } + // Malloc buffer for output of resize module + // Need to pay attention to release of the buffer + ret = acldvppMalloc((void **)(&(resizedImage_->data)), resizedImage_->dataSize); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to malloc " << resizedImage_->dataSize << " bytes on dvpp for resize, ret = " << ret + << "."; + return ret; + } + + aclrtMemset(resizedImage_->data, resizedImage_->dataSize, YUV_GREYER_VALUE, resizedImage_->dataSize); + resizedImage_->frameId = input.frameId; + ret = VpcResize(input, *resizedImage_, withSynchronize, processType); + if (ret != APP_ERR_OK) { + // Release the output buffer when resize failed, otherwise release it after use + RELEASE_DVPP_DATA(resizedImage_->data); + } + return ret; +} + +/* + * @description: Set picture description information and execute crop function + * @param: cropInput specifies the input image information and cropping area + * @param: output specifies the output image information + * @param: withSynchronize specifies whether to execute synchronously + * @return: APP_ERR_OK if success, other values if failure + * @attention: This function can be called only when the DvppCommon object is initialized with Init + */ +APP_ERROR DvppCommon::VpcCrop(const DvppCropInputInfo &cropInput, const DvppDataInfo &output, bool withSynchronize) { + // Return special error code when the DvppCommon object is initialized with InitVdec + if (isVdec_) { + MS_LOG(ERROR) << "VpcCrop cannot be called by the DvppCommon object which is initialized with InitVdec."; + return APP_ERR_DVPP_OBJ_FUNC_MISMATCH; + } + + acldvppPicDesc *inputDesc = acldvppCreatePicDesc(); + acldvppPicDesc *outputDesc = acldvppCreatePicDesc(); + cropInputDesc_.reset(inputDesc, g_picDescDeleter); + cropOutputDesc_.reset(outputDesc, g_picDescDeleter); + + // Set dvpp picture descriptin info of input image + APP_ERROR ret = SetDvppPicDescData(cropInput.dataInfo, *cropInputDesc_); + if (ret != APP_ERR_OK) { + return ret; + } + // Set dvpp picture descriptin info of output image + ret = SetDvppPicDescData(output, *cropOutputDesc_); + if (ret != APP_ERR_OK) { + return ret; + } + return CropProcess(*cropInputDesc_, *cropOutputDesc_, cropInput.roi, withSynchronize); +} + +/* + * @description: Check whether the size of the cropped data and the cropped area meet the requirements + * @param: input specifies the image information and the information about the area to be cropped + * @return: APP_ERR_OK if success, other values if failure + */ +APP_ERROR DvppCommon::CheckCropParams(const DvppCropInputInfo &input) { + APP_ERROR ret; + uint32_t payloadSize; + ret = GetVpcDataSize(input.dataInfo.widthStride, input.dataInfo.heightStride, PIXEL_FORMAT_YUV_SEMIPLANAR_420, + payloadSize); + if (ret != APP_ERR_OK) { + return ret; + } + if (payloadSize != input.dataInfo.dataSize) { + MS_LOG(ERROR) << "Input data size: " << payloadSize + << " to crop does not match input yuv image size: " << input.dataInfo.dataSize << "."; + return APP_ERR_COMM_INVALID_PARAM; + } + + if ((!CHECK_EVEN(input.roi.left)) || (!CHECK_EVEN(input.roi.up)) || (!CHECK_ODD(input.roi.right)) || + (!CHECK_ODD(input.roi.down))) { + MS_LOG(ERROR) << "Crop area left and top(" << input.roi.left << ", " << input.roi.up + << ") must be even, right bottom(" << input.roi.right << "," << input.roi.down << ") must be odd."; + return APP_ERR_COMM_INVALID_PARAM; + } + + // Calculate crop width and height according to the input location + uint32_t cropWidth = input.roi.right - input.roi.left + ODD_NUM_1; + uint32_t cropHeight = input.roi.down - input.roi.up + ODD_NUM_1; + if ((cropWidth < MIN_CROP_WIDTH) || (cropHeight < MIN_CROP_HEIGHT)) { + MS_LOG(ERROR) << "Crop area width:" << cropWidth << " need to be larger than 10 and height:" << cropHeight + << " need to be larger than 6."; + return APP_ERR_COMM_INVALID_PARAM; + } + + if ((input.roi.left + cropWidth > input.dataInfo.width) || (input.roi.up + cropHeight > input.dataInfo.height)) { + MS_LOG(ERROR) << "Target rectangle start location(" << input.roi.left << "," << input.roi.up << ") with size(" + << cropWidth << "," << cropHeight << ") is out of the input image(" << input.dataInfo.width << "," + << input.dataInfo.height << ") to be cropped."; + return APP_ERR_COMM_INVALID_PARAM; + } + + return APP_ERR_OK; +} + +/* + * @description: It is used to crop an input image based on a specified region and + * store the cropped image to the output memory as an output image + * @param: inputDesc specifies the description information of the input image + * @param: outputDesc specifies the description information of the output image + * @param: CropRoiConfig specifies the cropped area + * @param: withSynchronize specifies whether to execute synchronously + * @return: APP_ERR_OK if success, other values if failure + * @attention: if the region of the output image is inconsistent with the crop area, the image is scaled again + */ +APP_ERROR DvppCommon::CropProcess(acldvppPicDesc &inputDesc, acldvppPicDesc &outputDesc, const CropRoiConfig &cropArea, + bool withSynchronize) { + uint32_t leftOffset = CONVERT_TO_EVEN(cropArea.left); + uint32_t rightOffset = CONVERT_TO_ODD(cropArea.right); + uint32_t upOffset = CONVERT_TO_EVEN(cropArea.up); + uint32_t downOffset = CONVERT_TO_ODD(cropArea.down); + + auto cropRioCfg = acldvppCreateRoiConfig(leftOffset, rightOffset, upOffset, downOffset); + if (cropRioCfg == nullptr) { + MS_LOG(ERROR) << "DvppCommon: create dvpp vpc resize failed."; + return APP_ERR_DVPP_RESIZE_FAIL; + } + cropRoiConfig_.reset(cropRioCfg, g_roiConfigDeleter); + + APP_ERROR ret = acldvppVpcCropAsync(dvppChannelDesc_, &inputDesc, &outputDesc, cropRoiConfig_.get(), dvppStream_); + if (ret != APP_ERR_OK) { + // release resource. + MS_LOG(ERROR) << "Failed to crop, ret = " << ret << "."; + return ret; + } + if (withSynchronize) { + ret = aclrtSynchronizeStream(dvppStream_); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to synchronize stream, ret = " << ret << "."; + return ret; + } + } + return APP_ERR_OK; +} + +/* + * @description: Crop the image specified by the input parameter and saves the result to member variable cropImage_ + * @param: input specifies the input image information and cropping area + * @param: output specifies the output image information + * @param: withSynchronize specifies whether to execute synchronously + * @return: APP_ERR_OK if success, other values if failure + * @attention: This function can be called only when the DvppCommon object is initialized with Init + */ +APP_ERROR DvppCommon::CombineCropProcess(DvppCropInputInfo &input, DvppDataInfo &output, bool withSynchronize) { + // Return special error code when the DvppCommon object is initialized with InitVdec + if (isVdec_) { + MS_LOG(ERROR) << "CombineCropProcess cannot be called by the DvppCommon object which is initialized with InitVdec."; + return APP_ERR_DVPP_OBJ_FUNC_MISMATCH; + } + + // Get widthStride and heightStride for input and output image according to the format + APP_ERROR ret = GetVpcInputStrideSize(input.dataInfo.width, input.dataInfo.height, input.dataInfo.format, + input.dataInfo.widthStride, input.dataInfo.heightStride); + if (ret != APP_ERR_OK) { + return ret; + } + ret = CheckCropParams(input); + if (ret != APP_ERR_OK) { + return ret; + } + cropImage_ = std::make_shared(); + cropImage_->width = output.width; + cropImage_->height = output.height; + cropImage_->format = output.format; + ret = GetVpcOutputStrideSize(output.width, output.height, output.format, cropImage_->widthStride, + cropImage_->heightStride); + if (ret != APP_ERR_OK) { + return ret; + } + // Get output buffer size for resize output + ret = GetVpcDataSize(output.width, output.height, output.format, cropImage_->dataSize); + if (ret != APP_ERR_OK) { + return ret; + } + + // Malloc buffer for output of resize module + // Need to pay attention to release of the buffer + ret = acldvppMalloc((void **)(&(cropImage_->data)), cropImage_->dataSize); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to malloc " << cropImage_->dataSize << " bytes on dvpp for resize, ret = " << ret << "."; + return ret; + } + cropImage_->frameId = input.dataInfo.frameId; + ret = VpcCrop(input, *cropImage_, withSynchronize); + if (ret != APP_ERR_OK) { + // Release the output buffer when resize failed, otherwise release it after use + RELEASE_DVPP_DATA(cropImage_->data); + } + return ret; +} + +/* + * @description: Set the description of the output image and decode + * @param: input specifies the input image information + * @param: output specifies the output image information + * @param: withSynchronize specifies whether to execute synchronously + * @return: APP_ERR_OK if success, other values if failure + * @attention: This function can be called only when the DvppCommon object is initialized with Init + */ +APP_ERROR DvppCommon::JpegDecode(DvppDataInfo &input, DvppDataInfo &output, bool withSynchronize) { + // Return special error code when the DvppCommon object is initialized with InitVdec + if (isVdec_) { + MS_LOG(ERROR) << "JpegDecode cannot be called by the DvppCommon object which is initialized with InitVdec."; + return APP_ERR_DVPP_OBJ_FUNC_MISMATCH; + } + + acldvppPicDesc *outputDesc = acldvppCreatePicDesc(); + decodeOutputDesc_.reset(outputDesc, g_picDescDeleter); + + APP_ERROR ret = SetDvppPicDescData(output, *decodeOutputDesc_); + if (ret != APP_ERR_OK) { + return ret; + } + + ret = acldvppJpegDecodeAsync(dvppChannelDesc_, input.data, input.dataSize, decodeOutputDesc_.get(), dvppStream_); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to decode jpeg, ret = " << ret << "."; + return ret; + } + if (withSynchronize) { + ret = aclrtSynchronizeStream(dvppStream_); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to synchronize stream, ret = " << ret << "."; + return APP_ERR_DVPP_JPEG_DECODE_FAIL; + } + } + return APP_ERR_OK; +} + +/* + * @description: Get the aligned width and height of the image after decoding + * @param: width specifies the width before alignment + * @param: height specifies the height before alignment + * @param: widthStride is used to save the width after alignment + * @param: heightStride is used to save the height after alignment + * @return: APP_ERR_OK if success, other values if failure + */ +void DvppCommon::GetJpegDecodeStrideSize(uint32_t width, uint32_t height, uint32_t &widthStride, + uint32_t &heightStride) { + widthStride = DVPP_ALIGN_UP(width, JPEGD_STRIDE_WIDTH); + heightStride = DVPP_ALIGN_UP(height, JPEGD_STRIDE_HEIGHT); +} + +/* + * @description: Get picture width and height and number of channels from image data + * @param: data specifies the memory to store the image data + * @param: dataSize specifies the size of the image data + * @param: width is used to save the image width + * @param: height is used to save the image height + * @param: components is used to save the number of channels + * @return: APP_ERR_OK if success, other values if failure + */ +APP_ERROR DvppCommon::GetJpegImageInfo(const void *data, uint32_t dataSize, uint32_t &width, uint32_t &height, + int32_t &components) { + uint32_t widthTmp; + uint32_t heightTmp; + int32_t componentsTmp; + APP_ERROR ret = acldvppJpegGetImageInfo(data, dataSize, &widthTmp, &heightTmp, &componentsTmp); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to get image info of jpeg, ret = " << ret << "."; + return ret; + } + if (widthTmp > MAX_JPEGD_WIDTH || widthTmp < MIN_JPEGD_WIDTH) { + MS_LOG(ERROR) << "Input width is invalid, not in [" << MIN_JPEGD_WIDTH << ", " << MAX_JPEGD_WIDTH << "]."; + return APP_ERR_COMM_INVALID_PARAM; + } + if (heightTmp > MAX_JPEGD_HEIGHT || heightTmp < MIN_JPEGD_HEIGHT) { + MS_LOG(ERROR) << "Input height is invalid, not in [" << MIN_JPEGD_HEIGHT << ", " << MAX_JPEGD_HEIGHT << "]."; + return APP_ERR_COMM_INVALID_PARAM; + } + width = widthTmp; + height = heightTmp; + components = componentsTmp; + return APP_ERR_OK; +} + +/* + * @description: Get the size of the buffer for storing decoded images based on the image data, size, and format + * @param: data specifies the memory to store the image data + * @param: dataSize specifies the size of the image data + * @param: format specifies the image format + * @param: decSize is used to store the result size + * @return: APP_ERR_OK if success, other values if failure + */ +APP_ERROR DvppCommon::GetJpegDecodeDataSize(const void *data, uint32_t dataSize, acldvppPixelFormat format, + uint32_t &decSize) { + uint32_t outputSize; + APP_ERROR ret = acldvppJpegPredictDecSize(data, dataSize, format, &outputSize); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to predict decode size of jpeg image, ret = " << ret << "."; + return ret; + } + decSize = outputSize; + return APP_ERR_OK; +} + +/* + * @description: Decode the image specified by imageInfo and save the result to member variable decodedImage_ + * @param: imageInfo specifies image information + * @param: format specifies the image format + * @param: withSynchronize specifies whether to execute synchronously + * @return: APP_ERR_OK if success, other values if failure + * @attention: This function can be called only when the DvppCommon object is initialized with Init + */ +APP_ERROR DvppCommon::CombineJpegdProcess(const RawData &imageInfo, acldvppPixelFormat format, bool withSynchronize) { + // Return special error code when the DvppCommon object is initialized with InitVdec + if (isVdec_) { + MS_LOG(ERROR) + << "CombineJpegdProcess cannot be called by the DvppCommon object which is initialized with InitVdec."; + return APP_ERR_DVPP_OBJ_FUNC_MISMATCH; + } + + int32_t components; + inputImage_ = std::make_shared(); + inputImage_->format = format; + APP_ERROR ret = + GetJpegImageInfo(imageInfo.data.get(), imageInfo.lenOfByte, inputImage_->width, inputImage_->height, components); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to get input image info, ret = " << ret << "."; + return ret; + } + + // Get the buffer size of decode output according to the input data and output format + uint32_t outBuffSize; + ret = GetJpegDecodeDataSize(imageInfo.data.get(), imageInfo.lenOfByte, format, outBuffSize); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to get size of decode output buffer, ret = " << ret << "."; + return ret; + } + + // In TransferImageH2D function, device buffer will be alloced to store the input image + // Need to pay attention to release of the buffer + ret = TransferImageH2D(imageInfo, inputImage_); + if (ret != APP_ERR_OK) { + return ret; + } + + decodedImage_ = std::make_shared(); + decodedImage_->format = format; + decodedImage_->width = inputImage_->width; + decodedImage_->height = inputImage_->height; + GetJpegDecodeStrideSize(inputImage_->width, inputImage_->height, decodedImage_->widthStride, + decodedImage_->heightStride); + decodedImage_->dataSize = outBuffSize; + // Malloc dvpp buffer to store the output data after decoding + // Need to pay attention to release of the buffer + ret = acldvppMalloc((void **)&decodedImage_->data, decodedImage_->dataSize); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to malloc memory on dvpp, ret = " << ret << "."; + RELEASE_DVPP_DATA(inputImage_->data); + return ret; + } + + ret = JpegDecode(*inputImage_, *decodedImage_, withSynchronize); + if (ret != APP_ERR_OK) { + // Release the output buffer when decode failed, otherwise release it after use + RELEASE_DVPP_DATA(inputImage_->data); + inputImage_->data = nullptr; + RELEASE_DVPP_DATA(decodedImage_->data); + decodedImage_->data = nullptr; + return ret; + } + + return APP_ERR_OK; +} + +/* + * @description: Transfer data from host to device + * @param: imageInfo specifies the image data on the host + * @param: jpegInput is used to save the buffer and its size which is allocate on the device + * @return: APP_ERR_OK if success, other values if failure + */ +APP_ERROR DvppCommon::TransferImageH2D(const RawData &imageInfo, const std::shared_ptr &jpegInput) { + // Check image buffer size validity + if (imageInfo.lenOfByte <= 0) { + MS_LOG(ERROR) << "The input buffer size on host should not be empty."; + return APP_ERR_COMM_INVALID_PARAM; + } + + uint8_t *inDevBuff = nullptr; + APP_ERROR ret = acldvppMalloc((void **)&inDevBuff, imageInfo.lenOfByte); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to malloc " << imageInfo.lenOfByte << " bytes on dvpp, ret = " << ret << "."; + return ret; + } + + // Copy the image data from host to device + ret = aclrtMemcpyAsync(inDevBuff, imageInfo.lenOfByte, imageInfo.data.get(), imageInfo.lenOfByte, + ACL_MEMCPY_HOST_TO_DEVICE, dvppStream_); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to copy " << imageInfo.lenOfByte << " bytes from host to device, ret = " << ret << "."; + RELEASE_DVPP_DATA(inDevBuff); + return ret; + } + // Attention: We must call the aclrtSynchronizeStream to ensure the task of memory replication has been completed + // after calling aclrtMemcpyAsync + ret = aclrtSynchronizeStream(dvppStream_); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to synchronize stream, ret = " << ret << "."; + RELEASE_DVPP_DATA(inDevBuff); + return ret; + } + jpegInput->data = inDevBuff; + jpegInput->dataSize = imageInfo.lenOfByte; + return APP_ERR_OK; +} + +/* + * @description: Create and set the description of a video stream + * @param: data specifies the information about the video stream + * @return: APP_ERR_OK if success, other values if failure + */ +APP_ERROR DvppCommon::CreateStreamDesc(std::shared_ptr data) { + // Malloc input device memory which need to be released in vdec callback function + void *modelInBuff = nullptr; + APP_ERROR ret = acldvppMalloc(&modelInBuff, data->dataSize); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to malloc dvpp data with " << data->dataSize << " bytes, ret = " << ret << "."; + return APP_ERR_ACL_BAD_ALLOC; + } + // copy input to device memory + ret = aclrtMemcpy(modelInBuff, data->dataSize, static_cast(data->data), data->dataSize, + ACL_MEMCPY_HOST_TO_DEVICE); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to copy memory with " << data->dataSize << " bytes from host to device, ret = " << ret + << "."; + acldvppFree(modelInBuff); + modelInBuff = nullptr; + return APP_ERR_ACL_FAILURE; + } + // Create input stream desc which need to be destoryed in vdec callback function + streamInputDesc_ = acldvppCreateStreamDesc(); + if (streamInputDesc_ == nullptr) { + MS_LOG(ERROR) << "Failed to create input stream description."; + return APP_ERR_ACL_FAILURE; + } + ret = acldvppSetStreamDescData(streamInputDesc_, modelInBuff); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to set data for stream desdescription, ret = " << ret << "."; + return ret; + } + // set size for dvpp stream desc + ret = acldvppSetStreamDescSize(streamInputDesc_, data->dataSize); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to set data size for stream desdescription, ret = " << ret << "."; + return ret; + } + return APP_ERR_OK; +} + +/* + * @description: Decode the video based on the video stream specified by data and user-defined data, + * and outputs the image of each frame + * @param: data specifies the information about the video stream + * @param: userdata is specified for user-defined data + * @return: APP_ERR_OK if success, other values if failure + * @attention: This function can be called only when the DvppCommon object is initialized with InitVdec + */ +APP_ERROR DvppCommon::CombineVdecProcess(std::shared_ptr data, void *userData) { + // Return special error code when the DvppCommon object is not initialized with InitVdec + if (!isVdec_) { + MS_LOG(ERROR) + << "CombineVdecProcess cannot be called by the DvppCommon object which is not initialized with InitVdec."; + return APP_ERR_DVPP_OBJ_FUNC_MISMATCH; + } + // create stream desc + APP_ERROR ret = CreateStreamDesc(data); + if (ret != APP_ERR_OK) { + return ret; + } + + uint32_t dataSize; + ret = GetVideoDecodeDataSize(vdecConfig_.inputWidth, vdecConfig_.inputHeight, vdecConfig_.outFormat, dataSize); + if (ret != APP_ERR_OK) { + return ret; + } + + void *picOutBufferDev = nullptr; + // picOutBufferDev need to be destoryed in vdec callback function + ret = acldvppMalloc(&picOutBufferDev, dataSize); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to malloc memory with " << dataSize << " bytes, ret = " << ret << "."; + return APP_ERR_ACL_BAD_ALLOC; + } + + // picOutputDesc_ will be destoryed in vdec callback function + picOutputDesc_ = acldvppCreatePicDesc(); + if (picOutputDesc_ == NULL) { + return APP_ERR_ACL_BAD_ALLOC; + } + + DvppDataInfo dataInfo; + dataInfo.width = vdecConfig_.inputWidth; + dataInfo.height = vdecConfig_.inputHeight; + dataInfo.format = vdecConfig_.outFormat; + dataInfo.dataSize = dataSize; + dataInfo.data = static_cast(picOutBufferDev); + ret = SetDvppPicDescData(dataInfo, *picOutputDesc_); + if (ret != APP_ERR_OK) { + return ret; + } + + // send frame + ret = aclvdecSendFrame(vdecChannelDesc_, streamInputDesc_, picOutputDesc_, nullptr, userData); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to send frame, ret = " << ret << "."; + return APP_ERR_ACL_FAILURE; + } + + return APP_ERR_OK; +} + +/* + * @description: Send eos frame when video stream ends + * @return: APP_ERR_OK if success, other values if failure + */ +APP_ERROR DvppCommon::VdecSendEosFrame() const { + // create input stream desc + acldvppStreamDesc *eosStreamDesc = acldvppCreateStreamDesc(); + if (eosStreamDesc == nullptr) { + MS_LOG(ERROR) << "Fail to create dvpp stream desc for eos."; + return ACL_ERROR_FAILURE; + } + + // set eos for eos stream desc + APP_ERROR ret = acldvppSetStreamDescEos(eosStreamDesc, true); + if (ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Fail to set eos for stream desc, ret = " << ret << "."; + acldvppDestroyStreamDesc(eosStreamDesc); + return ret; + } + + // send eos and synchronize + ret = aclvdecSendFrame(vdecChannelDesc_, eosStreamDesc, nullptr, nullptr, nullptr); + if (ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Fail to send eos, ret = " << ret << "."; + acldvppDestroyStreamDesc(eosStreamDesc); + return ret; + } + + // destory input stream desc + ret = acldvppDestroyStreamDesc(eosStreamDesc); + if (ret != ACL_ERROR_NONE) { + MS_LOG(ERROR) << "Fail to destory dvpp stream desc for eos, ret = " << ret << "."; + return ret; + } + return ret; +} + +/* + * @description: Get the aligned width and height of the output image after video decoding + * @param: width specifies the width before alignment + * @param: height specifies the height before alignment + * @param: format specifies the format of the output image + * @param: widthStride is used to save the width after alignment + * @param: heightStride is used to save the height after alignment + * @return: APP_ERR_OK if success, other values if failure + */ +APP_ERROR DvppCommon::GetVideoDecodeStrideSize(uint32_t width, uint32_t height, acldvppPixelFormat format, + uint32_t &widthStride, uint32_t &heightStride) { + // Check the invalidty of output format and calculate the output width and height + if (format != PIXEL_FORMAT_YUV_SEMIPLANAR_420 && format != PIXEL_FORMAT_YVU_SEMIPLANAR_420) { + MS_LOG(ERROR) << "Input format[" << format << "] for VPC is not supported, just support NV12 or NV21."; + return APP_ERR_COMM_INVALID_PARAM; + } + widthStride = DVPP_ALIGN_UP(width, VDEC_STRIDE_WIDTH); + heightStride = DVPP_ALIGN_UP(height, VDEC_STRIDE_HEIGHT); + return APP_ERR_OK; +} + +/* + * @description: Get the buffer size for storing results after video decoding + * @param width specifies the width of the output image after video decoding + * @param height specifies the height of the output image after video decoding + * @param format specifies the format of the output image + * @param: vpcSize is used to save the result size + * @return: APP_ERR_OK if success, other values if failure + */ +APP_ERROR DvppCommon::GetVideoDecodeDataSize(uint32_t width, uint32_t height, acldvppPixelFormat format, + uint32_t &vdecSize) { + // Check the invalid format of vdec output and calculate the output buffer size + if (format != PIXEL_FORMAT_YUV_SEMIPLANAR_420 && format != PIXEL_FORMAT_YVU_SEMIPLANAR_420) { + MS_LOG(ERROR) << "Format[" << format << "] for VPC is not supported, just support NV12 or NV21."; + return APP_ERR_COMM_INVALID_PARAM; + } + uint32_t widthStride = DVPP_ALIGN_UP(width, VDEC_STRIDE_WIDTH); + uint32_t heightStride = DVPP_ALIGN_UP(height, VDEC_STRIDE_HEIGHT); + vdecSize = widthStride * heightStride * YUV_BGR_SIZE_CONVERT_3 / YUV_BGR_SIZE_CONVERT_2; + return APP_ERR_OK; +} + +/* + * @description: Encode a YUV image into a JPG image + * @param: input specifies the input image information + * @param: output specifies the output image information + * @param: jpegeConfig specifies the encoding configuration data + * @param: withSynchronize specifies whether to execute synchronously + * @return: APP_ERR_OK if success, other values if failure + * @attention: This function can be called only when the DvppCommon object is initialized with Init + */ +APP_ERROR DvppCommon::JpegEncode(DvppDataInfo &input, DvppDataInfo &output, acldvppJpegeConfig *jpegeConfig, + bool withSynchronize) { + // Return special error code when the DvppCommon object is initialized with InitVdec + if (isVdec_) { + MS_LOG(ERROR) << "JpegEncode cannot be called by the DvppCommon object which is initialized with InitVdec."; + return APP_ERR_DVPP_OBJ_FUNC_MISMATCH; + } + + APP_ERROR ret = SetDvppPicDescData(input, *encodeInputDesc_); + if (ret != APP_ERR_OK) { + return ret; + } + + ret = acldvppJpegEncodeAsync(dvppChannelDesc_, encodeInputDesc_.get(), output.data, &output.dataSize, jpegeConfig, + dvppStream_); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to encode image, ret = " << ret << "."; + return ret; + } + if (withSynchronize) { + ret = aclrtSynchronizeStream(dvppStream_); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to aclrtSynchronizeStream, ret = " << ret << "."; + return APP_ERR_DVPP_JPEG_ENCODE_FAIL; + } + } + MS_LOG(INFO) << "Encode successfully."; + return APP_ERR_OK; +} + +/* + * @description: Get the aligned width, height, and data size of the input image + * @param: inputImage specifies the input image information + * @return: APP_ERR_OK if success, other values if failure + */ +APP_ERROR DvppCommon::GetJpegEncodeStrideSize(std::shared_ptr &inputImage) { + uint32_t inputWidth = inputImage->width; + uint32_t inputHeight = inputImage->height; + acldvppPixelFormat format = inputImage->format; + uint32_t widthStride; + uint32_t heightStride; + uint32_t encodedBufferSize; + // Align up the input width and height and calculate buffer size of encoded input file + if (format == PIXEL_FORMAT_YUV_SEMIPLANAR_420 || format == PIXEL_FORMAT_YVU_SEMIPLANAR_420) { + widthStride = DVPP_ALIGN_UP(inputWidth, JPEGE_STRIDE_WIDTH); + heightStride = DVPP_ALIGN_UP(inputHeight, JPEGE_STRIDE_HEIGHT); + encodedBufferSize = widthStride * heightStride * YUV_BYTES_NU / YUV_BYTES_DE; + } else if (format == PIXEL_FORMAT_YUYV_PACKED_422 || format == PIXEL_FORMAT_UYVY_PACKED_422 || + format == PIXEL_FORMAT_YVYU_PACKED_422 || format == PIXEL_FORMAT_VYUY_PACKED_422) { + widthStride = DVPP_ALIGN_UP(inputWidth * YUV422_WIDTH_NU, JPEGE_STRIDE_WIDTH); + heightStride = DVPP_ALIGN_UP(inputHeight, JPEGE_STRIDE_HEIGHT); + encodedBufferSize = widthStride * heightStride; + } else { + return APP_ERR_COMM_INVALID_PARAM; + } + if (encodedBufferSize == 0) { + MS_LOG(ERROR) << "Input host buffer size is empty."; + return APP_ERR_COMM_INVALID_PARAM; + } + inputImage->widthStride = widthStride; + inputImage->heightStride = heightStride; + inputImage->dataSize = encodedBufferSize; + return APP_ERR_OK; +} + +/* + * @description: Estimate the size of the output memory required by image encoding according to + * the input image description and image encoding configuration data + * @param: input specifies specifies the input image information + * @param: jpegeConfig specifies the encoding configuration data + * @param: encSize is used to save the result size + * @return: APP_ERR_OK if success, other values if failure + * @attention: This function can be called only when the DvppCommon object is initialized with Init + */ +APP_ERROR DvppCommon::GetJpegEncodeDataSize(DvppDataInfo &input, acldvppJpegeConfig *jpegeConfig, uint32_t &encSize) { + // Return special error code when the DvppCommon object is initialized with InitVdec + if (isVdec_) { + MS_LOG(ERROR) + << "GetJpegEncodeDataSize cannot be called by the DvppCommon object which is initialized with InitVdec."; + return APP_ERR_DVPP_OBJ_FUNC_MISMATCH; + } + + acldvppPicDesc *inputDesc = acldvppCreatePicDesc(); + encodeInputDesc_.reset(inputDesc, g_picDescDeleter); + + APP_ERROR ret = SetDvppPicDescData(input, *encodeInputDesc_); + if (ret != APP_ERR_OK) { + return ret; + } + + uint32_t outputSize; + ret = acldvppJpegPredictEncSize(encodeInputDesc_.get(), jpegeConfig, &outputSize); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to predict encode size of jpeg image, ret = " << ret << "."; + return ret; + } + encSize = outputSize; + return APP_ERR_OK; +} + +/* + * @description: Set the encoding configuration data + * @param: level specifies the encode quality range + * @param: jpegeConfig specifies the encoding configuration data + * @return: APP_ERR_OK if success, other values if failure + */ +APP_ERROR DvppCommon::SetEncodeLevel(uint32_t level, acldvppJpegeConfig &jpegeConfig) { + // Set the encoding quality + // The coding quality range [0, 100] + // The level 0 coding quality is similar to the level 100 + // The smaller the value in [1, 100], the worse the quality of the output picture + auto ret = (APP_ERROR)acldvppSetJpegeConfigLevel(&jpegeConfig, level); + if (ret != APP_ERR_OK) { + return ret; + } + return APP_ERR_OK; +} + +/* + * @description: Encode the image specified by imageInfo and save the result to member variable encodedImage_ + * @param: imageInfo specifies image information + * @param: width specifies the width of the input image + * @param: height specifies the height of the input image + * @param: format specifies the format of the input image + * @param: withSynchronize specifies whether to execute synchronously + * @return: APP_ERR_OK if success, other values if failure + * @attention: This function can be called only when the DvppCommon object is initialized with Init + */ +APP_ERROR DvppCommon::CombineJpegeProcess(const RawData &imageInfo, uint32_t width, uint32_t height, + acldvppPixelFormat format, bool withSynchronize) { + // Return special error code when the DvppCommon object is initialized with InitVdec + if (isVdec_) { + MS_LOG(ERROR) + << "CombineJpegeProcess cannot be called by the DvppCommon object which is initialized with InitVdec."; + return APP_ERR_DVPP_OBJ_FUNC_MISMATCH; + } + inputImage_ = std::make_shared(); + inputImage_->format = format; + inputImage_->width = width; + inputImage_->height = height; + // In TransferImageH2D function, device buffer will be alloced to store the input image + // Need to pay attention to release of the buffer + APP_ERROR ret = TransferImageH2D(imageInfo, inputImage_); + if (ret != APP_ERR_OK) { + return ret; + } + // Get stride size of encoded image + ret = GetJpegEncodeStrideSize(inputImage_); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to get encode stride size of input image file, ret = " << ret << "."; + return ret; + } + + auto jpegeConfig = acldvppCreateJpegeConfig(); + jpegeConfig_.reset(jpegeConfig, g_jpegeConfigDeleter); + + uint32_t encodeLevel = 100; + ret = SetEncodeLevel(encodeLevel, *jpegeConfig_); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to set encode level, ret = " << ret << "."; + return ret; + } + + // Get the buffer size of encode output according to the input data and jpeg encode config + uint32_t encodeOutBufferSize; + ret = GetJpegEncodeDataSize(*inputImage_, jpegeConfig_.get(), encodeOutBufferSize); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to get size of encode output buffer, ret = " << ret << "."; + return ret; + } + + encodedImage_ = std::make_shared(); + encodedImage_->dataSize = encodeOutBufferSize; + // Malloc dvpp buffer to store the output data after decoding + // Need to pay attention to release of the buffer + ret = acldvppMalloc((void **)&encodedImage_->data, encodedImage_->dataSize); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to malloc memory on dvpp, ret = " << ret << "."; + acldvppFree(inputImage_->data); + return ret; + } + + // Encode input image + ret = JpegEncode(*inputImage_, *encodedImage_, jpegeConfig_.get(), withSynchronize); + if (ret != APP_ERR_OK) { + // Release the output buffer when decode failed, otherwise release it after use + acldvppFree(inputImage_->data); + acldvppFree(encodedImage_->data); + return ret; + } + return APP_ERR_OK; +} + +std::shared_ptr DvppCommon::GetInputImage() { return inputImage_; } + +std::shared_ptr DvppCommon::GetDecodedImage() { return decodedImage_; } + +std::shared_ptr DvppCommon::GetResizedImage() { return resizedImage_; } + +std::shared_ptr DvppCommon::GetEncodedImage() { return encodedImage_; } + +std::shared_ptr DvppCommon::GetCropedImage() { return cropImage_; } + +DvppCommon::~DvppCommon() {} diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/DvppCommon.h b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/DvppCommon.h new file mode 100644 index 0000000000..90aff168c8 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/DvppCommon.h @@ -0,0 +1,220 @@ +/* + * Copyright (c) 2020.Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef DVPP_COMMON_H +#define DVPP_COMMON_H + +#include "CommonDataType.h" +#include "ErrorCode.h" + +#include "acl/ops/acl_dvpp.h" + +const int MODULUS_NUM_2 = 2; +const uint32_t ODD_NUM_1 = 1; + +struct Rect { + /* left location of the rectangle */ + uint32_t x; + /* top location of the rectangle */ + uint32_t y; + /* with of the rectangle */ + uint32_t width; + /* height of the rectangle */ + uint32_t height; +}; + +struct DvppBaseData { + uint32_t dataSize; // Size of data in byte + uint8_t *data; +}; + +struct VdecConfig { + int inputWidth = 0; + int inputHeight = 0; + acldvppStreamFormat inFormat = H264_MAIN_LEVEL; // stream format renference acldvppStreamFormat + acldvppPixelFormat outFormat = PIXEL_FORMAT_YUV_SEMIPLANAR_420; // output format renference acldvppPixelFormat + uint32_t channelId = 0; // user define channelId: 0-15 + uint32_t deviceId = 0; + pthread_t threadId = 0; // thread for callback + aclvdecCallback callback = {0}; // user define how to process vdec out data + bool runflag = true; +}; + +struct DeviceStreamData { + std::vector detectResult; + uint32_t framId; + uint32_t channelId; +}; + +const uint32_t JPEGD_STRIDE_WIDTH = 128; // Jpegd module output width need to align up to 128 +const uint32_t JPEGD_STRIDE_HEIGHT = 16; // Jpegd module output height need to align up to 16 +const uint32_t JPEGE_STRIDE_WIDTH = 16; // Jpege module input width need to align up to 16 +const uint32_t JPEGE_STRIDE_HEIGHT = 1; // Jpege module input height remains unchanged +const uint32_t VPC_STRIDE_WIDTH = 16; // Vpc module output width need to align up to 16 +const uint32_t VPC_STRIDE_HEIGHT = 2; // Vpc module output height need to align up to 2 +const uint32_t VDEC_STRIDE_WIDTH = 16; // Vdec module output width need to align up to 16 +const uint32_t VDEC_STRIDE_HEIGHT = 2; // Vdec module output width need to align up to 2 +const uint32_t YUV_BYTES_NU = 3; // Numerator of yuv image, H x W x 3 / 2 +const uint32_t YUV_BYTES_DE = 2; // Denominator of yuv image, H x W x 3 / 2 +const uint32_t YUV422_WIDTH_NU = 2; // Width of YUV422, WidthStride = Width * 2 +const uint32_t YUV444_RGB_WIDTH_NU = 3; // Width of YUV444 and RGB888, WidthStride = Width * 3 +const uint32_t XRGB_WIDTH_NU = 4; // Width of XRGB8888, WidthStride = Width * 4 +const uint32_t JPEG_OFFSET = 8; // Offset of input file for jpegd module +const uint32_t MAX_JPEGD_WIDTH = 8192; // Max width of jpegd module +const uint32_t MAX_JPEGD_HEIGHT = 8192; // Max height of jpegd module +const uint32_t MIN_JPEGD_WIDTH = 32; // Min width of jpegd module +const uint32_t MIN_JPEGD_HEIGHT = 32; // Min height of jpegd module +const uint32_t MAX_JPEGE_WIDTH = 8192; // Max width of jpege module +const uint32_t MAX_JPEGE_HEIGHT = 8192; // Max height of jpege module +const uint32_t MIN_JPEGE_WIDTH = 32; // Min width of jpege module +const uint32_t MIN_JPEGE_HEIGHT = 32; // Min height of jpege module +const uint32_t MAX_RESIZE_WIDTH = 4096; // Max width stride of resize module +const uint32_t MAX_RESIZE_HEIGHT = 4096; // Max height stride of resize module +const uint32_t MIN_RESIZE_WIDTH = 32; // Min width stride of resize module +const uint32_t MIN_RESIZE_HEIGHT = 6; // Min height stride of resize module +const float MIN_RESIZE_SCALE = 0.03125; // Min resize scale of resize module +const float MAX_RESIZE_SCALE = 16.0; // Min resize scale of resize module +const uint32_t MAX_VPC_WIDTH = 4096; // Max width of picture to VPC(resize/crop) +const uint32_t MAX_VPC_HEIGHT = 4096; // Max height of picture to VPC(resize/crop) +const uint32_t MIN_VPC_WIDTH = 32; // Min width of picture to VPC(resize/crop) +const uint32_t MIN_VPC_HEIGHT = 6; // Min height of picture to VPC(resize/crop) +const uint32_t MIN_CROP_WIDTH = 10; // Min width of crop area +const uint32_t MIN_CROP_HEIGHT = 6; // Min height of crop area +const uint8_t YUV_GREYER_VALUE = 128; // Filling value of the resized YUV image + +#define CONVERT_TO_ODD(NUM) (((NUM) % MODULUS_NUM_2 != 0) ? (NUM) : ((NUM)-1)) // Convert the input to odd num +#define CONVERT_TO_EVEN(NUM) (((NUM) % MODULUS_NUM_2 == 0) ? (NUM) : ((NUM)-1)) // Convert the input to even num +#define CHECK_ODD(num) ((num) % MODULUS_NUM_2 != 0) +#define CHECK_EVEN(num) ((num) % MODULUS_NUM_2 == 0) +#define RELEASE_DVPP_DATA(dvppDataPtr) \ + do { \ + APP_ERROR retMacro; \ + if (dvppDataPtr != nullptr) { \ + retMacro = acldvppFree(dvppDataPtr); \ + if (retMacro != APP_ERR_OK) { \ + MS_LOG(ERROR) << "Failed to free memory on dvpp, ret = " << retMacro << "."; \ + } \ + dvppDataPtr = nullptr; \ + } \ + } while (0); + +class DvppCommon { + public: + explicit DvppCommon(aclrtStream dvppStream); + explicit DvppCommon(const VdecConfig &vdecConfig); // Need by vdec + ~DvppCommon(); + APP_ERROR Init(void); + APP_ERROR InitVdec(); // Needed by vdec + APP_ERROR DeInit(void); + + static APP_ERROR GetVpcDataSize(uint32_t widthVpc, uint32_t heightVpc, acldvppPixelFormat format, uint32_t &vpcSize); + static APP_ERROR GetVpcInputStrideSize(uint32_t width, uint32_t height, acldvppPixelFormat format, + uint32_t &widthStride, uint32_t &heightStride); + static APP_ERROR GetVpcOutputStrideSize(uint32_t width, uint32_t height, acldvppPixelFormat format, + uint32_t &widthStride, uint32_t &heightStride); + static void GetJpegDecodeStrideSize(uint32_t width, uint32_t height, uint32_t &widthStride, uint32_t &heightStride); + static APP_ERROR GetJpegImageInfo(const void *data, uint32_t dataSize, uint32_t &width, uint32_t &height, + int32_t &components); + static APP_ERROR GetJpegDecodeDataSize(const void *data, uint32_t dataSize, acldvppPixelFormat format, + uint32_t &decSize); + static APP_ERROR GetJpegEncodeStrideSize(std::shared_ptr &input); + static APP_ERROR SetEncodeLevel(uint32_t level, acldvppJpegeConfig &jpegeConfig); + static APP_ERROR GetVideoDecodeStrideSize(uint32_t width, uint32_t height, acldvppPixelFormat format, + uint32_t &widthStride, uint32_t &heightStride); + static APP_ERROR GetVideoDecodeDataSize(uint32_t width, uint32_t height, acldvppPixelFormat format, + uint32_t &vdecSize); + + // The following interfaces can be called only when the DvppCommon object is initialized with Init + APP_ERROR VpcResize(DvppDataInfo &input, DvppDataInfo &output, bool withSynchronize, + VpcProcessType processType = VPC_PT_DEFAULT); + APP_ERROR VpcCrop(const DvppCropInputInfo &input, const DvppDataInfo &output, bool withSynchronize); + APP_ERROR JpegDecode(DvppDataInfo &input, DvppDataInfo &output, bool withSynchronize); + + APP_ERROR JpegEncode(DvppDataInfo &input, DvppDataInfo &output, acldvppJpegeConfig *jpegeConfig, + bool withSynchronize); + + APP_ERROR GetJpegEncodeDataSize(DvppDataInfo &input, acldvppJpegeConfig *jpegeConfig, uint32_t &encSize); + + // These functions started with "Combine" encapsulate the DVPP process together, malloc DVPP memory, + // transfer pictures from host to device, and then execute the DVPP operation. + // The caller needs to pay attention to the release of the memory alloced in these functions. + // You can call the ReleaseDvppBuffer function to release memory after use completely. + APP_ERROR CombineResizeProcess(DvppDataInfo &input, DvppDataInfo &output, bool withSynchronize, + VpcProcessType processType = VPC_PT_DEFAULT); + APP_ERROR CombineCropProcess(DvppCropInputInfo &input, DvppDataInfo &output, bool withSynchronize); + APP_ERROR CombineJpegdProcess(const RawData &imageInfo, acldvppPixelFormat format, bool withSynchronize); + APP_ERROR CombineJpegeProcess(const RawData &imageInfo, uint32_t width, uint32_t height, acldvppPixelFormat format, + bool withSynchronize); + // The following interface can be called only when the DvppCommon object is initialized with InitVdec + APP_ERROR CombineVdecProcess(std::shared_ptr data, void *userData); + + // Get the private member variables which are assigned in the interfaces which are started with "Combine" + std::shared_ptr GetInputImage(); + std::shared_ptr GetDecodedImage(); + std::shared_ptr GetResizedImage(); + std::shared_ptr GetEncodedImage(); + std::shared_ptr GetCropedImage(); + + // Release the memory that is allocated in the interfaces which are started with "Combine" + void ReleaseDvppBuffer(); + APP_ERROR VdecSendEosFrame() const; + + private: + APP_ERROR SetDvppPicDescData(const DvppDataInfo &dataInfo, acldvppPicDesc &picDesc); + APP_ERROR ResizeProcess(acldvppPicDesc &inputDesc, acldvppPicDesc &outputDesc, bool withSynchronize); + APP_ERROR ResizeWithPadding(acldvppPicDesc &inputDesc, acldvppPicDesc &outputDesc, CropRoiConfig &cropRoi, + CropRoiConfig &pasteRoi, bool withSynchronize); + void GetCropRoi(const DvppDataInfo &input, const DvppDataInfo &output, VpcProcessType processType, + CropRoiConfig &cropRoi); + void GetPasteRoi(const DvppDataInfo &input, const DvppDataInfo &output, VpcProcessType processType, + CropRoiConfig &pasteRoi); + APP_ERROR CropProcess(acldvppPicDesc &inputDesc, acldvppPicDesc &outputDesc, const CropRoiConfig &cropArea, + bool withSynchronize); + APP_ERROR CheckResizeParams(const DvppDataInfo &input, const DvppDataInfo &output); + APP_ERROR CheckCropParams(const DvppCropInputInfo &input); + APP_ERROR TransferImageH2D(const RawData &imageInfo, const std::shared_ptr &jpegInput); + APP_ERROR CreateStreamDesc(std::shared_ptr data); + APP_ERROR DestroyResource(); + + std::shared_ptr cropAreaConfig_ = nullptr; + std::shared_ptr pasteAreaConfig_ = nullptr; + + std::shared_ptr cropInputDesc_ = nullptr; + std::shared_ptr cropOutputDesc_ = nullptr; + std::shared_ptr cropRoiConfig_ = nullptr; + + std::shared_ptr encodeInputDesc_ = nullptr; + std::shared_ptr jpegeConfig_ = nullptr; + + std::shared_ptr resizeInputDesc_ = nullptr; + std::shared_ptr resizeOutputDesc_ = nullptr; + std::shared_ptr resizeConfig_ = nullptr; + + std::shared_ptr decodeOutputDesc_ = nullptr; + + acldvppChannelDesc *dvppChannelDesc_ = nullptr; + aclrtStream dvppStream_ = nullptr; + std::shared_ptr inputImage_ = nullptr; + std::shared_ptr decodedImage_ = nullptr; + std::shared_ptr encodedImage_ = nullptr; + std::shared_ptr resizedImage_ = nullptr; + std::shared_ptr cropImage_ = nullptr; + bool isVdec_ = false; + aclvdecChannelDesc *vdecChannelDesc_ = nullptr; + acldvppStreamDesc *streamInputDesc_ = nullptr; + acldvppPicDesc *picOutputDesc_ = nullptr; + VdecConfig vdecConfig_; +}; +#endif diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ErrorCode.cpp b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ErrorCode.cpp new file mode 100644 index 0000000000..fa63141077 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ErrorCode.cpp @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2020.Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindspore/core/utils/log_adapter.h" +#include "ErrorCode.h" + +std::string GetAppErrCodeInfo(const APP_ERROR err) { + if ((err < APP_ERR_ACL_END) && (err >= APP_ERR_ACL_FAILURE)) { + return APP_ERR_ACL_LOG_STRING[((err < 0) ? (err + APP_ERR_ACL_END + 1) : err)]; + } else if ((err < APP_ERR_COMM_END) && (err > APP_ERR_COMM_BASE)) { + return (err - APP_ERR_COMM_BASE) < + (int)sizeof(APP_ERR_COMMON_LOG_STRING) / (int)sizeof(APP_ERR_COMMON_LOG_STRING[0]) + ? APP_ERR_COMMON_LOG_STRING[err - APP_ERR_COMM_BASE] + : "Undefine the error code information"; + } else if ((err < APP_ERR_DVPP_END) && (err > APP_ERR_DVPP_BASE)) { + return (err - APP_ERR_DVPP_BASE) < (int)sizeof(APP_ERR_DVPP_LOG_STRING) / (int)sizeof(APP_ERR_DVPP_LOG_STRING[0]) + ? APP_ERR_DVPP_LOG_STRING[err - APP_ERR_DVPP_BASE] + : "Undefine the error code information"; + } else if ((err < APP_ERR_QUEUE_END) && (err > APP_ERR_QUEUE_BASE)) { + return (err - APP_ERR_QUEUE_BASE) < (int)sizeof(APP_ERR_QUEUE_LOG_STRING) / (int)sizeof(APP_ERR_QUEUE_LOG_STRING[0]) + ? APP_ERR_QUEUE_LOG_STRING[err - APP_ERR_QUEUE_BASE] + : "Undefine the error code information"; + } else { + return "Error code unknown"; + } +} + +void AssertErrorCode(int code, std::string file, std::string function, int line) { + if (code != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed at " << file << "->" << function << "->" << line << ": error code=" << code; + exit(code); + } +} + +void CheckErrorCode(int code, std::string file, std::string function, int line) { + if (code != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed at " << file << "->" << function << "->" << line << ": error code=" << code; + } +} diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ErrorCode.h b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ErrorCode.h new file mode 100644 index 0000000000..a61b0f0304 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ErrorCode.h @@ -0,0 +1,258 @@ +/* + * Copyright (c) 2020.Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ERROR_CODE_H +#define ERROR_CODE_H +#include + +using APP_ERROR = int; +// define the data tpye of error code +enum { + APP_ERR_OK = 0, + + // define the error code of ACL model, this is same with the aclError which is + // error code of ACL API Error codes 1~999 are reserved for the ACL. Do not + // add other error codes. Add it after APP_ERR_COMMON_ERR_BASE. + APP_ERR_ACL_FAILURE = -1, // ACL: general error + APP_ERR_ACL_ERR_BASE = 0, + APP_ERR_ACL_INVALID_PARAM = 1, // ACL: invalid parameter + APP_ERR_ACL_BAD_ALLOC = 2, // ACL: memory allocation fail + APP_ERR_ACL_RT_FAILURE = 3, // ACL: runtime failure + APP_ERR_ACL_GE_FAILURE = 4, // ACL: Graph Engine failure + APP_ERR_ACL_OP_NOT_FOUND = 5, // ACL: operator not found + APP_ERR_ACL_OP_LOAD_FAILED = 6, // ACL: fail to load operator + APP_ERR_ACL_READ_MODEL_FAILURE = 7, // ACL: fail to read model + APP_ERR_ACL_PARSE_MODEL = 8, // ACL: parse model failure + APP_ERR_ACL_MODEL_MISSING_ATTR = 9, // ACL: model missing attribute + APP_ERR_ACL_DESERIALIZE_MODEL = 10, // ACL: deserialize model failure + APP_ERR_ACL_EVENT_NOT_READY = 12, // ACL: event not ready + APP_ERR_ACL_EVENT_COMPLETE = 13, // ACL: event complete + APP_ERR_ACL_UNSUPPORTED_DATA_TYPE = 14, // ACL: unsupported data type + APP_ERR_ACL_REPEAT_INITIALIZE = 15, // ACL: repeat initialize + APP_ERR_ACL_COMPILER_NOT_REGISTERED = 16, // ACL: compiler not registered + APP_ERR_ACL_IO = 17, // ACL: IO failed + APP_ERR_ACL_INVALID_FILE = 18, // ACL: invalid file + APP_ERR_ACL_INVALID_DUMP_CONFIG = 19, // ACL: invalid dump comfig + APP_ERR_ACL_INVALID_PROFILING_CONFIG = 20, // ACL: invalid profiling config + APP_ERR_ACL_OP_TYPE_NOT_MATCH = 21, // ACL: operator type not match + APP_ERR_ACL_OP_INPUT_NOT_MATCH = 22, // ACL: operator input not match + APP_ERR_ACL_OP_OUTPUT_NOT_MATCH = 23, // ACL: operator output not match + APP_ERR_ACL_OP_ATTR_NOT_MATCH = 24, // ACL: operator attribute not match + APP_ERR_ACL_API_NOT_SUPPORT = 25, // ACL: API not support + APP_ERR_ACL_CREATE_DATA_BUF_FAILED = 26, // ACL: create data buffer fail + APP_ERR_ACL_END, // Not an error code, define the range of ACL error code + + // define the common error code, range: 1001~1999 + APP_ERR_COMM_BASE = 1000, + APP_ERR_COMM_FAILURE = APP_ERR_COMM_BASE + 1, // General Failed + APP_ERR_COMM_INNER = APP_ERR_COMM_BASE + 2, // Internal error + APP_ERR_COMM_INVALID_POINTER = APP_ERR_COMM_BASE + 3, // Invalid Pointer + APP_ERR_COMM_INVALID_PARAM = APP_ERR_COMM_BASE + 4, // Invalid parameter + APP_ERR_COMM_UNREALIZED = APP_ERR_COMM_BASE + 5, // Not implemented + APP_ERR_COMM_OUT_OF_MEM = APP_ERR_COMM_BASE + 6, // Out of memory + APP_ERR_COMM_ALLOC_MEM = APP_ERR_COMM_BASE + 7, // memory allocation error + APP_ERR_COMM_FREE_MEM = APP_ERR_COMM_BASE + 8, // free memory error + APP_ERR_COMM_OUT_OF_RANGE = APP_ERR_COMM_BASE + 9, // out of range + APP_ERR_COMM_NO_PERMISSION = APP_ERR_COMM_BASE + 10, // NO Permission + APP_ERR_COMM_TIMEOUT = APP_ERR_COMM_BASE + 11, // Timed out + APP_ERR_COMM_NOT_INIT = APP_ERR_COMM_BASE + 12, // Not initialized + APP_ERR_COMM_INIT_FAIL = APP_ERR_COMM_BASE + 13, // initialize failed + APP_ERR_COMM_INPROGRESS = APP_ERR_COMM_BASE + 14, // Operation now in progress + APP_ERR_COMM_EXIST = APP_ERR_COMM_BASE + 15, // Object, file or other resource already exist + APP_ERR_COMM_NO_EXIST = APP_ERR_COMM_BASE + 16, // Object, file or other resource doesn't exist + APP_ERR_COMM_BUSY = APP_ERR_COMM_BASE + 17, // Object, file or other resource is in use + APP_ERR_COMM_FULL = APP_ERR_COMM_BASE + 18, // No available Device or resource + APP_ERR_COMM_OPEN_FAIL = APP_ERR_COMM_BASE + 19, // Device, file or resource open failed + APP_ERR_COMM_READ_FAIL = APP_ERR_COMM_BASE + 20, // Device, file or resource read failed + APP_ERR_COMM_WRITE_FAIL = APP_ERR_COMM_BASE + 21, // Device, file or resource write failed + APP_ERR_COMM_DESTORY_FAIL = APP_ERR_COMM_BASE + 22, // Device, file or resource destory failed + APP_ERR_COMM_EXIT = APP_ERR_COMM_BASE + 23, // End of data stream, stop the application + APP_ERR_COMM_CONNECTION_CLOSE = APP_ERR_COMM_BASE + 24, // Out of connection, Communication shutdown + APP_ERR_COMM_CONNECTION_FAILURE = APP_ERR_COMM_BASE + 25, // connection fail + APP_ERR_COMM_STREAM_INVALID = APP_ERR_COMM_BASE + 26, // ACL stream is null pointer + APP_ERR_COMM_END, // Not an error code, define the range of common error code + + // define the error code of DVPP + APP_ERR_DVPP_BASE = 2000, + APP_ERR_DVPP_CROP_FAIL = APP_ERR_DVPP_BASE + 1, // DVPP: crop fail + APP_ERR_DVPP_RESIZE_FAIL = APP_ERR_DVPP_BASE + 2, // DVPP: resize fail + APP_ERR_DVPP_CROP_RESIZE_FAIL = APP_ERR_DVPP_BASE + 3, // DVPP: corp and resize fail + APP_ERR_DVPP_CONVERT_FROMAT_FAIL = APP_ERR_DVPP_BASE + 4, // DVPP: convert image fromat fail + APP_ERR_DVPP_VPC_FAIL = APP_ERR_DVPP_BASE + 5, // DVPP: VPC(crop, resize, convert fromat) fail + APP_ERR_DVPP_JPEG_DECODE_FAIL = APP_ERR_DVPP_BASE + 6, // DVPP: decode jpeg or jpg fail + APP_ERR_DVPP_JPEG_ENCODE_FAIL = APP_ERR_DVPP_BASE + 7, // DVPP: encode jpeg or jpg fail + APP_ERR_DVPP_PNG_DECODE_FAIL = APP_ERR_DVPP_BASE + 8, // DVPP: encode png fail + APP_ERR_DVPP_H26X_DECODE_FAIL = APP_ERR_DVPP_BASE + 9, // DVPP: decode H264 or H265 fail + APP_ERR_DVPP_H26X_ENCODE_FAIL = APP_ERR_DVPP_BASE + 10, // DVPP: encode H264 or H265 fail + APP_ERR_DVPP_HANDLE_NULL = APP_ERR_DVPP_BASE + 11, // DVPP: acldvppChannelDesc is nullptr + APP_ERR_DVPP_PICDESC_FAIL = APP_ERR_DVPP_BASE + 12, // DVPP: fail to create acldvppCreatePicDesc or + // fail to set acldvppCreatePicDesc + APP_ERR_DVPP_CONFIG_FAIL = APP_ERR_DVPP_BASE + 13, // DVPP: fail to set dvpp configuration,such as + // resize configuration,crop configuration + APP_ERR_DVPP_OBJ_FUNC_MISMATCH = APP_ERR_DVPP_BASE + 14, // DVPP: DvppCommon object mismatch the function + APP_ERR_DVPP_END, // Not an error code, define the range of common error code + + // define the error code of inference + APP_ERR_INFER_BASE = 3000, + APP_ERR_INFER_SET_INPUT_FAIL = APP_ERR_INFER_BASE + 1, // Infer: set input fail + APP_ERR_INFER_SET_OUTPUT_FAIL = APP_ERR_INFER_BASE + 2, // Infer: set output fail + APP_ERR_INFER_CREATE_OUTPUT_FAIL = APP_ERR_INFER_BASE + 3, // Infer: create output fail + APP_ERR_INFER_OP_SET_ATTR_FAIL = APP_ERR_INFER_BASE + 4, // Infer: set op attribute fail + APP_ERR_INFER_GET_OUTPUT_FAIL = APP_ERR_INFER_BASE + 5, // Infer: get model output fail + APP_ERR_INFER_FIND_MODEL_ID_FAIL = APP_ERR_INFER_BASE + 6, // Infer: find model id fail + APP_ERR_INFER_FIND_MODEL_DESC_FAIL = APP_ERR_INFER_BASE + 7, // Infer: find model description fail + APP_ERR_INFER_FIND_MODEL_MEM_FAIL = APP_ERR_INFER_BASE + 8, // Infer: find model memory fail + APP_ERR_INFER_FIND_MODEL_WEIGHT_FAIL = APP_ERR_INFER_BASE + 9, // Infer: find model weight fail + + APP_ERR_INFER_END, // Not an error code, define the range of inference error + // code + + // define the error code of transmission + APP_ERR_TRANS_BASE = 4000, + + APP_ERR_TRANS_END, // Not an error code, define the range of transmission + // error code + + // define the error code of blocking queue + APP_ERR_QUEUE_BASE = 5000, + APP_ERR_QUEUE_EMPTY = APP_ERR_QUEUE_BASE + 1, // Queue: empty queue + APP_ERR_QUEUE_STOPED = APP_ERR_QUEUE_BASE + 2, // Queue: queue stoped + APP_ERROR_QUEUE_FULL = APP_ERR_QUEUE_BASE + 3, // Queue: full queue + + // define the idrecognition web error code + APP_ERROR_FACE_WEB_USE_BASE = 10000, + APP_ERROR_FACE_WEB_USE_SYSTEM_ERROR = APP_ERROR_FACE_WEB_USE_BASE + 1, // Web: system error + APP_ERROR_FACE_WEB_USE_MUL_FACE = APP_ERROR_FACE_WEB_USE_BASE + 2, // Web: multiple faces + APP_ERROR_FACE_WEB_USE_REPEAT_REG = APP_ERROR_FACE_WEB_USE_BASE + 3, // Web: repeat registration + APP_ERROR_FACE_WEB_USE_PART_SUCCESS = APP_ERROR_FACE_WEB_USE_BASE + 4, // Web: partial search succeeded + APP_ERROR_FACE_WEB_USE_NO_FACE = APP_ERROR_FACE_WEB_USE_BASE + 5, // Web: no face detected + APP_ERR_QUEUE_END, // Not an error code, define the range of blocking queue + // error code +}; +const std::string APP_ERR_ACL_LOG_STRING[] = { + [APP_ERR_OK] = "Success", + [APP_ERR_ACL_INVALID_PARAM] = "ACL: invalid parameter", + [APP_ERR_ACL_BAD_ALLOC] = "ACL: memory allocation fail", + [APP_ERR_ACL_RT_FAILURE] = "ACL: runtime failure", + [APP_ERR_ACL_GE_FAILURE] = "ACL: Graph Engine failure", + [APP_ERR_ACL_OP_NOT_FOUND] = "ACL: operator not found", + [APP_ERR_ACL_OP_LOAD_FAILED] = "ACL: fail to load operator", + [APP_ERR_ACL_READ_MODEL_FAILURE] = "ACL: fail to read model", + [APP_ERR_ACL_PARSE_MODEL] = "ACL: parse model failure", + [APP_ERR_ACL_MODEL_MISSING_ATTR] = "ACL: model missing attribute", + [APP_ERR_ACL_DESERIALIZE_MODEL] = "ACL: deserialize model failure", + [11] = "Placeholder", + [APP_ERR_ACL_EVENT_NOT_READY] = "ACL: event not ready", + [APP_ERR_ACL_EVENT_COMPLETE] = "ACL: event complete", + [APP_ERR_ACL_UNSUPPORTED_DATA_TYPE] = "ACL: unsupported data type", + [APP_ERR_ACL_REPEAT_INITIALIZE] = "ACL: repeat initialize", + [APP_ERR_ACL_COMPILER_NOT_REGISTERED] = "ACL: compiler not registered", + [APP_ERR_ACL_IO] = "ACL: IO failed", + [APP_ERR_ACL_INVALID_FILE] = "ACL: invalid file", + [APP_ERR_ACL_INVALID_DUMP_CONFIG] = "ACL: invalid dump comfig", + [APP_ERR_ACL_INVALID_PROFILING_CONFIG] = "ACL: invalid profiling config", + [APP_ERR_ACL_OP_TYPE_NOT_MATCH] = "ACL: operator type not match", + [APP_ERR_ACL_OP_INPUT_NOT_MATCH] = "ACL: operator input not match", + [APP_ERR_ACL_OP_OUTPUT_NOT_MATCH] = "ACL: operator output not match", + [APP_ERR_ACL_OP_ATTR_NOT_MATCH] = "ACL: operator attribute not match", + [APP_ERR_ACL_API_NOT_SUPPORT] = "ACL: API not supported", + [APP_ERR_ACL_CREATE_DATA_BUF_FAILED] = "ACL: create data buffer fail", +}; + +const std::string APP_ERR_COMMON_LOG_STRING[] = { + [0] = "Placeholder", + [1] = "General Failed", + [2] = "Internal error", + [3] = "Invalid Pointer", + [4] = "Invalid parameter", + [5] = "Not implemented", + [6] = "Out of memory", + [7] = "memory allocation error", + [8] = "free memory error", + [9] = "out of range", + [10] = "NO Permission ", + [11] = "Timed out", + [12] = "Not initialized", + [13] = "initialize failed", + [14] = "Operation now in progress ", + [15] = "Object, file or other resource already exist", + [16] = "Object, file or other resource already doesn't exist", + [17] = "Object, file or other resource is in use", + [18] = "No available Device or resource", + [19] = "Device, file or resource open failed", + [20] = "Device, file or resource read failed", + [21] = "Device, file or resource write failed", + [22] = "Device, file or resource destory failed", + [23] = " ", + [24] = "Out of connection, Communication shutdown", + [25] = "connection fail", + [26] = "ACL stream is null pointer", +}; + +const std::string APP_ERR_DVPP_LOG_STRING[] = { + [0] = "Placeholder", + [1] = "DVPP: crop fail", + [2] = "DVPP: resize fail", + [3] = "DVPP: corp and resize fail", + [4] = "DVPP: convert image format fail", + [5] = "DVPP: VPC(crop, resize, convert format) fail", + [6] = "DVPP: decode jpeg or jpg fail", + [7] = "DVPP: encode jpeg or jpg fail", + [8] = "DVPP: encode png fail", + [9] = "DVPP: decode H264 or H265 fail", + [10] = "DVPP: encode H264 or H265 fail", + [11] = "DVPP: acldvppChannelDesc is nullptr", + [12] = "DVPP: fail to create or set acldvppCreatePicDesc", + [13] = "DVPP: fail to set dvpp configuration", + [14] = "DVPP: DvppCommon object mismatch the function", +}; + +const std::string APP_ERR_INFER_LOG_STRING[] = { + [0] = "Placeholder", + [1] = "Infer: set input fail", + [2] = "Infer: set output fail", + [3] = "Infer: create output fail", + [4] = "Infer: set op attribute fail", + [5] = "Infer: get model output fail", + [6] = "Infer: find model id fail", + [7] = "Infer: find model description fail", + [8] = "Infer: find model memory fail", + [9] = "Infer: find model weight fail", +}; + +const std::string APP_ERR_QUEUE_LOG_STRING[] = { + [0] = "Placeholder", + [1] = "empty queue", + [2] = "queue stoped", + [3] = "full queue", +}; + +const std::string APP_ERR_FACE_LOG_STRING[] = { + [0] = "Placeholder", + [1] = "system error", + [2] = "multiple faces", + [3] = "repeat registration", + [4] = "partial search succeeded", + [5] = "no face detected", +}; + +std::string GetAppErrCodeInfo(APP_ERROR err); +void AssertErrorCode(int code, std::string file, std::string function, int line); +void CheckErrorCode(int code, std::string file, std::string function, int line); + +#define RtAssert(code) AssertErrorCode(code, __FILE__, __FUNCTION__, __LINE__); +#define RtCheckError(code) CheckErrorCode(code, __FILE__, __FUNCTION__, __LINE__); + +#endif // ERROR_CODE_H_ \ No newline at end of file diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ResourceManager.cc b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ResourceManager.cc new file mode 100644 index 0000000000..0583f657ba --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ResourceManager.cc @@ -0,0 +1,136 @@ +/* + * Copyright (c) 2020.Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ResourceManager.h" +#include + +bool ResourceManager::initFlag_ = true; +std::shared_ptr ResourceManager::ptr_ = nullptr; + +/** + * Check whether the file exists. + * + * @param filePath the file path we want to check + * @return APP_ERR_OK if file exists, error code otherwise + */ +APP_ERROR ExistFile(const std::string &filePath) { + struct stat fileSat = {0}; + char c[PATH_MAX + 1] = {0x00}; + size_t count = filePath.copy(c, PATH_MAX + 1); + if (count != filePath.length()) { + MS_LOG(ERROR) << "Failed to strcpy" << c; + return APP_ERR_COMM_FAILURE; + } + // Get the absolute path of input directory + char path[PATH_MAX + 1] = {0x00}; + if ((strlen(c) > PATH_MAX) || (realpath(c, path) == nullptr)) { + MS_LOG(ERROR) << "Failed to get canonicalize path"; + return APP_ERR_COMM_EXIST; + } + if (stat(c, &fileSat) == 0 && S_ISREG(fileSat.st_mode)) { + return APP_ERR_OK; + } + return APP_ERR_COMM_FAILURE; +} + +void ResourceManager::Release() { + APP_ERROR ret; + for (size_t i = 0; i < deviceIds_.size(); i++) { + if (contexts_[i] != nullptr) { + ret = aclrtDestroyContext(contexts_[i]); // Destroy context + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to destroy context, ret = " << ret << "."; + return; + } + contexts_[i] = nullptr; + } + ret = aclrtResetDevice(deviceIds_[i]); // Reset device + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to reset device, ret = " << ret << "."; + return; + } + } + ret = aclFinalize(); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to finalize acl, ret = " << ret << "."; + return; + } + MS_LOG(INFO) << "Finalized acl successfully."; +} + +std::shared_ptr ResourceManager::GetInstance() { + if (ptr_ == nullptr) { + ResourceManager *temp = new ResourceManager(); + ptr_.reset(temp); + } + return ptr_; +} + +APP_ERROR ResourceManager::InitResource(ResourceInfo &resourceInfo) { + if (!GetInitStatus()) { + return APP_ERR_OK; + } + + std::string &aclConfigPath = resourceInfo.aclConfigPath; + APP_ERROR ret; + if (aclConfigPath.length() == 0) { + // Init acl without aclconfig + ret = aclInit(nullptr); + } else { + ret = ExistFile(aclConfigPath); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Acl config file not exist, ret = " << ret << "."; + return ret; + } + ret = aclInit(aclConfigPath.c_str()); // Initialize ACL + } + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to init acl, ret = " << ret; + return ret; + } + std::copy(resourceInfo.deviceIds.begin(), resourceInfo.deviceIds.end(), std::back_inserter(deviceIds_)); + MS_LOG(INFO) << "Initialized acl successfully."; + // Open device and create context for each chip, note: it create one context for each chip + for (size_t i = 0; i < deviceIds_.size(); i++) { + deviceIdMap_[deviceIds_[i]] = i; + ret = aclrtSetDevice(deviceIds_[i]); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to open acl device: " << deviceIds_[i]; + return ret; + } + MS_LOG(INFO) << "Open device " << deviceIds_[i] << " successfully."; + aclrtContext context; + ret = aclrtCreateContext(&context, deviceIds_[i]); + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to create acl context, ret = " << ret << "."; + return ret; + } + MS_LOG(INFO) << "Created context for device " << deviceIds_[i] << " successfully"; + contexts_.push_back(context); + } + std::string singleOpPath = resourceInfo.singleOpFolderPath; + if (!singleOpPath.empty()) { + ret = aclopSetModelDir(singleOpPath.c_str()); // Set operator model directory for application + if (ret != APP_ERR_OK) { + MS_LOG(ERROR) << "Failed to aclopSetModelDir, ret = " << ret << "."; + return ret; + } + } + MS_LOG(INFO) << "Init resource successfully."; + ResourceManager::initFlag_ = false; + return APP_ERR_OK; +} + +aclrtContext ResourceManager::GetContext(int deviceId) { return contexts_[deviceIdMap_[deviceId]]; } \ No newline at end of file diff --git a/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ResourceManager.h b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ResourceManager.h new file mode 100644 index 0000000000..bc2bfe6321 --- /dev/null +++ b/mindspore/ccsrc/minddata/dataset/kernels/image/dvpp/utils/ResourceManager.h @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2020.Huawei Technologies Co., Ltd. All rights reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef RESOURCEMANAGER_H +#define RESOURCEMANAGER_H + +#include +#include +#include +#include +#include +#include "CommonDataType.h" +#include "ErrorCode.h" +#include +#include "mindspore/core/utils/log_adapter.h" + +#define PATH_MAX 4096 + +enum ModelLoadMethod { + LOAD_FROM_FILE = 0, // Loading from file, memory of model and weights are managed by ACL + LOAD_FROM_MEM, // Loading from memory, memory of model and weights are managed by ACL + LOAD_FROM_FILE_WITH_MEM, // Loading from file, memory of model and weight are managed by user + LOAD_FROM_MEM_WITH_MEM // Loading from memory, memory of model and weight are managed by user +}; + +struct ModelInfo { + std::string modelName; + std::string modelPath; // Path of om model file + size_t modelFileSize; // Size of om model file + std::shared_ptr modelFilePtr; // Smart pointer of model file data + uint32_t modelWidth; // Input width of model + uint32_t modelHeight; // Input height of model + ModelLoadMethod method; // Loading method of model +}; + +// Device resource info, such as model infos, etc +struct DeviceResInfo { + std::vector modelInfos; +}; + +struct ResourceInfo { + std::set deviceIds; + std::string aclConfigPath; + std::string singleOpFolderPath; + std::unordered_map deviceResInfos; // map +}; + +APP_ERROR ExistFile(const std::string &filePath); + +class ResourceManager { + public: + ResourceManager(){}; + + ~ResourceManager(){}; + + // Get the Instance of resource manager + static std::shared_ptr GetInstance(); + + // Init the resource of resource manager + APP_ERROR InitResource(ResourceInfo &resourceInfo); + + aclrtContext GetContext(int deviceId); + + void Release(); + + static bool GetInitStatus() { return initFlag_; } + + private: + static std::shared_ptr ptr_; + static bool initFlag_; + std::vector deviceIds_; + std::vector contexts_; + std::unordered_map deviceIdMap_; // Map of device to index +}; + +#endif \ No newline at end of file diff --git a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h index 779a5e6f26..0b1022a113 100644 --- a/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h +++ b/mindspore/ccsrc/minddata/dataset/kernels/tensor_op.h @@ -57,6 +57,7 @@ constexpr char kCenterCropOp[] = "CenterCropOp"; constexpr char kCutMixBatchOp[] = "CutMixBatchOp"; constexpr char kCutOutOp[] = "CutOutOp"; constexpr char kCropOp[] = "CropOp"; +constexpr char kDvppDecodeResizeCropJpegOp[] = "DvppDecodeResizeCropJpegOp"; constexpr char kEqualizeOp[] = "EqualizeOp"; constexpr char kHwcToChwOp[] = "HwcToChwOp"; constexpr char kInvertOp[] = "InvertOp"; diff --git a/tests/cxx_st/dataset/test_de.cc b/tests/cxx_st/dataset/test_de.cc index 242f55a857..2d0aaf923f 100644 --- a/tests/cxx_st/dataset/test_de.cc +++ b/tests/cxx_st/dataset/test_de.cc @@ -31,10 +31,10 @@ class TestDE : public ST::Common { TEST_F(TestDE, ResNetPreprocess) { std::vector> images; - MindDataEager::LoadImageFromDir("/home/workspace/mindspore_dataset/imagenet/imagenet_original/val/n01440764", &images); + MindDataEager::LoadImageFromDir("/home/workspace/mindspore_dataset/imagenet/imagenet_original/val/n01440764", + &images); - MindDataEager Compose({Decode(), - Resize({224, 224}), + MindDataEager Compose({Decode(), Resize({224, 224}), Normalize({0.485 * 255, 0.456 * 255, 0.406 * 255}, {0.229 * 255, 0.224 * 255, 0.225 * 255}), HWC2CHW()}); @@ -47,3 +47,19 @@ TEST_F(TestDE, ResNetPreprocess) { ASSERT_EQ(images[0]->Shape()[1], 224); ASSERT_EQ(images[0]->Shape()[2], 224); } + +TEST_F(TestDE, TestDvpp) { + std::vector> images; + MindDataEager::LoadImageFromDir("/root/Dvpp_Unit_Dev/val2014_test/", &images); + + MindDataEager Solo({DvppDecodeResizeCropJpeg({224, 224}, {256, 256})}); + + for (auto &img : images) { + img = Solo(img); + } + + ASSERT_EQ(images[0]->Shape().size(), 3); + ASSERT_EQ(images[0]->Shape()[0], 224 * 224 * 1.5); + ASSERT_EQ(images[0]->Shape()[1], 1); + ASSERT_EQ(images[0]->Shape()[2], 1); +}