|
|
|
@ -49,13 +49,13 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) {
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto inputNode = cnode->input(2);
|
|
|
|
|
if (!inputNode->isa<Parameter>()) {
|
|
|
|
|
auto input_node = cnode->input(2);
|
|
|
|
|
if (!input_node->isa<Parameter>()) {
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto paramNode = inputNode->cast<ParameterPtr>();
|
|
|
|
|
if (!paramNode->has_default()) {
|
|
|
|
|
auto param_node = input_node->cast<ParameterPtr>();
|
|
|
|
|
if (!param_node->has_default()) {
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -65,14 +65,26 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) {
|
|
|
|
|
auto op_type = (schema::PrimitiveType)primitive_c->Type();
|
|
|
|
|
bool depthwise = op_type == schema::PrimitiveType_DepthwiseConv2D ? true : false;
|
|
|
|
|
|
|
|
|
|
ParamValueLitePtr param_value = std::static_pointer_cast<ParamValueLite>(paramNode->default_param());
|
|
|
|
|
ParamValueLitePtr param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param());
|
|
|
|
|
auto status = QuantFilter<uint8_t>(param_value, primitive_c, QuantType_WeightQuant, 255, 0,
|
|
|
|
|
bitNum, true, depthwise);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
// set dtype
|
|
|
|
|
param_value->set_tensor_type(kNumberTypeUInt8);
|
|
|
|
|
auto abstractBase = param_node->abstract();
|
|
|
|
|
if (abstractBase == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (!utils::isa<abstract::AbstractTensorPtr>(abstractBase)) {
|
|
|
|
|
MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name();
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
|
|
|
|
|
abstractTensor->element()->set_type(TypeIdToType(kNumberTypeUInt8));
|
|
|
|
|
primitive_c->SetQuantType(schema::QuantType_WeightQuant);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -86,14 +98,14 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ParamValueLitePtr param_value = nullptr;
|
|
|
|
|
ParameterPtr param_node = nullptr;
|
|
|
|
|
for (size_t i = 1; i < node->size(); i++) {
|
|
|
|
|
auto inputNode = node->input(i);
|
|
|
|
|
if (inputNode->isa<Parameter>() == true) {
|
|
|
|
|
auto paramNode = inputNode->cast<ParameterPtr>();
|
|
|
|
|
if ((paramNode != nullptr) && (paramNode->has_default() == true)) {
|
|
|
|
|
param_value = std::static_pointer_cast<ParamValueLite>(paramNode->default_param());
|
|
|
|
|
param_node = inputNode->cast<ParameterPtr>();
|
|
|
|
|
if ((param_node != nullptr) && (param_node->has_default() == true)) {
|
|
|
|
|
param_value = std::static_pointer_cast<ParamValueLite>(param_node->default_param());
|
|
|
|
|
if ((param_value == nullptr) || (param_value->tensor_size() == 0)
|
|
|
|
|
|| (param_value->tensor_shape().size() != 4)
|
|
|
|
|
|| (param_value->tensor_addr() == nullptr)
|
|
|
|
|
|| (param_value->tensor_type() != mindspore::kNumberTypeFloat32)) {
|
|
|
|
|
param_value = nullptr;
|
|
|
|
@ -115,12 +127,26 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) {
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<schema::QuantParamT> quant_params;
|
|
|
|
|
primitive_c->AddInputQuantParam(quant_params);
|
|
|
|
|
auto status = QuantFilter<uint8_t>(param_value, primitive_c, QuantType_WeightQuant, 255, 0, bitNum, true, false);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
param_value->set_tensor_type(kNumberTypeUInt8);
|
|
|
|
|
// set dtype
|
|
|
|
|
auto abstractBase = param_node->abstract();
|
|
|
|
|
if (abstractBase == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "Abstract of parameter is nullptr, " << param_node->name();
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (!utils::isa<abstract::AbstractTensorPtr>(abstractBase)) {
|
|
|
|
|
MS_LOG(ERROR) << "Abstract of parameter should be anstract tensor, " << param_node->name();
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
|
|
|
|
|
abstractTensor->element()->set_type(TypeIdToType(kNumberTypeUInt8));
|
|
|
|
|
primitive_c->SetQuantType(schema::QuantType_WeightQuant);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|