|
|
@ -2305,9 +2305,10 @@ class MaxLayer(LayerBase):
|
|
|
|
active_type='linear',
|
|
|
|
active_type='linear',
|
|
|
|
device=None,
|
|
|
|
device=None,
|
|
|
|
bias=False,
|
|
|
|
bias=False,
|
|
|
|
output_max_index=None):
|
|
|
|
output_max_index=None,
|
|
|
|
|
|
|
|
**xargs):
|
|
|
|
super(MaxLayer, self).__init__(
|
|
|
|
super(MaxLayer, self).__init__(
|
|
|
|
name, 'max', 0, inputs=inputs, device=device)
|
|
|
|
name, 'max', 0, inputs=inputs, device=device, **xargs)
|
|
|
|
config_assert(len(self.inputs) == 1, 'MaxLayer must have 1 input')
|
|
|
|
config_assert(len(self.inputs) == 1, 'MaxLayer must have 1 input')
|
|
|
|
self.config.trans_type = trans_type
|
|
|
|
self.config.trans_type = trans_type
|
|
|
|
self.config.active_type = active_type
|
|
|
|
self.config.active_type = active_type
|
|
|
@ -2609,14 +2610,16 @@ class AverageLayer(LayerBase):
|
|
|
|
trans_type='non-seq',
|
|
|
|
trans_type='non-seq',
|
|
|
|
active_type='linear',
|
|
|
|
active_type='linear',
|
|
|
|
device=None,
|
|
|
|
device=None,
|
|
|
|
bias=False):
|
|
|
|
bias=False,
|
|
|
|
|
|
|
|
**xargs):
|
|
|
|
super(AverageLayer, self).__init__(
|
|
|
|
super(AverageLayer, self).__init__(
|
|
|
|
name,
|
|
|
|
name,
|
|
|
|
'average',
|
|
|
|
'average',
|
|
|
|
0,
|
|
|
|
0,
|
|
|
|
inputs=inputs,
|
|
|
|
inputs=inputs,
|
|
|
|
device=device,
|
|
|
|
device=device,
|
|
|
|
active_type=active_type)
|
|
|
|
active_type=active_type,
|
|
|
|
|
|
|
|
**xargs)
|
|
|
|
self.config.average_strategy = average_strategy
|
|
|
|
self.config.average_strategy = average_strategy
|
|
|
|
self.config.trans_type = trans_type
|
|
|
|
self.config.trans_type = trans_type
|
|
|
|
config_assert(len(inputs) == 1, 'AverageLayer must have 1 input')
|
|
|
|
config_assert(len(inputs) == 1, 'AverageLayer must have 1 input')
|
|
|
@ -3490,7 +3493,7 @@ def parse_config(config_file, config_arg_str):
|
|
|
|
def parse_config_and_serialize(config_file, config_arg_str):
|
|
|
|
def parse_config_and_serialize(config_file, config_arg_str):
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
config = parse_config(config_file, config_arg_str)
|
|
|
|
config = parse_config(config_file, config_arg_str)
|
|
|
|
#logger.info(config)
|
|
|
|
# logger.info(config)
|
|
|
|
return config.SerializeToString()
|
|
|
|
return config.SerializeToString()
|
|
|
|
except:
|
|
|
|
except:
|
|
|
|
traceback.print_exc()
|
|
|
|
traceback.print_exc()
|
|
|
|