commit
6ce25c99a0
@ -0,0 +1,146 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file eint8_outcept in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either eint8_outpress or
|
||||
// implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/ir/cpu_quantize_squash_pass.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
#include "paddle/fluid/string/pretty_log.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
using string::PrettyLogDetail;
|
||||
|
||||
void CPUQuantizeSquashPass::FindNodesToKeep(
|
||||
Graph* graph,
|
||||
std::unordered_map<const Node*, int>* nodes_keep_counter) const {
|
||||
GraphPatternDetector gpd;
|
||||
patterns::DequantAny deq_any_pattern{gpd.mutable_pattern(), "deqant_any"};
|
||||
deq_any_pattern();
|
||||
|
||||
int found_count = 0;
|
||||
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
||||
Graph* g) {
|
||||
GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, deq_any_pattern);
|
||||
|
||||
if (nodes_keep_counter->find(dequant_out) == nodes_keep_counter->end())
|
||||
(*nodes_keep_counter)[dequant_out] = 1;
|
||||
else
|
||||
(*nodes_keep_counter)[dequant_out] += 1;
|
||||
|
||||
found_count++;
|
||||
};
|
||||
gpd(graph, handler);
|
||||
AddStatis(found_count);
|
||||
}
|
||||
|
||||
void CPUQuantizeSquashPass::Squash(
|
||||
Graph* graph,
|
||||
std::unordered_map<const Node*, int>* nodes_keep_counter) const {
|
||||
GraphPatternDetector gpd;
|
||||
patterns::DequantQuantAny squash_pattern{gpd.mutable_pattern(), "squash"};
|
||||
squash_pattern();
|
||||
|
||||
int found_squash_count = 0;
|
||||
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
|
||||
Graph* g) {
|
||||
VLOG(4) << "squash requantize-quantize ops pair";
|
||||
|
||||
GET_IR_NODE_FROM_SUBGRAPH(dequant_in, dequant_in, squash_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(dequant_op, dequant_op, squash_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(dequant_out, dequant_out, squash_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(quant_op, quant_op, squash_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(quant_out, quant_out, squash_pattern);
|
||||
GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, squash_pattern);
|
||||
|
||||
auto* next_op_desc = next_op->Op();
|
||||
float dequant_scale = boost::get<float>(dequant_op->Op()->GetAttr("Scale"));
|
||||
float quant_scale = boost::get<float>(quant_op->Op()->GetAttr("Scale"));
|
||||
PADDLE_ENFORCE(nodes_keep_counter->find(dequant_out) !=
|
||||
nodes_keep_counter->end());
|
||||
|
||||
// check if dequantize op should be kept or removed, decrease the counter
|
||||
bool keep_dequant = (*nodes_keep_counter)[dequant_out]-- > 1;
|
||||
|
||||
if (dequant_scale == quant_scale) {
|
||||
// squash dequantize-quantize to nothing
|
||||
auto quant_out_var_name = quant_out->Name();
|
||||
auto next_op_inputs = next_op_desc->InputNames();
|
||||
for (const auto& name : next_op_inputs) {
|
||||
auto var_name = next_op_desc->Input(name)[0];
|
||||
if (var_name.compare(quant_out_var_name) == 0) {
|
||||
next_op_desc->SetInput(
|
||||
name, std::vector<std::string>({dequant_in->Name()}));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (keep_dequant)
|
||||
GraphSafeRemoveNodes(graph, {quant_op, quant_out});
|
||||
else
|
||||
GraphSafeRemoveNodes(graph,
|
||||
{dequant_op, quant_op, dequant_out, quant_out});
|
||||
|
||||
IR_NODE_LINK_TO(dequant_in, next_op);
|
||||
|
||||
found_squash_count++;
|
||||
} else {
|
||||
// squash dequantize-quantize to requantize op
|
||||
OpDesc desc;
|
||||
desc.SetType("requantize");
|
||||
desc.SetInput("Input", std::vector<std::string>({dequant_in->Name()}));
|
||||
desc.SetOutput("Output", std::vector<std::string>({quant_out->Name()}));
|
||||
desc.SetAttr("Scale_in", dequant_scale);
|
||||
desc.SetAttr("Scale_out", quant_scale);
|
||||
|
||||
auto requant_op = g->CreateOpNode(&desc);
|
||||
|
||||
if (keep_dequant)
|
||||
GraphSafeRemoveNodes(graph, {quant_op});
|
||||
else
|
||||
GraphSafeRemoveNodes(graph, {dequant_op, quant_op, dequant_out});
|
||||
|
||||
IR_NODE_LINK_TO(dequant_in, requant_op);
|
||||
IR_NODE_LINK_TO(requant_op, quant_out);
|
||||
|
||||
found_squash_count++;
|
||||
}
|
||||
};
|
||||
gpd(graph, handler);
|
||||
AddStatis(found_squash_count);
|
||||
PrettyLogDetail("--- squashed %d dequantize-quantize pairs",
|
||||
found_squash_count);
|
||||
}
|
||||
|
||||
std::unique_ptr<ir::Graph> CPUQuantizeSquashPass::ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const {
|
||||
PADDLE_ENFORCE(graph.get());
|
||||
FusePassBase::Init("cpu_quantize_squash_pass", graph.get());
|
||||
|
||||
std::unordered_map<const Node*, int> nodes_keep_counter;
|
||||
FindNodesToKeep(graph.get(), &nodes_keep_counter);
|
||||
Squash(graph.get(), &nodes_keep_counter);
|
||||
|
||||
return graph;
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(cpu_quantize_squash_pass,
|
||||
paddle::framework::ir::CPUQuantizeSquashPass);
|
@ -0,0 +1,58 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
/*
|
||||
* Squash dequantize->quantize pair pattern into requantize op
|
||||
*/
|
||||
class CPUQuantizeSquashPass : public FusePassBase {
|
||||
public:
|
||||
virtual ~CPUQuantizeSquashPass() {}
|
||||
|
||||
protected:
|
||||
std::unique_ptr<ir::Graph> ApplyImpl(
|
||||
std::unique_ptr<ir::Graph> graph) const override;
|
||||
|
||||
/*
|
||||
* For each dequantize's output find the number of operators it is an input to
|
||||
*/
|
||||
void FindNodesToKeep(
|
||||
Graph* graph,
|
||||
std::unordered_map<const Node*, int>* nodes_keep_counter) const;
|
||||
|
||||
/*
|
||||
* Squash dequantize-quantize ops pairs into requantize or nothing
|
||||
*/
|
||||
void Squash(Graph* graph,
|
||||
std::unordered_map<const Node*, int>* nodes_keep_counter) const;
|
||||
|
||||
const std::string name_scope_{"squash"};
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,179 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/ir/cpu_quantize_squash_pass.h"
|
||||
#include <gtest/gtest.h>
|
||||
#include "paddle/fluid/framework/naive_executor.h"
|
||||
#include "paddle/fluid/platform/place.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
|
||||
const std::vector<std::string>& inputs,
|
||||
const std::vector<std::string>& outputs, bool use_mkldnn,
|
||||
float scale = 0) {
|
||||
auto* op = prog->MutableBlock(0)->AppendOp();
|
||||
op->SetType(type);
|
||||
op->SetAttr("use_mkldnn", use_mkldnn);
|
||||
op->SetAttr("name", name);
|
||||
if (type == "conv2d") {
|
||||
op->SetInput("Input", {inputs[0]});
|
||||
if (inputs.size() > 1) op->SetInput("Filter", {inputs[1]});
|
||||
if (inputs.size() > 2) op->SetInput("Bias", {inputs[2]});
|
||||
op->SetOutput("Output", {outputs[0]});
|
||||
} else if (type == "quantize") {
|
||||
op->SetInput("Input", {inputs[0]});
|
||||
op->SetOutput("Output", {outputs[0]});
|
||||
op->SetAttr("Scale", scale);
|
||||
} else if (type == "dequantize") {
|
||||
op->SetInput("Input", {inputs[0]});
|
||||
op->SetOutput("Output", {outputs[0]});
|
||||
op->SetAttr("Scale", scale);
|
||||
}
|
||||
}
|
||||
|
||||
// (a,w1,b1)->Conv1->d
|
||||
// d->Dequant->e
|
||||
// e->Quant->f
|
||||
// (f,w2,b2)->Conv2->i
|
||||
ProgramDesc BuildProgramDesc(bool use_mkldnn, float scale1, float scale2) {
|
||||
ProgramDesc prog;
|
||||
for (auto& v : std::initializer_list<std::string>(
|
||||
{"a", "w1", "b1", "d", "e", "f", "w2", "b2", "i"})) {
|
||||
auto* var = prog.MutableBlock(0)->Var(v);
|
||||
if (v.find("w") == 0 || v.find("b") == 0) {
|
||||
var->SetPersistable(true);
|
||||
}
|
||||
}
|
||||
|
||||
SetOp(&prog, "conv2d", "Conv1", {"a", "w1", "b1"}, {"d"}, use_mkldnn);
|
||||
SetOp(&prog, "dequantize", "Dequant", {"d"}, {"e"}, use_mkldnn, scale1);
|
||||
SetOp(&prog, "quantize", "Quant", {"e"}, {"f"}, use_mkldnn, scale2);
|
||||
SetOp(&prog, "conv2d", "Conv2", {"f", "w2", "b2"}, {"i"}, use_mkldnn);
|
||||
return prog;
|
||||
}
|
||||
|
||||
static const std::initializer_list<std::string> variable_names{
|
||||
"a", "b", "c", "d", "e", "f", "g", "h"};
|
||||
// a->Conv1->b
|
||||
// b->Dequant->c
|
||||
//
|
||||
// c->Quant1->d and d->Conv2->e
|
||||
//
|
||||
// c->Conv3->f
|
||||
//
|
||||
// c->Quant2->g and g->Conv4->h
|
||||
//
|
||||
ProgramDesc BuildProgramDesc2(bool use_mkldnn, float scale1, float scale2,
|
||||
float scale3) {
|
||||
ProgramDesc prog;
|
||||
for (auto& v : variable_names) {
|
||||
prog.MutableBlock(0)->Var(v);
|
||||
}
|
||||
|
||||
SetOp(&prog, "conv2d", "Conv1", {"a"}, {"b"}, use_mkldnn);
|
||||
SetOp(&prog, "dequantize", "Dequant", {"b"}, {"c"}, use_mkldnn, scale1);
|
||||
|
||||
SetOp(&prog, "quantize", "Quant1", {"c"}, {"d"}, use_mkldnn, scale2);
|
||||
SetOp(&prog, "conv2d", "Conv2", {"d"}, {"e"}, use_mkldnn);
|
||||
|
||||
SetOp(&prog, "conv2d", "Conv3", {"c"}, {"f"}, use_mkldnn);
|
||||
|
||||
SetOp(&prog, "quantize", "Quant2", {"c"}, {"g"}, use_mkldnn, scale3);
|
||||
SetOp(&prog, "conv2d", "Conv4", {"g"}, {"h"}, use_mkldnn);
|
||||
|
||||
return prog;
|
||||
}
|
||||
|
||||
void InitTensorHolder(Scope* scope, const paddle::platform::Place& place,
|
||||
const char* var_name) {
|
||||
auto x = scope->Var(var_name);
|
||||
auto tensor = x->GetMutable<LoDTensor>();
|
||||
tensor->mutable_data(place, proto::VarType::FP32,
|
||||
::paddle::memory::Allocator::kDefault, 1);
|
||||
}
|
||||
|
||||
void MainTest(const ProgramDesc& prog, int removed_nodes_num) {
|
||||
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
|
||||
|
||||
// Init scope, as it is used in pass
|
||||
auto place = paddle::platform::CPUPlace();
|
||||
NaiveExecutor exe{place};
|
||||
Scope scope;
|
||||
exe.CreateVariables(prog, 0, true, &scope);
|
||||
|
||||
for (auto& v : variable_names) {
|
||||
InitTensorHolder(&scope, place, v.c_str());
|
||||
}
|
||||
|
||||
graph->Set(kParamScopeAttr, new framework::Scope*(&scope));
|
||||
|
||||
auto pass = PassRegistry::Instance().Get("cpu_quantize_squash_pass");
|
||||
|
||||
int original_nodes_num = graph->Nodes().size();
|
||||
|
||||
graph = pass->Apply(std::move(graph));
|
||||
|
||||
int current_nodes_num = graph->Nodes().size();
|
||||
|
||||
EXPECT_EQ(original_nodes_num - removed_nodes_num, current_nodes_num);
|
||||
}
|
||||
|
||||
TEST(CpuQuantizeSquashPass, equal_scales) {
|
||||
auto scale = 1.2345f;
|
||||
auto use_mkldnn = true;
|
||||
// Remove 4 nodes: Dequant, Quant, e, f
|
||||
auto remove_nodes = 4;
|
||||
MainTest(BuildProgramDesc(use_mkldnn, scale, scale), remove_nodes);
|
||||
|
||||
use_mkldnn = !use_mkldnn;
|
||||
MainTest(BuildProgramDesc(use_mkldnn, scale, scale), remove_nodes);
|
||||
}
|
||||
|
||||
TEST(CpuQuantizeSquashPass, inequal_scales) {
|
||||
auto scale1 = 1.2345f;
|
||||
auto scale2 = 21.0f;
|
||||
auto use_mkldnn = true;
|
||||
// Remove 3 nodes: Dequant, Quant, e
|
||||
// Insert 1 node: requantize
|
||||
auto remove_nodes = 2;
|
||||
MainTest(BuildProgramDesc(use_mkldnn, scale1, scale2), remove_nodes);
|
||||
|
||||
use_mkldnn = !use_mkldnn;
|
||||
MainTest(BuildProgramDesc(use_mkldnn, scale1, scale2), remove_nodes);
|
||||
}
|
||||
|
||||
TEST(CpuQuantizeSquashPass, branch_to_equal_inequal_and_fp32) {
|
||||
// Delete both quantize ops,
|
||||
// bypass dequantize in both branches,
|
||||
// insert requantize on one branch
|
||||
auto scale = 1.2345f;
|
||||
auto scale2 = 21.0f;
|
||||
auto use_mkldnn = true;
|
||||
// Remove 3 nodes: Quant1, Quant2, g
|
||||
// Insert 1 node: requantize
|
||||
auto remove_nodes = 2;
|
||||
MainTest(BuildProgramDesc2(use_mkldnn, scale, scale, scale2), remove_nodes);
|
||||
|
||||
use_mkldnn = !use_mkldnn;
|
||||
MainTest(BuildProgramDesc2(use_mkldnn, scale, scale, scale2), remove_nodes);
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
USE_PASS(cpu_quantize_squash_pass);
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,42 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/platform/float16.h"
|
||||
#include "paddle/fluid/platform/hostdevice.h"
|
||||
|
||||
#include "math.h" // NOLINT
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
inline HOSTDEVICE platform::float16 real_exp(platform::float16 x) {
|
||||
return static_cast<platform::float16>(::expf(static_cast<float>(x)));
|
||||
}
|
||||
|
||||
inline HOSTDEVICE float real_exp(float x) { return ::expf(x); }
|
||||
|
||||
inline HOSTDEVICE double real_exp(double x) { return ::exp(x); }
|
||||
|
||||
inline HOSTDEVICE platform::float16 real_log(platform::float16 x) {
|
||||
return static_cast<platform::float16>(::logf(static_cast<float>(x)));
|
||||
}
|
||||
|
||||
inline HOSTDEVICE float real_log(float x) { return ::logf(x); }
|
||||
|
||||
inline HOSTDEVICE double real_log(double x) { return ::log(x); }
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,82 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from op_test import OpTest
|
||||
import unittest
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
|
||||
class CrossEntropy2OpTestBase(OpTest):
|
||||
def initParameters(self):
|
||||
return [32, 64], 'float32', -100
|
||||
|
||||
def calc_output(self, logits, label, ignore_index):
|
||||
ret = np.zeros(shape=label.shape, dtype=logits.dtype)
|
||||
match_x = np.zeros(shape=label.shape, dtype=logits.dtype)
|
||||
for idx in six.moves.range(label.shape[0]):
|
||||
if label[idx] == ignore_index:
|
||||
continue
|
||||
match_x[idx] = logits[idx][label[idx]]
|
||||
ret[idx] = -np.log(match_x[idx])
|
||||
return ret, match_x
|
||||
|
||||
def setUp(self):
|
||||
self.shape, self.dtype, self.ignore_index = self.initParameters()
|
||||
self.op_type = 'cross_entropy2'
|
||||
feature_size = int(self.shape[-1])
|
||||
batch_size = int(np.prod(self.shape) / feature_size)
|
||||
logits = (np.random.random(size=self.shape) + 1).astype(self.dtype)
|
||||
label = np.random.random_integers(
|
||||
low=0, high=feature_size - 1,
|
||||
size=self.shape[0:-1] + [1]).astype('int64')
|
||||
outputs, match_x = self.calc_output(
|
||||
np.reshape(logits, [batch_size, feature_size]),
|
||||
np.reshape(label, [batch_size, 1]), self.ignore_index)
|
||||
self.inputs = {'X': logits, 'Label': label}
|
||||
self.outputs = {
|
||||
'Y': np.reshape(outputs, label.shape),
|
||||
'MatchX': np.reshape(match_x, label.shape),
|
||||
'XShape': np.zeros(
|
||||
shape=logits.shape, dtype=logits.dtype)
|
||||
}
|
||||
self.attrs = {'ignore_index': self.ignore_index}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output(no_check_set=['XShape'])
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(
|
||||
inputs_to_check=['X'],
|
||||
output_names=['Y'],
|
||||
no_grad_set=['XShape', 'MatchX', 'Label'])
|
||||
|
||||
|
||||
class CrossEntropy2OpTest2(CrossEntropy2OpTestBase):
|
||||
def initParameters(self):
|
||||
return [32, 64], 'float64', 3
|
||||
|
||||
|
||||
class CrossEntropy2OpTest3(CrossEntropy2OpTestBase):
|
||||
def initParameters(self):
|
||||
return [4, 8, 16, 32], 'float32', -100
|
||||
|
||||
|
||||
class CrossEntropy2OpTest4(CrossEntropy2OpTestBase):
|
||||
def initParameters(self):
|
||||
return [4, 8, 16, 32], 'float32', 3
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue