|
|
|
@ -31,6 +31,7 @@ namespace mindspore {
|
|
|
|
|
namespace lite {
|
|
|
|
|
#define MUL_ADD_MATCH_PATH_LEN 2
|
|
|
|
|
#define ADD_OP_BIAS_INDEX 1
|
|
|
|
|
#define MUL_OP_INPUT_INDEX 0
|
|
|
|
|
#define MUL_OP_BIAS_INDEX 1
|
|
|
|
|
#define MUL_OP_INPUT_NUM 2
|
|
|
|
|
#define ADD_OP_INPUT_NUM 2
|
|
|
|
@ -60,6 +61,23 @@ STATUS MulAddFusionPass::DefinePattern() {
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool ScaleInputShapeValid(const std::vector<int> &input_shape, const std::vector<int> &scale_shape,
|
|
|
|
|
const std::vector<int> &offset_shape) {
|
|
|
|
|
if (input_shape.size() < scale_shape.size() || scale_shape.size() == 0) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
size_t rank_diff = input_shape.size() - scale_shape.size();
|
|
|
|
|
for (size_t i = 0; i < scale_shape.size(); ++i) {
|
|
|
|
|
if (input_shape[i + rank_diff] != scale_shape[i]) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (scale_shape != offset_shape) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS MulAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternName,
|
|
|
|
|
std::unordered_map<std::string, std::shared_ptr<Path>> &matchedPath) {
|
|
|
|
|
MS_ASSERT(graph != nullptr);
|
|
|
|
@ -79,7 +97,7 @@ STATUS MulAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternN
|
|
|
|
|
MS_ASSERT(graph->allTensors.size() > mulNodeInputIndex.at(MUL_OP_BIAS_INDEX));
|
|
|
|
|
const auto &mulNodeBiasTensor = graph->allTensors.at(mulNodeInputIndex.at(MUL_OP_BIAS_INDEX));
|
|
|
|
|
MS_ASSERT(mulNodeBiasTensor != nullptr);
|
|
|
|
|
if (mulNodeBiasTensor->refCount != schema::NodeType::NodeType_ValueNode || mulNodeBiasTensor->dims.size() == 4) {
|
|
|
|
|
if (mulNodeBiasTensor->refCount != schema::NodeType::NodeType_ValueNode) {
|
|
|
|
|
// dont fusion, return
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
@ -96,7 +114,11 @@ STATUS MulAddFusionPass::DoFusion(MetaGraphT *graph, const std::string &patternN
|
|
|
|
|
// dont fusion, return
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// scale requires scale shape tail sub of input shape, scale shape same as bias shape
|
|
|
|
|
const auto &mulNodeInputTensor = graph->allTensors.at(mulNodeInputIndex.at(MUL_OP_INPUT_INDEX));
|
|
|
|
|
if (!ScaleInputShapeValid(mulNodeInputTensor->dims, mulNodeBiasTensor->dims, addNodeBiasTensor->dims)) {
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
// convert mul and add to scale
|
|
|
|
|
auto status = AddNewScaleNode(graph, mulNode, addNode.get(), addNodeInputIndex.at(ADD_OP_BIAS_INDEX));
|
|
|
|
|
if (RET_OK != status) {
|
|
|
|
|