[slim] Add quantization strategy and distillation strategy. (#16408)
* Add fsp operator. 1 Add unitest. 2. Add python API. 3. Add layer test. * Add quantization strategy. 1. Add API. 2. Add unitest. * Add distillatoin strategy. * Add unitest config file for quantization * Fix Copyright test=develop * Fix setup.py * Fix document of layers.py. test=develop * Fix unitest in python3. test=develop * Fix documents. test=develop * 1. refine fsp op by batched gemm 2. remove unused import test=develop * Fix test_dist_se_resnext. 1. disable test distillation. 2. reset framework.py test=develop * Enable unitest of distillation after fixing Block._clone_variable test=develop * Fix cdn issue. test=developmove-code
parent
de3b70a101
commit
e9bec9369b
@ -0,0 +1,128 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/fsp_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class FSPOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of FSPOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) of FSPOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"Output(Out) of FSPOp should not be null.");
|
||||
|
||||
auto x_dims = ctx->GetInputDim("X");
|
||||
auto y_dims = ctx->GetInputDim("Y");
|
||||
|
||||
PADDLE_ENFORCE(
|
||||
x_dims.size() == 4,
|
||||
"The Input(X) must have shape [batch_size, channel, height, width].");
|
||||
PADDLE_ENFORCE(
|
||||
y_dims.size() == 4,
|
||||
"The Input(Y) must have shape [batch_size, channel, height, width].");
|
||||
PADDLE_ENFORCE(
|
||||
(x_dims[2] == y_dims[2]) && (x_dims[3] == y_dims[3]),
|
||||
"The Input(X) and Input(Y) should have the same height and width.");
|
||||
|
||||
ctx->SetOutputDim("Out", {x_dims[0], x_dims[1], y_dims[1]});
|
||||
ctx->ShareLoD("X", "Out");
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
framework::LibraryType library_{framework::LibraryType::kPlain};
|
||||
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
|
||||
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
|
||||
ctx.device_context(), layout_, library_);
|
||||
}
|
||||
};
|
||||
|
||||
class FSPOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X",
|
||||
"(Tensor) The input of FSP op with shape [batch_size, x_channel, "
|
||||
"height, width]");
|
||||
AddInput("Y",
|
||||
"(Tensor) The input of FSP op with shape"
|
||||
"[batch_size, y_channel, height, width]."
|
||||
"The y_channel can be different with the x_channel of Input(X)"
|
||||
" while the other dimensions must be the same with Input(X)'s.");
|
||||
AddOutput(
|
||||
"Out",
|
||||
"(Tensor) The output of FSP op with shape "
|
||||
"[batch_size, x_channel, y_channel]. The x_channel is the channel "
|
||||
"of Input(X) and the y_channel is the channel of Input(Y).");
|
||||
AddComment(R"DOC(
|
||||
This op is used to calculate the flow of solution procedure (FSP) matrix of two feature maps.
|
||||
Given feature map x with shape [x_channel, h, w] and feature map y with shape
|
||||
[y_channel, h, w], we can get the fsp matrix of x and y in two steps:
|
||||
|
||||
step 1: reshape x into matrix with shape [x_channel, h * w] and reshape and
|
||||
transpose y into matrix with shape [h * w, y_channel]
|
||||
step 2: multiply x and y to get fsp matrix with shape [x_channel, y_channel]
|
||||
|
||||
The output is a batch of fsp matrices.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class FSPOpGrad : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
|
||||
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
||||
"Input(Out@GRAD) should not be null");
|
||||
auto x_dims = ctx->GetInputDim("X");
|
||||
auto y_dims = ctx->GetInputDim("Y");
|
||||
auto x_grad_name = framework::GradVarName("X");
|
||||
auto y_grad_name = framework::GradVarName("Y");
|
||||
if (ctx->HasOutput(x_grad_name)) {
|
||||
ctx->SetOutputDim(x_grad_name, x_dims);
|
||||
}
|
||||
if (ctx->HasOutput(y_grad_name)) {
|
||||
ctx->SetOutputDim(y_grad_name, y_dims);
|
||||
}
|
||||
}
|
||||
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(fsp, ops::FSPOp, ops::FSPOpMaker,
|
||||
paddle::framework::DefaultGradOpDescMaker<true>);
|
||||
REGISTER_OPERATOR(fsp_grad, ops::FSPOpGrad);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
fsp, ops::FSPOpKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::FSPOpKernel<paddle::platform::CPUDeviceContext, double>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
fsp_grad, ops::FSPGradOpKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::FSPGradOpKernel<paddle::platform::CPUDeviceContext, double>);
|
@ -0,0 +1,24 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/fsp_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
namespace plat = paddle::platform;
|
||||
REGISTER_OP_CUDA_KERNEL(fsp, ops::FSPOpKernel<plat::CUDADeviceContext, float>,
|
||||
ops::FSPOpKernel<plat::CUDADeviceContext, double>);
|
||||
REGISTER_OP_CUDA_KERNEL(fsp_grad,
|
||||
ops::FSPGradOpKernel<plat::CUDADeviceContext, float>,
|
||||
ops::FSPGradOpKernel<plat::CUDADeviceContext, double>);
|
@ -0,0 +1,136 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/math/blas.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class FSPOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* x = context.Input<Tensor>("X");
|
||||
auto* y = context.Input<Tensor>("Y");
|
||||
auto* output = context.Output<Tensor>("Out");
|
||||
output->mutable_data<T>(context.GetPlace());
|
||||
auto x_dims = x->dims();
|
||||
auto y_dims = y->dims();
|
||||
|
||||
auto batch_size = x_dims[0];
|
||||
auto x_channel = x_dims[1];
|
||||
auto y_channel = y_dims[1];
|
||||
auto height = x_dims[2];
|
||||
auto width = x_dims[3];
|
||||
|
||||
auto blas = math::GetBlas<DeviceContext, T>(context);
|
||||
|
||||
math::MatDescriptor x_mat_desc;
|
||||
x_mat_desc.height_ = x_channel;
|
||||
x_mat_desc.width_ = height * width;
|
||||
x_mat_desc.batch_size_ = batch_size;
|
||||
x_mat_desc.stride_ = x_channel * height * width;
|
||||
|
||||
math::MatDescriptor y_mat_desc;
|
||||
y_mat_desc.height_ = height * width;
|
||||
y_mat_desc.width_ = y_channel;
|
||||
y_mat_desc.batch_size_ = batch_size;
|
||||
y_mat_desc.stride_ = y_channel * height * width;
|
||||
y_mat_desc.trans_ = true;
|
||||
|
||||
blas.MatMul(*x, x_mat_desc, *y, y_mat_desc,
|
||||
static_cast<T>(1.0 / (height * width)), output,
|
||||
static_cast<T>(0.0));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class FSPGradOpKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* d_x = context.Output<Tensor>(framework::GradVarName("X"));
|
||||
auto* d_y = context.Output<Tensor>(framework::GradVarName("Y"));
|
||||
if (d_x == nullptr && d_y == nullptr) {
|
||||
return;
|
||||
}
|
||||
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
|
||||
auto d_out_dims = d_out->dims();
|
||||
auto batch_size = d_out_dims[0];
|
||||
auto x_channel = d_out_dims[1];
|
||||
auto y_channel = d_out_dims[2];
|
||||
int64_t h = 0;
|
||||
int64_t w = 0;
|
||||
|
||||
auto blas = math::GetBlas<DeviceContext, T>(context);
|
||||
math::SetConstant<DeviceContext, T> set_zero;
|
||||
if (d_x != nullptr) {
|
||||
d_x->mutable_data<T>(context.GetPlace());
|
||||
set_zero(context.template device_context<DeviceContext>(), d_x,
|
||||
static_cast<T>(0));
|
||||
auto* y = context.Input<Tensor>("Y");
|
||||
auto y_dims = y->dims();
|
||||
h = y_dims[2];
|
||||
w = y_dims[3];
|
||||
|
||||
math::MatDescriptor d_out_mat_desc;
|
||||
d_out_mat_desc.height_ = x_channel;
|
||||
d_out_mat_desc.width_ = y_channel;
|
||||
d_out_mat_desc.batch_size_ = batch_size;
|
||||
d_out_mat_desc.stride_ = x_channel * y_channel;
|
||||
|
||||
math::MatDescriptor y_mat_desc;
|
||||
y_mat_desc.height_ = y_channel;
|
||||
y_mat_desc.width_ = h * w;
|
||||
y_mat_desc.batch_size_ = batch_size;
|
||||
y_mat_desc.stride_ = y_channel * h * w;
|
||||
|
||||
blas.MatMul(*d_out, d_out_mat_desc, *y, y_mat_desc,
|
||||
static_cast<T>(1.0 / (h * w)), d_x, static_cast<T>(0.0));
|
||||
}
|
||||
|
||||
if (d_y != nullptr) {
|
||||
d_y->mutable_data<T>(context.GetPlace());
|
||||
set_zero(context.template device_context<DeviceContext>(), d_y,
|
||||
static_cast<T>(0));
|
||||
auto* x = context.Input<Tensor>("X");
|
||||
auto x_dims = x->dims();
|
||||
h = x_dims[2];
|
||||
w = x_dims[3];
|
||||
|
||||
math::MatDescriptor d_out_mat_desc;
|
||||
d_out_mat_desc.height_ = y_channel;
|
||||
d_out_mat_desc.width_ = x_channel;
|
||||
d_out_mat_desc.batch_size_ = batch_size;
|
||||
d_out_mat_desc.stride_ = x_channel * y_channel;
|
||||
d_out_mat_desc.trans_ = true;
|
||||
|
||||
math::MatDescriptor x_mat_desc;
|
||||
x_mat_desc.height_ = x_channel;
|
||||
x_mat_desc.width_ = h * w;
|
||||
x_mat_desc.batch_size_ = batch_size;
|
||||
x_mat_desc.stride_ = x_channel * h * w;
|
||||
|
||||
blas.MatMul(*d_out, d_out_mat_desc, *x, x_mat_desc,
|
||||
static_cast<T>(1.0 / (h * w)), d_y, static_cast<T>(0.0));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,94 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..core.strategy import Strategy
|
||||
from ....framework import Program, program_guard
|
||||
from .... import Executor
|
||||
import logging
|
||||
|
||||
__all__ = ['DistillationStrategy']
|
||||
|
||||
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
|
||||
_logger = logging.getLogger(__name__)
|
||||
_logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class DistillationStrategy(Strategy):
|
||||
def __init__(self, distillers=None, start_epoch=0, end_epoch=0):
|
||||
"""
|
||||
Args:
|
||||
distillers(list): A list of distiller used to combine student graph and teacher graph
|
||||
by adding some loss.
|
||||
start_epoch(int): The epoch when to merge student graph and teacher graph for
|
||||
distillation training. default: 0
|
||||
end_epoch(int): The epoch when to finish distillation training. default: 0
|
||||
|
||||
"""
|
||||
super(DistillationStrategy, self).__init__(start_epoch, end_epoch)
|
||||
self.distillers = distillers
|
||||
|
||||
def on_compression_begin(self, context):
|
||||
# load from checkpoint
|
||||
if context.epoch_id > 0:
|
||||
if context.epoch_id > self.start_epoch and context.epoch_id < self.end_epoch:
|
||||
_logger.info('Restore DistillationStrategy')
|
||||
self._create_distillation_graph(context)
|
||||
_logger.info('Restore DistillationStrategy finish.')
|
||||
|
||||
def on_epoch_begin(self, context):
|
||||
if self.start_epoch == context.epoch_id:
|
||||
_logger.info('DistillationStrategy::on_epoch_begin.')
|
||||
self._create_distillation_graph(context)
|
||||
_logger.info('DistillationStrategy set optimize_graph.')
|
||||
|
||||
def _create_distillation_graph(self, context):
|
||||
"""
|
||||
step 1: Merge student graph and teacher graph into distillation graph.
|
||||
step 2: Add loss into distillation graph by distillers.
|
||||
step 3: Append backward ops and optimize ops into distillation graph for training.
|
||||
"""
|
||||
# step 1
|
||||
teacher = context.teacher_graphs[0]
|
||||
for var in teacher.program.list_vars():
|
||||
var.stop_gradient = True
|
||||
graph = context.train_graph.clone()
|
||||
graph.merge(teacher)
|
||||
graph.out_nodes['student_loss'] = graph.out_nodes['loss']
|
||||
|
||||
# step 2
|
||||
for distiller in self.distillers:
|
||||
graph = distiller.distiller_loss(graph)
|
||||
|
||||
# step 3
|
||||
startup_program = Program()
|
||||
with program_guard(graph.program, startup_program):
|
||||
context.distiller_optimizer._name = 'distillation_optimizer'
|
||||
context.distiller_optimizer.minimize(
|
||||
graph.var(graph.out_nodes['loss'])._var)
|
||||
exe = Executor(context.place)
|
||||
exe.run(startup_program, scope=context.scope)
|
||||
|
||||
# backup graph for fine-tune after distillation
|
||||
context.put('distillation_backup_optimize_graph',
|
||||
context.optimize_graph)
|
||||
context.optimize_graph = graph
|
||||
|
||||
def on_epoch_end(self, context):
|
||||
if context.epoch_id == (self.end_epoch - 1):
|
||||
_logger.info('DistillationStrategy::on_epoch_end.')
|
||||
# restore optimize_graph for fine-tune or other strategy in next stage.
|
||||
context.optimize_graph = context.get(
|
||||
'distillation_backup_optimize_graph')
|
||||
_logger.info(
|
||||
'DistillationStrategy set context.optimize_graph to None.')
|
@ -0,0 +1,188 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .... import layers
|
||||
from .... import optimizer
|
||||
from .... import Executor
|
||||
from .... import Program
|
||||
from .... import program_guard
|
||||
from .... import regularizer
|
||||
|
||||
__all__ = ['FSPDistiller', 'L2Distiller']
|
||||
|
||||
|
||||
class L2Distiller(object):
|
||||
"""
|
||||
Combine two layers from student net and teacher net by l2-loss.
|
||||
And add the loss into the total loss using for distillation training.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
student_feature_map,
|
||||
teacher_feature_map,
|
||||
distillation_loss_weight=1):
|
||||
"""
|
||||
Args:
|
||||
student_feature_map(str): The name of feature map from student network.
|
||||
teacher_feature_map(str): The name of feature map from teacher network.
|
||||
It's shape should be the same with student network.
|
||||
distillation_loss_weight(float): The weight of the l2-loss.
|
||||
"""
|
||||
self.student_feature_map = student_feature_map
|
||||
self.teacher_feature_map = teacher_feature_map
|
||||
self.distillation_loss_weight = distillation_loss_weight
|
||||
|
||||
def distiller_loss(self, graph):
|
||||
"""
|
||||
Modify graph inplace to add l2-loss.
|
||||
Args:
|
||||
graph(GraphWrapper): The graph to be modified.
|
||||
Returns:
|
||||
GraphWrapper: The modified graph.
|
||||
"""
|
||||
distiller_pass = L2DistillerPass(self.student_feature_map,
|
||||
self.teacher_feature_map,
|
||||
self.distillation_loss_weight)
|
||||
dis_graph = distiller_pass.apply(graph)
|
||||
return dis_graph
|
||||
|
||||
|
||||
class L2DistillerPass(object):
|
||||
"""
|
||||
The pass used to add l2-loss.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
student_feature_map,
|
||||
teacher_feature_map,
|
||||
distillation_loss_weight=1):
|
||||
"""
|
||||
Args:
|
||||
student_feature_map(str): The name of feature map from student network.
|
||||
teacher_feature_map(str): The name of feature map from teacher network.
|
||||
It's shape should be the same with student network.
|
||||
distillation_loss_weight(float): The weight of the l2-loss.
|
||||
"""
|
||||
self.student_feature_map = student_feature_map
|
||||
self.teacher_feature_map = teacher_feature_map
|
||||
self.distillation_loss_weight = distillation_loss_weight
|
||||
|
||||
def apply(self, graph):
|
||||
ret_graph = graph
|
||||
with program_guard(ret_graph.program):
|
||||
|
||||
student_feature_map = ret_graph.var(self.student_feature_map)._var
|
||||
teacher_feature_map = ret_graph.var(self.teacher_feature_map)._var
|
||||
l2loss = layers.reduce_mean(
|
||||
layers.square(student_feature_map - teacher_feature_map))
|
||||
|
||||
distillation_loss = l2loss * self.distillation_loss_weight
|
||||
student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var
|
||||
loss = distillation_loss + student_loss
|
||||
|
||||
ret_graph.out_nodes[
|
||||
'l2loss_' + self.student_feature_map + "_" +
|
||||
self.teacher_feature_map] = distillation_loss.name
|
||||
ret_graph.out_nodes['loss'] = loss.name
|
||||
return ret_graph
|
||||
|
||||
|
||||
class FSPDistiller(object):
|
||||
"""
|
||||
Combine layers from student net and teacher net by fsp-loss.
|
||||
"""
|
||||
|
||||
def __init__(self, student_pairs, teacher_pairs,
|
||||
distillation_loss_weight=1):
|
||||
"""
|
||||
Args:
|
||||
student_pairs(list<tuple>): Each tuple, with two variable names, in student_pairs indicates
|
||||
a section in student network. The variables in a tuple should
|
||||
have the same feature map size.
|
||||
teacher_pairs(list<tuple>): Each tuple, with two variable names, in teacher_pairs indicates
|
||||
a section in teacher network. The variables in a tuple should
|
||||
have the same feature map size. Varibale named teacher_pairs[i][j]
|
||||
should has the save channel number with that of variable named
|
||||
student_pairs[i][j].
|
||||
|
||||
distillation_loss_weight(float): The weight of the fsp-loss. default: 1.
|
||||
"""
|
||||
self.student_pairs = student_pairs
|
||||
self.teacher_pairs = teacher_pairs
|
||||
self.distillation_loss_weight = distillation_loss_weight
|
||||
|
||||
def distiller_loss(self, graph):
|
||||
"""
|
||||
Modify graph inplace to add fsp-loss.
|
||||
Args:
|
||||
graph(GraphWrapper): The graph to be modified.
|
||||
Returns:
|
||||
GraphWrapper: The modified graph.
|
||||
"""
|
||||
distiller_pass = FSPDistillerPass(self.student_pairs,
|
||||
self.teacher_pairs,
|
||||
self.distillation_loss_weight)
|
||||
dis_graph = distiller_pass.apply(graph)
|
||||
return dis_graph
|
||||
|
||||
|
||||
class FSPDistillerPass(object):
|
||||
'''
|
||||
Combine layers from student net and teacher net by fsp-loss.
|
||||
'''
|
||||
|
||||
def __init__(self, s_pairs, t_pairs, distillation_loss_weight=1):
|
||||
"""
|
||||
Args:
|
||||
s_pairs(list<tuple>): Each tuple, with two variable names, in student_pairs indicates
|
||||
a section in student network. The variables in a tuple should
|
||||
have the same feature map size.
|
||||
t_pairs(list<tuple>): Each tuple, with two variable names, in teacher_pairs indicates
|
||||
a section in teacher network. The variables in a tuple should
|
||||
have the same feature map size. Varibale named teacher_pairs[i][j]
|
||||
should has the save channel number with that of variable named
|
||||
student_pairs[i][j].
|
||||
|
||||
distillation_loss_weight(float): The weight of the fsp-loss. default: 1.
|
||||
"""
|
||||
self.s_pairs = s_pairs
|
||||
self.t_pairs = t_pairs
|
||||
self.distillation_loss_weight = distillation_loss_weight
|
||||
|
||||
def apply(self, graph):
|
||||
ret_graph = graph
|
||||
with program_guard(ret_graph.program):
|
||||
losses = []
|
||||
for s_pair, t_pair in zip(self.s_pairs, self.t_pairs):
|
||||
s_pair_start = ret_graph.var(s_pair[0])._var
|
||||
s_pair_end = ret_graph.var(s_pair[1])._var
|
||||
s_fsp_matrix = self._fsp_matrix(s_pair_start, s_pair_end)
|
||||
t_pair_start = ret_graph.var(t_pair[0])._var
|
||||
t_pair_end = ret_graph.var(t_pair[1])._var
|
||||
t_fsp_matrix = self._fsp_matrix(t_pair_start, t_pair_end)
|
||||
l2_loss = layers.reduce_mean(
|
||||
layers.square(s_fsp_matrix - t_fsp_matrix))
|
||||
losses.append(l2_loss)
|
||||
distillation_loss = layers.sum(
|
||||
losses) * self.distillation_loss_weight
|
||||
student_loss = ret_graph.var(ret_graph.out_nodes['loss'])._var
|
||||
loss = distillation_loss + student_loss
|
||||
|
||||
ret_graph.out_nodes[
|
||||
'fsp_distillation_loss'] = distillation_loss.name
|
||||
ret_graph.out_nodes['loss'] = loss.name
|
||||
return ret_graph
|
||||
|
||||
def _fsp_matrix(self, fea_map_0, fea_map_1):
|
||||
return layers.fsp_matrix(fea_map_0, fea_map_1)
|
@ -0,0 +1,209 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import numpy as np
|
||||
from .... import Executor
|
||||
from .... import io
|
||||
from .... import core
|
||||
from ....compiler import CompiledProgram
|
||||
from ....compiler import BuildStrategy
|
||||
from ....framework import IrGraph
|
||||
from ..core.strategy import Strategy
|
||||
from .quantization_pass import *
|
||||
|
||||
__all__ = ['QuantizationStrategy']
|
||||
|
||||
logging.basicConfig(format='%(asctime)s-%(levelname)s: %(message)s')
|
||||
_logger = logging.getLogger(__name__)
|
||||
_logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class QuantizationStrategy(Strategy):
|
||||
"""
|
||||
The strategy for Quantization.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
start_epoch=0,
|
||||
end_epoch=0,
|
||||
float_model_save_path=None,
|
||||
mobile_model_save_path=None,
|
||||
int8_model_save_path=None,
|
||||
activation_bits=8,
|
||||
weight_bits=8,
|
||||
activation_quantize_type='abs_max',
|
||||
save_in_nodes=None,
|
||||
save_out_nodes=None):
|
||||
"""
|
||||
Args:
|
||||
start_epoch(int): The 'on_epoch_begin' function will be called in start_epoch. default: 0
|
||||
end_epoch(int): The 'on_epoch_end' function will be called in end_epoch. default: 0
|
||||
float_model_save_path(str): The path to save model with float weights.
|
||||
None means it doesn't save float model. defalut: None.
|
||||
mobile_model_save_path(str): The path to save model for paddle-mobile execution.
|
||||
None means it doesn't save mobile model. defalut: None.
|
||||
int8_model_save_path(str): The path to save model with int8_t weight.
|
||||
None means it doesn't save int8 model. defalut: None.
|
||||
activation_bits(int): quantization bit number for activation. default: 8.
|
||||
weight_bits(int): quantization bit number for weights. The bias is not quantized.
|
||||
default: 8.
|
||||
activation_quantize_type(str): quantization type for activation,
|
||||
now support 'abs_max', 'range_abs_max' and 'moving_average_abs_max'.
|
||||
If use 'abs_max' mode, the quantization scale will be calculated
|
||||
dynamically each step in both training and testing period. If use
|
||||
'range_abs_max', a static quantization scale will be calculated
|
||||
during training and used in inference.
|
||||
save_in_nodes(list<str>): A list of variable names used to prune graph
|
||||
for saving inference model.
|
||||
save_out_nodes(list<str>): A list of variable names used to prune graph
|
||||
for saving inference model.
|
||||
|
||||
"""
|
||||
super(QuantizationStrategy, self).__init__(start_epoch, end_epoch)
|
||||
self.start_epoch = start_epoch
|
||||
self.end_epoch = end_epoch
|
||||
self.float_model_save_path = float_model_save_path
|
||||
self.mobile_model_save_path = mobile_model_save_path
|
||||
self.int8_model_save_path = int8_model_save_path
|
||||
self.activation_bits = activation_bits
|
||||
self.weight_bits = weight_bits
|
||||
self.activation_quantize_type = activation_quantize_type
|
||||
self.save_out_nodes = save_out_nodes
|
||||
self.save_in_nodes = save_in_nodes
|
||||
|
||||
def on_epoch_begin(self, context):
|
||||
"""
|
||||
Insert fake_quantize_op and fake_dequantize_op before trainging and testing.
|
||||
"""
|
||||
super(QuantizationStrategy, self).on_compression_begin(context)
|
||||
if self.start_epoch == context.epoch_id:
|
||||
_logger.info('QuantizationStrategy::on_epoch_begin')
|
||||
train_ir_graph = IrGraph(
|
||||
core.Graph(context.optimize_graph.program.desc), for_test=False)
|
||||
test_ir_graph = IrGraph(
|
||||
core.Graph(context.eval_graph.program.desc), for_test=True)
|
||||
transform_pass = QuantizationTransformPass(
|
||||
scope=context.scope,
|
||||
place=context.place,
|
||||
weight_bits=self.weight_bits,
|
||||
activation_bits=self.activation_bits,
|
||||
activation_quantize_type=self.activation_quantize_type)
|
||||
transform_pass.apply(train_ir_graph)
|
||||
transform_pass.apply(test_ir_graph)
|
||||
|
||||
build_strategy = BuildStrategy()
|
||||
build_strategy.enable_inplace = False
|
||||
build_strategy.memory_optimize = False
|
||||
# for quantization training
|
||||
context.optimize_graph.compiled_graph = CompiledProgram(
|
||||
train_ir_graph.graph).with_data_parallel(
|
||||
loss_name=context.optimize_graph.out_nodes['loss'],
|
||||
build_strategy=build_strategy)
|
||||
# for evaluation. And program compiled from ir graph must be with data parallel.
|
||||
context.eval_graph.compiled_graph = CompiledProgram(
|
||||
test_ir_graph.graph).with_data_parallel(
|
||||
build_strategy=build_strategy)
|
||||
# for saving inference model after training
|
||||
context.put('quantization_test_ir_graph_backup', test_ir_graph)
|
||||
_logger.info('Finish QuantizationStrategy::on_epoch_begin')
|
||||
|
||||
def on_epoch_end(self, context):
|
||||
"""
|
||||
Free and save inference model.
|
||||
"""
|
||||
super(QuantizationStrategy, self).on_compression_end(context)
|
||||
|
||||
if context.epoch_id == self.end_epoch:
|
||||
_logger.info('QuantizationStrategy::on_epoch_end')
|
||||
test_ir_graph = context.get('quantization_test_ir_graph_backup')
|
||||
# freeze the graph after training
|
||||
freeze_pass = QuantizationFreezePass(
|
||||
scope=context.scope,
|
||||
place=context.place,
|
||||
weight_bits=self.weight_bits,
|
||||
activation_bits=self.activation_bits)
|
||||
freeze_pass.apply(test_ir_graph)
|
||||
|
||||
# for other strategies
|
||||
context.eval_graph.program = test_ir_graph.to_program()
|
||||
|
||||
if self.save_out_nodes == None:
|
||||
out_vars = [
|
||||
context.eval_graph.var(var_name)._var
|
||||
for var_name in context.eval_graph.out_nodes.values()
|
||||
]
|
||||
else:
|
||||
out_vars = [
|
||||
context.eval_graph.var(var_name)._var
|
||||
for var_name in self.save_out_nodes
|
||||
]
|
||||
|
||||
if self.save_in_nodes == None:
|
||||
in_vars = list(context.eval_graph.out_nodes.values())
|
||||
else:
|
||||
in_vars = self.save_in_nodes
|
||||
|
||||
# save float model
|
||||
if self.float_model_save_path:
|
||||
executor = Executor(context.place)
|
||||
io.save_inference_model(
|
||||
self.float_model_save_path,
|
||||
in_vars,
|
||||
out_vars,
|
||||
executor,
|
||||
main_program=test_ir_graph.to_program(),
|
||||
model_filename='model',
|
||||
params_filename='weights',
|
||||
export_for_deployment=True)
|
||||
|
||||
# save int8 model
|
||||
if self.int8_model_save_path:
|
||||
convert_int8_pass = ConvertToInt8Pass(
|
||||
scope=context.scope, place=context.place)
|
||||
convert_int8_pass.apply(test_ir_graph)
|
||||
|
||||
executor = Executor(context.place)
|
||||
io.save_inference_model(
|
||||
self.int8_model_save_path,
|
||||
in_vars,
|
||||
out_vars,
|
||||
executor,
|
||||
main_program=test_ir_graph.to_program(),
|
||||
model_filename='model',
|
||||
params_filename='weights',
|
||||
export_for_deployment=True)
|
||||
|
||||
# save mobile model
|
||||
if self.mobile_model_save_path:
|
||||
if not self.int8_model_save_path:
|
||||
# convert the weights as int8_t type
|
||||
convert_int8_pass = ConvertToInt8Pass(
|
||||
scope=context.scope, place=context.place)
|
||||
convert_int8_pass.apply(test_ir_graph)
|
||||
# make some changes on the graph for the mobile inference
|
||||
mobile_pass = TransformForMobilePass()
|
||||
mobile_pass.apply(test_ir_graph)
|
||||
executor = Executor(context.place)
|
||||
io.save_inference_model(
|
||||
self.mobile_model_save_path,
|
||||
in_vars,
|
||||
out_vars,
|
||||
executor,
|
||||
main_program=test_ir_graph.to_program(),
|
||||
model_filename='model',
|
||||
params_filename='weights',
|
||||
export_for_deployment=True)
|
||||
_logger.info('Finish QuantizationStrategy::on_epoch_end')
|
@ -0,0 +1,46 @@
|
||||
#start_epoch(int): The epoch when to merge student graph and teacher graph for
|
||||
# distillation training. default: 0
|
||||
#
|
||||
#end_epoch(int): The epoch when to finish distillation training. default: 0
|
||||
#
|
||||
#student_feature_map(str): The name of feature map from student network.
|
||||
#
|
||||
#teacher_feature_map(str): The name of feature map from teacher network.
|
||||
# It's shape should be the same with student network.
|
||||
#
|
||||
#student_pairs(list<tuple>): Each tuple, with two variable names, in student_pairs indicates
|
||||
# a section in student network. The variables in a tuple should
|
||||
# have the same feature map size.
|
||||
#
|
||||
#teacher_pairs(list<tuple>): Each tuple, with two variable names, in teacher_pairs indicates
|
||||
# a section in teacher network. The variables in a tuple should
|
||||
# have the same feature map size. Varibale named teacher_pairs[i][j]
|
||||
# should has the save channel number with that of variable named
|
||||
# student_pairs[i][j].
|
||||
#
|
||||
#distillation_loss_weight(float): The weight of the loss.
|
||||
version: 1.0
|
||||
distillers:
|
||||
fsp_distiller:
|
||||
class: 'FSPDistiller'
|
||||
# teacher_pairs: [['teacher_depthwise_conv2d_1.tmp_0', 'teacher_conv2d_3.tmp_0']]
|
||||
# student_pairs: [['student_depthwise_conv2d_1.tmp_0', 'student_conv2d_3.tmp_0']]
|
||||
teacher_pairs: [['teacher_conv2_1_dw.tmp_0', 'teacher_conv1.tmp_0']]
|
||||
student_pairs: [['student_conv2_1_dw.tmp_0', 'student_conv1.tmp_0']]
|
||||
distillation_loss_weight: 1
|
||||
l2_distiller:
|
||||
class: 'L2Distiller'
|
||||
teacher_feature_map: 'teacher.tmp_2'
|
||||
student_feature_map: 'student.tmp_2'
|
||||
distillation_loss_weight: 1
|
||||
strategies:
|
||||
distillation_strategy:
|
||||
class: 'DistillationStrategy'
|
||||
distillers: ['fsp_distiller', 'l2_distiller']
|
||||
start_epoch: 0
|
||||
end_epoch: 1
|
||||
compressor:
|
||||
epoch: 1
|
||||
checkpoint_path: './distillation_checkpoints/'
|
||||
strategies:
|
||||
- distillation_strategy
|
@ -0,0 +1,48 @@
|
||||
#start_epoch(int): The epoch to insert quantization operators. default: 0
|
||||
#
|
||||
#end_epoch(int): The epoch to save inferecne model. default: 0
|
||||
#
|
||||
#float_model_save_path(str): The path to save model with float weights.
|
||||
# None means it doesn't save float model. defalut: None.
|
||||
#
|
||||
#mobile_model_save_path(str): The path to save model for paddle-mobile execution.
|
||||
# None means it doesn't save mobile model. defalut: None.
|
||||
#
|
||||
#int8_model_save_path(str): The path to save model with int8_t weight.
|
||||
# None means it doesn't save int8 model. defalut: None.
|
||||
#
|
||||
#activation_bits(int): quantization bit number for activation. default: 8.
|
||||
#
|
||||
#weight_bits(int): quantization bit number for weights. The bias is not quantized.
|
||||
# default: 8.
|
||||
#
|
||||
#activation_quantize_type(str): quantization type for activation,
|
||||
# now support 'abs_max', 'range_abs_max' and 'moving_average_abs_max'.
|
||||
# If use 'abs_max' mode, the quantization scale will be calculated
|
||||
# dynamically each step in both training and testing period. If use
|
||||
# 'range_abs_max', a static quantization scale will be calculated
|
||||
# during training and used in inference.
|
||||
#
|
||||
#save_in_nodes(list<str>): A list of variable names used to prune graph
|
||||
# for saving inference model.
|
||||
#
|
||||
#save_out_nodes(list<str>): A list of variable names used to prune graph
|
||||
# for saving inference model.
|
||||
version: 1.0
|
||||
strategies:
|
||||
quantization_strategy:
|
||||
class: 'QuantizationStrategy'
|
||||
start_epoch: 0
|
||||
end_epoch: 0
|
||||
float_model_save_path: './output/float'
|
||||
weight_bits: 8
|
||||
activation_bits: 8
|
||||
weight_quantize_type: 'abs_max'
|
||||
activation_quantize_type: 'abs_max'
|
||||
save_in_nodes: ['image']
|
||||
save_out_nodes: ['quan.tmp_2']
|
||||
compressor:
|
||||
epoch: 1
|
||||
checkpoint_path: './checkpoints_quan/'
|
||||
strategies:
|
||||
- quantization_strategy
|
@ -0,0 +1,94 @@
|
||||
# copyright (c) 2019 paddlepaddle authors. all rights reserved.
|
||||
#
|
||||
# licensed under the apache license, version 2.0 (the "license");
|
||||
# you may not use this file except in compliance with the license.
|
||||
# you may obtain a copy of the license at
|
||||
#
|
||||
# http://www.apache.org/licenses/license-2.0
|
||||
#
|
||||
# unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the license is distributed on an "as is" basis,
|
||||
# without warranties or conditions of any kind, either express or implied.
|
||||
# see the license for the specific language governing permissions and
|
||||
# limitations under the license.
|
||||
|
||||
import paddle
|
||||
import unittest
|
||||
import paddle.fluid as fluid
|
||||
from mobilenet import MobileNet
|
||||
from paddle.fluid.contrib.slim.core import Compressor
|
||||
from paddle.fluid.contrib.slim.graph import GraphWrapper
|
||||
|
||||
|
||||
class TestDistillationStrategy(unittest.TestCase):
|
||||
"""
|
||||
Test API of distillation strategy.
|
||||
"""
|
||||
|
||||
def test_compression(self):
|
||||
if not fluid.core.is_compiled_with_cuda():
|
||||
return
|
||||
class_dim = 10
|
||||
image_shape = [1, 28, 28]
|
||||
image = fluid.layers.data(
|
||||
name='image', shape=image_shape, dtype='float32')
|
||||
image.stop_gradient = False
|
||||
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
|
||||
out = MobileNet(name="student").net(input=image, class_dim=class_dim)
|
||||
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
|
||||
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
|
||||
val_program = fluid.default_main_program().clone(for_test=False)
|
||||
|
||||
cost = fluid.layers.cross_entropy(input=out, label=label)
|
||||
avg_cost = fluid.layers.mean(x=cost)
|
||||
optimizer = fluid.optimizer.Momentum(
|
||||
momentum=0.9,
|
||||
learning_rate=0.01,
|
||||
regularization=fluid.regularizer.L2Decay(4e-5))
|
||||
|
||||
place = fluid.CUDAPlace(0)
|
||||
exe = fluid.Executor(place)
|
||||
exe.run(fluid.default_startup_program())
|
||||
|
||||
val_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128)
|
||||
|
||||
val_feed_list = [('img', image.name), ('label', label.name)]
|
||||
val_fetch_list = [('acc_top1', acc_top1.name), ('acc_top5',
|
||||
acc_top5.name)]
|
||||
|
||||
train_reader = paddle.batch(
|
||||
paddle.dataset.mnist.train(), batch_size=128)
|
||||
train_feed_list = [('img', image.name), ('label', label.name)]
|
||||
train_fetch_list = [('loss', avg_cost.name)]
|
||||
|
||||
# define teacher program
|
||||
teacher_program = fluid.Program()
|
||||
startup_program = fluid.Program()
|
||||
with fluid.program_guard(teacher_program, startup_program):
|
||||
img = teacher_program.global_block()._clone_variable(
|
||||
image, force_persistable=False)
|
||||
predict = MobileNet(name="teacher").net(input=img,
|
||||
class_dim=class_dim)
|
||||
|
||||
exe.run(startup_program)
|
||||
|
||||
com_pass = Compressor(
|
||||
place,
|
||||
fluid.global_scope(),
|
||||
fluid.default_main_program(),
|
||||
train_reader=train_reader,
|
||||
train_feed_list=train_feed_list,
|
||||
train_fetch_list=train_fetch_list,
|
||||
eval_program=val_program,
|
||||
eval_reader=val_reader,
|
||||
eval_feed_list=val_feed_list,
|
||||
eval_fetch_list=val_fetch_list,
|
||||
teacher_programs=[teacher_program.clone(for_test=True)],
|
||||
train_optimizer=optimizer,
|
||||
distiller_optimizer=optimizer)
|
||||
com_pass.config('./distillation/compress.yaml')
|
||||
eval_graph = com_pass.run()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,82 @@
|
||||
# copyright (c) 2019 paddlepaddle authors. all rights reserved.
|
||||
#
|
||||
# licensed under the apache license, version 2.0 (the "license");
|
||||
# you may not use this file except in compliance with the license.
|
||||
# you may obtain a copy of the license at
|
||||
#
|
||||
# http://www.apache.org/licenses/license-2.0
|
||||
#
|
||||
# unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the license is distributed on an "as is" basis,
|
||||
# without warranties or conditions of any kind, either express or implied.
|
||||
# see the license for the specific language governing permissions and
|
||||
# limitations under the license.
|
||||
|
||||
import paddle
|
||||
import unittest
|
||||
import paddle.fluid as fluid
|
||||
from mobilenet import MobileNet
|
||||
from paddle.fluid.contrib.slim.core import Compressor
|
||||
from paddle.fluid.contrib.slim.graph import GraphWrapper
|
||||
|
||||
|
||||
class TestQuantizationStrategy(unittest.TestCase):
|
||||
"""
|
||||
Test API of quantization strategy.
|
||||
"""
|
||||
|
||||
def test_compression(self):
|
||||
if not fluid.core.is_compiled_with_cuda():
|
||||
return
|
||||
class_dim = 10
|
||||
image_shape = [1, 28, 28]
|
||||
image = fluid.layers.data(
|
||||
name='image', shape=image_shape, dtype='float32')
|
||||
image.stop_gradient = False
|
||||
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
|
||||
out = MobileNet(name='quan').net(input=image, class_dim=class_dim)
|
||||
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
|
||||
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
|
||||
val_program = fluid.default_main_program().clone(for_test=False)
|
||||
|
||||
cost = fluid.layers.cross_entropy(input=out, label=label)
|
||||
avg_cost = fluid.layers.mean(x=cost)
|
||||
|
||||
optimizer = fluid.optimizer.Momentum(
|
||||
momentum=0.9,
|
||||
learning_rate=0.01,
|
||||
regularization=fluid.regularizer.L2Decay(4e-5))
|
||||
|
||||
place = fluid.CUDAPlace(0)
|
||||
exe = fluid.Executor(place)
|
||||
exe.run(fluid.default_startup_program())
|
||||
|
||||
val_reader = paddle.batch(paddle.dataset.mnist.test(), batch_size=128)
|
||||
|
||||
val_feed_list = [('img', image.name), ('label', label.name)]
|
||||
val_fetch_list = [('acc_top1', acc_top1.name), ('acc_top5',
|
||||
acc_top5.name)]
|
||||
|
||||
train_reader = paddle.batch(
|
||||
paddle.dataset.mnist.train(), batch_size=128)
|
||||
train_feed_list = [('img', image.name), ('label', label.name)]
|
||||
train_fetch_list = [('loss', avg_cost.name)]
|
||||
|
||||
com_pass = Compressor(
|
||||
place,
|
||||
fluid.global_scope(),
|
||||
fluid.default_main_program(),
|
||||
train_reader=train_reader,
|
||||
train_feed_list=train_feed_list,
|
||||
train_fetch_list=train_fetch_list,
|
||||
eval_program=val_program,
|
||||
eval_reader=val_reader,
|
||||
eval_feed_list=val_feed_list,
|
||||
eval_fetch_list=val_fetch_list,
|
||||
train_optimizer=optimizer)
|
||||
com_pass.config('./quantization/compress.yaml')
|
||||
eval_graph = com_pass.run()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,60 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
def fsp_matrix(a, b):
|
||||
batch = a.shape[0]
|
||||
a_channel = a.shape[1]
|
||||
b_channel = b.shape[1]
|
||||
h = a.shape[2]
|
||||
w = a.shape[3]
|
||||
a_t = a.transpose([0, 2, 3, 1])
|
||||
a_t = a_t.reshape([batch, h * w, a_channel])
|
||||
b_t = b.transpose([0, 2, 3, 1]).reshape([batch, h * w, b_channel])
|
||||
a_r = a_t.repeat(
|
||||
b_channel, axis=1).reshape(
|
||||
[batch, h * w, b_channel, a_channel]).transpose([0, 1, 3, 2])
|
||||
b_r = b_t.repeat(
|
||||
a_channel, axis=1).reshape([batch, h * w, a_channel, b_channel])
|
||||
return np.mean(a_r * b_r, axis=1)
|
||||
|
||||
|
||||
class TestFSPOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "fsp"
|
||||
self.initTestCase()
|
||||
|
||||
feature_map_0 = np.random.uniform(0, 10, self.a_shape).astype('float32')
|
||||
feature_map_1 = np.random.uniform(0, 10, self.b_shape).astype('float32')
|
||||
|
||||
self.inputs = {'X': feature_map_0, 'Y': feature_map_1}
|
||||
self.outputs = {'Out': fsp_matrix(feature_map_0, feature_map_1)}
|
||||
|
||||
def initTestCase(self):
|
||||
self.a_shape = (2, 16, 32, 31)
|
||||
self.b_shape = (2, 28, 32, 31)
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad_normal(self):
|
||||
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.05)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue