!13996 [lite]gelu fusion and fp16 slice
From: @xu_anyue Reviewed-by: @hangangqiang,@HilbertDavid Signed-off-by: @hangangqiangpull/13996/MERGE
commit
27f7bfef50
@ -0,0 +1,75 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "src/runtime/kernel/arm/fp16/slice_fp16.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "nnacl/base/slice_base.h"
|
||||
#include "nnacl/fp16/cast_fp16.h"
|
||||
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::lite::RET_ERROR;
|
||||
using mindspore::lite::RET_OK;
|
||||
using mindspore::schema::PrimitiveType_SliceFusion;
|
||||
|
||||
namespace mindspore::kernel {
|
||||
int SliceFp16Launch(void *cdata, int task_id) {
|
||||
if (cdata == nullptr) {
|
||||
MS_LOG(ERROR) << "Input cdata is nullptr!";
|
||||
return RET_ERROR;
|
||||
}
|
||||
auto kernel = reinterpret_cast<SliceFp16CPUKernel *>(cdata);
|
||||
return kernel->SliceFp16ParallelRun(task_id);
|
||||
}
|
||||
|
||||
SliceFp16CPUKernel::~SliceFp16CPUKernel() {
|
||||
if (input_data_ != nullptr) {
|
||||
context_->allocator->Free(input_data_);
|
||||
input_data_ = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
int SliceFp16CPUKernel::Init() {
|
||||
auto input_tensor = in_tensors_.at(0);
|
||||
if (input_tensor->data_type() == kNumberTypeFloat32 && input_tensor->data_c() != nullptr) {
|
||||
input_data_ =
|
||||
reinterpret_cast<float16_t *>(context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t)));
|
||||
Float32ToFloat16(reinterpret_cast<float *>(input_tensor->data_c()), input_data_, input_tensor->ElementsNum());
|
||||
}
|
||||
return SliceCPUKernel::Init();
|
||||
}
|
||||
|
||||
int SliceFp16CPUKernel::SliceFp16ParallelRun(int thread_id) {
|
||||
void *input_data = input_data_ == nullptr ? in_tensors_.at(0)->data_c() : input_data_;
|
||||
DoSlice(input_data, out_tensors_.at(0)->data_c(), param_, thread_id, lite::DataTypeSize(kNumberTypeFloat16));
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int SliceFp16CPUKernel::Run() {
|
||||
void *input_data = input_data_ == nullptr ? in_tensors_.at(0)->data_c() : input_data_;
|
||||
if (param_->size_[1] < op_parameter_->thread_num_) {
|
||||
DoSliceNoParallel(input_data, out_tensors_.at(0)->data_c(), param_, lite::DataTypeSize(kNumberTypeFloat16));
|
||||
return RET_OK;
|
||||
}
|
||||
auto ret = ParallelLaunch(this->context_->thread_pool_, SliceFp16Launch, this, op_parameter_->thread_num_);
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "fp16 slice launch fail!ret: " << ret;
|
||||
return RET_ERROR;
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_SliceFusion, LiteKernelCreator<SliceFp16CPUKernel>)
|
||||
} // namespace mindspore::kernel
|
@ -0,0 +1,41 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_SLICE_FP16_H_
|
||||
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_SLICE_FP16_H_
|
||||
|
||||
#include <vector>
|
||||
#include "src/lite_kernel.h"
|
||||
#include "src/runtime/kernel/arm/base/slice_base.h"
|
||||
|
||||
namespace mindspore::kernel {
|
||||
class SliceFp16CPUKernel : public SliceCPUKernel {
|
||||
public:
|
||||
SliceFp16CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
|
||||
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx)
|
||||
: SliceCPUKernel(parameter, inputs, outputs, ctx) {}
|
||||
~SliceFp16CPUKernel() override;
|
||||
|
||||
int Init() override;
|
||||
int Run() override;
|
||||
int SliceFp16ParallelRun(int thread_id);
|
||||
|
||||
private:
|
||||
float16_t *input_data_ = nullptr;
|
||||
};
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_SLICE_FP16_H_
|
@ -0,0 +1,85 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/optimizer/fusion/gelu_fusion.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "ops/fusion/activation.h"
|
||||
#include "utils/utils.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
CNodePtr GeLUFusion::CreateGeLUNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
MS_ASSERT(node != nullptr);
|
||||
auto gelu_prim = std::make_shared<ops::Activation>();
|
||||
gelu_prim->set_activation_type(mindspore::GELU);
|
||||
auto input_node = utils::cast<AnfNodePtr>((*equiv)[input_]);
|
||||
MS_ASSERT(input_node != nullptr);
|
||||
auto gelu_cnode = func_graph->NewCNode(gelu_prim, {input_node});
|
||||
gelu_cnode->set_fullname_with_scope(node->fullname_with_scope() + "_gelu");
|
||||
gelu_cnode->set_abstract(node->abstract()->Clone());
|
||||
return gelu_cnode;
|
||||
}
|
||||
|
||||
const float GeLUFusion::GetParameterValue(const EquivPtr &equiv, const VarPtr &input) const {
|
||||
MS_ASSERT(equiv != nullptr);
|
||||
MS_ASSERT(input != nullptr);
|
||||
float value = -1;
|
||||
auto node = utils::cast<AnfNodePtr>((*equiv)[input]);
|
||||
if (node == nullptr || !utils::isa<ParameterPtr>(node)) {
|
||||
return value;
|
||||
}
|
||||
auto parameter_node = node->cast<ParameterPtr>();
|
||||
if (!parameter_node->has_default() || parameter_node->default_param() == nullptr) {
|
||||
return value;
|
||||
}
|
||||
auto param_value_lite = parameter_node->default_param()->cast<ParamValueLitePtr>();
|
||||
if (param_value_lite == nullptr) {
|
||||
return value;
|
||||
}
|
||||
if (param_value_lite->tensor_type() != kNumberTypeFloat32 && param_value_lite->tensor_type() != kNumberTypeFloat) {
|
||||
return value;
|
||||
}
|
||||
if (param_value_lite->tensor_size() != sizeof(float)) {
|
||||
return value;
|
||||
}
|
||||
return *static_cast<float *>(param_value_lite->tensor_addr());
|
||||
}
|
||||
|
||||
const AnfNodePtr GeLUFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
|
||||
const EquivPtr &equiv) const {
|
||||
MS_ASSERT(func_graph != nullptr);
|
||||
MS_ASSERT(node != nullptr);
|
||||
MS_ASSERT(equiv != nullptr);
|
||||
MS_LOG(DEBUG) << "gelu_fusion pass";
|
||||
if (!utils::isa<CNodePtr>(node)) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!CheckPattern(equiv)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto cnode = CreateGeLUNode(func_graph, node, equiv);
|
||||
if (cnode == nullptr) {
|
||||
MS_LOG(DEBUG) << "new gelu node failed.";
|
||||
return nullptr;
|
||||
}
|
||||
return cnode;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
@ -0,0 +1,48 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_GELU_FUSION_H_
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_GELU_FUSION_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "backend/optimizer/common/optimizer.h"
|
||||
#include "utils/utils.h"
|
||||
#include "tools/optimizer/common/gllo_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class GeLUFusion : public PatternProcessPass {
|
||||
public:
|
||||
explicit GeLUFusion(const std::string &name = "gelu_fusion", bool multigraph = true)
|
||||
: PatternProcessPass(name, multigraph), input_(std::make_shared<Var>()) {}
|
||||
|
||||
~GeLUFusion() override = default;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
|
||||
|
||||
protected:
|
||||
virtual bool CheckPattern(const EquivPtr &equiv) const = 0;
|
||||
const float GetParameterValue(const EquivPtr &equiv, const VarPtr &input) const;
|
||||
VarPtr input_ = nullptr;
|
||||
|
||||
private:
|
||||
CNodePtr CreateGeLUNode(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &equiv) const;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_GELU_FUSION_H_
|
@ -0,0 +1,55 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/optimizer/fusion/onnx_gelu_fusion.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr float DIFF_THRESHOLD = 0.0001;
|
||||
constexpr float DIV_Y = 1.41421;
|
||||
constexpr float ADD_Y = 1.0;
|
||||
constexpr float MUL1_y = 0.5;
|
||||
} // namespace
|
||||
|
||||
// gelu(x) = 1/2 * x * [1 + erf(x / sqrt(2))]
|
||||
const BaseRef OnnxGeLUFusion::DefinePattern() const {
|
||||
VectorRef div_ref({std::make_shared<CondVar>(IsSpecifiedNode<&kPrimDivFusion>), input_, div_y_});
|
||||
VectorRef erf_ref({std::make_shared<CondVar>(IsSpecifiedNode<&kPrimErf>), div_ref});
|
||||
VectorRef add_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>), erf_ref, add_y_});
|
||||
VectorRef mul1_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), input_, mul1_y_});
|
||||
VectorRef mul2_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), mul1_ref, add_ref});
|
||||
return mul2_ref;
|
||||
}
|
||||
|
||||
bool OnnxGeLUFusion::CheckPattern(const EquivPtr &equiv) const {
|
||||
MS_ASSERT(equiv != nullptr);
|
||||
float div_y = GetParameterValue(equiv, div_y_);
|
||||
if (div_y < 0 || fabs(div_y - DIV_Y) > DIFF_THRESHOLD) {
|
||||
return false;
|
||||
}
|
||||
float add_y = GetParameterValue(equiv, add_y_);
|
||||
if (add_y < 0 || fabs(add_y - ADD_Y) > DIFF_THRESHOLD) {
|
||||
return false;
|
||||
}
|
||||
float mul1_y = GetParameterValue(equiv, mul1_y_);
|
||||
if (mul1_y < 0 || fabs(mul1_y - MUL1_y) > DIFF_THRESHOLD) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
@ -0,0 +1,49 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ONNX_GELU_FUSION_H_
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ONNX_GELU_FUSION_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "tools/optimizer/fusion/gelu_fusion.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class OnnxGeLUFusion : public GeLUFusion {
|
||||
public:
|
||||
explicit OnnxGeLUFusion(const std::string &name = "onnx_gelu_fusion", bool multigraph = true)
|
||||
: GeLUFusion(name, multigraph) {
|
||||
div_y_ = std::make_shared<Var>();
|
||||
add_y_ = std::make_shared<Var>();
|
||||
mul1_y_ = std::make_shared<Var>();
|
||||
}
|
||||
~OnnxGeLUFusion() override = default;
|
||||
|
||||
private:
|
||||
bool CheckPattern(const EquivPtr &equiv) const override;
|
||||
const BaseRef DefinePattern() const override;
|
||||
|
||||
private:
|
||||
VarPtr div_y_ = nullptr;
|
||||
VarPtr add_y_ = nullptr;
|
||||
VarPtr mul1_y_ = nullptr;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_ONNX_GELU_FUSION_H_
|
@ -0,0 +1,88 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tools/optimizer/fusion/tf_gelu_fusion.h"
|
||||
#include "ops/op_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace {
|
||||
constexpr float DIFF_THRESHOLD = 0.0001;
|
||||
constexpr float POW_Y = 3;
|
||||
constexpr float MUL1_Y = 0.044715;
|
||||
constexpr float MUL2_X = 0.79788;
|
||||
constexpr float ADD2_X = 1.0;
|
||||
constexpr float MUL3_X = 0.5;
|
||||
bool CheckTanh(const EquivPtr &equiv, const VarPtr &input) {
|
||||
MS_ASSERT(equiv != nullptr);
|
||||
MS_ASSERT(input != nullptr);
|
||||
auto anf_node = utils::cast<AnfNodePtr>((*equiv)[input]);
|
||||
MS_ASSERT(anf_node != nullptr);
|
||||
AnfNodePtr value_node = anf_node;
|
||||
if (anf_node->isa<CNode>()) {
|
||||
value_node = anf_node->cast<CNodePtr>()->input(0);
|
||||
}
|
||||
auto act_prim = GetValueNode<PrimitivePtr>(value_node);
|
||||
if (act_prim == nullptr) {
|
||||
return false;
|
||||
}
|
||||
return act_prim->GetAttr(ops::kActivationType) != nullptr &&
|
||||
GetValue<int64_t>(act_prim->GetAttr(ops::kActivationType)) == mindspore::TANH;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// gelu(x) = 1/2 * x * [1 + tanh(0.79788 * (x + 0.044715 * x ^ 3))]
|
||||
const BaseRef TfGeLUFusion::DefinePattern() const {
|
||||
VectorRef pow_ref({power_, input_, power_y_});
|
||||
VectorRef mul1_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), mul1_x_, pow_ref});
|
||||
VectorRef add1_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>), input_, mul1_ref});
|
||||
VectorRef mul2_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), mul2_x_, add1_ref});
|
||||
VectorRef tanh_ref({tanh_, mul2_ref});
|
||||
VectorRef add2_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimAddFusion>), add2_x_, tanh_ref});
|
||||
VectorRef mul3_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), mul3_x_, add2_ref});
|
||||
VectorRef mul4_ref({std::make_shared<CondVar>(IsSpecifiedNode<&prim::kPrimMulFusion>), input_, mul3_ref});
|
||||
return mul4_ref;
|
||||
}
|
||||
|
||||
bool TfGeLUFusion::CheckPattern(const EquivPtr &equiv) const {
|
||||
MS_ASSERT(equiv != nullptr);
|
||||
if (!CheckTanh(equiv, tanh_)) {
|
||||
return false;
|
||||
}
|
||||
float pow_y = GetParameterValue(equiv, power_y_);
|
||||
if (pow_y < 0 || fabs(pow_y - POW_Y) > DIFF_THRESHOLD) {
|
||||
return false;
|
||||
}
|
||||
float mul1_y = GetParameterValue(equiv, mul1_x_);
|
||||
if (mul1_y < 0 || fabs(mul1_y - MUL1_Y) > DIFF_THRESHOLD) {
|
||||
return false;
|
||||
}
|
||||
float mul2_x = GetParameterValue(equiv, mul2_x_);
|
||||
if (mul2_x < 0 || fabs(mul2_x - MUL2_X) > DIFF_THRESHOLD) {
|
||||
return false;
|
||||
}
|
||||
float add2_x = GetParameterValue(equiv, add2_x_);
|
||||
if (add2_x < 0 || fabs(add2_x - ADD2_X) > DIFF_THRESHOLD) {
|
||||
return false;
|
||||
}
|
||||
float mul3_x = GetParameterValue(equiv, mul3_x_);
|
||||
if (mul3_x < 0 || fabs(mul3_x - MUL3_X) > DIFF_THRESHOLD) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
@ -0,0 +1,57 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_GELU_FUSION_H_
|
||||
#define MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_GELU_FUSION_H_
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "tools/optimizer/fusion/gelu_fusion.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace opt {
|
||||
class TfGeLUFusion : public GeLUFusion {
|
||||
public:
|
||||
explicit TfGeLUFusion(const std::string &name = "tf_gelu_fusion", bool multigraph = true)
|
||||
: GeLUFusion(name, multigraph) {
|
||||
power_ = std::make_shared<Var>();
|
||||
power_y_ = std::make_shared<Var>();
|
||||
mul1_x_ = std::make_shared<Var>();
|
||||
mul2_x_ = std::make_shared<Var>();
|
||||
tanh_ = std::make_shared<Var>();
|
||||
add2_x_ = std::make_shared<Var>();
|
||||
mul3_x_ = std::make_shared<Var>();
|
||||
}
|
||||
~TfGeLUFusion() override = default;
|
||||
|
||||
private:
|
||||
bool CheckPattern(const EquivPtr &equiv) const override;
|
||||
const BaseRef DefinePattern() const override;
|
||||
|
||||
private:
|
||||
VarPtr power_ = nullptr;
|
||||
VarPtr power_y_ = nullptr;
|
||||
VarPtr mul1_x_ = nullptr;
|
||||
VarPtr mul2_x_ = nullptr;
|
||||
VarPtr tanh_ = nullptr;
|
||||
VarPtr add2_x_ = nullptr;
|
||||
VarPtr mul3_x_ = nullptr;
|
||||
};
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_LITE_TOOLS_OPTIMIZER_FUSION_TF_GELU_FUSION_H_
|
Loading…
Reference in new issue