|
|
@ -536,7 +536,7 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveTValue> primitiveT_value,
|
|
|
|
STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveTValue> primitiveT_value,
|
|
|
|
bool depthwise) {
|
|
|
|
bool perchanel, bool depthwise) {
|
|
|
|
// const vector<int> dims = filter->dims;
|
|
|
|
// const vector<int> dims = filter->dims;
|
|
|
|
// perlayer
|
|
|
|
// perlayer
|
|
|
|
if (!weight->isa<Parameter>()) {
|
|
|
|
if (!weight->isa<Parameter>()) {
|
|
|
@ -544,9 +544,17 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr weight, std::shared_ptr<P
|
|
|
|
return RET_PARAM_INVALID;
|
|
|
|
return RET_PARAM_INVALID;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auto parameter = std::dynamic_pointer_cast<Parameter>(weight);
|
|
|
|
auto parameter = std::dynamic_pointer_cast<Parameter>(weight);
|
|
|
|
|
|
|
|
if (parameter == nullptr) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << weight->fullname_with_scope() << " can not cast to Parameter";
|
|
|
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
|
|
|
}
|
|
|
|
ParamValueLitePtr paramValue = std::dynamic_pointer_cast<ParamValueLite>(parameter->default_param());
|
|
|
|
ParamValueLitePtr paramValue = std::dynamic_pointer_cast<ParamValueLite>(parameter->default_param());
|
|
|
|
|
|
|
|
if (paramValue == nullptr) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << weight->fullname_with_scope() << " can not get value";
|
|
|
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
|
|
|
}
|
|
|
|
auto status = QuantFilter(paramValue, primitiveT_value, QuantType_PostTraining, quant_max, quant_min, bit_num,
|
|
|
|
auto status = QuantFilter(paramValue, primitiveT_value, QuantType_PostTraining, quant_max, quant_min, bit_num,
|
|
|
|
per_channel_, depthwise);
|
|
|
|
perchanel, depthwise);
|
|
|
|
if (status != RET_OK) {
|
|
|
|
if (status != RET_OK) {
|
|
|
|
MS_LOG(ERROR) << "QuantFilter failed: " << status;
|
|
|
|
MS_LOG(ERROR) << "QuantFilter failed: " << status;
|
|
|
|
return status;
|
|
|
|
return status;
|
|
|
@ -690,11 +698,29 @@ STATUS PostTrainingQuantizer::QuantNode() {
|
|
|
|
auto op_name = cnode->fullname_with_scope();
|
|
|
|
auto op_name = cnode->fullname_with_scope();
|
|
|
|
auto op_type = primitiveT_value->GetPrimitiveT()->value.type;
|
|
|
|
auto op_type = primitiveT_value->GetPrimitiveT()->value.type;
|
|
|
|
MS_LOG(INFO) << "OpName: " << op_name;
|
|
|
|
MS_LOG(INFO) << "OpName: " << op_name;
|
|
|
|
if (op_type != PrimitiveType_Conv2D && op_type != PrimitiveType_DepthwiseConv2D) {
|
|
|
|
if (op_type != PrimitiveType_Conv2D && op_type != PrimitiveType_DepthwiseConv2D &&
|
|
|
|
|
|
|
|
op_type != PrimitiveType_FullConnection) {
|
|
|
|
for (size_t i = 1; i < cnode->inputs().size(); i++) {
|
|
|
|
for (size_t i = 1; i < cnode->inputs().size(); i++) {
|
|
|
|
auto input_node = cnode->input(i);
|
|
|
|
auto input_node = cnode->input(i);
|
|
|
|
if (!input_node->isa<mindspore::CNode>()) {
|
|
|
|
if (!input_node->isa<mindspore::CNode>()) {
|
|
|
|
MS_LOG(WARNING) << "node: " << cnode_name << " input " << i << " not a cnode";
|
|
|
|
MS_LOG(DEBUG) << "node: " << cnode_name << " input " << i << " not a cnode";
|
|
|
|
|
|
|
|
// get dtype
|
|
|
|
|
|
|
|
auto abstractBase = input_node->abstract();
|
|
|
|
|
|
|
|
if (abstractBase == nullptr) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << input_node->fullname_with_scope();
|
|
|
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (!utils::isa<abstract::AbstractTensorPtr>(abstractBase)) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << input_node->fullname_with_scope();
|
|
|
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
|
|
|
|
|
|
|
|
if (abstractTensor->element()->GetTypeTrack()->type_id() == kNumberTypeFloat32) {
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "this parameter do quant";
|
|
|
|
|
|
|
|
DoWeightQuant(input_node, primitiveT_value, false, false);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << "this parameter no need to do quant";
|
|
|
|
|
|
|
|
}
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
auto input_cnode = std::dynamic_pointer_cast<mindspore::CNode>(input_node);
|
|
|
|
auto input_cnode = std::dynamic_pointer_cast<mindspore::CNode>(input_node);
|
|
|
@ -704,8 +730,15 @@ STATUS PostTrainingQuantizer::QuantNode() {
|
|
|
|
<< " PrimitiveTValue is null";
|
|
|
|
<< " PrimitiveTValue is null";
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
for (auto &quant_param : input_cnode_primitiveT_value->GetOutputQuantParams()) {
|
|
|
|
if (!input_cnode_primitiveT_value->GetOutputQuantParams().empty()) {
|
|
|
|
primitiveT_value->AddInputQuantParam(quant_param);
|
|
|
|
for (auto &quant_param : input_cnode_primitiveT_value->GetOutputQuantParams()) {
|
|
|
|
|
|
|
|
primitiveT_value->AddInputQuantParam(quant_param);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
// do input quant
|
|
|
|
|
|
|
|
double scale = input_scale[cnode];
|
|
|
|
|
|
|
|
int32_t zp = input_zero_point[cnode];
|
|
|
|
|
|
|
|
DoQuantInput(scale, zp, &input_min_max[cnode], primitiveT_value);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
@ -715,8 +748,12 @@ STATUS PostTrainingQuantizer::QuantNode() {
|
|
|
|
DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitiveT_value);
|
|
|
|
DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitiveT_value);
|
|
|
|
// do weight quant
|
|
|
|
// do weight quant
|
|
|
|
auto weight = cnode->input(2);
|
|
|
|
auto weight = cnode->input(2);
|
|
|
|
bool depthwise = op_type == PrimitiveType_DeDepthwiseConv2D;
|
|
|
|
bool depthwise = op_type == PrimitiveType_DepthwiseConv2D;
|
|
|
|
DoWeightQuant(weight, primitiveT_value, depthwise);
|
|
|
|
bool perchannel = per_channel_;
|
|
|
|
|
|
|
|
if (op_type == PrimitiveType_FullConnection) {
|
|
|
|
|
|
|
|
perchannel = false;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
DoWeightQuant(weight, primitiveT_value, perchannel, depthwise);
|
|
|
|
// do bias quant
|
|
|
|
// do bias quant
|
|
|
|
if (cnode->inputs().size() == 4) {
|
|
|
|
if (cnode->inputs().size() == 4) {
|
|
|
|
auto bias = cnode->input(3);
|
|
|
|
auto bias = cnode->input(3);
|
|
|
|