|
|
|
@ -27,6 +27,7 @@
|
|
|
|
|
#include "ir/dtype.h"
|
|
|
|
|
#include "base/base.h"
|
|
|
|
|
#include "ir/primitive.h"
|
|
|
|
|
#include "ir/kernel_info_dev.h"
|
|
|
|
|
#include "runtime/device/device_address.h"
|
|
|
|
|
#include "backend/kernel_compiler/kernel.h"
|
|
|
|
|
#include "backend/kernel_compiler/kernel_build_info.h"
|
|
|
|
@ -109,7 +110,7 @@ class AnfRuntimeAlgorithm {
|
|
|
|
|
// get output format from prev node,input_index is the input index of current node related to prev node
|
|
|
|
|
static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx);
|
|
|
|
|
// get reshape_type of from the output of input node.
|
|
|
|
|
static std::vector<kernel::Axis> GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx);
|
|
|
|
|
static std::vector<Axis> GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx);
|
|
|
|
|
// get output shapes inferred by ME from input nodes.
|
|
|
|
|
static std::vector<size_t> GetOutputInferShape(const AnfNodePtr &node, size_t output_idx);
|
|
|
|
|
// get input shapes inferred by ME from input nodes.
|
|
|
|
@ -119,9 +120,9 @@ class AnfRuntimeAlgorithm {
|
|
|
|
|
// get input shapes which will built and run in device
|
|
|
|
|
static std::vector<size_t> GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx);
|
|
|
|
|
// Get Input Padding Axis
|
|
|
|
|
static std::vector<kernel::Axis> GetInputReshapeType(const AnfNodePtr &node, size_t output_idx);
|
|
|
|
|
static std::vector<Axis> GetInputReshapeType(const AnfNodePtr &node, size_t output_idx);
|
|
|
|
|
// Get Output Padding Axis
|
|
|
|
|
static std::vector<kernel::Axis> GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx);
|
|
|
|
|
static std::vector<Axis> GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx);
|
|
|
|
|
// get output data type inferred by ME of anf node
|
|
|
|
|
static TypeId GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx);
|
|
|
|
|
// get output original data type from prev node,input_index is the input index of current node related to prev node
|
|
|
|
|