!13515 cpp api modify

From: @zhoufeng54
Reviewed-by: 
Signed-off-by:
pull/13515/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 5ff9c08680

@ -148,7 +148,7 @@ if(PLATFORM_ARM64)
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ascend* ops*" EXCLUDE)
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ops*" EXCLUDE)
install(DIRECTORY ${TOP_DIR}/mindspore/lite/build/operator_library DESTINATION ${CODEGEN_ROOT_DIR}
COMPONENT ${RUNTIME_COMPONENT_NAME})
if(ENABLE_TOOLS)
@ -173,7 +173,7 @@ elseif(PLATFORM_ARM32)
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ascend*" EXCLUDE)
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ops*" EXCLUDE)
install(DIRECTORY ${TOP_DIR}/mindspore/lite/build/operator_library DESTINATION ${CODEGEN_ROOT_DIR}
COMPONENT ${RUNTIME_COMPONENT_NAME})
if(ENABLE_TOOLS)
@ -213,7 +213,7 @@ elseif(WIN32)
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ascend*" EXCLUDE)
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ops*" EXCLUDE)
install(FILES ${TOP_DIR}/build/mindspore/src/libmindspore-lite.a DESTINATION ${RUNTIME_LIB_DIR}
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/build/mindspore/src/libmindspore-lite.dll.a DESTINATION ${RUNTIME_LIB_DIR}
@ -231,7 +231,7 @@ else()
install(FILES ${TOP_DIR}/mindspore/core/ir/dtype/type_id.h DESTINATION ${RUNTIME_INC_DIR}/ir/dtype
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(DIRECTORY ${TOP_DIR}/include/api/ DESTINATION ${RUNTIME_INC_DIR}/api
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ascend*" EXCLUDE)
COMPONENT ${RUNTIME_COMPONENT_NAME} FILES_MATCHING PATTERN "*.h" PATTERN "ops*" EXCLUDE)
install(FILES ${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.so DESTINATION ${RUNTIME_LIB_DIR}
COMPONENT ${RUNTIME_COMPONENT_NAME})
install(FILES ${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.a DESTINATION ${RUNTIME_LIB_DIR}

@ -103,8 +103,9 @@ class MS_API GraphCell final : public Cell<GraphCell> {
std::vector<MSTensor> GetOutputs();
private:
friend class Model;
friend class ModelImpl;
Status Load();
Status Load(uint32_t device_id);
std::shared_ptr<Graph> graph_;
std::shared_ptr<GraphImpl> executor_;

File diff suppressed because it is too large Load Diff

@ -27,6 +27,7 @@ namespace mindspore {
class MS_API Graph {
public:
class GraphData;
Graph();
explicit Graph(const std::shared_ptr<GraphData> &graph_data);
explicit Graph(std::shared_ptr<GraphData> &&graph_data);
explicit Graph(std::nullptr_t);
@ -34,6 +35,7 @@ class MS_API Graph {
enum ModelType ModelType() const;
bool operator==(std::nullptr_t) const;
bool operator!=(std::nullptr_t) const;
private:
friend class GraphCell;

@ -1,71 +0,0 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INCLUDE_API_LITE_CONTEXT_H
#define MINDSPORE_INCLUDE_API_LITE_CONTEXT_H
#include <string>
#include <memory>
#include <map>
#include <any>
#include "include/api/types.h"
#include "include/lite_types.h"
namespace mindspore {
namespace lite {
class Allocator;
} // namespace lite
struct MS_API Context {
public:
static void Clear(const std::shared_ptr<Context> &context);
static void SetAsDefault(const std::shared_ptr<Context> &context);
static void SetVendorName(const std::shared_ptr<Context> &context, const std::string &name);
static std::string GetVendorName(const std::shared_ptr<Context> &context);
static void SetThreadNum(const std::shared_ptr<Context> &context, int num);
static int GetThreadNum(const std::shared_ptr<Context> &context);
static void SetAllocator(const std::shared_ptr<Context> &context, std::shared_ptr<lite::Allocator> alloc);
static std::shared_ptr<lite::Allocator> GetAllocator(const std::shared_ptr<Context> &context);
static void ConfigCPU(const std::shared_ptr<Context> &context, bool config);
static bool IfCPUEnabled(const std::shared_ptr<Context> &context);
static void ConfigCPUFp16(const std::shared_ptr<Context> &context, bool config);
static bool IfCPUFp16Enabled(const std::shared_ptr<Context> &context);
static void SetCPUBindMode(const std::shared_ptr<Context> &context, lite::CpuBindMode mode);
static lite::CpuBindMode GetCPUBindMode(const std::shared_ptr<Context> &context);
static void ConfigGPU(const std::shared_ptr<Context> &context, bool config);
static bool IfGPUEnabled(const std::shared_ptr<Context> &context);
static void ConfigGPUFp16(const std::shared_ptr<Context> &context, bool config);
static bool IfGPUFp16Enabled(const std::shared_ptr<Context> &context);
static void ConfigNPU(const std::shared_ptr<Context> &context, bool config);
static bool IfNPUEnabled(const std::shared_ptr<Context> &context);
static void SetNPUFrequency(const std::shared_ptr<Context> &context, int freq);
static int GetNPUFrequency(const std::shared_ptr<Context> &context);
private:
std::map<std::string, std::any> context_;
};
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_LITE_CONTEXT_H

@ -24,39 +24,52 @@
#include "include/api/status.h"
#include "include/api/types.h"
#include "include/api/graph.h"
#include "include/api/context.h"
#include "include/api/cell.h"
#include "include/api/dual_abi_helper.h"
namespace mindspore {
class ModelImpl;
struct Context;
class MS_API Model {
public:
explicit Model(const std::vector<Output> &network, const std::shared_ptr<Context> &model_context = nullptr);
explicit Model(const GraphCell &graph, const std::shared_ptr<Context> &model_context = nullptr);
Model();
~Model();
Model(const Model &) = delete;
void operator=(const Model &) = delete;
Status Build();
Status Build(GraphCell graph, const std::shared_ptr<Context> &model_context = nullptr);
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs);
std::vector<MSTensor> GetInputs();
inline MSTensor GetInputByTensorName(const std::string &tensor_name);
std::vector<MSTensor> GetOutputs();
inline std::vector<std::string> GetOutputTensorNames();
inline MSTensor GetOutputByTensorName(const std::string &tensor_name);
static inline bool CheckModelSupport(const std::string &device_type, ModelType model_type);
static bool CheckModelSupport(enum DeviceType device_type, ModelType model_type);
private:
// api without std::string
static bool CheckModelSupport(const std::vector<char> &device_type, ModelType model_type);
MSTensor GetInputByTensorName(const std::vector<char> &tensor_name);
std::vector<std::vector<char>> GetOutputTensorNamesChar();
MSTensor GetOutputByTensorName(const std::vector<char> &tensor_name);
std::vector<MSTensor> GetOutputsByNodeName(const std::vector<char> &node_name);
std::shared_ptr<ModelImpl> impl_;
};
bool Model::CheckModelSupport(const std::string &device_type, ModelType model_type) {
return CheckModelSupport(StringToChar(device_type), model_type);
MSTensor Model::GetInputByTensorName(const std::string &tensor_name) {
return GetInputByTensorName(StringToChar(tensor_name));
}
std::vector<std::string> Model::GetOutputTensorNames() { return VectorCharToString(GetOutputTensorNamesChar()); }
MSTensor Model::GetOutputByTensorName(const std::string &tensor_name) {
return GetOutputByTensorName(StringToChar(tensor_name));
}
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_MODEL_H

@ -29,19 +29,19 @@
namespace mindspore {
class MS_API Serialization {
public:
static Graph LoadModel(const void *model_data, size_t data_size, ModelType model_type);
inline static Graph LoadModel(const std::string &file, ModelType model_type);
static Status Load(const void *model_data, size_t data_size, ModelType model_type, Graph *graph);
inline static Status Load(const std::string &file, ModelType model_type, Graph *graph);
static Status LoadCheckPoint(const std::string &ckpt_file, std::map<std::string, Buffer> *parameters);
static Status SetParameters(const std::map<std::string, Buffer> &parameters, Model *model);
static Status ExportModel(const Model &model, ModelType model_type, Buffer *model_data);
static Status ExportModel(const Model &model, ModelType model_type, const std::string &model_file);
private:
static Graph LoadModel(const std::vector<char> &file, ModelType model_type);
static Status Load(const std::vector<char> &file, ModelType model_type, Graph *graph);
};
Graph Serialization::LoadModel(const std::string &file, ModelType model_type) {
return LoadModel(StringToChar(file), model_type);
Status Serialization::Load(const std::string &file, ModelType model_type, Graph *graph) {
return Load(StringToChar(file), model_type, graph);
}
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_API_SERIALIZATION_H

@ -43,15 +43,19 @@ class MS_API MSTensor {
public:
class Impl;
static inline MSTensor CreateTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept;
static inline MSTensor CreateRefTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept;
static inline MSTensor *CreateTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept;
static inline MSTensor *CreateRefTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept;
static inline MSTensor *StringsToTensor(const std::string &name, const std::vector<std::string> &str);
static inline std::vector<std::string> TensorToStrings(const MSTensor &tensor);
static void DestroyTensorPtr(MSTensor *tensor) noexcept;
MSTensor();
explicit MSTensor(const std::shared_ptr<Impl> &impl);
inline MSTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape, const void *data,
size_t data_len);
explicit MSTensor(std::nullptr_t);
~MSTensor();
inline std::string Name() const;
@ -65,21 +69,24 @@ class MS_API MSTensor {
bool IsDevice() const;
MSTensor Clone() const;
MSTensor *Clone() const;
bool operator==(std::nullptr_t) const;
bool operator!=(std::nullptr_t) const;
private:
// api without std::string
static MSTensor CreateTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept;
static MSTensor CreateRefTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept;
static MSTensor *CreateTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept;
static MSTensor *CreateRefTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept;
static MSTensor *CharStringsToTensor(const std::vector<char> &name, const std::vector<std::vector<char>> &str);
static std::vector<std::vector<char>> TensorToStringChars(const MSTensor &tensor);
MSTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
size_t data_len);
std::vector<char> CharName() const;
friend class ModelImpl;
explicit MSTensor(std::nullptr_t);
std::shared_ptr<Impl> impl_;
};
@ -103,16 +110,24 @@ class MS_API Buffer {
std::shared_ptr<Impl> impl_;
};
MSTensor MSTensor::CreateTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept {
MSTensor *MSTensor::CreateTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept {
return CreateTensor(StringToChar(name), type, shape, data, data_len);
}
MSTensor MSTensor::CreateRefTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept {
MSTensor *MSTensor::CreateRefTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept {
return CreateRefTensor(StringToChar(name), type, shape, data, data_len);
}
MSTensor *MSTensor::StringsToTensor(const std::string &name, const std::vector<std::string> &str) {
return CharStringsToTensor(StringToChar(name), VectorStringToChar(str));
}
std::vector<std::string> MSTensor::TensorToStrings(const MSTensor &tensor) {
return VectorCharToString(TensorToStringChars(tensor));
}
MSTensor::MSTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
size_t data_len)
: MSTensor(StringToChar(name), type, shape, data, data_len) {}

@ -1,134 +0,0 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INFERENCE_LOG_H_
#define MINDSPORE_INFERENCE_LOG_H_
#include <stdarg.h>
#include <stdint.h>
#include <string>
#include <sstream>
#include <memory>
#include <iostream>
#include <chrono>
#include <vector>
#ifndef ENABLE_ACL
#include "mindspore/core/utils/log_adapter.h"
#else // ENABLE_ACL
#include "acl/acl.h"
#endif
namespace mindspore::inference {
class LogStream {
public:
LogStream() { sstream_ = std::make_shared<std::stringstream>(); }
~LogStream() = default;
template <typename T>
LogStream &operator<<(const T &val) noexcept {
(*sstream_) << val;
return *this;
}
template <typename T>
LogStream &operator<<(const std::vector<T> &val) noexcept {
(*sstream_) << "[";
for (size_t i = 0; i < val.size(); i++) {
(*this) << val[i];
if (i + 1 < val.size()) {
(*sstream_) << ", ";
}
}
(*sstream_) << "]";
return *this;
}
LogStream &operator<<(std::ostream &func(std::ostream &os)) noexcept {
(*sstream_) << func;
return *this;
}
friend class LogWriter;
friend class Status;
private:
std::shared_ptr<std::stringstream> sstream_;
};
#ifndef ENABLE_ACL
#define MSI_LOG(level) MS_LOG(level)
#define MSI_LOG_DEBUG MSI_LOG(DEBUG)
#define MSI_LOG_INFO MSI_LOG(INFO)
#define MSI_LOG_WARNING MSI_LOG(WARNING)
#define MSI_LOG_ERROR MSI_LOG(ERROR)
#define MSI_ASSERT(item) MS_ASSERT(item)
#else // ENABLE_ACL
class LogWriter {
public:
LogWriter(const char *file, int line, const char *func, aclLogLevel log_level)
: file_(file), line_(line), func_(func), log_level_(log_level) {}
~LogWriter() = default;
void operator<(const LogStream &stream) const noexcept __attribute__((visibility("default"))) {
std::ostringstream msg;
msg << stream.sstream_->rdbuf();
OutputLog(msg);
}
private:
void OutputLog(const std::ostringstream &msg) const { aclAppLog(log_level_, func_, file_, line_, msg.str().c_str()); }
const char *file_;
int line_;
const char *func_;
aclLogLevel log_level_;
};
#define MSILOG_IF(level) inference::LogWriter(__FILE__, __LINE__, __FUNCTION__, ACL_##level) < inference::LogStream()
#define MSI_LOG(level) MSI_LOG_##level
#define MSI_LOG_DEBUG MSILOG_IF(DEBUG)
#define MSI_LOG_INFO MSILOG_IF(INFO)
#define MSI_LOG_WARNING MSILOG_IF(WARNING)
#define MSI_LOG_ERROR MSILOG_IF(ERROR)
#define MSI_ASSERT(item)
#endif // ENABLE_ACL
#define MSI_TIME_STAMP_START(name) auto time_start_##name = std::chrono::steady_clock::now();
#define MSI_TIME_STAMP_END(name) \
{ \
auto time_end_##name = std::chrono::steady_clock::now(); \
auto time_cost = std::chrono::duration<double, std::milli>(time_end_##name - time_start_##name).count(); \
MSI_LOG_INFO << #name " Time Cost # " << time_cost << " ms ---------------------"; \
}
#define INFER_STATUS(code) inference::Status(code) < inference::LogStream()
#define ERROR_INFER_STATUS(status, type, msg) \
MSI_LOG_ERROR << msg; \
status = inference::Status(type, msg)
} // namespace mindspore::inference
#endif // MINDSPORE_INFERENCE_LOG_H_

@ -1,217 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INCLUDE_INFER_TENSOR_H_
#define MINDSPORE_INCLUDE_INFER_TENSOR_H_
#include <utility>
#include <vector>
#include <memory>
#include <numeric>
#include <map>
#include <functional>
#include "securec/include/securec.h"
#include "include/infer_log.h"
namespace mindspore {
#define MS_API __attribute__((visibility("default")))
namespace inference {
enum DataType {
kMSI_Unknown = 0,
kMSI_Bool = 1,
kMSI_Int8 = 2,
kMSI_Int16 = 3,
kMSI_Int32 = 4,
kMSI_Int64 = 5,
kMSI_Uint8 = 6,
kMSI_Uint16 = 7,
kMSI_Uint32 = 8,
kMSI_Uint64 = 9,
kMSI_Float16 = 10,
kMSI_Float32 = 11,
kMSI_Float64 = 12,
};
class InferTensorBase {
public:
InferTensorBase() = default;
virtual ~InferTensorBase() = default;
virtual DataType data_type() const = 0;
virtual void set_data_type(DataType type) = 0;
virtual std::vector<int64_t> shape() const = 0;
virtual void set_shape(const std::vector<int64_t> &shape) = 0;
virtual const void *data() const = 0;
virtual size_t data_size() const = 0;
virtual bool resize_data(size_t data_len) = 0;
virtual void *mutable_data() = 0;
bool set_data(const void *data, size_t data_len) {
resize_data(data_len);
if (mutable_data() == nullptr) {
MSI_LOG_ERROR << "set data failed, data len " << data_len;
return false;
}
if (data_size() != data_len) {
MSI_LOG_ERROR << "set data failed, tensor current data size " << data_size() << " not match data len "
<< data_len;
return false;
}
if (data_len == 0) {
return true;
}
auto ret = memcpy_s(mutable_data(), data_size(), data, data_len);
if (ret != 0) {
MSI_LOG_ERROR << "Set data memcpy_s failed";
return false;
}
return true;
}
int64_t ElementNum() const {
std::vector<int64_t> shapex = shape();
return std::accumulate(shapex.begin(), shapex.end(), 1LL, std::multiplies<int64_t>());
}
int GetTypeSize(DataType type) const {
const std::map<DataType, size_t> type_size_map{
{kMSI_Bool, sizeof(bool)}, {kMSI_Float64, sizeof(double)}, {kMSI_Int8, sizeof(int8_t)},
{kMSI_Uint8, sizeof(uint8_t)}, {kMSI_Int16, sizeof(int16_t)}, {kMSI_Uint16, sizeof(uint16_t)},
{kMSI_Int32, sizeof(int32_t)}, {kMSI_Uint32, sizeof(uint32_t)}, {kMSI_Int64, sizeof(int64_t)},
{kMSI_Uint64, sizeof(uint64_t)}, {kMSI_Float16, sizeof(uint16_t)}, {kMSI_Float32, sizeof(float)},
};
auto it = type_size_map.find(type);
if (it != type_size_map.end()) {
return it->second;
}
return 0;
}
};
class InferTensor : public InferTensorBase {
public:
DataType type_;
std::vector<int64_t> shape_;
std::vector<uint8_t> data_;
public:
InferTensor() = default;
~InferTensor() = default;
InferTensor(DataType type, std::vector<int64_t> shape, const void *data, size_t data_len) {
set_data_type(type);
set_shape(shape);
set_data(data, data_len);
}
void set_data_type(DataType type) override { type_ = type; }
DataType data_type() const override { return type_; }
void set_shape(const std::vector<int64_t> &shape) override { shape_ = shape; }
std::vector<int64_t> shape() const override { return shape_; }
const void *data() const override { return data_.data(); }
size_t data_size() const override { return data_.size(); }
bool resize_data(size_t data_len) override {
data_.resize(data_len);
return true;
}
void *mutable_data() override { return data_.data(); }
};
class InferImagesBase {
public:
InferImagesBase() = default;
virtual ~InferImagesBase() = default;
virtual size_t batch_size() const = 0;
virtual bool get(size_t index, const void *&pic_buffer, uint32_t &pic_size) const = 0;
virtual size_t input_index() const = 0; // the index of images as input in model
};
class RequestBase {
public:
RequestBase() = default;
virtual ~RequestBase() = default;
virtual size_t size() const = 0;
virtual const InferTensorBase *operator[](size_t index) const = 0;
};
class ImagesRequestBase {
public:
ImagesRequestBase() = default;
virtual ~ImagesRequestBase() = default;
virtual size_t size() const = 0;
virtual const InferImagesBase *operator[](size_t index) const = 0;
};
class ReplyBase {
public:
ReplyBase() = default;
virtual ~ReplyBase() = default;
virtual size_t size() const = 0;
virtual InferTensorBase *operator[](size_t index) = 0;
virtual const InferTensorBase *operator[](size_t index) const = 0;
virtual InferTensorBase *add() = 0;
virtual void clear() = 0;
};
class VectorInferTensorWrapReply : public ReplyBase {
public:
explicit VectorInferTensorWrapReply(std::vector<InferTensor> &tensor_list) : tensor_list_(tensor_list) {}
~VectorInferTensorWrapReply() = default;
size_t size() const { return tensor_list_.size(); }
InferTensorBase *operator[](size_t index) {
if (index >= tensor_list_.size()) {
MSI_LOG_ERROR << "visit invalid index " << index << " total size " << tensor_list_.size();
return nullptr;
}
return &(tensor_list_[index]);
}
const InferTensorBase *operator[](size_t index) const {
if (index >= tensor_list_.size()) {
MSI_LOG_ERROR << "visit invalid index " << index << " total size " << tensor_list_.size();
return nullptr;
}
return &(tensor_list_[index]);
}
InferTensorBase *add() {
tensor_list_.push_back(InferTensor());
return &(tensor_list_.back());
}
void clear() { tensor_list_.clear(); }
std::vector<InferTensor> &tensor_list_;
};
class VectorInferTensorWrapRequest : public RequestBase {
public:
explicit VectorInferTensorWrapRequest(const std::vector<InferTensor> &tensor_list) : tensor_list_(tensor_list) {}
~VectorInferTensorWrapRequest() = default;
size_t size() const { return tensor_list_.size(); }
const InferTensorBase *operator[](size_t index) const {
if (index >= tensor_list_.size()) {
MSI_LOG_ERROR << "visit invalid index " << index << " total size " << tensor_list_.size();
return nullptr;
}
return &(tensor_list_[index]);
}
const std::vector<InferTensor> &tensor_list_;
};
} // namespace inference
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_INFER_TENSOR_H_

@ -1,86 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_INCLUDE_MS_SESSION_H
#define MINDSPORE_INCLUDE_MS_SESSION_H
#include <memory>
#include <vector>
#include <string>
#include "include/infer_tensor.h"
#include "include/infer_log.h"
namespace mindspore {
namespace inference {
enum StatusCode { SUCCESS = 0, FAILED, INVALID_INPUTS };
class Status {
public:
Status() : status_code_(FAILED) {}
Status(enum StatusCode status_code, const std::string &status_msg = "")
: status_code_(status_code), status_msg_(status_msg) {}
~Status() = default;
bool IsSuccess() const { return status_code_ == SUCCESS; }
enum StatusCode StatusCode() const { return status_code_; }
std::string StatusMessage() const { return status_msg_; }
bool operator==(const Status &other) const { return status_code_ == other.status_code_; }
bool operator==(enum StatusCode other_code) const { return status_code_ == other_code; }
bool operator!=(const Status &other) const { return status_code_ != other.status_code_; }
bool operator!=(enum StatusCode other_code) const { return status_code_ != other_code; }
operator bool() const = delete;
Status &operator<(const LogStream &stream) noexcept __attribute__((visibility("default"))) {
status_msg_ = stream.sstream_->str();
return *this;
}
private:
enum StatusCode status_code_;
std::string status_msg_;
};
class MS_API InferSession {
public:
InferSession() = default;
virtual ~InferSession() = default;
virtual Status InitEnv(const std::string &device_type, uint32_t device_id) = 0;
virtual Status FinalizeEnv() = 0;
virtual Status LoadModelFromFile(const std::string &file_name, uint32_t &model_id) = 0;
virtual Status UnloadModel(uint32_t model_id) = 0;
// override this method to avoid request/reply data copy
virtual Status ExecuteModel(uint32_t model_id, const RequestBase &request, ReplyBase &reply) = 0;
virtual Status ExecuteModel(uint32_t model_id, const std::vector<InferTensor> &inputs,
std::vector<InferTensor> &outputs) {
VectorInferTensorWrapRequest request(inputs);
VectorInferTensorWrapReply reply(outputs);
return ExecuteModel(model_id, request, reply);
}
// default not support input data preprocess(decode, resize, crop, crop&paste, etc.)
virtual Status ExecuteModel(uint32_t /*model_id*/,
const ImagesRequestBase & /*images_inputs*/, // images for preprocess
const RequestBase & /*request*/, ReplyBase & /*reply*/) {
return FAILED;
}
virtual Status GetModelInputsInfo(uint32_t graph_id, std::vector<inference::InferTensor> *tensor_list) const {
Status status(SUCCESS);
return status;
}
static std::shared_ptr<InferSession> CreateSession(const std::string &device, uint32_t device_id);
};
} // namespace inference
} // namespace mindspore
#endif // MINDSPORE_INCLUDE_MS_SESSION_H

@ -21,12 +21,19 @@
namespace mindspore {
std::vector<Output> CellBase::operator()(const std::vector<Input> &inputs) const { return Clone()->Construct(inputs); }
ParameterCell::ParameterCell(const ParameterCell &cell) : tensor_(cell.tensor_.Clone()) {}
ParameterCell::ParameterCell(const ParameterCell &cell) {
auto tmp_ptr = cell.tensor_.Clone();
tensor_ = *tmp_ptr;
MSTensor::DestroyTensorPtr(tmp_ptr);
}
ParameterCell &ParameterCell::operator=(const ParameterCell &cell) {
if (&cell == this) {
return *this;
}
tensor_ = cell.tensor_.Clone();
auto tmp_ptr = cell.tensor_.Clone();
tensor_ = *tmp_ptr;
MSTensor::DestroyTensorPtr(tmp_ptr);
return *this;
}
@ -40,10 +47,16 @@ ParameterCell &ParameterCell::operator=(ParameterCell &&cell) {
return *this;
}
ParameterCell::ParameterCell(const MSTensor &tensor) : tensor_(tensor.Clone()) {}
ParameterCell::ParameterCell(const MSTensor &tensor) {
auto tmp_ptr = tensor.Clone();
tensor_ = *tmp_ptr;
MSTensor::DestroyTensorPtr(tmp_ptr);
}
ParameterCell &ParameterCell::operator=(const MSTensor &tensor) {
tensor_ = tensor.Clone();
auto tmp_ptr = tensor.Clone();
tensor_ = *tmp_ptr;
MSTensor::DestroyTensorPtr(tmp_ptr);
return *this;
}
@ -54,54 +67,67 @@ ParameterCell &ParameterCell::operator=(MSTensor &&tensor) {
return *this;
}
GraphCell::GraphCell(const Graph &graph)
: graph_(std::make_shared<Graph>(graph)),
executor_(Factory<GraphCell::GraphImpl>::Instance().Create(GlobalContext::GetGlobalDeviceTarget())) {
MS_EXCEPTION_IF_NULL(graph_);
MS_EXCEPTION_IF_NULL(executor_);
executor_->SetGraph(graph_);
}
GraphCell::GraphCell(const Graph &graph) : graph_(std::make_shared<Graph>(graph)) { MS_EXCEPTION_IF_NULL(graph_); }
GraphCell::GraphCell(const std::shared_ptr<Graph> &graph)
: graph_(graph),
executor_(Factory<GraphCell::GraphImpl>::Instance().Create(GlobalContext::GetGlobalDeviceTarget())) {
MS_EXCEPTION_IF_NULL(graph_);
MS_EXCEPTION_IF_NULL(executor_);
executor_->SetGraph(graph_);
}
GraphCell::GraphCell(const std::shared_ptr<Graph> &graph) : graph_(graph) { MS_EXCEPTION_IF_NULL(graph_); }
GraphCell::GraphCell(Graph &&graph)
: graph_(std::make_shared<Graph>(graph)),
executor_(Factory<GraphCell::GraphImpl>::Instance().Create(GlobalContext::GetGlobalDeviceTarget())) {
MS_EXCEPTION_IF_NULL(graph_);
MS_EXCEPTION_IF_NULL(executor_);
executor_->SetGraph(graph_);
}
GraphCell::GraphCell(Graph &&graph) : graph_(std::make_shared<Graph>(graph)) { MS_EXCEPTION_IF_NULL(graph_); }
Status GraphCell::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
MS_EXCEPTION_IF_NULL(executor_);
if (executor_ == nullptr) {
executor_ = Factory<GraphCell::GraphImpl>::Instance().Create(g_device_target);
if (executor_ == nullptr) {
MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed.";
return kMEFailed;
}
executor_->SetGraph(graph_);
}
return executor_->Run(inputs, outputs);
}
Status GraphCell::Load() {
MS_EXCEPTION_IF_NULL(executor_);
return executor_->Load();
Status GraphCell::Load(uint32_t device_id) {
if (executor_ == nullptr) {
executor_ = Factory<GraphCell::GraphImpl>::Instance().Create(g_device_target);
if (executor_ == nullptr) {
MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed.";
return kMEFailed;
}
executor_->SetGraph(graph_);
}
return executor_->Load(device_id);
}
std::vector<MSTensor> GraphCell::GetInputs() {
MS_EXCEPTION_IF_NULL(executor_);
if (executor_ == nullptr) {
executor_ = Factory<GraphCell::GraphImpl>::Instance().Create(g_device_target);
if (executor_ == nullptr) {
MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed.";
return {};
}
executor_->SetGraph(graph_);
}
return executor_->GetInputs();
}
std::vector<MSTensor> GraphCell::GetOutputs() {
MS_EXCEPTION_IF_NULL(executor_);
if (executor_ == nullptr) {
executor_ = Factory<GraphCell::GraphImpl>::Instance().Create(g_device_target);
if (executor_ == nullptr) {
MS_LOG(ERROR) << "Create graph impl for device target " << g_device_target << " failed.";
return {};
}
executor_->SetGraph(graph_);
}
return executor_->GetOutputs();
}
InputAndOutput::InputAndOutput() : cell_(nullptr), prev_(), index_(-1) {}
InputAndOutput::InputAndOutput(const MSTensor &tensor)
: cell_(std::make_shared<ParameterCell>(tensor.Clone())), prev_(), index_(-1) {}
InputAndOutput::InputAndOutput(const MSTensor &tensor) : prev_(), index_(-1) {
auto tmp_ptr = tensor.Clone();
cell_ = std::make_shared<ParameterCell>(*tmp_ptr);
MSTensor::DestroyTensorPtr(tmp_ptr);
}
InputAndOutput::InputAndOutput(MSTensor &&tensor)
: cell_(std::make_shared<ParameterCell>(tensor)), prev_(), index_(-1) {}

File diff suppressed because it is too large Load Diff

@ -24,6 +24,8 @@
#include "utils/utils.h"
namespace mindspore {
inline std::string g_device_target = "Default";
template <class T>
class Factory {
using U = std::function<std::shared_ptr<T>()>;

@ -45,6 +45,9 @@ std::shared_ptr<AclEnvGuard> AclEnvGuard::GetAclEnv(std::string_view cfg_file) {
acl_env = global_acl_env_;
if (acl_env != nullptr) {
MS_LOG(INFO) << "Acl has been initialized, skip.";
if (!cfg_file.empty()) {
MS_LOG(WARNING) << "Dump config file option " << cfg_file << " is ignored.";
}
} else {
acl_env = std::make_shared<AclEnvGuard>(cfg_file);
aclError ret = acl_env->GetErrno();

@ -25,7 +25,7 @@ AclGraphImpl::AclGraphImpl()
: init_flag_(false),
load_flag_(false),
device_type_("AscendCL"),
device_id_(GlobalContext::GetGlobalDeviceID()),
device_id_(0),
context_(nullptr),
acl_env_(nullptr) {}
@ -33,7 +33,7 @@ AclGraphImpl::~AclGraphImpl() { (void)FinalizeEnv(); }
Status AclGraphImpl::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
MS_EXCEPTION_IF_NULL(outputs);
Status ret = Load();
Status ret = Load(device_id_);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Prepare model resource failed.";
return ret;
@ -43,7 +43,7 @@ Status AclGraphImpl::Run(const std::vector<MSTensor> &inputs, std::vector<MSTens
}
std::vector<MSTensor> AclGraphImpl::GetInputs() {
Status ret = Load();
Status ret = Load(device_id_);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Prepare model resource failed.";
return {};
@ -53,7 +53,7 @@ std::vector<MSTensor> AclGraphImpl::GetInputs() {
}
std::vector<MSTensor> AclGraphImpl::GetOutputs() {
Status ret = Load();
Status ret = Load(device_id_);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Prepare model resource failed.";
return {};
@ -90,7 +90,7 @@ Status AclGraphImpl::InitEnv() {
return kSuccess;
}
acl_env_ = AclEnvGuard::GetAclEnv(GlobalContext::GetGlobalDumpConfigPath());
acl_env_ = AclEnvGuard::GetAclEnv("");
if (acl_env_ == nullptr) {
MS_LOG(ERROR) << "Acl init failed.";
return kMCDeviceError;
@ -161,7 +161,7 @@ Status AclGraphImpl::FinalizeEnv() {
return kSuccess;
}
Status AclGraphImpl::Load() {
Status AclGraphImpl::Load(uint32_t device_id) {
// check graph type
if (graph_->ModelType() != ModelType::kOM) {
Status ret = ConvertToOM();
@ -176,6 +176,7 @@ Status AclGraphImpl::Load() {
auto om_data = graph_data->GetOMData();
// init
device_id_ = device_id;
Status ret = InitEnv();
if (ret != kSuccess) {
MS_LOG(ERROR) << "InitEnv failed.";

@ -34,7 +34,7 @@ class AclGraphImpl : public GraphCell::GraphImpl {
~AclGraphImpl() override;
Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
Status Load() override;
Status Load(uint32_t device_id) override;
std::vector<MSTensor> GetInputs() override;
std::vector<MSTensor> GetOutputs() override;

@ -39,7 +39,7 @@ AscendGraphImpl::AscendGraphImpl()
: session_impl_(nullptr),
graph_id_(0),
device_type_("Ascend"),
device_id_(GlobalContext::GetGlobalDeviceID()),
device_id_(0),
context_(nullptr),
inputs_info_(),
outputs_info_(),
@ -142,7 +142,7 @@ Status AscendGraphImpl::ExecuteModel(const std::vector<MSTensor> &request, std::
std::vector<MSTensor> AscendGraphImpl::GetInputs() {
if (!load_flag_) {
Status ret = Load();
Status ret = Load(device_id_);
if (ret != kSuccess) {
MS_LOG(ERROR) << "PrepareModel failed.";
return {};
@ -166,7 +166,7 @@ std::vector<MSTensor> AscendGraphImpl::GetInputs() {
std::vector<MSTensor> AscendGraphImpl::GetOutputs() {
if (!load_flag_) {
Status ret = Load();
Status ret = Load(device_id_);
if (ret != kSuccess) {
MS_LOG(ERROR) << "PrepareModel failed.";
return {};
@ -188,7 +188,7 @@ std::vector<MSTensor> AscendGraphImpl::GetOutputs() {
return result;
}
Status AscendGraphImpl::Load() {
Status AscendGraphImpl::Load(uint32_t device_id) {
// check graph type
if (graph_->ModelType() != ModelType::kMindIR) {
MS_LOG(ERROR) << "Unsupported model type " << graph_->ModelType();
@ -200,6 +200,7 @@ Status AscendGraphImpl::Load() {
auto func_graph = graph_data->GetFuncGraph();
// init
device_id_ = device_id;
Status ret = InitEnv();
if (ret != kSuccess) {
MS_LOG(ERROR) << "InitEnv failed.";
@ -247,7 +248,7 @@ Status AscendGraphImpl::Load() {
Status AscendGraphImpl::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
MS_EXCEPTION_IF_NULL(outputs);
if (!load_flag_) {
Status ret = Load();
Status ret = Load(device_id_);
if (ret != kSuccess) {
MS_LOG(ERROR) << "PrepareModel failed.";
return ret;

@ -36,7 +36,7 @@ class AscendGraphImpl : public GraphCell::GraphImpl {
~AscendGraphImpl() override;
Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
Status Load() override;
Status Load(uint32_t device_id) override;
std::vector<MSTensor> GetInputs() override;
std::vector<MSTensor> GetOutputs() override;

@ -30,7 +30,7 @@ API_FACTORY_REG(GraphCell::GraphImpl, GPU, GPUGraphImpl);
GPUGraphImpl::GPUGraphImpl()
: session_impl_(nullptr),
graph_id_(0),
device_id_(GlobalContext::GetGlobalDeviceID()),
device_id_(0),
inputs_info_(),
outputs_info_(),
input_names_(),
@ -83,7 +83,7 @@ Status GPUGraphImpl::FinalizeEnv() {
return kSuccess;
}
Status GPUGraphImpl::Load() {
Status GPUGraphImpl::Load(uint32_t device_id) {
// check graph type
if (graph_->ModelType() != ModelType::kMindIR) {
MS_LOG(ERROR) << "Unsupported model type " << graph_->ModelType();
@ -95,6 +95,7 @@ Status GPUGraphImpl::Load() {
auto func_graph = graph_data->GetFuncGraph();
// init
device_id_ = device_id;
Status ret = InitEnv();
if (ret != kSuccess) {
MS_LOG(ERROR) << "InitEnv failed.";
@ -176,7 +177,7 @@ Status GPUGraphImpl::ExecuteModel(const std::vector<MSTensor> &request, std::vec
Status GPUGraphImpl::Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) {
MS_EXCEPTION_IF_NULL(outputs);
if (!load_flag_) {
Status ret = Load();
Status ret = Load(device_id_);
if (ret != kSuccess) {
MS_LOG(ERROR) << "PrepareModel failed.";
return ret;
@ -211,7 +212,7 @@ Status GPUGraphImpl::Run(const std::vector<MSTensor> &inputs, std::vector<MSTens
std::vector<MSTensor> GPUGraphImpl::GetInputs() {
if (!load_flag_) {
Status ret = Load();
Status ret = Load(device_id_);
if (ret != kSuccess) {
MS_LOG(ERROR) << "PrepareModel failed.";
return {};
@ -235,7 +236,7 @@ std::vector<MSTensor> GPUGraphImpl::GetInputs() {
std::vector<MSTensor> GPUGraphImpl::GetOutputs() {
if (!load_flag_) {
Status ret = Load();
Status ret = Load(device_id_);
if (ret != kSuccess) {
MS_LOG(ERROR) << "PrepareModel failed.";
return {};

@ -33,7 +33,7 @@ class GPUGraphImpl : public GraphCell::GraphImpl {
~GPUGraphImpl() override = default;
Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) override;
Status Load() override;
Status Load(uint32_t device_id) override;
std::vector<MSTensor> GetInputs() override;
std::vector<MSTensor> GetOutputs() override;

@ -18,6 +18,8 @@
#include "utils/log_adapter.h"
namespace mindspore {
Graph::Graph() : graph_data_(nullptr) {}
Graph::Graph(const std::shared_ptr<GraphData> &graph_data) : graph_data_(graph_data) {}
Graph::Graph(std::shared_ptr<GraphData> &&graph_data) : graph_data_(graph_data) {}
@ -28,6 +30,8 @@ Graph::Graph(std::nullptr_t) : graph_data_(nullptr) {}
bool Graph::operator==(std::nullptr_t) const { return graph_data_ == nullptr; }
bool Graph::operator!=(std::nullptr_t) const { return graph_data_ != nullptr; }
ModelType Graph::ModelType() const {
MS_EXCEPTION_IF_NULL(graph_data_);
return graph_data_->ModelType();

@ -36,7 +36,7 @@ class GraphCell::GraphImpl {
void SetGraph(const std::shared_ptr<Graph> &graph) { graph_ = graph; }
virtual Status Run(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs) = 0;
virtual Status Load() = 0;
virtual Status Load(uint32_t device_id) = 0;
virtual std::vector<MSTensor> GetInputs() = 0;
virtual std::vector<MSTensor> GetOutputs() = 0;

@ -18,6 +18,7 @@
#include <memory>
#include "include/api/context.h"
#include "cxx_api/factory.h"
#include "cxx_api/graph/acl/acl_env_guard.h"
namespace mindspore {
API_FACTORY_REG(ModelImpl, Ascend310, AclModel);
@ -40,6 +41,11 @@ Status AclModel::Build() {
std::unique_ptr<AclModelOptions> options = std::make_unique<AclModelOptions>(model_context_);
MS_EXCEPTION_IF_NULL(options);
std::string dump_cfg = options->GetDumpCfgPath();
if (!dump_cfg.empty()) {
MS_LOG(INFO) << "Options dump config file path " << dump_cfg;
(void)AclEnvGuard::GetAclEnv(dump_cfg);
}
std::string options_key = options->GenAclOptionsKey();
std::shared_ptr<Graph> graph;
if (auto iter = dynamic_size_graph_map_.find(options_key); iter != dynamic_size_graph_map_.end()) {
@ -75,7 +81,7 @@ Status AclModel::Build() {
MS_EXCEPTION_IF_NULL(graph);
auto graph_cell = std::make_shared<GraphCell>(graph);
MS_EXCEPTION_IF_NULL(graph_cell);
auto ret = ModelImpl::Load(graph_cell);
auto ret = ModelImpl::Load(graph_cell, options->GetDeviceID());
if (ret != kSuccess) {
MS_LOG(ERROR) << "Load failed.";
return ret;
@ -108,7 +114,8 @@ Status AclModel::Resize(const std::vector<MSTensor> &inputs, const std::vector<s
}
if (model_context_ == nullptr) {
model_context_ = std::make_shared<ModelContext>();
model_context_ = std::make_shared<Context>();
model_context_->MutableDeviceInfo().emplace_back(std::make_shared<Ascend310DeviceInfo>());
}
std::string input_shape_option;
@ -130,7 +137,14 @@ Status AclModel::Resize(const std::vector<MSTensor> &inputs, const std::vector<s
}
}
MS_LOG(INFO) << "Set input size option is " << input_shape_option;
ModelContext::SetInputShape(model_context_, input_shape_option);
auto &device_infos = model_context_->MutableDeviceInfo();
if (device_infos.size() != 1) {
MS_LOG(ERROR) << "Invalid model context, only single device info is supported.";
return kMCInvalidArgs;
}
auto ascend310_info = device_infos[0]->Cast<Ascend310DeviceInfo>();
MS_EXCEPTION_IF_NULL(ascend310_info);
ascend310_info->SetInputShape(input_shape_option);
auto graph_cell_bak = std::move(graph_cell_);
auto ret = Build();
if (ret != kSuccess) {

@ -27,10 +27,19 @@ AclModelOptions::AclModelOptions(const std::shared_ptr<Context> &context) {
if (context == nullptr) {
return;
}
insert_op_cfg_path_ = ModelContext::GetInsertOpConfigPath(context);
input_format_ = ModelContext::GetInputFormat(context);
input_shape_map_ = ModelContext::GetInputShapeMap(context);
auto out_type = ModelContext::GetOutputType(context);
auto &device_infos = context->MutableDeviceInfo();
if (device_infos.size() != 1) {
return;
}
auto ascend310_info = device_infos[0]->Cast<Ascend310DeviceInfo>();
if (ascend310_info == nullptr) {
return;
}
insert_op_cfg_path_ = ascend310_info->GetInsertOpConfigPath();
input_format_ = ascend310_info->GetInputFormat();
input_shape_map_ = ascend310_info->GetInputShapeMap();
auto out_type = ascend310_info->GetOutputType();
auto iter = kSupportedDtypeOptionMap.find(out_type);
if (out_type == DataType::kTypeUnknown) {
// do nothing
@ -39,10 +48,12 @@ AclModelOptions::AclModelOptions(const std::shared_ptr<Context> &context) {
} else {
output_type_ = iter->second;
}
dynamic_batch_size_ = ModelContext::GetDynamicBatchSize(context);
precision_mode_ = ModelContext::GetPrecisionMode(context);
op_select_impl_mode_ = ModelContext::GetOpSelectImplMode(context);
fusion_switch_cfg_path_ = ModelContext::GetFusionSwitchConfigPath(context);
dynamic_batch_size_ = ascend310_info->GetDynamicBatchSize();
precision_mode_ = ascend310_info->GetPrecisionMode();
op_select_impl_mode_ = ascend310_info->GetOpSelectImplMode();
fusion_switch_cfg_path_ = ascend310_info->GetFusionSwitchConfigPath();
device_id_ = ascend310_info->GetDeviceID();
dump_cfg_path_ = ascend310_info->GetDumpConfigPath();
}
void AclModelOptions::RenameInput(const std::vector<std::string> &input_names) {

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save