|
|
|
@ -265,13 +265,20 @@ class ImperativeQuantAware(object):
|
|
|
|
|
if hasattr(layer, "skip_quant") and layer.skip_quant == True:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
scopes = name.split('.')
|
|
|
|
|
target = scopes[-1]
|
|
|
|
|
last_idx = 0
|
|
|
|
|
idx = 0
|
|
|
|
|
obj = model
|
|
|
|
|
parent = model
|
|
|
|
|
for i in range(len(scopes) - 1):
|
|
|
|
|
obj = getattr(parent, scopes[i])
|
|
|
|
|
parent = obj
|
|
|
|
|
|
|
|
|
|
while idx < len(name):
|
|
|
|
|
if (name[idx] == '.'):
|
|
|
|
|
if hasattr(parent, name[last_idx:idx]):
|
|
|
|
|
obj = getattr(obj, name[last_idx:idx])
|
|
|
|
|
parent = obj
|
|
|
|
|
last_idx = idx + 1
|
|
|
|
|
idx += 1
|
|
|
|
|
target = name[last_idx:idx]
|
|
|
|
|
|
|
|
|
|
quant_layer = self._get_quantized_counterpart(layer)
|
|
|
|
|
setattr(quant_layer, "layer_name", layer.full_name())
|
|
|
|
|
setattr(obj, target, quant_layer)
|
|
|
|
|