!13780 add mnist stm32f746 example
From: @zoloft Reviewed-by: @wangchengyuan,@hangangqiang Signed-off-by: @wangchengyuanpull/13780/MERGE
commit
5d1a340fb4
@ -0,0 +1,71 @@
|
||||
/* USER CODE BEGIN Header */
|
||||
/**
|
||||
******************************************************************************
|
||||
* @file : main.h
|
||||
* @brief : Header for main.c file.
|
||||
* This file contains the common defines of the application.
|
||||
******************************************************************************
|
||||
* @attention
|
||||
*
|
||||
* <h2><center>© Copyright (c) 2021 STMicroelectronics.
|
||||
* All rights reserved.</center></h2>
|
||||
*
|
||||
* This software component is licensed by ST under BSD 3-Clause license,
|
||||
* the "License"; You may not use this file except in compliance with the
|
||||
* License. You may obtain a copy of the License at:
|
||||
* opensource.org/licenses/BSD-3-Clause
|
||||
*
|
||||
******************************************************************************
|
||||
*/
|
||||
/* USER CODE END Header */
|
||||
|
||||
/* Define to prevent recursive inclusion -------------------------------------*/
|
||||
#ifndef __MAIN_H
|
||||
#define __MAIN_H
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/* Includes ------------------------------------------------------------------*/
|
||||
#include "stm32f7xx_hal.h"
|
||||
|
||||
/* Private includes ----------------------------------------------------------*/
|
||||
/* USER CODE BEGIN Includes */
|
||||
|
||||
/* USER CODE END Includes */
|
||||
|
||||
/* Exported types ------------------------------------------------------------*/
|
||||
/* USER CODE BEGIN ET */
|
||||
|
||||
/* USER CODE END ET */
|
||||
|
||||
/* Exported constants --------------------------------------------------------*/
|
||||
/* USER CODE BEGIN EC */
|
||||
|
||||
/* USER CODE END EC */
|
||||
|
||||
/* Exported macro ------------------------------------------------------------*/
|
||||
/* USER CODE BEGIN EM */
|
||||
|
||||
/* USER CODE END EM */
|
||||
|
||||
/* Exported functions prototypes ---------------------------------------------*/
|
||||
void Error_Handler(void);
|
||||
|
||||
/* USER CODE BEGIN EFP */
|
||||
|
||||
/* USER CODE END EFP */
|
||||
|
||||
/* Private defines -----------------------------------------------------------*/
|
||||
/* USER CODE BEGIN Private defines */
|
||||
|
||||
/* USER CODE END Private defines */
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif /* __MAIN_H */
|
||||
|
||||
/************************ (C) COPYRIGHT STMicroelectronics *****END OF FILE****/
|
@ -0,0 +1,214 @@
|
||||
/* USER CODE BEGIN Header */
|
||||
/**
|
||||
******************************************************************************
|
||||
* @file : main.c
|
||||
* @brief : Main program body
|
||||
******************************************************************************
|
||||
* @attention
|
||||
*
|
||||
* <h2><center>© Copyright (c) 2021 STMicroelectronics.
|
||||
* All rights reserved.</center></h2>
|
||||
*
|
||||
* This software component is licensed by ST under BSD 3-Clause license,
|
||||
* the "License"; You may not use this file except in compliance with the
|
||||
* License. You may obtain a copy of the License at:
|
||||
* opensource.org/licenses/BSD-3-Clause
|
||||
*
|
||||
******************************************************************************
|
||||
*/
|
||||
/* USER CODE END Header */
|
||||
/* Includes ------------------------------------------------------------------*/
|
||||
#include "main.h"
|
||||
#include "SEGGER_RTT.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "include/lite_session.h"
|
||||
#include "include/ms_tensor.h"
|
||||
#include "mnist_input_data.h"
|
||||
// #include <stdio.h>
|
||||
|
||||
using namespace mindspore;
|
||||
/* Private includes ----------------------------------------------------------*/
|
||||
/* USER CODE BEGIN Includes */
|
||||
|
||||
/* USER CODE END Includes */
|
||||
|
||||
/* Private typedef -----------------------------------------------------------*/
|
||||
/* USER CODE BEGIN PTD */
|
||||
|
||||
/* USER CODE END PTD */
|
||||
|
||||
/* Private define ------------------------------------------------------------*/
|
||||
/* USER CODE BEGIN PD */
|
||||
/* USER CODE END PD */
|
||||
|
||||
/* Private macro -------------------------------------------------------------*/
|
||||
/* USER CODE BEGIN PM */
|
||||
|
||||
/* USER CODE END PM */
|
||||
|
||||
/* Private variables ---------------------------------------------------------*/
|
||||
|
||||
/* USER CODE BEGIN PV */
|
||||
|
||||
/* USER CODE END PV */
|
||||
|
||||
/* Private function prototypes -----------------------------------------------*/
|
||||
void SystemClock_Config(void);
|
||||
/* USER CODE BEGIN PFP */
|
||||
|
||||
/* USER CODE END PFP */
|
||||
|
||||
/* Private user code ---------------------------------------------------------*/
|
||||
/* USER CODE BEGIN 0 */
|
||||
|
||||
/* USER CODE END 0 */
|
||||
|
||||
/**
|
||||
* @brief The application entry point.
|
||||
* @retval int
|
||||
*/
|
||||
int main(void) {
|
||||
/* USER CODE BEGIN 1 */
|
||||
|
||||
/* USER CODE END 1 */
|
||||
|
||||
/* MCU Configuration--------------------------------------------------------*/
|
||||
|
||||
/* Reset of all peripherals, Initializes the Flash interface and the Systick. */
|
||||
HAL_Init();
|
||||
|
||||
/* USER CODE BEGIN Init */
|
||||
|
||||
/* USER CODE END Init */
|
||||
|
||||
/* Configure the system clock */
|
||||
SystemClock_Config();
|
||||
|
||||
/* USER CODE BEGIN SysInit */
|
||||
|
||||
/* USER CODE END SysInit */
|
||||
|
||||
/* Initialize all configured peripherals */
|
||||
/* USER CODE BEGIN 2 */
|
||||
|
||||
/* USER CODE END 2 */
|
||||
|
||||
/* Infinite loop */
|
||||
/* USER CODE BEGIN WHILE */
|
||||
// float inputs_binbuf[784] = {0};
|
||||
while (1) {
|
||||
/* USER CODE END WHILE */
|
||||
SEGGER_RTT_printf(0, "***********mnist test start***********\n");
|
||||
const char *model_buffer = nullptr;
|
||||
int model_size = 0;
|
||||
session::LiteSession *session = mindspore::session::LiteSession::CreateSession(model_buffer, model_size, nullptr);
|
||||
Vector<tensor::MSTensor *> inputs = session->GetInputs();
|
||||
size_t inputs_num = inputs.size();
|
||||
void *inputs_binbuf[inputs_num];
|
||||
int inputs_size[inputs_num];
|
||||
for (size_t i = 0; i < inputs_num; ++i) {
|
||||
inputs_size[i] = inputs[i]->Size();
|
||||
}
|
||||
// here mnist only have one input data,just hard code to it's array;
|
||||
inputs_binbuf[0] = mnist_inputs_data;
|
||||
for (size_t i = 0; i < inputs_num; ++i) {
|
||||
void *input_data = inputs[i]->MutableData();
|
||||
memcpy(input_data, inputs_binbuf[i], inputs_size[i]);
|
||||
}
|
||||
int ret = session->RunGraph();
|
||||
if (ret != lite::RET_OK) {
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
Vector<String> outputs_name = session->GetOutputTensorNames();
|
||||
for (int i = 0; i < outputs_name.size(); ++i) {
|
||||
tensor::MSTensor *output_tensor = session->GetOutputByTensorName(outputs_name[i]);
|
||||
if (output_tensor == nullptr) {
|
||||
return -1;
|
||||
}
|
||||
SEGGER_RTT_printf(0, "***********mnist test start5.2***********\n");
|
||||
float *casted_data = static_cast<float *>(output_tensor->MutableData());
|
||||
if (casted_data == nullptr) {
|
||||
return -1;
|
||||
}
|
||||
SEGGER_RTT_printf(0, "***********mnist test start5.3***********\n");
|
||||
for (size_t j = 0; j < 10 && j < output_tensor->ElementsNum(); j++) {
|
||||
SEGGER_RTT_printf(0, "output: [%d] is : [%d]/100\n", i, casted_data[i] * 100);
|
||||
}
|
||||
}
|
||||
delete session;
|
||||
SEGGER_RTT_printf(0, "***********mnist test end***********\n");
|
||||
/* USER CODE BEGIN 3 */
|
||||
}
|
||||
/* USER CODE END 3 */
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief System Clock Configuration
|
||||
* @retval None
|
||||
*/
|
||||
void SystemClock_Config(void) {
|
||||
RCC_OscInitTypeDef RCC_OscInitStruct = {0};
|
||||
RCC_ClkInitTypeDef RCC_ClkInitStruct = {0};
|
||||
|
||||
/** Configure the main internal regulator output voltage
|
||||
*/
|
||||
__HAL_RCC_PWR_CLK_ENABLE();
|
||||
__HAL_PWR_VOLTAGESCALING_CONFIG(PWR_REGULATOR_VOLTAGE_SCALE3);
|
||||
/** Initializes the RCC Oscillators according to the specified parameters
|
||||
* in the RCC_OscInitTypeDef structure.
|
||||
*/
|
||||
RCC_OscInitStruct.OscillatorType = RCC_OSCILLATORTYPE_HSI;
|
||||
RCC_OscInitStruct.HSIState = RCC_HSI_ON;
|
||||
RCC_OscInitStruct.HSICalibrationValue = RCC_HSICALIBRATION_DEFAULT;
|
||||
RCC_OscInitStruct.PLL.PLLState = RCC_PLL_NONE;
|
||||
if (HAL_RCC_OscConfig(&RCC_OscInitStruct) != HAL_OK) {
|
||||
Error_Handler();
|
||||
}
|
||||
/** Initializes the CPU, AHB and APB buses clocks
|
||||
*/
|
||||
RCC_ClkInitStruct.ClockType = RCC_CLOCKTYPE_HCLK | RCC_CLOCKTYPE_SYSCLK
|
||||
| RCC_CLOCKTYPE_PCLK1 | RCC_CLOCKTYPE_PCLK2;
|
||||
RCC_ClkInitStruct.SYSCLKSource = RCC_SYSCLKSOURCE_HSI;
|
||||
RCC_ClkInitStruct.AHBCLKDivider = RCC_SYSCLK_DIV1;
|
||||
RCC_ClkInitStruct.APB1CLKDivider = RCC_HCLK_DIV1;
|
||||
RCC_ClkInitStruct.APB2CLKDivider = RCC_HCLK_DIV1;
|
||||
|
||||
if (HAL_RCC_ClockConfig(&RCC_ClkInitStruct, FLASH_LATENCY_0) != HAL_OK) {
|
||||
Error_Handler();
|
||||
}
|
||||
}
|
||||
|
||||
/* USER CODE BEGIN 4 */
|
||||
|
||||
/* USER CODE END 4 */
|
||||
|
||||
/**
|
||||
* @brief This function is executed in case of error occurrence.
|
||||
* @retval None
|
||||
*/
|
||||
void Error_Handler(void) {
|
||||
/* USER CODE BEGIN Error_Handler_Debug */
|
||||
/* User can add his own implementation to report the HAL error return state */
|
||||
__disable_irq();
|
||||
while (1) {
|
||||
}
|
||||
/* USER CODE END Error_Handler_Debug */
|
||||
}
|
||||
|
||||
#ifdef USE_FULL_ASSERT
|
||||
/**
|
||||
* @brief Reports the name of the source file and the source line number
|
||||
* where the assert_param error has occurred.
|
||||
* @param file: pointer to the source file name
|
||||
* @param line: assert_param error line source number
|
||||
* @retval None
|
||||
*/
|
||||
void assert_failed(uint8_t *file, uint32_t line) {
|
||||
/* USER CODE BEGIN 6 */
|
||||
/* User can add his own implementation to report the file name and line number,
|
||||
ex: printf("Wrong parameters value: file %s on line %d\r\n", file, line) */
|
||||
/* USER CODE END 6 */
|
||||
}
|
||||
#endif /* USE_FULL_ASSERT */
|
||||
|
||||
/************************ (C) COPYRIGHT STMicroelectronics *****END OF FILE****/
|
@ -0,0 +1,59 @@
|
||||
|
||||
|
||||
cmake_minimum_required(VERSION 3.14)
|
||||
project(benchmark)
|
||||
|
||||
if(NOT DEFINED PKG_PATH)
|
||||
message(FATAL_ERROR "PKG_PATH not set")
|
||||
endif()
|
||||
|
||||
get_filename_component(PKG_PATH ${PKG_PATH} ABSOLUTE BASE_DIR ${CMAKE_CURRENT_BINARY_DIR})
|
||||
|
||||
set(HEADER_PATH ${PKG_PATH}/inference)
|
||||
|
||||
option(MICRO_BUILD_ARM64 "build android arm64" OFF)
|
||||
option(MICRO_BUILD_ARM32A "build android arm32" OFF)
|
||||
|
||||
add_compile_definitions(NOT_USE_STL)
|
||||
|
||||
if(MICRO_BUILD_ARM64 OR MICRO_BUILD_ARM32A)
|
||||
add_compile_definitions(ENABLE_NEON)
|
||||
add_compile_definitions(ENABLE_ARM)
|
||||
endif()
|
||||
|
||||
if(MICRO_BUILD_ARM64)
|
||||
add_compile_definitions(ENABLE_ARM64)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+dotprod")
|
||||
endif()
|
||||
|
||||
if(MICRO_BUILD_ARM32A)
|
||||
add_compile_definitions(ENABLE_ARM32)
|
||||
add_definitions(-mfloat-abi=softfp -mfpu=neon)
|
||||
endif()
|
||||
|
||||
set(CMAKE_C_FLAGS "${CMAKE_ENABLE_C99} ${CMAKE_C_FLAGS}")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
|
||||
if("${CMAKE_BUILD_TYPE}" STREQUAL "Debug")
|
||||
message(STATUS "build benchmark with debug info")
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DDebug -g")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DDebug -g")
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fvisibility=default")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=default")
|
||||
else()
|
||||
set(CMAKE_C_FLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O2 -Wall -Werror -fstack-protector-strong -Wno-attributes \
|
||||
-Wno-deprecated-declarations -Wno-missing-braces ${CMAKE_C_FLAGS}")
|
||||
set(CMAKE_CXX_FLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O2 -Wall -Werror -fstack-protector-strong -Wno-attributes \
|
||||
-Wno-deprecated-declarations -Wno-missing-braces -Wno-overloaded-virtual ${CMAKE_CXX_FLAGS}")
|
||||
endif()
|
||||
|
||||
add_subdirectory(src)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../src/)
|
||||
include_directories(${HEADER_PATH})
|
||||
set(SRC_FILES
|
||||
benchmark/benchmark.cc
|
||||
benchmark/load_input.c
|
||||
)
|
||||
add_executable(benchmark ${SRC_FILES})
|
||||
target_link_libraries(benchmark net -lm -pthread)
|
||||
|
@ -0,0 +1,65 @@
|
||||
|
||||
cmake_minimum_required(VERSION 3.14)
|
||||
project(benchmark)
|
||||
|
||||
if(NOT DEFINED MODEL_LIB)
|
||||
message(FATAL_ERROR "MODEL_LIB not set")
|
||||
endif()
|
||||
|
||||
if(NOT DEFINED HEADER_PATH)
|
||||
message(FATAL_ERROR "HEADER_PATH not set")
|
||||
endif()
|
||||
|
||||
get_filename_component(MODEL_LIB ${MODEL_LIB} ABSOLUTE BASE_DIR ${CMAKE_CURRENT_BINARY_DIR})
|
||||
get_filename_component(HEADER_PATH ${HEADER_PATH} ABSOLUTE BASE_DIR ${CMAKE_CURRENT_BINARY_DIR})
|
||||
|
||||
function(parse_lib_info lib_full_path lib_name lib_path)
|
||||
string(FIND "${lib_full_path}" "/" POS REVERSE)
|
||||
math(EXPR POS "${POS} + 1")
|
||||
string(SUBSTRING ${lib_full_path} 0 ${POS} path)
|
||||
set(${lib_path} ${path} PARENT_SCOPE)
|
||||
string(SUBSTRING ${lib_full_path} "${POS}" "-1" name)
|
||||
set(${lib_name} ${name} PARENT_SCOPE)
|
||||
endfunction(parse_lib_info)
|
||||
|
||||
parse_lib_info(${MODEL_LIB} MODEL_LIB_NAME MODEL_LIB_PATH)
|
||||
|
||||
message("project name: ${MODEL_LIB_NAME}")
|
||||
|
||||
option(MICRO_BUILD_ARM64 "build android arm64" OFF)
|
||||
option(MICRO_BUILD_ARM32A "build android arm32" OFF)
|
||||
|
||||
if(MICRO_BUILD_ARM64 OR MICRO_BUILD_ARM32A)
|
||||
add_compile_definitions(ENABLE_NEON)
|
||||
add_compile_definitions(ENABLE_ARM)
|
||||
endif()
|
||||
|
||||
if(MICRO_BUILD_ARM64)
|
||||
add_compile_definitions(ENABLE_ARM64)
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -march=armv8.2-a+dotprod")
|
||||
endif()
|
||||
|
||||
if(MICRO_BUILD_ARM32A)
|
||||
add_compile_definitions(ENABLE_ARM32)
|
||||
add_definitions(-mfloat-abi=softfp -mfpu=neon)
|
||||
endif()
|
||||
|
||||
set(CMAKE_C_FLAGS "${CMAKE_ENABLE_C99} ${CMAKE_C_FLAGS}")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17")
|
||||
if("${CMAKE_BUILD_TYPE}" STREQUAL "Debug")
|
||||
message(STATUS "build benchmark with debug info")
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DDebug -g")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DDebug -g")
|
||||
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -fvisibility=default")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=default")
|
||||
else()
|
||||
set(CMAKE_C_FLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O3 -Wall -Werror -fstack-protector-strong -Wno-attributes \
|
||||
-Wno-deprecated-declarations -Wno-missing-braces ${CMAKE_C_FLAGS}")
|
||||
set(CMAKE_CXX_FLAGS "-fPIC -fPIE -D_FORTIFY_SOURCE=2 -O3 -Wall -Werror -fstack-protector-strong -Wno-attributes \
|
||||
-Wno-deprecated-declarations -Wno-missing-braces -Wno-overloaded-virtual ${CMAKE_CXX_FLAGS}")
|
||||
endif()
|
||||
link_directories(${MODEL_LIB_PATH})
|
||||
include(benchmark.cmake)
|
||||
add_executable(benchmark ${SRC_FILES})
|
||||
target_link_libraries(benchmark ${MODEL_LIB_NAME} -lm -pthread)
|
||||
|
@ -0,0 +1,136 @@
|
||||
|
||||
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <cstring>
|
||||
|
||||
#include "include/lite_session.h"
|
||||
#include "include/ms_tensor.h"
|
||||
#include "include/errorcode.h"
|
||||
|
||||
#include "load_input.h"
|
||||
|
||||
using namespace mindspore;
|
||||
|
||||
void usage() {
|
||||
printf(
|
||||
"-- mindspore benchmark params usage:\n"
|
||||
"args[0]: executable file\n"
|
||||
"args[1]: inputs binary file\n"
|
||||
"args[2]: model weight binary file\n"
|
||||
"args[3]: loop count for performance test\n"
|
||||
"args[4]: runtime thread num\n"
|
||||
"args[5]: runtime thread bind mode\n\n");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void PrintData(void *data, size_t data_number) {
|
||||
if (data == nullptr) {
|
||||
return;
|
||||
}
|
||||
auto casted_data = static_cast<T *>(data);
|
||||
for (size_t i = 0; i < 10 && i < data_number; i++) {
|
||||
std::cout << std::to_string(casted_data[i]) << ", ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
void TensorToString(tensor::MSTensor *tensor) {
|
||||
std::cout << ", DataType: " << tensor->data_type();
|
||||
std::cout << ", Size: " << tensor->Size();
|
||||
std::cout << ", Shape:";
|
||||
for (auto &dim : tensor->shape()) {
|
||||
std::cout << " " << dim;
|
||||
}
|
||||
std::cout << ", Data:" << std::endl;
|
||||
switch (tensor->data_type()) {
|
||||
case kNumberTypeFloat32: {
|
||||
PrintData<float>(tensor->MutableData(), tensor->ElementsNum());
|
||||
} break;
|
||||
case kNumberTypeFloat16: {
|
||||
PrintData<int16_t>(tensor->MutableData(), tensor->ElementsNum());
|
||||
} break;
|
||||
case kNumberTypeInt32: {
|
||||
PrintData<int32_t>(tensor->MutableData(), tensor->ElementsNum());
|
||||
} break;
|
||||
case kNumberTypeInt16: {
|
||||
PrintData<int16_t>(tensor->MutableData(), tensor->ElementsNum());
|
||||
} break;
|
||||
case kNumberTypeInt8: {
|
||||
PrintData<int8_t>(tensor->MutableData(), tensor->ElementsNum());
|
||||
} break;
|
||||
case kNumberTypeUInt8: {
|
||||
PrintData<uint8_t>(tensor->MutableData(), tensor->ElementsNum());
|
||||
} break;
|
||||
default:
|
||||
std::cout << "Unsupported data type to print" << std::endl;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, const char **argv) {
|
||||
if (argc < 2) {
|
||||
std::cout << "input command is invalid\n" << std::endl;
|
||||
usage();
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
std::cout << "start run benchmark" << std::endl;
|
||||
|
||||
const char *model_buffer = nullptr;
|
||||
int model_size = 0;
|
||||
// read .bin file by ReadBinaryFile;
|
||||
if (argc >= 3) {
|
||||
model_buffer = static_cast<const char *>(ReadInputData(argv[2], &model_size));
|
||||
}
|
||||
session::LiteSession *session = mindspore::session::LiteSession::CreateSession(model_buffer, model_size, nullptr);
|
||||
if (session == nullptr) {
|
||||
std::cerr << "create lite session failed" << std::endl;
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
||||
// set model inputs tensor data
|
||||
Vector<tensor::MSTensor *> inputs = session->GetInputs();
|
||||
size_t inputs_num = inputs.size();
|
||||
void *inputs_binbuf[inputs_num];
|
||||
int inputs_size[inputs_num];
|
||||
for (size_t i = 0; i < inputs_num; ++i) {
|
||||
inputs_size[i] = inputs[i]->Size();
|
||||
}
|
||||
int ret = ReadInputsFile(const_cast<char *>(argv[1]), inputs_binbuf, inputs_size, inputs_num);
|
||||
if (ret != lite::RET_OK) {
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
for (size_t i = 0; i < inputs_num; ++i) {
|
||||
void *input_data = inputs[i]->MutableData();
|
||||
memcpy(input_data, inputs_binbuf[i], inputs_size[i]);
|
||||
}
|
||||
|
||||
ret = session->RunGraph();
|
||||
if (ret != lite::RET_OK) {
|
||||
return lite::RET_ERROR;
|
||||
}
|
||||
|
||||
std::cout << "run benchmark success" << std::endl;
|
||||
delete session;
|
||||
for (size_t i = 0; i < inputs_num; ++i) {
|
||||
free(inputs_binbuf[i]);
|
||||
}
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
@ -0,0 +1,7 @@
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../src/)
|
||||
include_directories(${HEADER_PATH})
|
||||
set(SRC_FILES
|
||||
benchmark.cc
|
||||
load_input.c
|
||||
)
|
@ -0,0 +1,95 @@
|
||||
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "load_input.h"
|
||||
#include <stdlib.h>
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
|
||||
void *ReadInputData(const char *real_input_path, int *size) {
|
||||
if (real_input_path == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
if (strstr(real_input_path, ".bin") || strstr(real_input_path, ".net")) {
|
||||
FILE *file;
|
||||
file = fopen(real_input_path, "rb+");
|
||||
if (!file) {
|
||||
printf("Can't find %s\n", real_input_path);
|
||||
return NULL;
|
||||
}
|
||||
int curr_file_posi = ftell(file);
|
||||
fseek(file, 0, SEEK_END);
|
||||
*size = ftell(file);
|
||||
unsigned char *buf = malloc((*size));
|
||||
(void)memset(buf, 0, (*size));
|
||||
fseek(file, curr_file_posi, SEEK_SET);
|
||||
int read_size = (int)(fread(buf, 1, *size, file));
|
||||
if (read_size != (*size)) {
|
||||
printf("read file failed, total file size: %d, read_size: %d\n", (*size), read_size);
|
||||
fclose(file);
|
||||
free(buf);
|
||||
return NULL;
|
||||
}
|
||||
fclose(file);
|
||||
return (void *)buf;
|
||||
} else {
|
||||
printf("input data file should be .bin , .net");
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
|
||||
void SaveOutputData(char *final_name, unsigned char *output_data, unsigned int out_size) {
|
||||
FILE *output_file;
|
||||
output_file = fopen(final_name, "w");
|
||||
if (output_file == NULL) {
|
||||
printf("fopen output file: %s failed\n", final_name);
|
||||
return;
|
||||
}
|
||||
unsigned char str[out_size];
|
||||
for (unsigned int i = 0; i < out_size; ++i) {
|
||||
str[i] = output_data[i];
|
||||
fprintf(output_file, "%d\t", str[i]);
|
||||
}
|
||||
fclose(output_file);
|
||||
}
|
||||
|
||||
int ReadInputsFile(char *path, void **buffers, const int *inputs_size, int inputs_num) {
|
||||
char *inputs_path[inputs_num];
|
||||
char *delim = ",";
|
||||
char *token;
|
||||
int i = 0;
|
||||
while ((token = strtok_r(path, delim, &path))) {
|
||||
if (i >= inputs_num) {
|
||||
printf("inputs num is error, need: %d\n", inputs_num);
|
||||
return -1;
|
||||
}
|
||||
inputs_path[i] = token;
|
||||
printf("input %d: %s\n", i, inputs_path[i]);
|
||||
i++;
|
||||
}
|
||||
|
||||
for (i = 0; i < inputs_num; ++i) {
|
||||
int size = 0;
|
||||
buffers[i] = ReadInputData(inputs_path[i], &size);
|
||||
if (size != inputs_size[i] || buffers[i] == NULL) {
|
||||
printf("size mismatch, %s, input: %d, needed: %d\n", inputs_path[i], size, inputs_size[i]);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
@ -0,0 +1,36 @@
|
||||
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MICRO_EXAMPLE_LOAD_INPUT_LOAD_INPUT_H_
|
||||
#define MICRO_EXAMPLE_LOAD_INPUT_LOAD_INPUT_H_
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
void *ReadInputData(const char *real_input_path, int *size);
|
||||
|
||||
void SaveOutputData(char *final_name, unsigned char *output_data, unsigned int out_size);
|
||||
|
||||
int ReadInputsFile(char *path, void **buffers, const int *inputs_size, int inputs_num);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // MICRO_EXAMPLE_LOAD_INPUT_LOAD_INPUT_H_
|
||||
|
@ -0,0 +1,133 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INCLUDE_API_CELL_H
|
||||
#define MINDSPORE_INCLUDE_API_CELL_H
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include "include/api/status.h"
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/graph.h"
|
||||
|
||||
namespace mindspore {
|
||||
class InputAndOutput;
|
||||
using Input = InputAndOutput;
|
||||
using Output = InputAndOutput;
|
||||
|
||||
class MS_API CellBase {
|
||||
public:
|
||||
CellBase() = default;
|
||||
virtual ~CellBase() = default;
|
||||
virtual std::vector<Output> Construct(const std::vector<Input> &inputs) { return {}; }
|
||||
virtual std::shared_ptr<CellBase> Clone() const = 0;
|
||||
virtual Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) { return kSuccess; }
|
||||
std::vector<Output> operator()(const std::vector<Input> &inputs) const;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
class MS_API Cell : public CellBase {
|
||||
public:
|
||||
virtual ~Cell() = default;
|
||||
std::shared_ptr<CellBase> Clone() const override { return std::make_shared<T>(static_cast<const T &>(*this)); }
|
||||
};
|
||||
|
||||
class MS_API ParameterCell final : public Cell<ParameterCell> {
|
||||
public:
|
||||
ParameterCell() = default;
|
||||
~ParameterCell() override = default;
|
||||
|
||||
ParameterCell(const ParameterCell &);
|
||||
ParameterCell &operator=(const ParameterCell &);
|
||||
|
||||
ParameterCell(ParameterCell &&);
|
||||
ParameterCell &operator=(ParameterCell &&);
|
||||
|
||||
explicit ParameterCell(const MSTensor &);
|
||||
ParameterCell &operator=(const MSTensor &);
|
||||
|
||||
explicit ParameterCell(MSTensor &&);
|
||||
ParameterCell &operator=(MSTensor &&);
|
||||
|
||||
MSTensor GetTensor() const { return tensor_; }
|
||||
|
||||
private:
|
||||
MSTensor tensor_;
|
||||
};
|
||||
|
||||
class MS_API OpCellBase : public CellBase {
|
||||
public:
|
||||
explicit OpCellBase(const std::string &name) : name_(name) {}
|
||||
~OpCellBase() override = default;
|
||||
const std::string &GetOpType() const { return name_; }
|
||||
|
||||
protected:
|
||||
std::string name_;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
class MS_API OpCell : public OpCellBase, public std::enable_shared_from_this<T> {
|
||||
public:
|
||||
explicit OpCell(const std::string &name) : OpCellBase(name) {}
|
||||
~OpCell() override = default;
|
||||
std::shared_ptr<CellBase> Clone() const override { return std::make_shared<T>(static_cast<const T &>(*this)); }
|
||||
};
|
||||
|
||||
class MS_API GraphCell final : public Cell<GraphCell> {
|
||||
public:
|
||||
class GraphImpl;
|
||||
|
||||
GraphCell() = default;
|
||||
~GraphCell() override = default;
|
||||
|
||||
explicit GraphCell(const Graph &);
|
||||
explicit GraphCell(Graph &&);
|
||||
explicit GraphCell(const std::shared_ptr<Graph> &);
|
||||
|
||||
const std::shared_ptr<Graph> &GetGraph() const { return graph_; }
|
||||
Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
|
||||
std::vector<MSTensor> GetInputs();
|
||||
std::vector<MSTensor> GetOutputs();
|
||||
|
||||
private:
|
||||
friend class ModelImpl;
|
||||
Status Load();
|
||||
|
||||
std::shared_ptr<Graph> graph_;
|
||||
std::shared_ptr<GraphImpl> executor_;
|
||||
};
|
||||
|
||||
class MS_API InputAndOutput {
|
||||
public:
|
||||
InputAndOutput();
|
||||
~InputAndOutput() = default;
|
||||
|
||||
// no explicit
|
||||
InputAndOutput(const MSTensor &); // NOLINT(runtime/explicit)
|
||||
InputAndOutput(MSTensor &&); // NOLINT(runtime/explicit)
|
||||
|
||||
InputAndOutput(const std::shared_ptr<CellBase> &, const std::vector<InputAndOutput> &, int32_t index);
|
||||
|
||||
int32_t GetIndex() const { return index_; }
|
||||
void SetIndex(int32_t index) { index_ = index; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<CellBase> cell_;
|
||||
std::vector<InputAndOutput> prev_;
|
||||
int32_t index_;
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_CELL_H
|
@ -0,0 +1,185 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INCLUDE_API_CONTEXT_H
|
||||
#define MINDSPORE_INCLUDE_API_CONTEXT_H
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/dual_abi_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
constexpr auto kDeviceTypeAscend310 = "Ascend310";
|
||||
constexpr auto kDeviceTypeAscend910 = "Ascend910";
|
||||
constexpr auto kDeviceTypeGPU = "GPU";
|
||||
|
||||
struct MS_API Context {
|
||||
public:
|
||||
Context();
|
||||
virtual ~Context() = default;
|
||||
struct Data;
|
||||
std::shared_ptr<Data> data;
|
||||
};
|
||||
|
||||
struct MS_API GlobalContext : public Context {
|
||||
public:
|
||||
static std::shared_ptr<Context> GetGlobalContext();
|
||||
|
||||
static inline void SetGlobalDeviceTarget(const std::string &device_target);
|
||||
static inline std::string GetGlobalDeviceTarget();
|
||||
|
||||
static void SetGlobalDeviceID(const uint32_t &device_id);
|
||||
static uint32_t GetGlobalDeviceID();
|
||||
|
||||
static inline void SetGlobalDumpConfigPath(const std::string &cfg_path);
|
||||
static inline std::string GetGlobalDumpConfigPath();
|
||||
|
||||
private:
|
||||
// api without std::string
|
||||
static void SetGlobalDeviceTarget(const std::vector<char> &device_target);
|
||||
static std::vector<char> GetGlobalDeviceTargetChar();
|
||||
|
||||
static void SetGlobalDumpConfigPath(const std::vector<char> &cfg_path);
|
||||
static std::vector<char> GetGlobalDumpConfigPathChar();
|
||||
};
|
||||
|
||||
struct MS_API ModelContext : public Context {
|
||||
public:
|
||||
static inline void SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path);
|
||||
static inline std::string GetInsertOpConfigPath(const std::shared_ptr<Context> &context);
|
||||
|
||||
static inline void SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format);
|
||||
static inline std::string GetInputFormat(const std::shared_ptr<Context> &context);
|
||||
|
||||
static inline void SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape);
|
||||
static inline std::string GetInputShape(const std::shared_ptr<Context> &context);
|
||||
|
||||
static void SetInputShapeMap(const std::shared_ptr<Context> &context, const std::map<int, std::vector<int>> &shape);
|
||||
static std::map<int, std::vector<int>> GetInputShapeMap(const std::shared_ptr<Context> &context);
|
||||
|
||||
static void SetDynamicBatchSize(const std::shared_ptr<Context> &context,
|
||||
const std::vector<size_t> &dynamic_batch_size);
|
||||
static inline std::string GetDynamicBatchSize(const std::shared_ptr<Context> &context);
|
||||
|
||||
static void SetOutputType(const std::shared_ptr<Context> &context, enum DataType output_type);
|
||||
static enum DataType GetOutputType(const std::shared_ptr<Context> &context);
|
||||
|
||||
static inline void SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode);
|
||||
static inline std::string GetPrecisionMode(const std::shared_ptr<Context> &context);
|
||||
|
||||
static inline void SetOpSelectImplMode(const std::shared_ptr<Context> &context,
|
||||
const std::string &op_select_impl_mode);
|
||||
static inline std::string GetOpSelectImplMode(const std::shared_ptr<Context> &context);
|
||||
|
||||
static inline void SetFusionSwitchConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path);
|
||||
static inline std::string GetFusionSwitchConfigPath(const std::shared_ptr<Context> &context);
|
||||
|
||||
static inline void SetGpuTrtInferMode(const std::shared_ptr<Context> &context, const std::string &gpu_trt_infer_mode);
|
||||
static inline std::string GetGpuTrtInferMode(const std::shared_ptr<Context> &context);
|
||||
|
||||
private:
|
||||
// api without std::string
|
||||
static void SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::vector<char> &cfg_path);
|
||||
static std::vector<char> GetInsertOpConfigPathChar(const std::shared_ptr<Context> &context);
|
||||
|
||||
static void SetInputFormat(const std::shared_ptr<Context> &context, const std::vector<char> &format);
|
||||
static std::vector<char> GetInputFormatChar(const std::shared_ptr<Context> &context);
|
||||
|
||||
static void SetInputShape(const std::shared_ptr<Context> &context, const std::vector<char> &shape);
|
||||
static std::vector<char> GetInputShapeChar(const std::shared_ptr<Context> &context);
|
||||
|
||||
static void SetPrecisionMode(const std::shared_ptr<Context> &context, const std::vector<char> &precision_mode);
|
||||
static std::vector<char> GetPrecisionModeChar(const std::shared_ptr<Context> &context);
|
||||
|
||||
static void SetOpSelectImplMode(const std::shared_ptr<Context> &context,
|
||||
const std::vector<char> &op_select_impl_mode);
|
||||
static std::vector<char> GetOpSelectImplModeChar(const std::shared_ptr<Context> &context);
|
||||
|
||||
static void SetFusionSwitchConfigPath(const std::shared_ptr<Context> &context, const std::vector<char> &cfg_path);
|
||||
static std::vector<char> GetFusionSwitchConfigPathChar(const std::shared_ptr<Context> &context);
|
||||
|
||||
static void SetGpuTrtInferMode(const std::shared_ptr<Context> &context, const std::vector<char> &gpu_trt_infer_mode);
|
||||
static std::vector<char> GetGpuTrtInferModeChar(const std::shared_ptr<Context> &context);
|
||||
static std::vector<char> GetDynamicBatchSizeChar(const std::shared_ptr<Context> &context);
|
||||
};
|
||||
|
||||
void GlobalContext::SetGlobalDeviceTarget(const std::string &device_target) {
|
||||
SetGlobalDeviceTarget(StringToChar(device_target));
|
||||
}
|
||||
std::string GlobalContext::GetGlobalDeviceTarget() { return CharToString(GetGlobalDeviceTargetChar()); }
|
||||
|
||||
void GlobalContext::SetGlobalDumpConfigPath(const std::string &cfg_path) {
|
||||
SetGlobalDumpConfigPath(StringToChar(cfg_path));
|
||||
}
|
||||
std::string GlobalContext::GetGlobalDumpConfigPath() { return CharToString(GetGlobalDumpConfigPathChar()); }
|
||||
|
||||
void ModelContext::SetInsertOpConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path) {
|
||||
SetInsertOpConfigPath(context, StringToChar(cfg_path));
|
||||
}
|
||||
std::string ModelContext::GetInsertOpConfigPath(const std::shared_ptr<Context> &context) {
|
||||
return CharToString(GetInsertOpConfigPathChar(context));
|
||||
}
|
||||
|
||||
void ModelContext::SetInputFormat(const std::shared_ptr<Context> &context, const std::string &format) {
|
||||
SetInputFormat(context, StringToChar(format));
|
||||
}
|
||||
std::string ModelContext::GetInputFormat(const std::shared_ptr<Context> &context) {
|
||||
return CharToString(GetInputFormatChar(context));
|
||||
}
|
||||
|
||||
void ModelContext::SetInputShape(const std::shared_ptr<Context> &context, const std::string &shape) {
|
||||
SetInputShape(context, StringToChar(shape));
|
||||
}
|
||||
std::string ModelContext::GetInputShape(const std::shared_ptr<Context> &context) {
|
||||
return CharToString(GetInputShapeChar(context));
|
||||
}
|
||||
|
||||
void ModelContext::SetPrecisionMode(const std::shared_ptr<Context> &context, const std::string &precision_mode) {
|
||||
SetPrecisionMode(context, StringToChar(precision_mode));
|
||||
}
|
||||
std::string ModelContext::GetPrecisionMode(const std::shared_ptr<Context> &context) {
|
||||
return CharToString(GetPrecisionModeChar(context));
|
||||
}
|
||||
|
||||
void ModelContext::SetOpSelectImplMode(const std::shared_ptr<Context> &context,
|
||||
const std::string &op_select_impl_mode) {
|
||||
SetOpSelectImplMode(context, StringToChar(op_select_impl_mode));
|
||||
}
|
||||
std::string ModelContext::GetOpSelectImplMode(const std::shared_ptr<Context> &context) {
|
||||
return CharToString(GetOpSelectImplModeChar(context));
|
||||
}
|
||||
|
||||
void ModelContext::SetFusionSwitchConfigPath(const std::shared_ptr<Context> &context, const std::string &cfg_path) {
|
||||
SetFusionSwitchConfigPath(context, StringToChar(cfg_path));
|
||||
}
|
||||
std::string ModelContext::GetFusionSwitchConfigPath(const std::shared_ptr<Context> &context) {
|
||||
return CharToString(GetFusionSwitchConfigPathChar(context));
|
||||
}
|
||||
|
||||
std::string ModelContext::GetDynamicBatchSize(const std::shared_ptr<Context> &context) {
|
||||
return CharToString(GetDynamicBatchSizeChar(context));
|
||||
}
|
||||
|
||||
void ModelContext::SetGpuTrtInferMode(const std::shared_ptr<Context> &context, const std::string &gpu_trt_infer_mode) {
|
||||
SetGpuTrtInferMode(context, StringToChar(gpu_trt_infer_mode));
|
||||
}
|
||||
std::string ModelContext::GetGpuTrtInferMode(const std::shared_ptr<Context> &context) {
|
||||
return CharToString(GetGpuTrtInferModeChar(context));
|
||||
}
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_CONTEXT_H
|
@ -0,0 +1,43 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INCLUDE_API_DATA_TYPE_H_
|
||||
#define MINDSPORE_INCLUDE_API_DATA_TYPE_H_
|
||||
|
||||
namespace mindspore {
|
||||
enum class DataType : int {
|
||||
kTypeUnknown = 0,
|
||||
kObjectTypeString = 12,
|
||||
kObjectTypeList = 13,
|
||||
kObjectTypeTuple = 14,
|
||||
kObjectTypeTensorType = 17,
|
||||
kNumberTypeBool = 30,
|
||||
kNumberTypeInt8 = 32,
|
||||
kNumberTypeInt16 = 33,
|
||||
kNumberTypeInt32 = 34,
|
||||
kNumberTypeInt64 = 35,
|
||||
kNumberTypeUInt8 = 37,
|
||||
kNumberTypeUInt16 = 38,
|
||||
kNumberTypeUInt32 = 39,
|
||||
kNumberTypeUInt64 = 40,
|
||||
kNumberTypeFloat16 = 42,
|
||||
kNumberTypeFloat32 = 43,
|
||||
kNumberTypeFloat64 = 44,
|
||||
kNumberTypeEnd = 46,
|
||||
// add new enum here
|
||||
kInvalidType = INT32_MAX,
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_DATA_TYPE_H_
|
@ -0,0 +1,164 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INCLUDE_API_DUAL_ABI_HELPER_H_
|
||||
#define MINDSPORE_INCLUDE_API_DUAL_ABI_HELPER_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace mindspore {
|
||||
inline std::vector<char> StringToChar(const std::string &s) { return std::vector<char>(s.begin(), s.end()); }
|
||||
|
||||
inline std::string CharToString(const std::vector<char> &c) { return std::string(c.begin(), c.end()); }
|
||||
|
||||
inline std::optional<std::vector<char>> OptionalStringToChar(const std::optional<std::string> &s) {
|
||||
if (s == std::nullopt) return std::nullopt;
|
||||
std::optional<std::vector<char>> ret = std::vector<char>(s->begin(), s->end());
|
||||
return ret;
|
||||
}
|
||||
|
||||
inline std::optional<std::string> OptionalCharToString(const std::optional<std::vector<char>> &c) {
|
||||
if (c == std::nullopt) return std::nullopt;
|
||||
std::optional<std::string> ret = std::string(c->begin(), c->end());
|
||||
return ret;
|
||||
}
|
||||
|
||||
inline std::pair<std::vector<char>, int32_t> PairStringToChar(const std::pair<std::string, int32_t> &s) {
|
||||
return std::pair<std::vector<char>, int32_t>(std::vector<char>(s.first.begin(), s.first.end()), s.second);
|
||||
}
|
||||
|
||||
inline std::pair<std::string, int32_t> PairCharToString(const std::pair<std::vector<char>, int32_t> &c) {
|
||||
return std::pair<std::string, int32_t>(std::string(c.first.begin(), c.first.end()), c.second);
|
||||
}
|
||||
|
||||
inline std::vector<std::vector<char>> VectorStringToChar(const std::vector<std::string> &s) {
|
||||
std::vector<std::vector<char>> ret;
|
||||
std::transform(s.begin(), s.end(), std::back_inserter(ret),
|
||||
[](auto str) { return std::vector<char>(str.begin(), str.end()); });
|
||||
return ret;
|
||||
}
|
||||
|
||||
inline std::vector<std::string> VectorCharToString(const std::vector<std::vector<char>> &c) {
|
||||
std::vector<std::string> ret;
|
||||
std::transform(c.begin(), c.end(), std::back_inserter(ret),
|
||||
[](auto ch) { return std::string(ch.begin(), ch.end()); });
|
||||
return ret;
|
||||
}
|
||||
|
||||
inline std::set<std::vector<char>> SetStringToChar(const std::set<std::string> &s) {
|
||||
std::set<std::vector<char>> ret;
|
||||
std::transform(s.begin(), s.end(), std::inserter(ret, ret.begin()),
|
||||
[](auto str) { return std::vector<char>(str.begin(), str.end()); });
|
||||
return ret;
|
||||
}
|
||||
|
||||
inline std::set<std::string> SetCharToString(const std::set<std::vector<char>> &c) {
|
||||
std::set<std::string> ret;
|
||||
std::transform(c.begin(), c.end(), std::inserter(ret, ret.begin()),
|
||||
[](auto ch) { return std::string(ch.begin(), ch.end()); });
|
||||
return ret;
|
||||
}
|
||||
|
||||
inline std::map<std::vector<char>, int32_t> MapStringToChar(const std::map<std::string, int32_t> &s) {
|
||||
std::map<std::vector<char>, int32_t> ret;
|
||||
std::transform(s.begin(), s.end(), std::inserter(ret, ret.begin()), [](auto str) {
|
||||
return std::pair<std::vector<char>, int32_t>(std::vector<char>(str.first.begin(), str.first.end()), str.second);
|
||||
});
|
||||
return ret;
|
||||
}
|
||||
|
||||
inline std::map<std::string, int32_t> MapCharToString(const std::map<std::vector<char>, int32_t> &c) {
|
||||
std::map<std::string, int32_t> ret;
|
||||
std::transform(c.begin(), c.end(), std::inserter(ret, ret.begin()), [](auto ch) {
|
||||
return std::pair<std::string, int32_t>(std::string(ch.first.begin(), ch.first.end()), ch.second);
|
||||
});
|
||||
return ret;
|
||||
}
|
||||
|
||||
inline std::map<std::vector<char>, std::vector<char>> UnorderedMapStringToChar(
|
||||
const std::unordered_map<std::string, std::string> &s) {
|
||||
std::map<std::vector<char>, std::vector<char>> ret;
|
||||
std::transform(s.begin(), s.end(), std::inserter(ret, ret.begin()), [](auto str) {
|
||||
return std::pair<std::vector<char>, std::vector<char>>(std::vector<char>(str.first.begin(), str.first.end()),
|
||||
std::vector<char>(str.second.begin(), str.second.end()));
|
||||
});
|
||||
return ret;
|
||||
}
|
||||
|
||||
inline std::unordered_map<std::string, std::string> UnorderedMapCharToString(
|
||||
const std::map<std::vector<char>, std::vector<char>> &c) {
|
||||
std::unordered_map<std::string, std::string> ret;
|
||||
std::transform(c.begin(), c.end(), std::inserter(ret, ret.begin()), [](auto ch) {
|
||||
return std::pair<std::string, std::string>(std::string(ch.first.begin(), ch.first.end()),
|
||||
std::string(ch.second.begin(), ch.second.end()));
|
||||
});
|
||||
return ret;
|
||||
}
|
||||
|
||||
inline std::vector<std::pair<std::vector<char>, std::vector<int32_t>>> ClassIndexStringToChar(
|
||||
const std::vector<std::pair<std::string, std::vector<int32_t>>> &s) {
|
||||
std::vector<std::pair<std::vector<char>, std::vector<int32_t>>> ret;
|
||||
std::transform(s.begin(), s.end(), std::back_inserter(ret), [](auto str) {
|
||||
return std::pair<std::vector<char>, std::vector<int32_t>>(std::vector<char>(str.first.begin(), str.first.end()),
|
||||
str.second);
|
||||
});
|
||||
return ret;
|
||||
}
|
||||
|
||||
inline std::vector<std::pair<std::string, std::vector<int32_t>>> ClassIndexCharToString(
|
||||
const std::vector<std::pair<std::vector<char>, std::vector<int32_t>>> &c) {
|
||||
std::vector<std::pair<std::string, std::vector<int32_t>>> ret;
|
||||
std::transform(c.begin(), c.end(), std::back_inserter(ret), [](auto ch) {
|
||||
return std::pair<std::string, std::vector<int32_t>>(std::string(ch.first.begin(), ch.first.end()), ch.second);
|
||||
});
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
inline std::map<std::vector<char>, T> PadInfoStringToChar(const std::map<std::string, T> &s_pad_info) {
|
||||
std::map<std::vector<char>, T> ret;
|
||||
std::transform(s_pad_info.begin(), s_pad_info.end(), std::inserter(ret, ret.begin()), [](auto str) {
|
||||
return std::pair<std::vector<char>, T>(std::vector<char>(str.first.begin(), str.first.end()), str.second);
|
||||
});
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
inline std::map<std::string, T> PadInfoCharToString(const std::map<std::vector<char>, T> &c_pad_info) {
|
||||
std::map<std::string, T> ret;
|
||||
std::transform(c_pad_info.begin(), c_pad_info.end(), std::inserter(ret, ret.begin()), [](auto ch) {
|
||||
return std::pair<std::string, T>(std::string(ch.first.begin(), ch.first.end()), ch.second);
|
||||
});
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
inline void TensorMapCharToString(const std::map<std::vector<char>, T> *c, std::unordered_map<std::string, T> *s) {
|
||||
for (auto ch : *c) {
|
||||
auto key = std::string(ch.first.begin(), ch.first.end());
|
||||
auto val = ch.second;
|
||||
s->insert(std::pair<std::string, T>(key, val));
|
||||
}
|
||||
}
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_DUAL_ABI_HELPER_H_
|
@ -0,0 +1,44 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INCLUDE_API_GRAPH_H
|
||||
#define MINDSPORE_INCLUDE_API_GRAPH_H
|
||||
|
||||
#include <cstddef>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include "include/api/status.h"
|
||||
#include "include/api/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
class MS_API Graph {
|
||||
public:
|
||||
class GraphData;
|
||||
explicit Graph(const std::shared_ptr<GraphData> &graph_data);
|
||||
explicit Graph(std::shared_ptr<GraphData> &&graph_data);
|
||||
explicit Graph(std::nullptr_t);
|
||||
~Graph();
|
||||
|
||||
enum ModelType ModelType() const;
|
||||
bool operator==(std::nullptr_t) const;
|
||||
|
||||
private:
|
||||
friend class GraphCell;
|
||||
friend class ModelImpl;
|
||||
std::shared_ptr<GraphData> graph_data_;
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_GRAPH_H
|
@ -0,0 +1,71 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INCLUDE_API_LITE_CONTEXT_H
|
||||
#define MINDSPORE_INCLUDE_API_LITE_CONTEXT_H
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <any>
|
||||
#include "include/api/types.h"
|
||||
#include "include/lite_types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace lite {
|
||||
class Allocator;
|
||||
} // namespace lite
|
||||
|
||||
struct MS_API Context {
|
||||
public:
|
||||
static void Clear(const std::shared_ptr<Context> &context);
|
||||
|
||||
static void SetAsDefault(const std::shared_ptr<Context> &context);
|
||||
|
||||
static void SetVendorName(const std::shared_ptr<Context> &context, const std::string &name);
|
||||
static std::string GetVendorName(const std::shared_ptr<Context> &context);
|
||||
|
||||
static void SetThreadNum(const std::shared_ptr<Context> &context, int num);
|
||||
static int GetThreadNum(const std::shared_ptr<Context> &context);
|
||||
|
||||
static void SetAllocator(const std::shared_ptr<Context> &context, std::shared_ptr<lite::Allocator> alloc);
|
||||
static std::shared_ptr<lite::Allocator> GetAllocator(const std::shared_ptr<Context> &context);
|
||||
|
||||
static void ConfigCPU(const std::shared_ptr<Context> &context, bool config);
|
||||
static bool IfCPUEnabled(const std::shared_ptr<Context> &context);
|
||||
|
||||
static void ConfigCPUFp16(const std::shared_ptr<Context> &context, bool config);
|
||||
static bool IfCPUFp16Enabled(const std::shared_ptr<Context> &context);
|
||||
|
||||
static void SetCPUBindMode(const std::shared_ptr<Context> &context, lite::CpuBindMode mode);
|
||||
static lite::CpuBindMode GetCPUBindMode(const std::shared_ptr<Context> &context);
|
||||
|
||||
static void ConfigGPU(const std::shared_ptr<Context> &context, bool config);
|
||||
static bool IfGPUEnabled(const std::shared_ptr<Context> &context);
|
||||
|
||||
static void ConfigGPUFp16(const std::shared_ptr<Context> &context, bool config);
|
||||
static bool IfGPUFp16Enabled(const std::shared_ptr<Context> &context);
|
||||
|
||||
static void ConfigNPU(const std::shared_ptr<Context> &context, bool config);
|
||||
static bool IfNPUEnabled(const std::shared_ptr<Context> &context);
|
||||
|
||||
static void SetNPUFrequency(const std::shared_ptr<Context> &context, int freq);
|
||||
static int GetNPUFrequency(const std::shared_ptr<Context> &context);
|
||||
|
||||
private:
|
||||
std::map<std::string, std::any> context_;
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_LITE_CONTEXT_H
|
@ -0,0 +1,62 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INCLUDE_API_MODEL_H
|
||||
#define MINDSPORE_INCLUDE_API_MODEL_H
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "include/api/status.h"
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/graph.h"
|
||||
#include "include/api/cell.h"
|
||||
#include "include/api/dual_abi_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
class ModelImpl;
|
||||
struct Context;
|
||||
|
||||
class MS_API Model {
|
||||
public:
|
||||
explicit Model(const std::vector<Output> &network, const std::shared_ptr<Context> &model_context = nullptr);
|
||||
explicit Model(const GraphCell &graph, const std::shared_ptr<Context> &model_context = nullptr);
|
||||
~Model();
|
||||
Model(const Model &) = delete;
|
||||
void operator=(const Model &) = delete;
|
||||
|
||||
Status Build();
|
||||
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);
|
||||
|
||||
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);
|
||||
|
||||
std::vector<MSTensor> GetInputs();
|
||||
std::vector<MSTensor> GetOutputs();
|
||||
|
||||
static inline bool CheckModelSupport(const std::string &device_type, ModelType model_type);
|
||||
|
||||
private:
|
||||
// api without std::string
|
||||
static bool CheckModelSupport(const std::vector<char> &device_type, ModelType model_type);
|
||||
std::shared_ptr<ModelImpl> impl_;
|
||||
};
|
||||
|
||||
bool Model::CheckModelSupport(const std::string &device_type, ModelType model_type) {
|
||||
return CheckModelSupport(StringToChar(device_type), model_type);
|
||||
}
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_MODEL_H
|
@ -0,0 +1,48 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INCLUDE_API_OPS_OPS_H
|
||||
#define MINDSPORE_INCLUDE_API_OPS_OPS_H
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include "include/api/status.h"
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/cell.h"
|
||||
|
||||
namespace mindspore {
|
||||
struct MS_API Conv2D : public OpCell<Conv2D> {
|
||||
Conv2D() : OpCell("Conv2D") {}
|
||||
~Conv2D() override = default;
|
||||
std::vector<Output> Construct(const std::vector<Input> &inputs) override;
|
||||
Conv2D(int out_channel, const std::vector<int> &kernel_size, int mode = 1, const std::string &pad_mode = "valid",
|
||||
const std::vector<int> &pad = {0, 0, 0, 0}, const std::vector<int> &stride = {1, 1, 1, 1},
|
||||
const std::vector<int> &dilation = {1, 1, 1, 1}, int group = 1);
|
||||
|
||||
Output operator()(const Input &, const Input &) const;
|
||||
|
||||
int out_channel;
|
||||
std::vector<int> kernel_size;
|
||||
int mode = 1;
|
||||
std::string pad_mode = "valid";
|
||||
std::vector<int> pad = {0, 0, 0, 0};
|
||||
std::vector<int> stride = {1, 1, 1, 1};
|
||||
std::vector<int> dilation = {1, 1, 1, 1};
|
||||
int group = 1;
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_OPS_OPS_H
|
@ -0,0 +1,47 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INCLUDE_API_SERIALIZATION_H
|
||||
#define MINDSPORE_INCLUDE_API_SERIALIZATION_H
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include "include/api/status.h"
|
||||
#include "include/api/types.h"
|
||||
#include "include/api/model.h"
|
||||
#include "include/api/graph.h"
|
||||
#include "include/api/dual_abi_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
class MS_API Serialization {
|
||||
public:
|
||||
static Graph LoadModel(const void *model_data, size_t data_size, ModelType model_type);
|
||||
inline static Graph LoadModel(const std::string &file, ModelType model_type);
|
||||
static Status LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters);
|
||||
static Status SetParameters(const std::map<std::string, Buffer> ¶meters, Model *model);
|
||||
static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data);
|
||||
static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file);
|
||||
|
||||
private:
|
||||
static Graph LoadModel(const std::vector<char> &file, ModelType model_type);
|
||||
};
|
||||
|
||||
Graph Serialization::LoadModel(const std::string &file, ModelType model_type) {
|
||||
return LoadModel(StringToChar(file), model_type);
|
||||
}
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H
|
@ -0,0 +1,122 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INCLUDE_API_TYPES_H
|
||||
#define MINDSPORE_INCLUDE_API_TYPES_H
|
||||
|
||||
#include <cstddef>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "include/api/data_type.h"
|
||||
#include "include/api/dual_abi_helper.h"
|
||||
|
||||
#ifdef _WIN32
|
||||
#define MS_API __declspec(dllexport)
|
||||
#else
|
||||
#define MS_API __attribute__((visibility("default")))
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
enum ModelType : uint32_t {
|
||||
kMindIR = 0,
|
||||
kAIR = 1,
|
||||
kOM = 2,
|
||||
kONNX = 3,
|
||||
// insert new data type here
|
||||
kUnknownType = 0xFFFFFFFF
|
||||
};
|
||||
|
||||
class MS_API MSTensor {
|
||||
public:
|
||||
class Impl;
|
||||
|
||||
static inline MSTensor CreateTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
|
||||
const void *data, size_t data_len) noexcept;
|
||||
static inline MSTensor CreateRefTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
|
||||
const void *data, size_t data_len) noexcept;
|
||||
|
||||
MSTensor();
|
||||
explicit MSTensor(const std::shared_ptr<Impl> &impl);
|
||||
inline MSTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape, const void *data,
|
||||
size_t data_len);
|
||||
~MSTensor();
|
||||
|
||||
inline std::string Name() const;
|
||||
enum DataType DataType() const;
|
||||
const std::vector<int64_t> &Shape() const;
|
||||
int64_t ElementNum() const;
|
||||
|
||||
std::shared_ptr<const void> Data() const;
|
||||
void *MutableData();
|
||||
size_t DataSize() const;
|
||||
|
||||
bool IsDevice() const;
|
||||
|
||||
MSTensor Clone() const;
|
||||
bool operator==(std::nullptr_t) const;
|
||||
|
||||
private:
|
||||
// api without std::string
|
||||
static MSTensor CreateTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
|
||||
const void *data, size_t data_len) noexcept;
|
||||
static MSTensor CreateRefTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
|
||||
const void *data, size_t data_len) noexcept;
|
||||
MSTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
|
||||
size_t data_len);
|
||||
std::vector<char> CharName() const;
|
||||
|
||||
friend class ModelImpl;
|
||||
explicit MSTensor(std::nullptr_t);
|
||||
std::shared_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
class MS_API Buffer {
|
||||
public:
|
||||
Buffer();
|
||||
Buffer(const void *data, size_t data_len);
|
||||
~Buffer();
|
||||
|
||||
const void *Data() const;
|
||||
void *MutableData();
|
||||
size_t DataSize() const;
|
||||
|
||||
bool ResizeData(size_t data_len);
|
||||
bool SetData(const void *data, size_t data_len);
|
||||
|
||||
Buffer Clone() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::shared_ptr<Impl> impl_;
|
||||
};
|
||||
|
||||
MSTensor MSTensor::CreateTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
|
||||
const void *data, size_t data_len) noexcept {
|
||||
return CreateTensor(StringToChar(name), type, shape, data, data_len);
|
||||
}
|
||||
|
||||
MSTensor MSTensor::CreateRefTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
|
||||
const void *data, size_t data_len) noexcept {
|
||||
return CreateRefTensor(StringToChar(name), type, shape, data, data_len);
|
||||
}
|
||||
|
||||
MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
|
||||
size_t data_len)
|
||||
: MSTensor(StringToChar(name), type, shape, data, data_len) {}
|
||||
|
||||
std::string MSTensor::Name() const { return CharToString(CharName()); }
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_TYPES_H
|
@ -0,0 +1,66 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_INCLUDE_CONTEXT_H_
|
||||
#define MINDSPORE_LITE_INCLUDE_CONTEXT_H_
|
||||
|
||||
#include "include/ms_tensor.h"
|
||||
#include "include/lite_utils.h"
|
||||
#include "include/lite_types.h"
|
||||
|
||||
namespace mindspore::lite {
|
||||
/// \brief CpuDeviceInfo defined for CPU's configuration information.
|
||||
typedef struct {
|
||||
bool enable_float16_ = false; /**< prior enable float16 inference */
|
||||
CpuBindMode cpu_bind_mode_ = MID_CPU;
|
||||
} CpuDeviceInfo;
|
||||
|
||||
/// \brief GpuDeviceInfo defined for GPU's configuration information.
|
||||
typedef struct {
|
||||
bool enable_float16_ = false; /**< prior enable float16 inference */
|
||||
} GpuDeviceInfo;
|
||||
|
||||
/// \brief NpuDeviceInfo defined for NPU's configuration information.
|
||||
typedef struct {
|
||||
int frequency_ = 3; /**< npu frequency inference */
|
||||
} NpuDeviceInfo;
|
||||
|
||||
/// \brief DeviceInfo defined for backend's configuration information.
|
||||
union DeviceInfo {
|
||||
CpuDeviceInfo cpu_device_info_;
|
||||
GpuDeviceInfo gpu_device_info_;
|
||||
NpuDeviceInfo npu_device_info_;
|
||||
};
|
||||
|
||||
/// \brief DeviceContext defined for holding backend's configuration information.
|
||||
struct DeviceContext {
|
||||
DeviceType device_type_ = DT_CPU;
|
||||
DeviceInfo device_info_;
|
||||
};
|
||||
|
||||
/// \brief Context defined for holding environment variables during runtime.
|
||||
struct Context {
|
||||
String vendor_name_;
|
||||
int thread_num_ = 2; /**< thread number config for thread pool */
|
||||
AllocatorPtr allocator = nullptr;
|
||||
#ifndef NOT_USE_STL
|
||||
DeviceContextVector device_list_ = {{DT_CPU, {false, MID_CPU}}};
|
||||
#else
|
||||
DeviceContextVector device_list_;
|
||||
#endif // NOT_USE_STL
|
||||
};
|
||||
} // namespace mindspore::lite
|
||||
#endif // MINDSPORE_LITE_INCLUDE_CONTEXT_H_
|
@ -0,0 +1,95 @@
|
||||
/**
|
||||
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
||||
*
|
||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_IR_DTYPE_TYPE_ID_H_
|
||||
#define MINDSPORE_CORE_IR_DTYPE_TYPE_ID_H_
|
||||
|
||||
namespace mindspore {
|
||||
//
|
||||
// Supported meta type
|
||||
//
|
||||
enum TypeId : int {
|
||||
kTypeUnknown = 0,
|
||||
kMetaTypeBegin = kTypeUnknown,
|
||||
kMetaTypeType, // Type
|
||||
kMetaTypeAnything,
|
||||
kMetaTypeObject,
|
||||
kMetaTypeTypeType, // TypeType
|
||||
kMetaTypeProblem,
|
||||
kMetaTypeExternal,
|
||||
kMetaTypeNone,
|
||||
kMetaTypeNull,
|
||||
kMetaTypeEllipsis,
|
||||
kMetaTypeEnd,
|
||||
//
|
||||
// Object types
|
||||
//
|
||||
kObjectTypeBegin = kMetaTypeEnd,
|
||||
kObjectTypeNumber,
|
||||
kObjectTypeString,
|
||||
kObjectTypeList,
|
||||
kObjectTypeTuple,
|
||||
kObjectTypeSlice,
|
||||
kObjectTypeKeyword,
|
||||
kObjectTypeTensorType,
|
||||
kObjectTypeRowTensorType,
|
||||
kObjectTypeSparseTensorType,
|
||||
kObjectTypeUndeterminedType,
|
||||
kObjectTypeClass,
|
||||
kObjectTypeDictionary,
|
||||
kObjectTypeFunction,
|
||||
kObjectTypeJTagged,
|
||||
kObjectTypeSymbolicKeyType,
|
||||
kObjectTypeEnvType,
|
||||
kObjectTypeRefKey,
|
||||
kObjectTypeRef,
|
||||
kObjectTypeEnd,
|
||||
//
|
||||
// Number Types
|
||||
//
|
||||
kNumberTypeBegin = kObjectTypeEnd,
|
||||
kNumberTypeBool,
|
||||
kNumberTypeInt,
|
||||
kNumberTypeInt8,
|
||||
kNumberTypeInt16,
|
||||
kNumberTypeInt32,
|
||||
kNumberTypeInt64,
|
||||
kNumberTypeUInt,
|
||||
kNumberTypeUInt8,
|
||||
kNumberTypeUInt16,
|
||||
kNumberTypeUInt32,
|
||||
kNumberTypeUInt64,
|
||||
kNumberTypeFloat,
|
||||
kNumberTypeFloat16,
|
||||
kNumberTypeFloat32,
|
||||
kNumberTypeFloat64,
|
||||
kNumberTypeComplex64,
|
||||
kNumberTypeEnd,
|
||||
//
|
||||
// Monad Types
|
||||
//
|
||||
// Monad types is placed at the end of enum,
|
||||
// in order to keep fit with the type of existing model on the lite side.
|
||||
kMonadTypeBegin = kNumberTypeEnd,
|
||||
kObjectTypeMonad,
|
||||
kObjectTypeUMonad,
|
||||
kObjectTypeIOMonad,
|
||||
kMonadTypeEnd
|
||||
};
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_IR_DTYPE_TYPE_ID_H_
|
@ -0,0 +1,125 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_INCLUDE_LITE_SESSION_H
|
||||
#define MINDSPORE_LITE_INCLUDE_LITE_SESSION_H
|
||||
|
||||
#ifndef NOT_USE_STL
|
||||
#include <unordered_map>
|
||||
#endif // NOT_USE_STL
|
||||
#include "include/ms_tensor.h"
|
||||
#include "include/model.h"
|
||||
#include "include/context.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace session {
|
||||
/// \brief LiteSession defined session in MindSpore Lite for compiling Model and forwarding model.
|
||||
class MS_API LiteSession {
|
||||
public:
|
||||
/// \brief Static method to create a LiteSession pointer.
|
||||
///
|
||||
/// \param[in] context Define the context of session to be created.
|
||||
///
|
||||
/// \return Pointer of MindSpore Lite LiteSession.
|
||||
static LiteSession *CreateSession(const lite::Context *context);
|
||||
|
||||
/// \brief Static method to create a LiteSession pointer which has already compiled a model.
|
||||
///
|
||||
/// \param[in] model_buf Define the buffer read from a model file.
|
||||
/// \param[in] size Define bytes number of model buffer.
|
||||
/// \param[in] context Define the context of session to be created.
|
||||
///
|
||||
/// \return Pointer of MindSpore Lite LiteSession.
|
||||
static LiteSession *CreateSession(const char *model_buf, size_t size, const lite::Context *context);
|
||||
|
||||
/// \brief Destructor of MindSpore Lite LiteSession.
|
||||
virtual ~LiteSession() = default;
|
||||
|
||||
/// \brief Attempt to bind or unbind threads in the thread pool to or from the specified cpu core.
|
||||
///
|
||||
/// \param[in] if_bind Define whether to bind or unbind threads.
|
||||
virtual void BindThread(bool if_bind) = 0;
|
||||
|
||||
/// \brief Compile MindSpore Lite model.
|
||||
///
|
||||
/// \note CompileGraph should be called before RunGraph.
|
||||
///
|
||||
/// \param[in] model Define the model to be compiled.
|
||||
///
|
||||
/// \return STATUS as an error code of compiling graph, STATUS is defined in errorcode.h.
|
||||
virtual int CompileGraph(lite::Model *model) = 0;
|
||||
|
||||
/// \brief Get input MindSpore Lite MSTensors of model.
|
||||
///
|
||||
/// \return The vector of MindSpore Lite MSTensor.
|
||||
virtual Vector<tensor::MSTensor *> GetInputs() const = 0;
|
||||
|
||||
/// \brief Get input MindSpore Lite MSTensors of model by tensor name.
|
||||
///
|
||||
/// \param[in] node_name Define tensor name.
|
||||
///
|
||||
/// \return The vector of MindSpore Lite MSTensor.
|
||||
virtual mindspore::tensor::MSTensor *GetInputsByTensorName(const String &tensor_name) const = 0;
|
||||
|
||||
/// \brief Run session with callback.
|
||||
///
|
||||
/// \param[in] before Define a call_back_function to be called before running each node.
|
||||
/// \param[in] after Define a call_back_function called after running each node.
|
||||
///
|
||||
/// \note RunGraph should be called after CompileGraph.
|
||||
///
|
||||
/// \return STATUS as an error code of running graph, STATUS is defined in errorcode.h.
|
||||
virtual int RunGraph(const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) = 0;
|
||||
|
||||
/// \brief Get output MindSpore Lite MSTensors of model by node name.
|
||||
///
|
||||
/// \param[in] node_name Define node name.
|
||||
///
|
||||
/// \note Deprecated, replace with GetOutputByTensorName
|
||||
///
|
||||
/// \return The vector of MindSpore Lite MSTensor.
|
||||
virtual Vector<tensor::MSTensor *> GetOutputsByNodeName(const String &node_name) const = 0;
|
||||
|
||||
#ifndef NOT_USE_STL
|
||||
/// \brief Get output MindSpore Lite MSTensors of model mapped by tensor name.
|
||||
///
|
||||
/// \return The map of output tensor name and MindSpore Lite MSTensor.
|
||||
virtual std::unordered_map<String, mindspore::tensor::MSTensor *> GetOutputs() const = 0;
|
||||
#endif
|
||||
|
||||
/// \brief Get name of output tensors of model compiled by this session.
|
||||
///
|
||||
/// \return The vector of string as output tensor names in order.
|
||||
virtual Vector<String> GetOutputTensorNames() const = 0;
|
||||
|
||||
/// \brief Get output MindSpore Lite MSTensors of model by tensor name.
|
||||
///
|
||||
/// \param[in] tensor_name Define tensor name.
|
||||
///
|
||||
/// \return Pointer of MindSpore Lite MSTensor.
|
||||
virtual mindspore::tensor::MSTensor *GetOutputByTensorName(const String &tensor_name) const = 0;
|
||||
|
||||
/// \brief Resize inputs shape.
|
||||
///
|
||||
/// \param[in] inputs Define the inputs of the model.
|
||||
/// \param[in] dims Define the inputs new shape.
|
||||
///
|
||||
/// \return STATUS as an error code of resize inputs, STATUS is defined in errorcode.h.
|
||||
virtual int Resize(const Vector<tensor::MSTensor *> &inputs, const Vector<Vector<int>> &dims) = 0;
|
||||
};
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_LITE_INCLUDE_LITE_SESSION_H
|
@ -0,0 +1,36 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_INCLUDE_LITE_TYPES_H_
|
||||
#define MINDSPORE_LITE_INCLUDE_LITE_TYPES_H_
|
||||
|
||||
namespace mindspore::lite {
|
||||
/// \brief CpuBindMode defined for holding bind cpu strategy argument.
|
||||
typedef enum {
|
||||
NO_BIND, /**< no bind */
|
||||
HIGHER_CPU, /**< bind higher cpu first */
|
||||
MID_CPU /**< bind middle cpu first */
|
||||
} CpuBindMode;
|
||||
|
||||
/// \brief DeviceType defined for holding user's preferred backend.
|
||||
typedef enum {
|
||||
DT_CPU, /**< CPU device type */
|
||||
DT_GPU, /**< GPU device type */
|
||||
DT_NPU /**< NPU device type */
|
||||
} DeviceType;
|
||||
|
||||
} // namespace mindspore::lite
|
||||
#endif // MINDSPORE_LITE_INCLUDE_LITE_TYPES_H_
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue