[slim] Refine framework of slim and add filter pruning strategy (#16226)
* First pr of paddle slim. 1. Add framework of paddle slim 2. Add filter pruning strategy test=develop * Rename unitest to tests. test=develop * Add prettytable into requirements. test=develop * Change in_nodes and out_nodes to odered dict. test=develop * Remove distillation. test=develop * Fix API.spec test=develop * Fix unitest. test=develop * Fix unitest in windows. test=develop * Fix unitest in windows. test=develop * Fix unitest. test=develop * Hide some functions. test=develop * Fix python import in python3.5 test=develop * Fix compress pass. test=develop * Fix unitest of test_dist_ctr. test=develop * Enhence flops. * use os.path.join * Fix pickle for python3 Fix log and comments. test=develop * 1. Remove feed_reader in compress pass 2. Fix cache reader 3. Rename CompressPass to Compressor 4. Add comments for distiller optimizer 5. Remove unused pruner currently 6. Add some comments. 7. Change API.spec test=develop * Fix pruning in python3. test=develop * Fix unitest in python3. test=develop * Fix format in python3. test=developrevert-16303-checkpoint
parent
18779b5b8f
commit
2e5831f0dc
@ -1,129 +0,0 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ....core import CPUPlace
|
||||
from ..graph import get_executor
|
||||
|
||||
__all__ = ['Context', 'CompressPass']
|
||||
|
||||
|
||||
class Context(object):
|
||||
"""
|
||||
The context in the process of compression.
|
||||
Args:
|
||||
exe: The executor used to execute graph.
|
||||
graph: The graph to be compressed.
|
||||
scope: The scope used to execute graph.
|
||||
program_exe: The program_exe is used to execute the program
|
||||
created for modifying the variables in scope.
|
||||
"""
|
||||
|
||||
def __init__(self, exe, graph, scope, program_exe=None):
|
||||
# The total number of epoches to be trained.
|
||||
self.epoch = 0
|
||||
# Current epoch
|
||||
self.epoch_id = 0
|
||||
# Current batch
|
||||
self.batch_id = 0
|
||||
self.exe = exe
|
||||
self.graph = graph
|
||||
self.scope = scope
|
||||
self.program_exe = program_exe
|
||||
|
||||
|
||||
class CompressPass(object):
|
||||
"""
|
||||
The pass used to compress model.
|
||||
Args:
|
||||
place: The device used in compression.
|
||||
data_reader: The data_reader used to run graph.
|
||||
data_feeder: The data_feeder used to run graph.
|
||||
scope: The scope used to run graph.
|
||||
metrics: The metrics for evaluating model.
|
||||
epoch: The total epoches of trainning in compression.
|
||||
program_exe: The program_exe is used to execute the program
|
||||
created for modifying the variables in scope.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
place=None,
|
||||
data_reader=None,
|
||||
data_feeder=None,
|
||||
scope=None,
|
||||
metrics=None,
|
||||
epoch=None,
|
||||
program_exe=None):
|
||||
self.strategies = []
|
||||
self.place = CPUPlace() if place is None else place
|
||||
self.data_reader = data_reader
|
||||
self.data_feeder = data_feeder
|
||||
self.scope = scope
|
||||
self.metrics = metrics
|
||||
self.epoch = epoch
|
||||
self.program_exe = program_exe
|
||||
|
||||
def add_strategy(self, strategy):
|
||||
"""
|
||||
Add a strategy to current compress pass.
|
||||
Args:
|
||||
strategy: The strategy to be added into current compress pass.
|
||||
"""
|
||||
self.strategies.append(strategy)
|
||||
self.epoch = max(strategy.end_epoch, self.epoch)
|
||||
|
||||
def apply(self, graph):
|
||||
"""
|
||||
Compress a model.
|
||||
Args:
|
||||
graph: The target graph to be compressed.
|
||||
"""
|
||||
self.executor = get_executor(graph, self.place)
|
||||
context = Context(
|
||||
self.executor, graph, self.scope, program_exe=self.program_exe)
|
||||
|
||||
for strategy in self.strategies:
|
||||
strategy.on_compress_begin(context)
|
||||
|
||||
for epoch in range(self.epoch):
|
||||
|
||||
for strategy in self.strategies:
|
||||
strategy.on_epoch_begin(context)
|
||||
|
||||
for data in self.data_reader():
|
||||
|
||||
for strategy in self.strategies:
|
||||
strategy.on_batch_begin(context)
|
||||
fetches = None
|
||||
if self.metrics:
|
||||
fetches = self.metrics.values()
|
||||
feed = None
|
||||
if self.data_feeder:
|
||||
feed = self.data_feeder.feed(data)
|
||||
results = self.executor.run(graph,
|
||||
fetches=fetches,
|
||||
scope=self.scope,
|
||||
feed=feed)
|
||||
if results:
|
||||
print("results: {}".format(
|
||||
zip(self.metrics.keys(), results)))
|
||||
for strategy in self.strategies:
|
||||
strategy.on_batch_end(context)
|
||||
context.batch_id += 1
|
||||
|
||||
for strategy in self.strategies:
|
||||
strategy.on_epoch_end(context)
|
||||
context.epoch_id += 1
|
||||
|
||||
for strategy in self.strategies:
|
||||
strategy.on_compress_end(context)
|
File diff suppressed because it is too large
Load Diff
@ -1,39 +0,0 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .compress_pass import CompressPass
|
||||
from .config import ConfigFactory
|
||||
|
||||
__all__ = ['build_compressor']
|
||||
|
||||
|
||||
def build_compressor(place=None,
|
||||
data_reader=None,
|
||||
data_feeder=None,
|
||||
scope=None,
|
||||
metrics=None,
|
||||
epoch=None,
|
||||
config=None):
|
||||
if config is not None:
|
||||
factory = ConfigFactory(config)
|
||||
comp_pass = factory.get_compress_pass()
|
||||
else:
|
||||
comp_pass = CompressPass()
|
||||
comp_pass.place = place
|
||||
comp_pass.data_reader = data_reader
|
||||
comp_pass.data_feeder = data_feeder
|
||||
comp_pass.scope = scope
|
||||
comp_pass.metrics = metrics
|
||||
comp_pass.epoch = epoch
|
||||
return comp_pass
|
@ -1,28 +0,0 @@
|
||||
version: 1.0
|
||||
pruners:
|
||||
pruner_1:
|
||||
class: 'RatioPruner'
|
||||
ratios:
|
||||
'conv1_1.w': 0.3
|
||||
'conv1_2.w': 0.4
|
||||
'*': 0.9
|
||||
group_dims:
|
||||
'*': [1, 2, 3]
|
||||
criterions:
|
||||
'*': 'l1-norm'
|
||||
strategies:
|
||||
strategy_1:
|
||||
class: 'SensitivePruneStrategy'
|
||||
pruner: 'pruner_1'
|
||||
start_epoch: 0
|
||||
end_epoch: 10
|
||||
delta_rate: 0.20
|
||||
acc_loss_threshold: 0.2
|
||||
sensitivities:
|
||||
'conv1_1.w': 0.4
|
||||
|
||||
compress_pass:
|
||||
class: 'CompressPass'
|
||||
epoch: 100
|
||||
strategies:
|
||||
- strategy_1
|
@ -1,69 +0,0 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle.fluid as fluid
|
||||
import paddle
|
||||
import os
|
||||
import sys
|
||||
from paddle.fluid.contrib.slim import CompressPass
|
||||
from paddle.fluid.contrib.slim import build_compressor
|
||||
from paddle.fluid.contrib.slim import ImitationGraph
|
||||
|
||||
|
||||
class LinearModel(object):
|
||||
def __init__(slef):
|
||||
pass
|
||||
|
||||
def train(self):
|
||||
train_program = fluid.Program()
|
||||
startup_program = fluid.Program()
|
||||
startup_program.random_seed = 10
|
||||
with fluid.program_guard(train_program, startup_program):
|
||||
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
|
||||
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
|
||||
predict = fluid.layers.fc(input=x, size=1, act=None)
|
||||
cost = fluid.layers.square_error_cost(input=predict, label=y)
|
||||
avg_cost = fluid.layers.mean(cost)
|
||||
eval_program = train_program.clone()
|
||||
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
|
||||
sgd_optimizer.minimize(avg_cost)
|
||||
|
||||
train_reader = paddle.batch(
|
||||
paddle.dataset.uci_housing.train(), batch_size=1)
|
||||
eval_reader = paddle.batch(
|
||||
paddle.dataset.uci_housing.test(), batch_size=1)
|
||||
place = fluid.CPUPlace()
|
||||
train_feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
|
||||
eval_feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
|
||||
exe = fluid.Executor(place)
|
||||
exe.run(startup_program)
|
||||
train_metrics = {"loss": avg_cost.name}
|
||||
eval_metrics = {"loss": avg_cost.name}
|
||||
|
||||
graph = ImitationGraph(train_program)
|
||||
config = './config.yaml'
|
||||
comp_pass = build_compressor(
|
||||
place,
|
||||
data_reader=train_reader,
|
||||
data_feeder=train_feeder,
|
||||
scope=fluid.global_scope(),
|
||||
metrics=train_metrics,
|
||||
epoch=1,
|
||||
config=config)
|
||||
comp_pass.apply(graph)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = LinearModel()
|
||||
model.train()
|
@ -1,49 +0,0 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import subprocess
|
||||
from ....framework import Program
|
||||
from ....framework import Block
|
||||
from .... import core
|
||||
|
||||
__all__ = ['Graph', 'ImitationGraph', 'IRGraph']
|
||||
|
||||
|
||||
class Graph(object):
|
||||
"""
|
||||
Base class for all graph.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def all_parameters(self):
|
||||
"""
|
||||
Return all the parameters in current graph.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ImitationGraph(Graph):
|
||||
def __init__(self, program=None):
|
||||
super(ImitationGraph, self).__init__()
|
||||
self.program = Program() if program is None else program
|
||||
|
||||
def all_parameters(self):
|
||||
return self.program.global_block().all_parameters()
|
||||
|
||||
|
||||
class IRGraph(Graph):
|
||||
pass
|
@ -1,42 +0,0 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
__all__ = ['GraphPass', 'PruneParameterPass']
|
||||
|
||||
|
||||
class GraphPass(object):
|
||||
"""
|
||||
Base class for all graph pass.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def apply(self, graph):
|
||||
pass
|
||||
|
||||
|
||||
class PruneParameterPass(GraphPass):
|
||||
"""
|
||||
Generate a graph for pruning parameters from target graph.
|
||||
"""
|
||||
|
||||
def __init__(self, pruned_params, thresholds):
|
||||
super(PruneParameterPass, self).__init__()
|
||||
self.pruned_params = pruned_params
|
||||
self.thresholds = thresholds
|
||||
self.default_threshold = thresholds['*']
|
||||
|
||||
def apply(self, graph):
|
||||
pass
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,29 +0,0 @@
|
||||
version: 1.0
|
||||
include: ["./configs/pruners.yaml", "./configs/pruners_0.yaml"]
|
||||
pruners:
|
||||
pruner_1:
|
||||
class: 'RatioPruner'
|
||||
ratios:
|
||||
'conv1_1.w': 0.3
|
||||
'conv1_2.w': 0.4
|
||||
'*': 0.9
|
||||
group_dims:
|
||||
'*': [1, 2, 3]
|
||||
criterions:
|
||||
'*': 'l1-norm'
|
||||
strategies:
|
||||
strategy_1:
|
||||
class: 'SensitivePruneStrategy'
|
||||
pruner: 'pruner_2'
|
||||
start_epoch: 0
|
||||
end_epoch: 10
|
||||
delta_rate: 0.20
|
||||
acc_loss_threshold: 0.2
|
||||
sensitivities:
|
||||
'conv1_1.w': 0.4
|
||||
|
||||
compress_pass:
|
||||
class: 'CompressPass'
|
||||
epoch: 100
|
||||
strategies:
|
||||
- strategy_1
|
@ -0,0 +1,34 @@
|
||||
#start_epoch: The 'on_epoch_begin' function will be called in start_epoch. default: 0.
|
||||
#end_epoch: The 'on_epoch_end' function will be called in end_epoch. default: 10.
|
||||
#delta_rate: The delta used to generate ratios when calculating sensitivities.
|
||||
#target_ratio: The flops ratio to be pruned from current model.
|
||||
#metric_name: The metric used to evaluate the model.
|
||||
#pruned_params: The pattern str to match the parameter names to be pruned.
|
||||
#sensitivities_file: The sensitivities file.
|
||||
#num_steps: The number of pruning steps.
|
||||
#eval_rate: The rate of sampled data used to calculate sensitivities.
|
||||
version: 1.0
|
||||
pruners:
|
||||
pruner_1:
|
||||
class: 'StructurePruner'
|
||||
pruning_axis:
|
||||
'*': 0
|
||||
criterions:
|
||||
'*': 'l1_norm'
|
||||
strategies:
|
||||
sensitive_pruning_strategy:
|
||||
class: 'SensitivePruneStrategy'
|
||||
pruner: 'pruner_1'
|
||||
start_epoch: 0
|
||||
delta_rate: 0.1
|
||||
target_ratio: 0.3
|
||||
num_steps: 1
|
||||
eval_rate: 0.5
|
||||
pruned_params: '.*_sep_weights'
|
||||
sensitivities_file: 'mobilenet_acc_top1_sensitive.data'
|
||||
metric_name: 'acc_top1'
|
||||
compressor:
|
||||
epoch: 120
|
||||
checkpoint_path: './checkpoints/'
|
||||
strategies:
|
||||
- sensitive_pruning_strategy
|
@ -1,12 +0,0 @@
|
||||
version: 1.0
|
||||
pruners:
|
||||
pruner_2:
|
||||
class: 'RatioPruner'
|
||||
ratios:
|
||||
'conv1_1.w': 0.5
|
||||
'conv1_2.w': 0.2
|
||||
'*': 0.7
|
||||
group_dims:
|
||||
'*': [1, 2, 3]
|
||||
criterions:
|
||||
'*': 'l1-norm'
|
@ -1,12 +0,0 @@
|
||||
version: 1.0
|
||||
pruners:
|
||||
pruner_3:
|
||||
class: 'RatioPruner'
|
||||
ratios:
|
||||
'conv1_1.w': 0.5
|
||||
'conv1_2.w': 0.2
|
||||
'*': 0.7
|
||||
group_dims:
|
||||
'*': [1, 2, 3]
|
||||
criterions:
|
||||
'*': 'l1-norm'
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,34 @@
|
||||
#start_epoch: The 'on_epoch_begin' function will be called in start_epoch. default: 0.
|
||||
#end_epoch: The 'on_epoch_end' function will be called in end_epoch. default: 10.
|
||||
#delta_rate: The delta used to generate ratios when calculating sensitivities.
|
||||
#target_ratio: The flops ratio to be pruned from current model.
|
||||
#metric_name: The metric used to evaluate the model.
|
||||
#pruned_params: The pattern str to match the parameter names to be pruned.
|
||||
#sensitivities_file: The sensitivities file.
|
||||
#num_steps: The number of pruning steps.
|
||||
#eval_rate: The rate of sampled data used to calculate sensitivities.
|
||||
version: 1.0
|
||||
pruners:
|
||||
pruner_1:
|
||||
class: 'StructurePruner'
|
||||
pruning_axis:
|
||||
'*': 0
|
||||
criterions:
|
||||
'*': 'l1_norm'
|
||||
strategies:
|
||||
sensitive_pruning_strategy:
|
||||
class: 'SensitivePruneStrategy'
|
||||
pruner: 'pruner_1'
|
||||
start_epoch: 1
|
||||
delta_rate: 0.2
|
||||
target_ratio: 0.08
|
||||
num_steps: 1
|
||||
eval_rate: 0.5
|
||||
pruned_params: 'conv6_sep_weights'
|
||||
sensitivities_file: 'mobilenet_acc_top1_sensitive.data'
|
||||
metric_name: 'acc_top1'
|
||||
compressor:
|
||||
epoch: 2
|
||||
checkpoint_path: './checkpoints/'
|
||||
strategies:
|
||||
- sensitive_pruning_strategy
|
@ -0,0 +1,210 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.initializer import MSRA
|
||||
from paddle.fluid.param_attr import ParamAttr
|
||||
|
||||
__all__ = ['MobileNet']
|
||||
|
||||
train_parameters = {
|
||||
"input_size": [3, 224, 224],
|
||||
"input_mean": [0.485, 0.456, 0.406],
|
||||
"input_std": [0.229, 0.224, 0.225],
|
||||
"learning_strategy": {
|
||||
"name": "piecewise_decay",
|
||||
"batch_size": 256,
|
||||
"epochs": [30, 60, 90],
|
||||
"steps": [0.1, 0.01, 0.001, 0.0001]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class MobileNet():
|
||||
def __init__(self):
|
||||
self.params = train_parameters
|
||||
|
||||
def net(self, input, class_dim=1000, scale=1.0):
|
||||
# conv1: 112x112
|
||||
input = self.conv_bn_layer(
|
||||
input,
|
||||
filter_size=3,
|
||||
channels=3,
|
||||
num_filters=int(32 * scale),
|
||||
stride=2,
|
||||
padding=1,
|
||||
name="conv1")
|
||||
|
||||
# 56x56
|
||||
input = self.depthwise_separable(
|
||||
input,
|
||||
num_filters1=32,
|
||||
num_filters2=64,
|
||||
num_groups=32,
|
||||
stride=1,
|
||||
scale=scale,
|
||||
name="conv2_1")
|
||||
|
||||
input = self.depthwise_separable(
|
||||
input,
|
||||
num_filters1=64,
|
||||
num_filters2=128,
|
||||
num_groups=64,
|
||||
stride=2,
|
||||
scale=scale,
|
||||
name="conv2_2")
|
||||
|
||||
# 28x28
|
||||
input = self.depthwise_separable(
|
||||
input,
|
||||
num_filters1=128,
|
||||
num_filters2=128,
|
||||
num_groups=128,
|
||||
stride=1,
|
||||
scale=scale,
|
||||
name="conv3_1")
|
||||
|
||||
input = self.depthwise_separable(
|
||||
input,
|
||||
num_filters1=128,
|
||||
num_filters2=256,
|
||||
num_groups=128,
|
||||
stride=2,
|
||||
scale=scale,
|
||||
name="conv3_2")
|
||||
|
||||
# 14x14
|
||||
input = self.depthwise_separable(
|
||||
input,
|
||||
num_filters1=256,
|
||||
num_filters2=256,
|
||||
num_groups=256,
|
||||
stride=1,
|
||||
scale=scale,
|
||||
name="conv4_1")
|
||||
|
||||
input = self.depthwise_separable(
|
||||
input,
|
||||
num_filters1=256,
|
||||
num_filters2=512,
|
||||
num_groups=256,
|
||||
stride=2,
|
||||
scale=scale,
|
||||
name="conv4_2")
|
||||
|
||||
# 14x14
|
||||
for i in range(5):
|
||||
input = self.depthwise_separable(
|
||||
input,
|
||||
num_filters1=512,
|
||||
num_filters2=512,
|
||||
num_groups=512,
|
||||
stride=1,
|
||||
scale=scale,
|
||||
name="conv5" + "_" + str(i + 1))
|
||||
# 7x7
|
||||
input = self.depthwise_separable(
|
||||
input,
|
||||
num_filters1=512,
|
||||
num_filters2=1024,
|
||||
num_groups=512,
|
||||
stride=2,
|
||||
scale=scale,
|
||||
name="conv5_6")
|
||||
|
||||
input = self.depthwise_separable(
|
||||
input,
|
||||
num_filters1=1024,
|
||||
num_filters2=1024,
|
||||
num_groups=1024,
|
||||
stride=1,
|
||||
scale=scale,
|
||||
name="conv6")
|
||||
|
||||
input = fluid.layers.pool2d(
|
||||
input=input,
|
||||
pool_size=0,
|
||||
pool_stride=1,
|
||||
pool_type='avg',
|
||||
global_pooling=True)
|
||||
|
||||
output = fluid.layers.fc(input=input,
|
||||
size=class_dim,
|
||||
act='softmax',
|
||||
param_attr=ParamAttr(
|
||||
initializer=MSRA(), name="fc7_weights"),
|
||||
bias_attr=ParamAttr(name="fc7_offset"))
|
||||
return output
|
||||
|
||||
def conv_bn_layer(self,
|
||||
input,
|
||||
filter_size,
|
||||
num_filters,
|
||||
stride,
|
||||
padding,
|
||||
channels=None,
|
||||
num_groups=1,
|
||||
act='relu',
|
||||
use_cudnn=True,
|
||||
name=None):
|
||||
conv = fluid.layers.conv2d(
|
||||
input=input,
|
||||
num_filters=num_filters,
|
||||
filter_size=filter_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
groups=num_groups,
|
||||
act=None,
|
||||
use_cudnn=use_cudnn,
|
||||
param_attr=ParamAttr(
|
||||
initializer=MSRA(), name=name + "_weights"),
|
||||
bias_attr=False)
|
||||
bn_name = name + "_bn"
|
||||
return fluid.layers.batch_norm(
|
||||
input=conv,
|
||||
act=act,
|
||||
param_attr=ParamAttr(name=bn_name + "_scale"),
|
||||
bias_attr=ParamAttr(name=bn_name + "_offset"),
|
||||
moving_mean_name=bn_name + '_mean',
|
||||
moving_variance_name=bn_name + '_variance')
|
||||
|
||||
def depthwise_separable(self,
|
||||
input,
|
||||
num_filters1,
|
||||
num_filters2,
|
||||
num_groups,
|
||||
stride,
|
||||
scale,
|
||||
name=None):
|
||||
depthwise_conv = self.conv_bn_layer(
|
||||
input=input,
|
||||
filter_size=3,
|
||||
num_filters=int(num_filters1 * scale),
|
||||
stride=stride,
|
||||
padding=1,
|
||||
num_groups=int(num_groups * scale),
|
||||
use_cudnn=False,
|
||||
name=name + "_dw")
|
||||
|
||||
pointwise_conv = self.conv_bn_layer(
|
||||
input=depthwise_conv,
|
||||
filter_size=1,
|
||||
num_filters=int(num_filters2 * scale),
|
||||
stride=1,
|
||||
padding=0,
|
||||
name=name + "_sep")
|
||||
return pointwise_conv
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue