parent
756594d000
commit
3204ecb7d6
@ -0,0 +1,41 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INCLUDE_API_CONTEXT_H
|
||||
#define MINDSPORE_INCLUDE_API_CONTEXT_H
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include "include/api/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
class MS_API Context {
|
||||
public:
|
||||
static Context &Instance();
|
||||
const std::string &GetDeviceTarget() const;
|
||||
Context &SetDeviceTarget(const std::string &device_target);
|
||||
uint32_t GetDeviceID() const;
|
||||
Context &SetDeviceID(uint32_t device_id);
|
||||
|
||||
private:
|
||||
Context();
|
||||
~Context();
|
||||
class ContextImpl;
|
||||
std::shared_ptr<ContextImpl> impl_;
|
||||
};
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_CONTEXT_H
|
@ -0,0 +1,43 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_INCLUDE_API_GRAPH_H
|
||||
#define MINDSPORE_INCLUDE_API_GRAPH_H
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include "include/api/status.h"
|
||||
#include "include/api/types.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace api {
|
||||
class MS_API Graph {
|
||||
public:
|
||||
class GraphData;
|
||||
explicit Graph(const std::shared_ptr<GraphData> &graph_data);
|
||||
explicit Graph(std::shared_ptr<GraphData> &&graph_data);
|
||||
|
||||
enum ModelType ModelType() const;
|
||||
|
||||
private:
|
||||
friend class GraphCell;
|
||||
friend class ModelImpl;
|
||||
std::shared_ptr<GraphData> graph_data_;
|
||||
};
|
||||
} // namespace api
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_INCLUDE_API_GRAPH_H
|
@ -0,0 +1,63 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "include/api/context.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
class Context::ContextImpl {
|
||||
public:
|
||||
ContextImpl() : device_target_("NotSet"), device_id_(0) {}
|
||||
const std::string &GetDeviceTarget() const { return device_target_; }
|
||||
void SetDeviceTarget(std::string_view device_target) { device_target_ = device_target; }
|
||||
uint32_t GetDeviceID() const { return device_id_; }
|
||||
void SetDeviceID(uint32_t device_id) { device_id_ = device_id; }
|
||||
|
||||
private:
|
||||
std::string device_target_;
|
||||
uint32_t device_id_;
|
||||
};
|
||||
|
||||
Context &Context::Instance() {
|
||||
static Context context;
|
||||
return context;
|
||||
}
|
||||
|
||||
const std::string &Context::GetDeviceTarget() const {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->GetDeviceTarget();
|
||||
}
|
||||
|
||||
Context &Context::SetDeviceTarget(const std::string &device_target) {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
impl_->SetDeviceTarget(device_target);
|
||||
return *this;
|
||||
}
|
||||
|
||||
uint32_t Context::GetDeviceID() const {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
return impl_->GetDeviceID();
|
||||
}
|
||||
|
||||
Context &Context::SetDeviceID(uint32_t device_id) {
|
||||
MS_EXCEPTION_IF_NULL(impl_);
|
||||
impl_->SetDeviceID(device_id);
|
||||
return *this;
|
||||
}
|
||||
|
||||
Context::Context() : impl_(std::make_shared<Context::ContextImpl>()) { MS_EXCEPTION_IF_NULL(impl_); }
|
||||
|
||||
Context::~Context() {}
|
||||
} // namespace mindspore::api
|
@ -0,0 +1,83 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_CXX_API_FACTORY_H
|
||||
#define MINDSPORE_CCSRC_CXX_API_FACTORY_H
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "utils/utils.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
template <class T>
|
||||
class Factory {
|
||||
using U = std::function<std::shared_ptr<T>()>;
|
||||
|
||||
public:
|
||||
Factory(const Factory &) = delete;
|
||||
void operator=(const Factory &) = delete;
|
||||
|
||||
static Factory &Instance() {
|
||||
static Factory instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void Register(const std::string &device_name, U &&creator) {
|
||||
if (creators_.find(device_name) == creators_.end()) {
|
||||
(void)creators_.emplace(device_name, creator);
|
||||
}
|
||||
}
|
||||
|
||||
bool CheckModelSupport(const std::string &device_name) {
|
||||
return std::any_of(creators_.begin(), creators_.end(),
|
||||
[&device_name](const std::pair<std::string, U> &item) { return item.first == device_name; });
|
||||
}
|
||||
|
||||
std::shared_ptr<T> Create(const std::string &device_name) {
|
||||
auto iter = creators_.find(device_name);
|
||||
if (creators_.end() != iter) {
|
||||
MS_EXCEPTION_IF_NULL(iter->second);
|
||||
return (iter->second)();
|
||||
}
|
||||
|
||||
MS_LOG(ERROR) << "Unsupported device target " << device_name;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
Factory() = default;
|
||||
~Factory() = default;
|
||||
std::map<std::string, U> creators_;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
class Registrar {
|
||||
using U = std::function<std::shared_ptr<T>()>;
|
||||
|
||||
public:
|
||||
Registrar(const std::string &device_name, U creator) {
|
||||
Factory<T>::Instance().Register(device_name, std::move(creator));
|
||||
}
|
||||
~Registrar() = default;
|
||||
};
|
||||
|
||||
#define API_FACTORY_REG(BASE_CLASS, DEVICE_NAME, DERIVE_CLASS) \
|
||||
static const Registrar<BASE_CLASS> g_api_##DERIVE_CLASS##_registrar_##DEVICE_NAME##_reg( \
|
||||
#DEVICE_NAME, []() { return std::make_shared<DERIVE_CLASS>(); });
|
||||
} // namespace mindspore::api
|
||||
#endif // MINDSPORE_CCSRC_CXX_API_FACTORY_H
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,73 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_ACL_ACL_GRAPH_IMPL_H
|
||||
#define MINDSPORE_CCSRC_CXX_API_GRAPH_ACL_ACL_GRAPH_IMPL_H
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "include/api/graph.h"
|
||||
#include "cxx_api/graph/acl/model_process.h"
|
||||
#include "cxx_api/graph/graph_impl.h"
|
||||
#include "cxx_api/factory.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
class AclGraphImpl : public GraphCell::GraphImpl {
|
||||
public:
|
||||
AclGraphImpl();
|
||||
~AclGraphImpl() override;
|
||||
|
||||
Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override;
|
||||
Status Load() override;
|
||||
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override;
|
||||
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override;
|
||||
|
||||
private:
|
||||
class AclEnvGuard;
|
||||
|
||||
Status ConvertToOM();
|
||||
Status InitEnv();
|
||||
Status FinalizeEnv();
|
||||
Status LoadAclModel(Buffer om_data);
|
||||
|
||||
bool init_flag_;
|
||||
bool load_flag_;
|
||||
std::string device_type_;
|
||||
int32_t device_id_;
|
||||
aclrtContext context_;
|
||||
|
||||
std::shared_ptr<AclEnvGuard> acl_env_;
|
||||
static std::weak_ptr<AclEnvGuard> global_acl_env_;
|
||||
static std::mutex global_acl_env_mutex_;
|
||||
|
||||
ModelProcess model_process_;
|
||||
};
|
||||
|
||||
class AclGraphImpl::AclEnvGuard {
|
||||
public:
|
||||
explicit AclEnvGuard(std::string_view cfg_file);
|
||||
~AclEnvGuard();
|
||||
aclError GetErrno() const { return errno_; }
|
||||
|
||||
private:
|
||||
aclError errno_;
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_ACL_ACL_GRAPH_IMPL_H
|
@ -0,0 +1,29 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "include/api/graph.h"
|
||||
#include "cxx_api/graph/graph_data.h"
|
||||
#include "utils/log_adapter.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
Graph::Graph(const std::shared_ptr<GraphData> &graph_data) : graph_data_(graph_data) {}
|
||||
|
||||
Graph::Graph(std::shared_ptr<GraphData> &&graph_data) : graph_data_(graph_data) {}
|
||||
|
||||
ModelType Graph::ModelType() const {
|
||||
MS_EXCEPTION_IF_NULL(graph_data_);
|
||||
return graph_data_->ModelType();
|
||||
}
|
||||
} // namespace mindspore::api
|
@ -0,0 +1,73 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "cxx_api/graph/graph_data.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#ifdef ENABLE_ACL
|
||||
#include "framework/common/helper/model_helper.h"
|
||||
#endif
|
||||
|
||||
namespace mindspore::api {
|
||||
Graph::GraphData::GraphData(const FuncGraphPtr &func_graph, enum ModelType model_type)
|
||||
: func_graph_(nullptr), om_data_(), model_type_(ModelType::kUnknownType) {
|
||||
if (model_type != ModelType::kMindIR) {
|
||||
MS_LOG(EXCEPTION) << "Invalid ModelType " << model_type;
|
||||
}
|
||||
func_graph_ = func_graph;
|
||||
model_type_ = model_type;
|
||||
}
|
||||
|
||||
Graph::GraphData::GraphData(Buffer om_data, enum ModelType model_type)
|
||||
: func_graph_(nullptr), om_data_(), model_type_(ModelType::kUnknownType) {
|
||||
if (model_type != ModelType::kOM) {
|
||||
MS_LOG(EXCEPTION) << "Invalid ModelType " << model_type;
|
||||
}
|
||||
|
||||
#ifdef ENABLE_ACL
|
||||
// check om
|
||||
ge::ModelHelper helper;
|
||||
ge::ModelData model_data;
|
||||
model_data.model_data = om_data.MutableData();
|
||||
model_data.model_len = om_data.DataSize();
|
||||
ge::Status ret = helper.LoadModel(model_data);
|
||||
if (ret != ge::SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Invalid input data cannot parse to om.";
|
||||
}
|
||||
|
||||
om_data_ = om_data;
|
||||
model_type_ = model_type;
|
||||
#else
|
||||
MS_LOG(EXCEPTION) << "Unsupported ModelType OM.";
|
||||
#endif
|
||||
}
|
||||
|
||||
FuncGraphPtr Graph::GraphData::GetFuncGraph() const {
|
||||
if (model_type_ != ModelType::kMindIR) {
|
||||
MS_LOG(ERROR) << "Invalid ModelType " << model_type_;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return func_graph_;
|
||||
}
|
||||
|
||||
Buffer Graph::GraphData::GetOMData() const {
|
||||
if (model_type_ != ModelType::kOM) {
|
||||
MS_LOG(ERROR) << "Invalid ModelType " << model_type_;
|
||||
return Buffer();
|
||||
}
|
||||
|
||||
return om_data_;
|
||||
}
|
||||
} // namespace mindspore::api
|
@ -0,0 +1,48 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_DATA_H
|
||||
#define MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_DATA_H
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include "include/api/graph.h"
|
||||
#include "include/api/types.h"
|
||||
#include "ir/func_graph.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
class Graph::GraphData {
|
||||
public:
|
||||
GraphData();
|
||||
|
||||
explicit GraphData(const FuncGraphPtr &func_graph, enum ModelType model_type = kMindIR);
|
||||
|
||||
GraphData(Buffer om_data, enum ModelType model_type);
|
||||
|
||||
enum ModelType ModelType() const { return model_type_; }
|
||||
|
||||
FuncGraphPtr GetFuncGraph() const;
|
||||
|
||||
Buffer GetOMData() const;
|
||||
|
||||
private:
|
||||
FuncGraphPtr func_graph_;
|
||||
Buffer om_data_;
|
||||
enum ModelType model_type_;
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_DATA_H
|
@ -0,0 +1,51 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H
|
||||
#define MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "include/api/cell.h"
|
||||
#include "include/api/graph.h"
|
||||
#include "cxx_api/graph/graph_data.h"
|
||||
#include "utils/utils.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
class GraphCell::GraphImpl {
|
||||
public:
|
||||
GraphImpl() = default;
|
||||
virtual ~GraphImpl() = default;
|
||||
|
||||
std::shared_ptr<Graph::GraphData> &MutableGraphData() const { return graph_->graph_data_; }
|
||||
void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; }
|
||||
|
||||
virtual Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) = 0;
|
||||
virtual Status Load() = 0;
|
||||
|
||||
virtual Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) = 0;
|
||||
virtual Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) = 0;
|
||||
|
||||
protected:
|
||||
std::shared_ptr<Graph> graph_;
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
|
||||
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_GRAPH_IMPL_H
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,65 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H
|
||||
#define MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include "include/api/status.h"
|
||||
#include "include/api/graph.h"
|
||||
#include "cxx_api/graph/graph_impl.h"
|
||||
#include "backend/session/session_basic.h"
|
||||
#include "ir/anf.h"
|
||||
#include "cxx_api/model/model_impl.h"
|
||||
#include "runtime/context.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
class MsGraphImpl : public GraphCell::GraphImpl {
|
||||
public:
|
||||
MsGraphImpl();
|
||||
~MsGraphImpl() override;
|
||||
|
||||
Status Run(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs) override;
|
||||
Status Load() override;
|
||||
Status GetInputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override;
|
||||
Status GetOutputsInfo(std::vector<std::string> *names, std::vector<std::vector<int64_t>> *shapes,
|
||||
std::vector<DataType> *data_types, std::vector<size_t> *mem_sizes) override;
|
||||
|
||||
private:
|
||||
Status InitEnv();
|
||||
Status FinalizeEnv();
|
||||
Status CompileGraph(const std::shared_ptr<FuncGraph> &funcGraphPtr);
|
||||
Status CheckModelInputs(const std::vector<tensor::TensorPtr> &inputs) const;
|
||||
std::vector<tensor::TensorPtr> RunGraph(const std::vector<tensor::TensorPtr> &inputs);
|
||||
Status ExecuteModel(const std::vector<Buffer> &inputs, std::vector<Buffer> *outputs);
|
||||
|
||||
std::shared_ptr<session::SessionBasic> session_impl_;
|
||||
uint32_t graph_id_;
|
||||
std::string device_type_;
|
||||
uint32_t device_id_;
|
||||
rtContext_t context_;
|
||||
std::vector<tensor::TensorPtr> inputs_;
|
||||
std::vector<tensor::TensorPtr> outputs_;
|
||||
std::vector<std::string> input_names_;
|
||||
std::vector<std::string> output_names_;
|
||||
bool load_flag_;
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
#endif // MINDSPORE_CCSRC_CXX_API_GRAPH_MS_MS_GRAPH_IMPL_H
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,160 +0,0 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_DVPP_PROCESS_H
|
||||
#define MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_DVPP_PROCESS_H
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include "acl/acl.h"
|
||||
#include "acl/acl_mdl.h"
|
||||
#include "acl/acl_rt.h"
|
||||
#include "acl/ops/acl_dvpp.h"
|
||||
#include "include/api/status.h"
|
||||
|
||||
namespace mindspore::api {
|
||||
struct DvppDecodePara {
|
||||
acldvppPixelFormat pixel_format = PIXEL_FORMAT_YUV_SEMIPLANAR_420;
|
||||
};
|
||||
|
||||
struct DvppResizePara {
|
||||
uint32_t output_width = 0;
|
||||
uint32_t output_height = 0;
|
||||
};
|
||||
|
||||
enum DvppCropType {
|
||||
// crop left,top,right,bottom is given in config
|
||||
kDvppCropTypeOffset = 0,
|
||||
// crop left,top,right,bottom is calculated by image width/height and output crop width/height
|
||||
kDvppCropTypeCentre = 1,
|
||||
};
|
||||
|
||||
struct DvppRoiArea {
|
||||
uint32_t left = 0;
|
||||
uint32_t top = 0;
|
||||
uint32_t right = 0;
|
||||
uint32_t bottom = 0;
|
||||
};
|
||||
|
||||
struct DvppCropInfo {
|
||||
DvppCropType crop_type = kDvppCropTypeOffset;
|
||||
DvppRoiArea crop_area; // when kDvppCropTypeOffset
|
||||
uint32_t crop_width = 0; // when kDvppCropTypeCentre
|
||||
uint32_t crop_height = 0; // when kDvppCropTypeCentre
|
||||
};
|
||||
|
||||
struct DvppCropPara {
|
||||
DvppCropInfo crop_info;
|
||||
uint32_t output_width = 0;
|
||||
uint32_t output_height = 0;
|
||||
};
|
||||
|
||||
struct DvppCropAndPastePara {
|
||||
DvppCropInfo crop_info;
|
||||
DvppRoiArea paste_area;
|
||||
uint32_t output_width = 0;
|
||||
uint32_t output_height = 0;
|
||||
};
|
||||
|
||||
class DvppProcess {
|
||||
public:
|
||||
DvppProcess();
|
||||
~DvppProcess();
|
||||
|
||||
Status InitResource(aclrtStream stream);
|
||||
void Finalize();
|
||||
Status InitJpegDecodePara(const DvppDecodePara &decode_para); // jpeg decode + (resize | crop)
|
||||
Status InitResizePara(const DvppResizePara &resize_para); // jpeg decode + resize
|
||||
Status InitCropPara(const DvppCropPara &crop_para); // jpeg decode + crop
|
||||
Status InitCropAndPastePara(const DvppCropAndPastePara &crop_and_paste_para); // jpeg decode + crop&paste
|
||||
|
||||
Status InitWithJsonConfig(const std::string &json_config);
|
||||
|
||||
// output device buffer will be destroy by DvppProcess itself.
|
||||
Status Process(const void *pic_buffer, size_t pic_buffer_size, void **output_device_buffer, size_t *output_size);
|
||||
Status Process(const std::vector<const void *> &pic_buffer_list, const std::vector<size_t> &pic_buffer_size_list,
|
||||
void **output_device_buffer, size_t *output_size);
|
||||
bool HasLoaded() const { return loaded_flag_; }
|
||||
|
||||
private:
|
||||
bool loaded_flag_ = false;
|
||||
uint32_t pic_width_ = 0;
|
||||
uint32_t pic_height_ = 0;
|
||||
|
||||
DvppDecodePara decode_para_;
|
||||
DvppResizePara resize_para_;
|
||||
DvppCropPara crop_para_;
|
||||
DvppCropAndPastePara crop_and_paste_para_;
|
||||
// only one of the resize or crop flag can be true
|
||||
bool to_resize_flag_ = false;
|
||||
bool to_crop_flag_ = false;
|
||||
bool to_crop_and_paste_flag_ = false;
|
||||
|
||||
void *input_pic_dev_buffer_ = nullptr;
|
||||
uint32_t input_pic_buffer_size_ = 0;
|
||||
|
||||
uint32_t decode_output_buffer_size_ = 0;
|
||||
void *decode_output_buffer_dev_ = nullptr;
|
||||
acldvppPicDesc *decode_output_desc_ = nullptr;
|
||||
|
||||
acldvppResizeConfig *resize_config_ = nullptr;
|
||||
acldvppRoiConfig *crop_area_ = nullptr;
|
||||
acldvppRoiConfig *paste_area_ = nullptr;
|
||||
|
||||
acldvppPicDesc *vpc_output_desc_ = nullptr;
|
||||
void *vpc_output_buffer_dev_ = nullptr; // vpc_output_buffer_size_ length
|
||||
uint32_t vpc_output_buffer_size_ = 0;
|
||||
|
||||
void *batch_vpc_output_buffer_dev_ = nullptr; // batch_size_ * vpc_output_buffer_size_ length
|
||||
uint32_t batch_size_ = 0;
|
||||
|
||||
aclrtStream stream_ = nullptr;
|
||||
acldvppChannelDesc *dvpp_channel_desc_ = nullptr;
|
||||
|
||||
uint32_t AlignmentHelper(uint32_t org_size, uint32_t alignment) const;
|
||||
uint32_t GetImageBufferSize(uint32_t stride_width, uint32_t stride_height, acldvppPixelFormat pixel_format) const;
|
||||
Status GetPicDescStride(uint32_t width, uint32_t height, uint32_t *stride_width, uint32_t *stride_height);
|
||||
Status GetPicDescStrideDecode(uint32_t width, uint32_t height, uint32_t *stride_width, uint32_t *stride_height);
|
||||
Status InputInputBuffer(const void *pic_buffer, size_t pic_buffer_size);
|
||||
Status InitDecodeOutputDesc(uint32_t image_width,
|
||||
uint32_t image_height); // decode_output_desc_, decode_output_buffer_dev_
|
||||
Status CheckRoiAreaWidthHeight(uint32_t width, uint32_t height);
|
||||
Status CheckAndAdjustRoiArea(DvppRoiArea *area);
|
||||
Status UpdateCropArea(uint32_t image_width, uint32_t image_height);
|
||||
Status CheckResizeImageInfo(uint32_t image_width, uint32_t image_height) const;
|
||||
void DestroyDecodeDesc();
|
||||
|
||||
Status InitVpcOutputDesc(uint32_t output_width, uint32_t output_height,
|
||||
acldvppPixelFormat pixel_format); // vpc_output_desc_, vpc_output_buffer_dev_batch_
|
||||
Status InitRoiAreaConfig(const DvppRoiArea &init_para, acldvppRoiConfig **roi_area);
|
||||
Status InitCommonCropPara(uint32_t out_width, uint32_t out_height, DvppCropInfo *crop_info);
|
||||
Status InitResizeOutputDesc(); // vpc_output_desc_, vpc_output_buffer_dev_, resize_config
|
||||
Status InitCropOutputDesc(); // vpc_output_desc_, vpc_output_buffer_dev_, crop_area_
|
||||
Status InitCropAndPasteOutputDesc(); // vpc_output_desc_, vpc_output_buffer_dev_, crop_area_, paste_area_
|
||||
void DestroyVpcOutputDesc();
|
||||
|
||||
Status ProcessDecode();
|
||||
Status ProcessResize();
|
||||
Status ProcessCrop();
|
||||
Status ProcessCropAndPaste();
|
||||
void DestroyResource();
|
||||
|
||||
Status GetJpegWidthHeight(const void *pic_buffer, size_t pic_buffer_size, uint32_t *image_width,
|
||||
uint32_t *image_height);
|
||||
};
|
||||
} // namespace mindspore::api
|
||||
|
||||
#endif // MINDSPORE_CCSRC_CXXAPI_SESSION_ACL_DVPP_PROCESS_H
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue