cherry-pick from feature/anakin-engine: deal the changing shape when using anakin #16189
parent
c79f06d3d8
commit
a25331bc26
@ -0,0 +1,85 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "paddle/fluid/framework/ir/anakin_fillconstant_elementwisemul_fuse.h"
|
||||
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
|
||||
#define GET_NODES \
|
||||
GET_IR_NODE(fill_constant); \
|
||||
GET_IR_NODE(fill_constant_out); \
|
||||
GET_IR_NODE(elementwise_mul); \
|
||||
GET_IR_NODE(elementwise_mul_out);
|
||||
|
||||
std::unique_ptr<ir::Graph> AnakinFillconstantElementwisemulFuse::ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const {
|
||||
const std::string pattern_name = "anakin_fillconstant_elementwisemul_fuse";
|
||||
FusePassBase::Init(pattern_name, graph.get());
|
||||
|
||||
GraphPatternDetector gpd;
|
||||
auto* x = gpd.mutable_pattern()
|
||||
->NewNode("x")
|
||||
->assert_is_op_input("elementwise_mul", "X")
|
||||
->AsInput();
|
||||
|
||||
patterns::AnakinFillConstantElementWiseMulFuse pattern(gpd.mutable_pattern(),
|
||||
pattern_name);
|
||||
pattern(x);
|
||||
|
||||
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
||||
Graph* g) {
|
||||
GET_NODES;
|
||||
|
||||
PADDLE_ENFORCE(subgraph.count(x));
|
||||
auto* elementwise_in = subgraph.at(x);
|
||||
float constant_value =
|
||||
boost::get<float>(fill_constant->Op()->GetAttr("value"));
|
||||
|
||||
framework::OpDesc new_op_desc;
|
||||
new_op_desc.SetType("scale");
|
||||
new_op_desc.SetInput("X", {elementwise_in->Name()});
|
||||
new_op_desc.SetAttr("scale", constant_value);
|
||||
new_op_desc.SetAttr("bias", static_cast<float>(0.0));
|
||||
new_op_desc.SetAttr("bias_after_scale", true);
|
||||
new_op_desc.SetOutput("Out", {elementwise_mul_out->Name()});
|
||||
new_op_desc.Flush();
|
||||
|
||||
// Create a new node for the fused op.
|
||||
auto* scale_op = graph->CreateOpNode(&new_op_desc);
|
||||
|
||||
IR_NODE_LINK_TO(elementwise_in, scale_op); // Input
|
||||
IR_NODE_LINK_TO(scale_op, elementwise_mul_out); // Output
|
||||
|
||||
// Delete the unneeded nodes.
|
||||
GraphSafeRemoveNodes(graph.get(),
|
||||
{fill_constant, fill_constant_out, elementwise_mul});
|
||||
};
|
||||
|
||||
gpd(graph.get(), handler);
|
||||
return graph;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(anakin_fillconstant_elementwisemul_fuse,
|
||||
paddle::framework::ir::AnakinFillconstantElementwisemulFuse);
|
@ -0,0 +1,35 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include <memory>
|
||||
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
class AnakinFillconstantElementwisemulFuse : public FusePassBase {
|
||||
public:
|
||||
virtual ~AnakinFillconstantElementwisemulFuse() {}
|
||||
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const override;
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,56 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/inference/anakin/convert/scale.h"
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
|
||||
using anakin::graph::GraphGlobalMem;
|
||||
using anakin::AK_FLOAT;
|
||||
using anakin::saber::NV;
|
||||
using anakin::saber::Shape;
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace anakin {
|
||||
|
||||
void ScaleOpConverter::operator()(const framework::proto::OpDesc &op,
|
||||
const framework::Scope &scope,
|
||||
bool test_mode) {
|
||||
framework::OpDesc op_desc(op, nullptr);
|
||||
PADDLE_ENFORCE_EQ(op_desc.Input("X").size(), 1);
|
||||
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1);
|
||||
|
||||
auto op_name = op_desc.Type() + ":" + op_desc.Output("Out").front();
|
||||
|
||||
auto input_name = op_desc.Input("X").front();
|
||||
auto output_name = op_desc.Output("Out").front();
|
||||
float scale = boost::get<float>(op_desc.GetAttr("scale"));
|
||||
float bias = boost::get<float>(op_desc.GetAttr("bias"));
|
||||
float bias_after_scale =
|
||||
boost::get<bool>(op_desc.GetAttr("bias_after_scale"));
|
||||
PADDLE_ENFORCE(bias_after_scale,
|
||||
"The anakin scale layer only support bias after scale now.");
|
||||
|
||||
engine_->AddOp(op_name, "Power", {input_name}, {output_name});
|
||||
engine_->AddOpAttr(op_name, "shift", bias);
|
||||
engine_->AddOpAttr(op_name, "scale", scale);
|
||||
engine_->AddOpAttr(op_name, "power", static_cast<float>(1.0));
|
||||
}
|
||||
|
||||
} // namespace anakin
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_ANAKIN_OP_CONVERTER(scale, ScaleOpConverter);
|
@ -0,0 +1,37 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace inference {
|
||||
namespace anakin {
|
||||
|
||||
class ScaleOpConverter : public AnakinOpConverter {
|
||||
public:
|
||||
ScaleOpConverter() = default;
|
||||
|
||||
virtual void operator()(const framework::proto::OpDesc &op,
|
||||
const framework::Scope &scope,
|
||||
bool test_mode) override;
|
||||
virtual ~ScaleOpConverter() {}
|
||||
};
|
||||
|
||||
} // namespace anakin
|
||||
} // namespace inference
|
||||
} // namespace paddle
|
Loading…
Reference in new issue