mindspore c++ interface

Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
pull/9324/head
zhoufeng 5 years ago
parent 756594d000
commit 3204ecb7d6

@ -23,7 +23,7 @@ usage()
{
echo "Usage:"
echo "bash build.sh [-d] [-r] [-v] [-c on|off] [-t on|off] [-g on|off] [-h] [-b ge] [-m infer|train] \\"
echo " [-a on|off] [-p on|off] [-i] [-L] [-R] [-D on|off] [-j[n]] [-e gpu|ascend|cpu|acl] \\"
echo " [-a on|off] [-p on|off] [-i] [-L] [-R] [-D on|off] [-j[n]] [-e gpu|ascend|cpu|ascend310] \\"
echo " [-P on|off] [-z [on|off]] [-M on|off] [-V 9.2|10.1] [-I arm64|arm32|x86_64] [-K] \\"
echo " [-B on|off] [-E] [-l on|off] [-n full|lite|off] [-T on|off] \\"
echo " [-A [cpp|java|object-c] [-C on|off] [-o on|off] [-S on|off] [-k on|off] [-W sse|neon|avx|off] \\"
@ -45,7 +45,7 @@ usage()
echo " -i Enable increment building, default off"
echo " -L Enable load ANF-IR as input of 'infer', default off"
echo " -j[n] Set the threads when building (Default: -j8)"
echo " -e Use cpu, gpu, ascend or acl"
echo " -e Use cpu, gpu, ascend or ascend310"
echo " -P Enable dump anf graph to file in ProtoBuffer format, default on"
echo " -D Enable dumping of function graph ir, default on"
echo " -z Compile dataset & mindrecord, default on"
@ -224,7 +224,7 @@ checkopts()
ENABLE_D="on"
ENABLE_CPU="on"
ENABLE_SERVING="on"
elif [[ "X$OPTARG" == "Xacl" ]]; then
elif [[ "X$OPTARG" == "Xascend310" ]]; then
ENABLE_SERVING="on"
ENABLE_ACL="on"
elif [[ "X$OPTARG" == "Xcpu" ]]; then

@ -21,6 +21,7 @@
#include <memory>
#include "include/api/status.h"
#include "include/api/types.h"
#include "include/api/graph.h"
namespace mindspore {
namespace api {
@ -34,6 +35,7 @@ class MS_API CellBase {
virtual ~CellBase() = default;
virtual std::vector<Output> Construct(const std::vector<Input> &inputs) { return {}; }
virtual std::shared_ptr<CellBase> Clone() const = 0;
virtual Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) { return SUCCESS; }
std::vector<Output> operator()(const std::vector<Input> &inputs) const;
};
@ -41,9 +43,7 @@ template <class T>
class MS_API Cell : public CellBase {
public:
virtual ~Cell() = default;
std::shared_ptr<CellBase> Clone() const override {
return std::make_shared<T>(static_cast<const T&>(*this));
}
std::shared_ptr<CellBase> Clone() const override { return std::make_shared<T>(static_cast<const T &>(*this)); }
};
class MS_API ParameterCell final : public Cell<ParameterCell> {
@ -84,9 +84,33 @@ class MS_API OpCell : public OpCellBase, public std::enable_shared_from_this<T>
public:
explicit OpCell(const std::string &name) : OpCellBase(name) {}
~OpCell() override = default;
std::shared_ptr<CellBase> Clone() const override {
return std::make_shared<T>(static_cast<const T&>(*this));
}
std::shared_ptr<CellBase> Clone() const override { return std::make_shared<T>(static_cast<const T &>(*this)); }
};
class MS_API GraphCell final : public Cell<GraphCell> {
public:
class GraphImpl;
GraphCell() = default;
~GraphCell() override = default;
explicit GraphCell(const Graph &);
explicit GraphCell(Graph &&);
explicit GraphCell(const std::shared_ptr<Graph> &);
const std::shared_ptr<Graph> &GetGraph() const { return graph_; }
Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override;
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const;
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const;
private:
friend class ModelImpl;
Status Load();
std::shared_ptr<Graph> graph_;
std::shared_ptr<GraphImpl> executor_;
};
class MS_API InputAndOutput {
@ -96,7 +120,7 @@ class MS_API InputAndOutput {
// no explicit
InputAndOutput(const Tensor &); // NOLINT(runtime/explicit)
InputAndOutput(Tensor &&); // NOLINT(runtime/explicit)
InputAndOutput(Tensor &&); // NOLINT(runtime/explicit)
InputAndOutput(const std::shared_ptr<CellBase> &, const std::vector<InputAndOutput> &, int32_t index);

@ -0,0 +1,41 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INCLUDE_API_CONTEXT_H
#define MINDSPORE_INCLUDE_API_CONTEXT_H
#include <string>
#include <memory>
#include "include/api/types.h"
namespace mindspore {
namespace api {
class MS_API Context {
public:
static Context &Instance();
const std::string &GetDeviceTarget() const;
Context &SetDeviceTarget(const std::string &device_target);
uint32_t GetDeviceID() const;
Context &SetDeviceID(uint32_t device_id);
private:
Context();
~Context();
class ContextImpl;
std::shared_ptr<ContextImpl> impl_;
};
} // namespace api
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_CONTEXT_H

@ -0,0 +1,43 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INCLUDE_API_GRAPH_H
#define MINDSPORE_INCLUDE_API_GRAPH_H
#include <string>
#include <vector>
#include <map>
#include <memory>
#include "include/api/status.h"
#include "include/api/types.h"
namespace mindspore {
namespace api {
class MS_API Graph {
public:
class GraphData;
explicit Graph(const std::shared_ptr<GraphData> &graph_data);
explicit Graph(std::shared_ptr<GraphData> &&graph_data);
enum ModelType ModelType() const;
private:
friend class GraphCell;
friend class ModelImpl;
std::shared_ptr<GraphData> graph_data_;
};
} // namespace api
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_GRAPH_H

@ -22,42 +22,39 @@
#include <memory>
#include "include/api/status.h"
#include "include/api/types.h"
#include "include/api/graph.h"
#include "include/api/cell.h"
namespace mindspore {
namespace api {
class ModelImpl;
// todo: minddata c++ interface
class DataSet {};
class NetWork {};
class MS_API Model {
public:
Model(const std::string &device_type, uint32_t device_id);
Model(NetWork network, const std::string &device_type, uint32_t device_id);
explicit Model(const std::vector<Output> &network);
explicit Model(const GraphCell &graph);
~Model();
Model(const Model &) = delete;
void operator=(const Model &) = delete;
Status LoadModel(const Buffer &model_data, ModelType type, const std::map<std::string, std::string> &options);
Status LoadModel(const std::string &file_name, ModelType type, const std::map<std::string, std::string> &options);
Status UnloadModel();
Status Build(const std::map<std::string, std::string> &options);
Status Train(const DataSet &dataset, std::map<std::string, Buffer> *outputs);
Status Eval(const DataSet &dataset, std::map<std::string, Buffer> *outputs);
Status Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs);
Status Predict(const std::vector<Buffer> &inputs, std::map<std::string, Buffer> *outputs);
Status Train(const DataSet &dataset, bool data_sink, std::map<std::string, Buffer> *outputs);
Status Eval(const DataSet &dataset, bool data_sink, std::map<std::string, Buffer> *outputs);
Status Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs);
Status GetInputsInfo(std::vector<Tensor> *tensor_list) const;
Status GetOutputsInfo(std::vector<Tensor> *tensor_list) const;
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const;
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const;
static bool CheckModelSupport(const std::string& device_type, ModelType model_type);
static bool CheckModelSupport(const std::string &device_type, ModelType model_type);
private:
std::shared_ptr<ModelImpl> impl_;
};
extern MS_API const char* kDeviceTypeAscendCL;
extern MS_API const char* kDeviceTypeAscendMS;
} // namespace api
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_MODEL_H

@ -23,11 +23,13 @@
#include "include/api/status.h"
#include "include/api/types.h"
#include "include/api/model.h"
#include "include/api/graph.h"
namespace mindspore {
namespace api {
class MS_API Serialization {
public:
static Graph LoadModel(const std::string &file, ModelType model_type);
static Status LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters);
static Status SetParameters(const std::map<std::string, Buffer> &parameters, Model *model);
static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data);

@ -102,6 +102,9 @@ class MS_API Buffer {
std::shared_ptr<Impl> impl_;
};
extern MS_API const char *kDeviceTypeAscend310;
extern MS_API const char *kDeviceTypeAscend910;
constexpr auto kModelOptionDumpCfgPath = "mindspore.option.dump_config_file_path";
constexpr auto kModelOptionDvppCfgPath = "mindspore.option.dvpp_config_file_path";
constexpr auto kModelOptionInsertOpCfgPath = "mindspore.option.insert_op_config_file_path"; // aipp config file

@ -6,22 +6,35 @@ set(LOAD_MINDIR_SRC
file(GLOB_RECURSE API_OPS_SRC ${CMAKE_CURRENT_SOURCE_DIR} "ops/*.cc")
if (ENABLE_ACL)
file(GLOB_RECURSE API_ACL_SRC ${CMAKE_CURRENT_SOURCE_DIR} "model/acl/*.cc" "model/model_converter_utils/*.cc")
elseif (ENABLE_D)
file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} "model/ms/*.cc")
add_compile_definitions(ENABLE_ACL)
include_directories(${CMAKE_SOURCE_DIR}/graphengine/src/ge)
include_directories(${CMAKE_BINARY_DIR}/proto/ge)
file(GLOB_RECURSE API_ACL_SRC ${CMAKE_CURRENT_SOURCE_DIR}
"model/acl/*.cc"
"model/model_converter_utils/*.cc"
"graph/acl/*.cc"
)
endif ()
if (ENABLE_D)
file(GLOB_RECURSE API_MS_INFER_SRC ${CMAKE_CURRENT_SOURCE_DIR} "model/ms/*.cc" "graph/ms/*.cc")
endif ()
set(MSLIB_SRC ${CMAKE_CURRENT_SOURCE_DIR}/types.cc
${CMAKE_CURRENT_SOURCE_DIR}/cell.cc
${CMAKE_CURRENT_SOURCE_DIR}/serialization.cc
${CMAKE_CURRENT_SOURCE_DIR}/model/model.cc
${API_MS_INFER_SRC}
${API_ACL_SRC}
${API_OPS_SRC}
${LOAD_MINDIR_SRC})
${CMAKE_CURRENT_SOURCE_DIR}/context.cc
${CMAKE_CURRENT_SOURCE_DIR}/cell.cc
${CMAKE_CURRENT_SOURCE_DIR}/serialization.cc
${CMAKE_CURRENT_SOURCE_DIR}/python_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/graph/graph.cc
${CMAKE_CURRENT_SOURCE_DIR}/graph/graph_data.cc
${CMAKE_CURRENT_SOURCE_DIR}/model/model.cc
${API_MS_INFER_SRC}
${API_ACL_SRC}
${API_OPS_SRC}
${LOAD_MINDIR_SRC})
add_library(mindspore_shared_lib SHARED ${MSLIB_SRC})
set_target_properties(mindspore_shared_lib PROPERTIES OUTPUT_NAME mindspore PUBLIC_HEADER "${API_INCLUDE}")
set_target_properties(mindspore_shared_lib PROPERTIES OUTPUT_NAME mindspore)
target_link_libraries(mindspore_shared_lib PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}
-Wl,--whole-archive mindspore -Wl,--no-whole-archive proto_input mindspore_gvar mindspore::protobuf)
@ -69,5 +82,6 @@ endif ()
if (ENABLE_D)
find_library(adump_server libadump_server.a ${ASCEND_RUNTIME_PATH} ${ASCEND_TOOLKIT_RUNTIME_PATH})
target_link_libraries(mindspore_shared_lib PRIVATE ${adump_server})
target_link_libraries(mindspore_shared_lib PRIVATE mindspore_core hccl_adapter)
endif ()

@ -14,6 +14,9 @@
* limitations under the License.
*/
#include "include/api/cell.h"
#include "include/api/context.h"
#include "cxx_api/factory.h"
#include "cxx_api/graph/graph_impl.h"
namespace mindspore::api {
std::vector<Output> CellBase::operator()(const std::vector<Input> &inputs) const { return Clone()->Construct(inputs); }
@ -51,6 +54,52 @@ ParameterCell &ParameterCell::operator=(Tensor &&tensor) {
return *this;
}
GraphCell::GraphCell(const Graph &graph)
: graph_(std::make_shared<Graph>(graph)),
executor_(Factory<GraphCell::GraphImpl>::Instance().Create(Context::Instance().GetDeviceTarget())) {
MS_EXCEPTION_IF_NULL(graph_);
MS_EXCEPTION_IF_NULL(executor_);
executor_->SetGraph(graph_);
}
GraphCell::GraphCell(const std::shared_ptr<Graph> &graph)
: graph_(graph),
executor_(Factory<GraphCell::GraphImpl>::Instance().Create(Context::Instance().GetDeviceTarget())) {
MS_EXCEPTION_IF_NULL(graph_);
MS_EXCEPTION_IF_NULL(executor_);
executor_->SetGraph(graph_);
}
GraphCell::GraphCell(Graph &&graph)
: graph_(std::make_shared<Graph>(graph)),
executor_(Factory<GraphCell::GraphImpl>::Instance().Create(Context::Instance().GetDeviceTarget())) {
MS_EXCEPTION_IF_NULL(graph_);
MS_EXCEPTION_IF_NULL(executor_);
executor_->SetGraph(graph_);
}
Status GraphCell::Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) {
MS_EXCEPTION_IF_NULL(executor_);
return executor_->Run(inputs, outputs);
}
Status GraphCell::Load() {
MS_EXCEPTION_IF_NULL(executor_);
return executor_->Load();
}
Status GraphCell::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
MS_EXCEPTION_IF_NULL(executor_);
return executor_->GetInputsInfo(names, shapes, data_types, mem_sizes);
}
Status GraphCell::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
MS_EXCEPTION_IF_NULL(executor_);
return executor_->GetOutputsInfo(names, shapes, data_types, mem_sizes);
}
InputAndOutput::InputAndOutput() : cell_(nullptr), prev_(), index_(-1) {}
InputAndOutput::InputAndOutput(const Tensor &tensor)

@ -0,0 +1,63 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "include/api/context.h"
#include "utils/log_adapter.h"
namespace mindspore::api {
class Context::ContextImpl {
public:
ContextImpl() : device_target_("NotSet"), device_id_(0) {}
const std::string &GetDeviceTarget() const { return device_target_; }
void SetDeviceTarget(std::string_view device_target) { device_target_ = device_target; }
uint32_t GetDeviceID() const { return device_id_; }
void SetDeviceID(uint32_t device_id) { device_id_ = device_id; }
private:
std::string device_target_;
uint32_t device_id_;
};
Context &Context::Instance() {
static Context context;
return context;
}
const std::string &Context::GetDeviceTarget() const {
MS_EXCEPTION_IF_NULL(impl_);
return impl_->GetDeviceTarget();
}
Context &Context::SetDeviceTarget(const std::string &device_target) {
MS_EXCEPTION_IF_NULL(impl_);
impl_->SetDeviceTarget(device_target);
return *this;
}
uint32_t Context::GetDeviceID() const {
MS_EXCEPTION_IF_NULL(impl_);
return impl_->GetDeviceID();
}
Context &Context::SetDeviceID(uint32_t device_id) {
MS_EXCEPTION_IF_NULL(impl_);
impl_->SetDeviceID(device_id);
return *this;
}
Context::Context() : impl_(std::make_shared<Context::ContextImpl>()) { MS_EXCEPTION_IF_NULL(impl_); }
Context::~Context() {}
} // namespace mindspore::api

@ -0,0 +1,83 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_CXX_API_FACTORY_H
#define MINDSPORE_CCSRC_CXX_API_FACTORY_H
#include <functional>
#include <map>
#include <string>
#include <vector>
#include <memory>
#include <utility>
#include "utils/utils.h"
namespace mindspore::api {
template <class T>
class Factory {
using U = std::function<std::shared_ptr<T>()>;
public:
Factory(const Factory &) = delete;
void operator=(const Factory &) = delete;
static Factory &Instance() {
static Factory instance;
return instance;
}
void Register(const std::string &device_name, U &&creator) {
if (creators_.find(device_name) == creators_.end()) {
(void)creators_.emplace(device_name, creator);
}
}
bool CheckModelSupport(const std::string &device_name) {
return std::any_of(creators_.begin(), creators_.end(),
[&device_name](const std::pair<std::string, U> &item) { return item.first == device_name; });
}
std::shared_ptr<T> Create(const std::string &device_name) {
auto iter = creators_.find(device_name);
if (creators_.end() != iter) {
MS_EXCEPTION_IF_NULL(iter->second);
return (iter->second)();
}
MS_LOG(ERROR) << "Unsupported device target " << device_name;
return nullptr;
}
private:
Factory() = default;
~Factory() = default;
std::map<std::string, U> creators_;
};
template <class T>
class Registrar {
using U = std::function<std::shared_ptr<T>()>;
public:
Registrar(const std::string &device_name, U creator) {
Factory<T>::Instance().Register(device_name, std::move(creator));
}
~Registrar() = default;
};
#define API_FACTORY_REG(BASE_CLASS, DEVICE_NAME, DERIVE_CLASS) \
static const Registrar<BASE_CLASS> g_api_##DERIVE_CLASS##_registrar_##DEVICE_NAME##_reg( \
#DEVICE_NAME, []() { return std::make_shared<DERIVE_CLASS>(); });
} // namespace mindspore::api
#endif // MINDSPORE_CCSRC_CXX_API_FACTORY_H

File diff suppressed because it is too large Load Diff

@ -0,0 +1,73 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_ACL_ACL_GRAPH_IMPL_H
#define MINDSPORE_CCSRC_CXX_API_GRAPH_ACL_ACL_GRAPH_IMPL_H
#include <functional>
#include <map>
#include <string>
#include <vector>
#include <memory>
#include <utility>
#include "include/api/graph.h"
#include "cxx_api/graph/acl/model_process.h"
#include "cxx_api/graph/graph_impl.h"
#include "cxx_api/factory.h"
namespace mindspore::api {
class AclGraphImpl : public GraphCell::GraphImpl {
public:
AclGraphImpl();
~AclGraphImpl() override;
Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override;
Status Load() override;
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override;
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override;
private:
class AclEnvGuard;
Status ConvertToOM();
Status InitEnv();
Status FinalizeEnv();
Status LoadAclModel(Buffer om_data);
bool init_flag_;
bool load_flag_;
std::string device_type_;
int32_t device_id_;
aclrtContext context_;
std::shared_ptr<AclEnvGuard> acl_env_;
static std::weak_ptr<AclEnvGuard> global_acl_env_;
static std::mutex global_acl_env_mutex_;
ModelProcess model_process_;
};
class AclGraphImpl::AclEnvGuard {
public:
explicit AclEnvGuard(std::string_view cfg_file);
~AclEnvGuard();
aclError GetErrno() const { return errno_; }
private:
aclError errno_;
};
} // namespace mindspore::api
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_ACL_ACL_GRAPH_IMPL_H

@ -14,7 +14,7 @@
* limitations under the License.
*/
#include "cxx_api/model/acl/model_process.h"
#include "cxx_api/graph/acl/model_process.h"
#include <algorithm>
#include <map>
#include "utils/utils.h"
@ -35,17 +35,33 @@ static DataType TransToApiType(aclDataType data_type) {
}
}
static void ConstructTensorDesc(const std::vector<AclTensorInfo> &acl_tensor_list, std::vector<Tensor> *tensor_list) {
MS_EXCEPTION_IF_NULL(tensor_list);
tensor_list->clear();
template <class T>
inline static void ClearIfNotNull(T *vec) {
if (vec != nullptr) {
vec->clear();
}
}
template <class T, class U = std::vector<T>>
inline static void PushbackIfNotNull(U *vec, T &&item) {
if (vec != nullptr) {
vec->emplace_back(item);
}
}
static void ConstructTensorDesc(const std::vector<AclTensorInfo> &acl_tensor_list, std::vector<std::string> *names,
std::vector<std::vector<int64_t>> *shapes, std::vector<DataType> *data_types,
std::vector<size_t> *mem_sizes) {
ClearIfNotNull(names);
ClearIfNotNull(shapes);
ClearIfNotNull(data_types);
ClearIfNotNull(mem_sizes);
for (size_t i = 0; i < acl_tensor_list.size(); ++i) {
const auto &info = acl_tensor_list[i];
Tensor tensor_desc;
tensor_desc.SetName(info.name);
tensor_desc.SetDataType(TransToApiType(info.data_type));
tensor_desc.SetShape(info.dims);
tensor_list->push_back(tensor_desc);
PushbackIfNotNull(names, info.name);
PushbackIfNotNull(shapes, info.dims);
PushbackIfNotNull(data_types, TransToApiType(info.data_type));
PushbackIfNotNull(mem_sizes, info.buffer_size);
}
}
@ -272,7 +288,7 @@ Status ModelProcess::UnLoad() {
return SUCCESS;
}
Status ModelProcess::CheckAndInitInput(const std::map<std::string, Buffer> &inputs) {
Status ModelProcess::CheckAndInitInput(const std::vector<Buffer> &inputs) {
aclError ret;
inputs_ = aclmdlCreateDataset();
// check inputs
@ -282,29 +298,16 @@ Status ModelProcess::CheckAndInitInput(const std::map<std::string, Buffer> &inpu
return INVALID_INPUTS;
}
for (size_t i = 0; i < input_infos_.size(); ++i) {
const std::string &input_name = input_infos_[i].name;
auto iter = inputs.find(input_name);
if (iter == inputs.end()) {
MS_LOG(ERROR) << "Model missing input " << input_name;
return INVALID_INPUTS;
}
if (iter->second.DataSize() != input_infos_[i].buffer_size) {
if (inputs[i].DataSize() != input_infos_[i].buffer_size) {
MS_LOG(ERROR) << "input " << i << " data size not match, required size " << input_infos_[i].buffer_size
<< ", given count " << iter->second.DataSize();
<< ", given count " << inputs[i].DataSize();
return INVALID_INPUTS;
}
}
// copy inputs
for (size_t i = 0; i < input_infos_.size(); ++i) {
const auto &info = input_infos_[i];
auto iter = inputs.find(info.name);
if (iter == inputs.end()) {
MS_LOG(ERROR) << "Model missing input " << info.name;
return INVALID_INPUTS;
}
const auto &input = iter->second;
const auto &input = inputs[i];
const void *data = input.Data();
void *input_buffer = nullptr;
@ -333,42 +336,7 @@ Status ModelProcess::CheckAndInitInput(const std::map<std::string, Buffer> &inpu
return SUCCESS;
}
Status ModelProcess::CheckAndInitDvppInput(const void *dvpp_outputs_buffer_dev, size_t dvpp_outputs_buffer_size,
size_t input_index) {
aclError ret;
inputs_ = aclmdlCreateDataset();
// check inputs
if (input_index >= input_infos_.size()) {
MS_LOG(ERROR) << "inputs count not match, required count " << input_infos_.size() << ", given index "
<< input_index;
return INVALID_INPUTS;
}
if (dvpp_outputs_buffer_dev == nullptr) {
MS_LOG(ERROR) << "input " << 0 << " cannot be null";
return FAILED;
}
if (dvpp_outputs_buffer_size != input_infos_[input_index].buffer_size) {
MS_LOG(ERROR) << "input " << 0 << " data size not match, required size " << input_infos_[input_index].buffer_size
<< ", given count " << dvpp_outputs_buffer_size;
return INVALID_INPUTS;
}
// copy inputs
auto &info = input_infos_[input_index];
auto data_buffer = aclCreateDataBuffer(const_cast<void *>(dvpp_outputs_buffer_dev), info.buffer_size);
if (data_buffer == nullptr) {
MS_LOG(ERROR) << "Create Data Buffer failed";
return FAILED;
}
ret = aclmdlAddDatasetBuffer(inputs_, data_buffer);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "add data buffer failed";
aclDestroyDataBuffer(data_buffer);
return FAILED;
}
return SUCCESS;
}
Status ModelProcess::Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs) {
Status ModelProcess::PredictFromHost(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) {
MS_EXCEPTION_IF_NULL(outputs);
aclError acl_ret;
Status ret = CheckAndInitInput(inputs);
@ -392,18 +360,7 @@ Status ModelProcess::Predict(const std::map<std::string, Buffer> &inputs, std::m
return SUCCESS;
}
size_t ModelProcess::GetBatchSize() const {
if (input_infos_.empty()) {
MS_LOG(ERROR) << "Model is not loaded";
return 0;
}
if (input_infos_[0].dims.empty()) {
return 1;
}
return static_cast<size_t>(input_infos_[0].dims[0]);
}
Status ModelProcess::BuildOutputs(std::map<std::string, Buffer> *outputs) {
Status ModelProcess::BuildOutputs(std::vector<Buffer> *outputs) {
MS_EXCEPTION_IF_NULL(outputs);
aclError ret;
// copy outputs
@ -411,14 +368,13 @@ Status ModelProcess::BuildOutputs(std::map<std::string, Buffer> *outputs) {
aclrtMemcpyKind kind = is_run_on_device_ ? ACL_MEMCPY_HOST_TO_HOST : ACL_MEMCPY_DEVICE_TO_HOST;
for (size_t i = 0; i < output_infos_.size(); ++i) {
const auto &info = output_infos_[i];
// todo
outputs->emplace(info.name, Buffer());
auto output = outputs->rbegin()->second;
if (!output.ResizeData(info.buffer_size)) {
outputs->emplace_back(Buffer());
auto output = outputs->rbegin();
if (!output->ResizeData(info.buffer_size)) {
MS_LOG(ERROR) << "new output data buffer failed, data size " << info.buffer_size;
return FAILED;
}
ret = aclrtMemcpy(output.MutableData(), output.DataSize(), info.device_data, info.buffer_size, kind);
ret = aclrtMemcpy(output->MutableData(), output->DataSize(), info.device_data, info.buffer_size, kind);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Memcpy output " << i << " from " << (is_run_on_device_ ? "host" : "device")
<< " to host failed, memory size " << info.buffer_size;
@ -428,13 +384,15 @@ Status ModelProcess::BuildOutputs(std::map<std::string, Buffer> *outputs) {
return SUCCESS;
}
Status ModelProcess::GetInputsInfo(std::vector<Tensor> *tensor_list) const {
ConstructTensorDesc(input_infos_, tensor_list);
Status ModelProcess::GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
ConstructTensorDesc(input_infos_, names, shapes, data_types, mem_sizes);
return SUCCESS;
}
Status ModelProcess::GetOutputsInfo(std::vector<Tensor> *tensor_list) const {
ConstructTensorDesc(output_infos_, tensor_list);
Status ModelProcess::GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const {
ConstructTensorDesc(output_infos_, names, shapes, data_types, mem_sizes);
return SUCCESS;
}
} // namespace mindspore::api

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_MODEL_PROCESS_H
#define MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_MODEL_PROCESS_H
#ifndef MINDSPORE_CCSRC_CXXAPI_GRAPH_ACL_MODEL_PROCESS_H
#define MINDSPORE_CCSRC_CXXAPI_GRAPH_ACL_MODEL_PROCESS_H
#include <vector>
#include <string>
#include <map>
@ -34,12 +34,6 @@ struct AclTensorInfo {
std::string name;
};
struct ImagesDvppOutput {
void *buffer_device = nullptr;
size_t buffer_size = 0;
size_t input_index = 0;
};
class ModelProcess {
public:
ModelProcess()
@ -53,24 +47,23 @@ class ModelProcess {
~ModelProcess() {}
Status LoadModelFromFile(const std::string &file_name, uint32_t *model_id);
Status UnLoad();
Status Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs);
Status PredictFromHost(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs);
Status PreInitModelResource();
Status GetInputsInfo(std::vector<Tensor> *tensor_list) const;
Status GetOutputsInfo(std::vector<Tensor> *tensor_list) const;
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const;
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const;
// override this method to avoid request/reply data copy
void SetIsDevice(bool is_device) { is_run_on_device_ = is_device; }
size_t GetBatchSize() const;
void set_model_id(uint32_t model_id) { model_id_ = model_id; }
uint32_t model_id() const { return model_id_; }
private:
Status CreateDataBuffer(void **data_mem_buffer, size_t buffer_size, aclmdlDataset *dataset);
Status CheckAndInitInput(const std::map<std::string, Buffer> &inputs);
Status CheckAndInitDvppInput(const void *dvpp_outputs_buffer_dev, size_t dvpp_outputs_buffer_size,
size_t input_index);
Status BuildOutputs(std::map<std::string, Buffer> *outputs);
Status CheckAndInitInput(const std::vector<Buffer> &inputs);
Status BuildOutputs(std::vector<Buffer> *outputs);
Status InitInputsBuffer();
Status InitOutputsBuffer();
@ -90,4 +83,4 @@ class ModelProcess {
};
} // namespace mindspore::api
#endif // MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_MODEL_PROCESS_H
#endif // MINDSPORE_CCSRC_CXXAPI_GRAPH_ACL_MODEL_PROCESS_H

@ -0,0 +1,29 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "include/api/graph.h"
#include "cxx_api/graph/graph_data.h"
#include "utils/log_adapter.h"
namespace mindspore::api {
Graph::Graph(const std::shared_ptr<GraphData> &graph_data) : graph_data_(graph_data) {}
Graph::Graph(std::shared_ptr<GraphData> &&graph_data) : graph_data_(graph_data) {}
ModelType Graph::ModelType() const {
MS_EXCEPTION_IF_NULL(graph_data_);
return graph_data_->ModelType();
}
} // namespace mindspore::api

@ -0,0 +1,73 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cxx_api/graph/graph_data.h"
#include "utils/log_adapter.h"
#ifdef ENABLE_ACL
#include "framework/common/helper/model_helper.h"
#endif
namespace mindspore::api {
Graph::GraphData::GraphData(const FuncGraphPtr &func_graph, enum ModelType model_type)
: func_graph_(nullptr), om_data_(), model_type_(ModelType::kUnknownType) {
if (model_type != ModelType::kMindIR) {
MS_LOG(EXCEPTION) << "Invalid ModelType " << model_type;
}
func_graph_ = func_graph;
model_type_ = model_type;
}
Graph::GraphData::GraphData(Buffer om_data, enum ModelType model_type)
: func_graph_(nullptr), om_data_(), model_type_(ModelType::kUnknownType) {
if (model_type != ModelType::kOM) {
MS_LOG(EXCEPTION) << "Invalid ModelType " << model_type;
}
#ifdef ENABLE_ACL
// check om
ge::ModelHelper helper;
ge::ModelData model_data;
model_data.model_data = om_data.MutableData();
model_data.model_len = om_data.DataSize();
ge::Status ret = helper.LoadModel(model_data);
if (ret != ge::SUCCESS) {
MS_LOG(EXCEPTION) << "Invalid input data cannot parse to om.";
}
om_data_ = om_data;
model_type_ = model_type;
#else
MS_LOG(EXCEPTION) << "Unsupported ModelType OM.";
#endif
}
FuncGraphPtr Graph::GraphData::GetFuncGraph() const {
if (model_type_ != ModelType::kMindIR) {
MS_LOG(ERROR) << "Invalid ModelType " << model_type_;
return nullptr;
}
return func_graph_;
}
Buffer Graph::GraphData::GetOMData() const {
if (model_type_ != ModelType::kOM) {
MS_LOG(ERROR) << "Invalid ModelType " << model_type_;
return Buffer();
}
return om_data_;
}
} // namespace mindspore::api

@ -0,0 +1,48 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_DATA_H
#define MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_DATA_H
#include <string>
#include <vector>
#include <map>
#include <memory>
#include "include/api/graph.h"
#include "include/api/types.h"
#include "ir/func_graph.h"
namespace mindspore::api {
class Graph::GraphData {
public:
GraphData();
explicit GraphData(const FuncGraphPtr &func_graph, enum ModelType model_type = kMindIR);
GraphData(Buffer om_data, enum ModelType model_type);
enum ModelType ModelType() const { return model_type_; }
FuncGraphPtr GetFuncGraph() const;
Buffer GetOMData() const;
private:
FuncGraphPtr func_graph_;
Buffer om_data_;
enum ModelType model_type_;
};
} // namespace mindspore::api
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_DATA_H

@ -0,0 +1,51 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H
#define MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H
#include <functional>
#include <map>
#include <string>
#include <vector>
#include <memory>
#include <utility>
#include "include/api/cell.h"
#include "include/api/graph.h"
#include "cxx_api/graph/graph_data.h"
#include "utils/utils.h"
namespace mindspore::api {
class GraphCell::GraphImpl {
public:
GraphImpl() = default;
virtual ~GraphImpl() = default;
std::shared_ptr<Graph::GraphData> &MutableGraphData() const { return graph_->graph_data_; }
void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; }
virtual Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) = 0;
virtual Status Load() = 0;
virtual Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) = 0;
virtual Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) = 0;
protected:
std::shared_ptr<Graph> graph_;
};
} // namespace mindspore::api
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H

File diff suppressed because it is too large Load Diff

@ -0,0 +1,65 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H
#define MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H
#include <functional>
#include <map>
#include <string>
#include <vector>
#include <memory>
#include <utility>
#include "include/api/status.h"
#include "include/api/graph.h"
#include "cxx_api/graph/graph_impl.h"
#include "backend/session/session_basic.h"
#include "ir/anf.h"
#include "cxx_api/model/model_impl.h"
#include "runtime/context.h"
namespace mindspore::api {
class MsGraphImpl : public GraphCell::GraphImpl {
public:
MsGraphImpl();
~MsGraphImpl() override;
Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override;
Status Load() override;
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override;
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override;
private:
Status InitEnv();
Status FinalizeEnv();
Status CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr);
Status CheckModelInputs(const std::vector<tensor::TensorPtr> &inputs) const;
std::vector<tensor::TensorPtr> RunGraph(const std::vector<tensor::TensorPtr> &inputs);
Status ExecuteModel(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs);
std::shared_ptr<session::SessionBasic> session_impl_;
uint32_t graph_id_;
std::string device_type_;
uint32_t device_id_;
rtContext_t context_;
std::vector<tensor::TensorPtr> inputs_;
std::vector<tensor::TensorPtr> outputs_;
std::vector<std::string> input_names_;
std::vector<std::string> output_names_;
bool load_flag_;
};
} // namespace mindspore::api
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H

File diff suppressed because it is too large Load Diff

@ -23,77 +23,38 @@
#include <memory>
#include <map>
#include "ir/anf.h"
#include "include/api/cell.h"
#include "include/api/status.h"
#include "cxx_api/model/model_impl.h"
#include "cxx_api/model/acl/dvpp_process.h"
#include "cxx_api/model/acl/model_process.h"
#include "cxx_api/model/acl/model_converter.h"
#include "cxx_api/model/acl/acl_model_options.h"
#include "ir/tensor.h"
#include "ir/anf.h"
namespace mindspore::api {
class AclModel : public ModelImpl {
public:
explicit AclModel(uint32_t device_id)
: init_flag_(false),
load_flag_(false),
device_type_("AscendCL"),
device_id_(device_id),
context_(nullptr),
stream_(nullptr),
acl_env_(nullptr),
model_process_(),
dvpp_process_(),
model_converter_(),
options_(nullptr) {}
AclModel() : model_converter_(), options_(nullptr), options_str_() {}
~AclModel() = default;
Status LoadModel(const Buffer &model_data, ModelType type,
const std::map<std::string, std::string> &options) override;
Status LoadModel(const std::string &file_name, ModelType type,
const std::map<std::string, std::string> &options) override;
Status UnloadModel() override;
Status Build(const std::map<std::string, std::string> &options_map) override;
Status Train(const DataSet &dataset, std::map<std::string, Buffer> *outputs) override;
Status Eval(const DataSet &dataset, std::map<std::string, Buffer> *outputs) override;
Status Predict(const std::map<std::string, Buffer> &inputs, std::map<std::string, Buffer> *outputs) override;
Status Predict(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override;
Status GetInputsInfo(std::vector<Tensor> *tensor_list) const override;
Status GetOutputsInfo(std::vector<Tensor> *tensor_list) const override;
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const override;
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) const override;
private:
bool init_flag_;
bool load_flag_;
std::string device_type_;
int32_t device_id_;
aclrtContext context_;
aclrtStream stream_;
class AclEnvGuard;
std::shared_ptr<AclEnvGuard> acl_env_;
static std::weak_ptr<AclEnvGuard> global_acl_env_;
static std::mutex global_acl_env_mutex_;
static std::string GenerateOptionsStr(const std::map<std::string, std::string> &options);
ModelProcess model_process_;
DvppProcess dvpp_process_;
std::shared_ptr<GraphCell> graph_cell_;
ModelConverter model_converter_;
std::unique_ptr<AclModelOptions> options_;
Status InitEnv();
Status FinalizeEnv();
};
class AclModel::AclEnvGuard {
public:
explicit AclEnvGuard(const std::string &cfg_file);
~AclEnvGuard();
aclError GetErrno() const { return errno_; }
private:
aclError errno_;
std::string options_str_;
};
API_REG_MODEL(AscendCL, AclModel);
} // namespace mindspore::api
#endif // MINDSPORE_CCSRC_CXX_API_ACL_MODEL_H

File diff suppressed because it is too large Load Diff

@ -1,160 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_DVPP_PROCESS_H
#define MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_DVPP_PROCESS_H
#include <vector>
#include <string>
#include <map>
#include "acl/acl.h"
#include "acl/acl_mdl.h"
#include "acl/acl_rt.h"
#include "acl/ops/acl_dvpp.h"
#include "include/api/status.h"
namespace mindspore::api {
struct DvppDecodePara {
acldvppPixelFormat pixel_format = PIXEL_FORMAT_YUV_SEMIPLANAR_420;
};
struct DvppResizePara {
uint32_t output_width = 0;
uint32_t output_height = 0;
};
enum DvppCropType {
// crop left,top,right,bottom is given in config
kDvppCropTypeOffset = 0,
// crop left,top,right,bottom is calculated by image width/height and output crop width/height
kDvppCropTypeCentre = 1,
};
struct DvppRoiArea {
uint32_t left = 0;
uint32_t top = 0;
uint32_t right = 0;
uint32_t bottom = 0;
};
struct DvppCropInfo {
DvppCropType crop_type = kDvppCropTypeOffset;
DvppRoiArea crop_area; // when kDvppCropTypeOffset
uint32_t crop_width = 0; // when kDvppCropTypeCentre
uint32_t crop_height = 0; // when kDvppCropTypeCentre
};
struct DvppCropPara {
DvppCropInfo crop_info;
uint32_t output_width = 0;
uint32_t output_height = 0;
};
struct DvppCropAndPastePara {
DvppCropInfo crop_info;
DvppRoiArea paste_area;
uint32_t output_width = 0;
uint32_t output_height = 0;
};
class DvppProcess {
public:
DvppProcess();
~DvppProcess();
Status InitResource(aclrtStream stream);
void Finalize();
Status InitJpegDecodePara(const DvppDecodePara &decode_para); // jpeg decode + (resize | crop)
Status InitResizePara(const DvppResizePara &resize_para); // jpeg decode + resize
Status InitCropPara(const DvppCropPara &crop_para); // jpeg decode + crop
Status InitCropAndPastePara(const DvppCropAndPastePara &crop_and_paste_para); // jpeg decode + crop&paste
Status InitWithJsonConfig(const std::string &json_config);
// output device buffer will be destroy by DvppProcess itself.
Status Process(const void *pic_buffer, size_t pic_buffer_size, void **output_device_buffer, size_t *output_size);
Status Process(const std::vector<const void *> &pic_buffer_list, const std::vector<size_t> &pic_buffer_size_list,
void **output_device_buffer, size_t *output_size);
bool HasLoaded() const { return loaded_flag_; }
private:
bool loaded_flag_ = false;
uint32_t pic_width_ = 0;
uint32_t pic_height_ = 0;
DvppDecodePara decode_para_;
DvppResizePara resize_para_;
DvppCropPara crop_para_;
DvppCropAndPastePara crop_and_paste_para_;
// only one of the resize or crop flag can be true
bool to_resize_flag_ = false;
bool to_crop_flag_ = false;
bool to_crop_and_paste_flag_ = false;
void *input_pic_dev_buffer_ = nullptr;
uint32_t input_pic_buffer_size_ = 0;
uint32_t decode_output_buffer_size_ = 0;
void *decode_output_buffer_dev_ = nullptr;
acldvppPicDesc *decode_output_desc_ = nullptr;
acldvppResizeConfig *resize_config_ = nullptr;
acldvppRoiConfig *crop_area_ = nullptr;
acldvppRoiConfig *paste_area_ = nullptr;
acldvppPicDesc *vpc_output_desc_ = nullptr;
void *vpc_output_buffer_dev_ = nullptr; // vpc_output_buffer_size_ length
uint32_t vpc_output_buffer_size_ = 0;
void *batch_vpc_output_buffer_dev_ = nullptr; // batch_size_ * vpc_output_buffer_size_ length
uint32_t batch_size_ = 0;
aclrtStream stream_ = nullptr;
acldvppChannelDesc *dvpp_channel_desc_ = nullptr;
uint32_t AlignmentHelper(uint32_t org_size, uint32_t alignment) const;
uint32_t GetImageBufferSize(uint32_t stride_width, uint32_t stride_height, acldvppPixelFormat pixel_format) const;
Status GetPicDescStride(uint32_t width, uint32_t height, uint32_t *stride_width, uint32_t *stride_height);
Status GetPicDescStrideDecode(uint32_t width, uint32_t height, uint32_t *stride_width, uint32_t *stride_height);
Status InputInputBuffer(const void *pic_buffer, size_t pic_buffer_size);
Status InitDecodeOutputDesc(uint32_t image_width,
uint32_t image_height); // decode_output_desc_, decode_output_buffer_dev_
Status CheckRoiAreaWidthHeight(uint32_t width, uint32_t height);
Status CheckAndAdjustRoiArea(DvppRoiArea *area);
Status UpdateCropArea(uint32_t image_width, uint32_t image_height);
Status CheckResizeImageInfo(uint32_t image_width, uint32_t image_height) const;
void DestroyDecodeDesc();
Status InitVpcOutputDesc(uint32_t output_width, uint32_t output_height,
acldvppPixelFormat pixel_format); // vpc_output_desc_, vpc_output_buffer_dev_batch_
Status InitRoiAreaConfig(const DvppRoiArea &init_para, acldvppRoiConfig **roi_area);
Status InitCommonCropPara(uint32_t out_width, uint32_t out_height, DvppCropInfo *crop_info);
Status InitResizeOutputDesc(); // vpc_output_desc_, vpc_output_buffer_dev_, resize_config
Status InitCropOutputDesc(); // vpc_output_desc_, vpc_output_buffer_dev_, crop_area_
Status InitCropAndPasteOutputDesc(); // vpc_output_desc_, vpc_output_buffer_dev_, crop_area_, paste_area_
void DestroyVpcOutputDesc();
Status ProcessDecode();
Status ProcessResize();
Status ProcessCrop();
Status ProcessCropAndPaste();
void DestroyResource();
Status GetJpegWidthHeight(const void *pic_buffer, size_t pic_buffer_size, uint32_t *image_width,
uint32_t *image_height);
};
} // namespace mindspore::api
#endif // MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_DVPP_PROCESS_H

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save