commit
7a2bb89b30
@ -0,0 +1,64 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*conv_activation_fusion.h
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "src/gllo/fusion/conv_activation_fusion.h"
|
||||||
|
#include <memory>
|
||||||
|
#include "schema/inner/model_generated.h"
|
||||||
|
#include "src/ir/primitive_t_value.h"
|
||||||
|
#include "mindspore/ccsrc/utils/utils.h"
|
||||||
|
#include "src/gllo/common/utils.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
const BaseRef ConvActivationFusion::DefinePattern() const {
|
||||||
|
VarPtr X = std::make_shared<Var>();
|
||||||
|
// conv2d inputs may be 2 or 3 inputs,match move to process
|
||||||
|
auto prim = new schema::PrimitiveT();
|
||||||
|
prim->value.type = primitive_type;
|
||||||
|
auto prim_value = std::make_shared<lite::PrimitiveTValue>(prim);
|
||||||
|
|
||||||
|
return VectorRef({prim_value, X});
|
||||||
|
}
|
||||||
|
|
||||||
|
const AnfNodePtr ConvActivationFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||||
|
const EquivPtr &) const {
|
||||||
|
MS_LOG(DEBUG) << "conv activation pass process:" << schema::EnumNamesPrimitiveType()[primitive_type];
|
||||||
|
CheckIfFuncGraphIsNull(func_graph);
|
||||||
|
|
||||||
|
CheckIfAnfNodeIsNull(node);
|
||||||
|
auto act_node = node->cast<CNodePtr>();
|
||||||
|
CheckIfCNodeIsNull(act_node);
|
||||||
|
CheckInputSize(act_node, 2);
|
||||||
|
|
||||||
|
auto act_primitive = GetValueNode<std::shared_ptr<lite::PrimitiveTValue>>(act_node->input(0));
|
||||||
|
if (act_primitive->GetPrimitiveT()->value.AsActivation()->type != activation_type) {
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
AnfNodePtr pre_node = act_node->input(1);
|
||||||
|
CheckIfAnfNodeIsNull(pre_node);
|
||||||
|
if (pre_node != nullptr && pre_node->isa<CNode>()) {
|
||||||
|
auto conv_node = pre_node->cast<CNodePtr>();
|
||||||
|
auto node_type = GetCNodeType(conv_node);
|
||||||
|
if (node_type == schema::PrimitiveType_Conv2D || node_type == schema::PrimitiveType_DepthwiseConv2D) {
|
||||||
|
auto primitiveT_value = GetValueNode<std::shared_ptr<lite::PrimitiveTValue>>(conv_node->input(0));
|
||||||
|
primitiveT_value->GetPrimitiveT()->value.AsConv2D()->activationType = activation_type;
|
||||||
|
return pre_node;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
@ -0,0 +1,38 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*conv_activation_fusion.h
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_ACTIVATION_FUSION_H_
|
||||||
|
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_ACTIVATION_FUSION_H_
|
||||||
|
|
||||||
|
#include "src/gllo/common/optimizer.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
class ConvActivationFusion : public PatternProcessPass {
|
||||||
|
public:
|
||||||
|
explicit ConvActivationFusion(bool multigraph = true,
|
||||||
|
schema::PrimitiveType primitive = schema::PrimitiveType_LeakyReLU,
|
||||||
|
schema::ActivationType activation = schema::ActivationType_LEAKY_RELU) : primitive_type(
|
||||||
|
primitive), activation_type(activation), PatternProcessPass("conv_activation_fusion", multigraph) {}
|
||||||
|
~ConvActivationFusion() override = default;
|
||||||
|
const BaseRef DefinePattern() const override;
|
||||||
|
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||||
|
schema::PrimitiveType primitive_type;
|
||||||
|
schema::ActivationType activation_type;
|
||||||
|
};
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_ACTIVATION_FUSION_H_
|
@ -0,0 +1,126 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*conv_activation_fusion.h
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "src/gllo/fusion/conv_scale_fusion.h"
|
||||||
|
#include <memory>
|
||||||
|
#include "schema/inner/model_generated.h"
|
||||||
|
#include "src/ir/primitive_t_value.h"
|
||||||
|
#include "src/param_value_lite.h"
|
||||||
|
#include "mindspore/ccsrc/utils/utils.h"
|
||||||
|
#include "src/gllo/common/utils.h"
|
||||||
|
#include "include/errorcode.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
const BaseRef ConvScaleFusion::DefinePattern() const {
|
||||||
|
VarPtr X = std::make_shared<Var>();
|
||||||
|
// conv2d inputs may be 2 or 3 inputs,match move to process
|
||||||
|
auto prim = new schema::PrimitiveT();
|
||||||
|
prim->value.type = schema::PrimitiveType_Scale;
|
||||||
|
auto prim_value = std::make_shared<lite::PrimitiveTValue>(prim);
|
||||||
|
|
||||||
|
return VectorRef({prim_value, X});
|
||||||
|
}
|
||||||
|
|
||||||
|
const AnfNodePtr ConvScaleFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||||
|
const EquivPtr &) const {
|
||||||
|
MS_LOG(DEBUG) << "conv activation pass process";
|
||||||
|
CheckIfFuncGraphIsNull(func_graph);
|
||||||
|
|
||||||
|
CheckIfAnfNodeIsNull(node);
|
||||||
|
auto scale_node = node->cast<CNodePtr>();
|
||||||
|
CheckIfCNodeIsNull(scale_node);
|
||||||
|
CheckInputSize(scale_node, 2);
|
||||||
|
|
||||||
|
AnfNodePtr pre_node = scale_node->input(1);
|
||||||
|
CheckIfAnfNodeIsNull(pre_node);
|
||||||
|
if (pre_node != nullptr && pre_node->isa<CNode>()) {
|
||||||
|
auto conv_node = pre_node->cast<CNodePtr>();
|
||||||
|
auto node_type = GetCNodeType(conv_node);
|
||||||
|
if (node_type == schema::PrimitiveType_Conv2D || node_type == schema::PrimitiveType_DepthwiseConv2D) {
|
||||||
|
return DoFusion(conv_node, scale_node);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return node;
|
||||||
|
}
|
||||||
|
const AnfNodePtr ConvScaleFusion::DoFusion(const CNodePtr &conv_node, const CNodePtr &scale_node) const {
|
||||||
|
if (scale_node->inputs().size() == 3) {
|
||||||
|
GetTransParam(scale_node->input(2), nullptr);
|
||||||
|
} else if (scale_node->inputs().size() == 4) {
|
||||||
|
// todo add bias fusion zhengjun10
|
||||||
|
GetTransParam(scale_node->input(2), scale_node->input(3));
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << "scale inputs size is error:" << scale_node->DebugString();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfNodePtr conv_weight_node;
|
||||||
|
if (conv_node->inputs().size() == 3) {
|
||||||
|
conv_weight_node = conv_node->input(2);
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << "scale inputs size is error:" << scale_node->DebugString();
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
auto conv_weight_param = conv_weight_node->cast<ParameterPtr>()->default_param();
|
||||||
|
auto weight_value = std::dynamic_pointer_cast<ParamValueLite>(conv_weight_param);
|
||||||
|
auto old_conv_weight = reinterpret_cast<const float *>(weight_value->tensor_addr());
|
||||||
|
|
||||||
|
auto new_conv_weight = new(std::nothrow) float[weight_value->tensor_shape_size()];
|
||||||
|
CalNewWeightTensor(old_conv_weight, new_conv_weight, weight_value->tensor_shape_size());
|
||||||
|
weight_value->set_tensor_addr(new_conv_weight);
|
||||||
|
return conv_node;
|
||||||
|
}
|
||||||
|
|
||||||
|
const lite::STATUS ConvScaleFusion::GetTransParam(const AnfNodePtr &scale_weight_node,
|
||||||
|
const AnfNodePtr &scale_bias_node) const {
|
||||||
|
if (!scale_weight_node->isa<Parameter>()) {
|
||||||
|
MS_LOG(EXCEPTION) << "scale weight node not paramter node";
|
||||||
|
}
|
||||||
|
if (scale_bias_node != nullptr && !scale_bias_node->isa<Parameter>()) {
|
||||||
|
MS_LOG(EXCEPTION) << "scale bias node not paramter node";
|
||||||
|
}
|
||||||
|
auto scale_weight_param = scale_weight_node->cast<ParameterPtr>()->default_param();
|
||||||
|
auto weight_value = std::dynamic_pointer_cast<ParamValueLite>(scale_weight_param);
|
||||||
|
auto weight_data = reinterpret_cast<const float *>(weight_value->tensor_addr());
|
||||||
|
|
||||||
|
if (0 != memcpy_s(trans_scale, kernel_nums * sizeof(float), weight_data, kernel_nums * sizeof(float))) {
|
||||||
|
MS_LOG(ERROR) << "memcpy_s transScale failed";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
const lite::STATUS ConvScaleFusion::CalNewWeightTensor(const float *oldWeightTensor, float *newWeightTensor,
|
||||||
|
const size_t tensor_shape_size) const {
|
||||||
|
MS_ASSERT(oldWeightTensor != nullptr);
|
||||||
|
if (0 != memset_s(newWeightTensor, tensor_shape_size * sizeof(float), 0, tensor_shape_size * sizeof(float))) {
|
||||||
|
MS_LOG(ERROR) << "memset newWeightData failed";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
if (kernel_nums == 0) {
|
||||||
|
MS_LOG(ERROR) << "kernel nums is 0";
|
||||||
|
return lite::RET_ERROR;
|
||||||
|
}
|
||||||
|
auto kernel_size = tensor_shape_size / kernel_nums;
|
||||||
|
for (size_t i = 0; i < kernel_nums; i++) {
|
||||||
|
for (size_t j = 0; j < kernel_size; j++) {
|
||||||
|
newWeightTensor[i * kernel_size + j] = oldWeightTensor[i * kernel_size + j] * trans_scale[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return lite::RET_OK;
|
||||||
|
}
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
@ -0,0 +1,40 @@
|
|||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*conv_activation_fusion.h
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_LITE_SRC_PASS_FUSION_CONV_SCALE_FUSION_H_
|
||||||
|
#define MINDSPORE_LITE_SRC_PASS_FUSION_CONV_SCALE_FUSION_H_
|
||||||
|
|
||||||
|
#include "src/gllo/common/optimizer.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace opt {
|
||||||
|
class ConvScaleFusion : public PatternProcessPass {
|
||||||
|
public:
|
||||||
|
explicit ConvScaleFusion(bool multigraph = true) : PatternProcessPass("conv_scale_fusion", multigraph) {}
|
||||||
|
~ConvScaleFusion() override = default;
|
||||||
|
const BaseRef DefinePattern() const override;
|
||||||
|
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||||
|
const AnfNodePtr DoFusion(const CNodePtr &, const CNodePtr &) const;
|
||||||
|
const lite::STATUS GetTransParam(const AnfNodePtr &, const AnfNodePtr &) const;
|
||||||
|
const lite::STATUS CalNewWeightTensor(const float *, float *, const size_t) const;
|
||||||
|
private:
|
||||||
|
float *trans_scale = nullptr;
|
||||||
|
int kernel_nums = 0;
|
||||||
|
};
|
||||||
|
} // namespace opt
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_LITE_SRC_PASS_FUSION_CONV_SCALE_FUSION_H_
|
||||||
|
|
Loading…
Reference in new issue