/** * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef PREDICT_INCLUDE_SESSION_H_ #define PREDICT_INCLUDE_SESSION_H_ #include #include #include #include #include #include "include/context.h" #include "include/tensor.h" #define MSPREDICT_API __attribute__((visibility("default"))) namespace mindspore { namespace predict { using NODE_ID = std::string; ///\brief Graph defined by MindSpore predict. /// ///\note /// The caller does not need to care about detailed implementation of this class, so just list the class name here. class Graph; ///\brief GraphExecution defined by MindSpore predict. /// ///\note /// The caller does not need to care about detailed implementation of this class, so just list the class name here. class GraphExecution; ///\brief MindSpore predict session. /// /// This class represents session of MindSpore predict. /// ///\note /// The caller needs to allocate and free memory of inputs and outputs. /// New Session is not suggested, please use CreateSession function to create new session class. class MSPREDICT_API Session { public: ///\brief Constructor of MindSpore predict session. /// ///\param[in] ctx The context of the session. /// ///\return Instance of MindSpore predict session. explicit Session(const Context &ctx); ///\brief Destructor of MindSpore predict session. ~Session(); ///\brief Init the session. /// ///\param[in] ctx The context of the session. ///\param[in] size The size of the session. ///\param[in] graphBuf The buffer of the graph, used for build session. /// ///\return Return RET_OK if the initialization is success, otherwhise return RET_ERROR. int Init(const char *graphBuf, size_t size); ///\brief Get the input of session. /// ///\return Input node's input tensors if found, empty vector otherwise. /// ///\note /// The caller needs to allocate and free memory of inputs. std::vector GetInput(); ///\brief Run the session. /// ///\param[in] inputs The input of the session. /// ///\return Return RET_OK if run success, otherwhise return RET_ERROR. ///\note /// Currently input tensors' data format only support FORMAT_NCHW. /// Currently input tensors' data type only support FLOAT. int Run(const std::vector &inputs); ///\brief Get the output of session. /// ///\param[in] nodeName Given output node name. /// ///\return Output node's output tensors if found, empty vector otherwise. /// ///\note /// The caller needs to free memory of outputs. std::vector GetOutput(const std::string &nodeName); ///\brief Get the all output of session. /// ///\return Every output node's output tensors. /// ///\note /// The caller needs to free memory of outputs. std::map> GetAllOutput(); protected: ///\brief Init the executor. /// ///\return Return RET_OK if the initialization is success, otherwhise return RET_ERROR. int InitExecutor(); const Context &_ctx; Graph *_graph = nullptr; GraphExecution *_executor = nullptr; bool reinitExecutor = true; }; ///\brief MindSpore predict neural network session create function /// /// This function used to create MindSpore predict neural network session, which will be used to run the neural network. /// ///\param[in] sessionName The name of the session. ///\param[in] graphBuf The buffer of the graph, used for build session. ///\param[in] size The size of the session. ///\param[in] ctx The context of the session. /// ///\return Instance of MindSpore predict session. /// ///\note /// The caller needs to allocate and free memory of graph buffer. std::shared_ptr MSPREDICT_API CreateSession(const char *graphBuf, size_t size, const Context &ctx); } // namespace predict } // namespace mindspore #endif // PREDICT_INCLUDE_SESSION_H_