diff --git a/mindspore/lite/src/ops/deconv2d.cc b/mindspore/lite/src/ops/deconv2d.cc index 369c23be60..f3b6bfc289 100644 --- a/mindspore/lite/src/ops/deconv2d.cc +++ b/mindspore/lite/src/ops/deconv2d.cc @@ -17,6 +17,13 @@ #include "src/ops/deconv2d.h" #include #include +#include "include/errorcode.h" +#include "src/common/log_adapter.h" +#ifdef PRIMITIVE_WRITEABLE +#include + +#include "tools/converter/quantizer/quantize_util.h" +#endif namespace mindspore { namespace lite { @@ -58,6 +65,121 @@ void DeConv2D::SetHasBias(bool has_bias) { this->primitive_->value.AsDeConv2D()- void DeConv2D::SetActivationType(int activation_type) { this->primitive_->value.AsDeConv2D()->activationType = (schema::ActivationType)activation_type; } +template +void ConvertConvWeight(const ParameterPtr ¶m_node) { + MS_ASSERT(param_node != nullptr); + auto param = param_node->default_param(); + auto weight = std::dynamic_pointer_cast(param); + MS_ASSERT(weight != nullptr); + + std::unique_ptr buf(new (std::nothrow) T[weight->tensor_shape_size()]); + if (buf == nullptr) { + MS_LOG(ERROR) << "new buf failed"; + return; + } + + size_t filter_k = weight->tensor_shape()[0]; + size_t filter_c = weight->tensor_shape()[1]; + size_t filter_h = weight->tensor_shape()[2]; + size_t filter_w = weight->tensor_shape()[3]; + T *p1Buff = nullptr; + T *p2Buff = nullptr; + for (size_t k = 0; k < filter_k; ++k) { + for (size_t c = 0; c < filter_c; ++c) { + for (size_t h = 0; h < filter_h; ++h) { + for (size_t w = 0; w < filter_w; ++w) { + p1Buff = reinterpret_cast(weight->tensor_addr()) + + ((k * filter_c * filter_h * filter_w) + (c * filter_h * filter_w) + (h * filter_w) + (w)); + p2Buff = + buf.get() + ((c * filter_k * filter_h * filter_w) + (k * filter_h * filter_w) + (h * filter_w) + (w)); + *p2Buff = *p1Buff; + } + } + } + } + + auto ret = ::memcpy_s(weight->tensor_addr(), weight->tensor_shape_size() * sizeof(T), buf.get(), + weight->tensor_shape_size() * sizeof(T)); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy_s failed: " << ret; + return; + } + + auto abstract_base = param_node->abstract(); + MS_ASSERT(abstract_base != nullptr); + if (utils::isa(abstract_base)) { + auto abstract_tensor = utils::cast(abstract_base); + utils::cast(abstract_tensor->BuildShape())->shape()[0] = filter_c; + utils::cast(abstract_tensor->BuildShape())->shape()[1] = filter_k; + utils::cast(abstract_tensor->BuildShape())->shape()[2] = filter_h; + utils::cast(abstract_tensor->BuildShape())->shape()[3] = filter_w; + } + return; +} + +void DeConv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group, + const std::vector &inputs) { + auto attr = std::make_unique(); + auto format = GetValue(prim.GetAttr("data_format")); + if (format == "NCHW") { + attr->format = schema::Format::Format_NCHW; + } else if (format == "NHWC") { + attr->format = schema::Format::Format_NHWC; + } else { + attr->format = schema::Format::Format_NUM_OF_FORMAT; + } + auto pad_list = GetValue>(prim.GetAttr("pad_list")); + attr->padUp = pad_list[0]; + attr->padDown = pad_list[1]; + attr->padLeft = pad_list[2]; + attr->padRight = pad_list[3]; + + auto dilation = GetValue>(prim.GetAttr("dilation")); + attr->dilateH = dilation[0]; + attr->dilateW = dilation[1]; + + auto kernel_size = GetValue>(prim.GetAttr("kernel_size")); + attr->kernelH = kernel_size[0]; + attr->kernelW = kernel_size[1]; + + auto stride = GetValue>(prim.GetAttr("stride")); + attr->strideH = stride[0]; + attr->strideW = stride[1]; + + auto pad_mode = GetValue(prim.GetAttr("pad_mode")); + if (pad_mode == "valid") { + attr->padMode = schema::PadMode_VALID; + } else if (pad_mode == "same") { + attr->padMode = schema::PadMode_SAME_UPPER; + } else { + attr->padMode = schema::PadMode_NOTSET; + } + + if (prim.GetAttr("activation_name") != nullptr) { + std::string activate_name = GetValue(prim.GetAttr("activation_name")); + attr->activationType = kActivationTypeMap[activate_name]; + } else { + attr->activationType = schema::ActivationType_NO_ACTIVATION; + } + + int channel_mutiplier = 1; + if (prim.GetAttr("channel_mutiplier") != nullptr) { + channel_mutiplier = GetValue(prim.GetAttr("channel_multiplier")); + } + attr->channelMultiplier = channel_mutiplier; + + MS_ASSERT(inputs.size() == kAnfPopulaterTwo); + auto input_node = inputs[kAnfPopulaterOne]; + MS_ASSERT(input_node != nullptr); + if (input_node->isa()) { + auto param_node = input_node->cast(); + ConvertConvWeight(param_node); + } + + primitive->value.type = schema::PrimitiveType_DeDepthwiseConv2D; + primitive->value.value = attr.release(); +} + void DeConv2D::PopulaterDeConv2DSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group) { auto attr = std::make_unique(); attr->group = group; @@ -125,6 +247,8 @@ int DeConv2D::UnPackAttr(const Primitive &prim, const std::vector &i int group = GetValue(prim.GetAttr("group")); if (group == 1) { PopulaterDeConv2DSingleGroup(prim, this->primitive_, group); + } else if (group > 1) { + PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs); } if (GetQuantType() == schema::QuantType_AwareTraining) { diff --git a/mindspore/lite/src/ops/deconv2d.h b/mindspore/lite/src/ops/deconv2d.h index e11d346e16..77c7a3ced9 100644 --- a/mindspore/lite/src/ops/deconv2d.h +++ b/mindspore/lite/src/ops/deconv2d.h @@ -48,6 +48,8 @@ class DeConv2D : public PrimitiveC { void SetHasBias(bool has_bias); void SetActivationType(int activation_type); void PopulaterDeConv2DSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group); + void PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group, + const std::vector &inputs); int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else diff --git a/mindspore/lite/src/ops/dedepthwise_conv2d.cc b/mindspore/lite/src/ops/dedepthwise_conv2d.cc index c92666aae6..634ff21ba8 100644 --- a/mindspore/lite/src/ops/dedepthwise_conv2d.cc +++ b/mindspore/lite/src/ops/dedepthwise_conv2d.cc @@ -153,7 +153,7 @@ int DeDepthwiseConv2D::InferShape(std::vector inputs_, std::vect out_shape.at(1) = output_h; out_shape.at(2) = output_w; if (GetChannelMultiplier() * input_channel != weight->shape()[0]) { - MS_LOG(ERROR) << "Conv depthwise only support group equals output channel."; + MS_LOG(ERROR) << "Conv dedepthwise only support group equals output channel."; return RET_ERROR; } out_shape.at(3) = weight->shape()[0] * weight->shape()[3]; // in_channel * out_channel diff --git a/mindspore/lite/src/ops/maximum.cc b/mindspore/lite/src/ops/maximum.cc index 39223ee4b5..55224dc589 100644 --- a/mindspore/lite/src/ops/maximum.cc +++ b/mindspore/lite/src/ops/maximum.cc @@ -14,11 +14,45 @@ * limitations under the License. */ +#include "include/errorcode.h" #include "src/ops/maximum.h" +#include "src/common/log_adapter.h" +#ifdef PRIMITIVE_WRITEABLE +#include + +#include "tools/converter/quantizer/quantize_util.h" +#endif namespace mindspore { namespace lite { #ifdef PRIMITIVE_WRITEABLE +int Maximum::UnPackAttr(const Primitive &prim, const std::vector &inputs) { + if (this->primitive_ == nullptr) { + this->primitive_ = new (std::nothrow) schema::PrimitiveT; + if (this->primitive_ == nullptr) { + MS_LOG(ERROR) << "new primitiveT failed"; + return RET_ERROR; + } + this->primitive_->value.type = schema::PrimitiveType_Maximum; + } + if (this->primitive_->value.type != schema::PrimitiveType_Maximum) { + MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type; + return RET_ERROR; + } + if (this->primitive_->value.value == nullptr) { + auto attr = new (std::nothrow) schema::MaximumT(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new primitiveT value failed"; + return RET_ERROR; + } + this->primitive_->value.value = attr; + if (this->primitive_->value.value == nullptr) { + MS_LOG(ERROR) << "primitive value is nullptr"; + return RET_ERROR; + } + } + return RET_OK; +} #else int Maximum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { MS_ASSERT(nullptr != primitive); diff --git a/mindspore/lite/src/ops/maximum.h b/mindspore/lite/src/ops/maximum.h index 6704e8fe61..e8c0ae6eec 100644 --- a/mindspore/lite/src/ops/maximum.h +++ b/mindspore/lite/src/ops/maximum.h @@ -22,6 +22,7 @@ #include #include "src/ops/arithmetic.h" +#include "src/ops/primitive_c.h" namespace mindspore { namespace lite { @@ -31,6 +32,7 @@ class Maximum : public Arithmetic { MS_DECLARE_PARENT(Arithmetic, Arithmetic); Maximum() = default; explicit Maximum(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} + int UnPackAttr(const Primitive &prim, const std::vector &inputs) override; #else Maximum() = default; diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index 135409e2b8..0ce0cccaae 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -423,6 +423,8 @@ std::shared_ptr PrimitiveC::Create(const Primitive &prim, const std: return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Cast") { return NewPrimitiveC(prim, inputs, quantType); + } else if (op_type == "Maximum") { + return NewPrimitiveC(prim, inputs, quantType); } else if (op_type == "Split") { return NewPrimitiveC(prim, inputs, quantType); diff --git a/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc b/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc index 9867565691..8083585a3c 100644 --- a/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc +++ b/mindspore/lite/tools/optimizer/graph/weight_format_hardcode_pass.cc @@ -18,23 +18,19 @@ #include "tools/optimizer/common/gllo_utils.h" using mindspore::lite::converter::FmkType_CAFFE; -using mindspore::lite::converter::FmkType_TFLITE; -using mindspore::lite::converter::FmkType_ONNX; using mindspore::lite::converter::FmkType_MS; -using mindspore::schema::QuantType_WeightQuant; -using mindspore::schema::QuantType_QUANT_NONE; +using mindspore::lite::converter::FmkType_ONNX; +using mindspore::lite::converter::FmkType_TFLITE; using mindspore::schema::QuantType_AwareTraining; using mindspore::schema::QuantType_PostTraining; +using mindspore::schema::QuantType_QUANT_NONE; +using mindspore::schema::QuantType_WeightQuant; namespace mindspore::opt { namespace { constexpr size_t kConvWeightIndex = 2; } // namespace -void WeightFormatHardCodePass::SetQuantType(QuantType type) { - this->quant_type = type; -} -void WeightFormatHardCodePass::SetFmkType(FmkType type) { - this->fmk_type = type; -} +void WeightFormatHardCodePass::SetQuantType(QuantType type) { this->quant_type = type; } +void WeightFormatHardCodePass::SetFmkType(FmkType type) { this->fmk_type = type; } lite::STATUS WeightFormatHardCodePass::HardCodeCAFFE(const AnfNodePtr &conv_node, const ParamValueLitePtr ¶m_value) const { MS_ASSERT(conv_cnode != nullptr); @@ -42,11 +38,12 @@ lite::STATUS WeightFormatHardCodePass::HardCodeCAFFE(const AnfNodePtr &conv_node switch (quant_type) { case schema::QuantType_PostTraining: case QuantType_WeightQuant: - case QuantType_QUANT_NONE:param_value->set_format(schema::Format::Format_KCHW); + case QuantType_QUANT_NONE: + param_value->set_format(schema::Format::Format_KCHW); break; default: { - MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) << ", node: " - << conv_node->fullname_with_scope(); + MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) + << ", node: " << conv_node->fullname_with_scope(); return lite::RET_ERROR; } } @@ -68,12 +65,11 @@ lite::STATUS WeightFormatHardCodePass::HardCodeONNX(const AnfNodePtr &conv_node, } else if (op_type == schema::PrimitiveType_DeConv2D) { param_value->set_format(schema::Format::Format_KCHW); } else { - MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: " - << conv_node->fullname_with_scope(); + MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) + << ", node: " << conv_node->fullname_with_scope(); return lite::RET_ERROR; } - } - break; + } break; case QuantType_PostTraining: case QuantType_WeightQuant: case QuantType_QUANT_NONE: { @@ -81,19 +77,18 @@ lite::STATUS WeightFormatHardCodePass::HardCodeONNX(const AnfNodePtr &conv_node, // depth (K x C/group x kH x kW) group = channelOut ==> (K, multiplier, H, W) // deconv (C x K/group x kH x kW) group = 1 // dedepth (C x K/group x kH x kW) group = channelIn ==> (C, multiplier, H, W) - if (op_type == schema::PrimitiveType_Conv2D || op_type == schema::PrimitiveType_DepthwiseConv2D - || op_type == schema::PrimitiveType_DeConv2D) { + if (op_type == schema::PrimitiveType_Conv2D || op_type == schema::PrimitiveType_DepthwiseConv2D || + op_type == schema::PrimitiveType_DeConv2D) { param_value->set_format(schema::Format::Format_KCHW); } else { - MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: " - << conv_node->fullname_with_scope(); + MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) + << ", node: " << conv_node->fullname_with_scope(); return lite::RET_ERROR; } - } - break; + } break; default: { - MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) << ", node: " - << conv_node->fullname_with_scope(); + MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) + << ", node: " << conv_node->fullname_with_scope(); return lite::RET_ERROR; } } @@ -114,8 +109,7 @@ lite::STATUS WeightFormatHardCodePass::HardCodeMS(const AnfNodePtr &conv_node, } else { param_value->set_format(schema::Format::Format_KCHW); } - } - break; + } break; case QuantType_PostTraining: case QuantType_WeightQuant: case QuantType_QUANT_NONE: { @@ -124,18 +118,19 @@ lite::STATUS WeightFormatHardCodePass::HardCodeMS(const AnfNodePtr &conv_node, param_value->set_format(schema::Format::Format_KCHW); } else if (op_type == schema::PrimitiveType_DepthwiseConv2D) { param_value->set_format(schema::Format::Format_CKHW); + } else if (op_type == schema::PrimitiveType_DeDepthwiseConv2D) { + param_value->set_format(schema::Format::Format_CKHW); } else if (op_type == schema::PrimitiveType_DeConv2D) { param_value->set_format(schema::Format::Format_KCHW); } else { - MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: " - << conv_node->fullname_with_scope(); + MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) + << ", node: " << conv_node->fullname_with_scope(); return lite::RET_ERROR; } - } - break; + } break; default: { - MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) << ", node: " - << conv_node->fullname_with_scope(); + MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) + << ", node: " << conv_node->fullname_with_scope(); return lite::RET_ERROR; } } @@ -159,15 +154,14 @@ lite::STATUS WeightFormatHardCodePass::HardCodeTFLITE(const AnfNodePtr &conv_nod } else if (op_type == schema::PrimitiveType_DeConv2D) { param_value->set_format(schema::Format::Format_CHWK); } else { - MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: " - << conv_node->fullname_with_scope(); + MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) + << ", node: " << conv_node->fullname_with_scope(); return lite::RET_ERROR; } - } - break; + } break; default: { - MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: " - << conv_node->fullname_with_scope(); + MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) + << ", node: " << conv_node->fullname_with_scope(); return lite::RET_ERROR; } } @@ -183,8 +177,8 @@ bool WeightFormatHardCodePass::Run(const FuncGraphPtr &graph) { } auto conv_cnode = node->cast(); auto type = opt::GetCNodeType(node); - if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D - && type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) { + if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D && + type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) { continue; } MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex); @@ -197,15 +191,20 @@ bool WeightFormatHardCodePass::Run(const FuncGraphPtr &graph) { } lite::STATUS status; switch (fmk_type) { - case FmkType_CAFFE:status = HardCodeCAFFE(node, param_value); + case FmkType_CAFFE: + status = HardCodeCAFFE(node, param_value); break; - case FmkType_TFLITE:status = HardCodeTFLITE(node, param_value); + case FmkType_TFLITE: + status = HardCodeTFLITE(node, param_value); break; - case FmkType_ONNX:status = HardCodeONNX(node, param_value); + case FmkType_ONNX: + status = HardCodeONNX(node, param_value); break; - case FmkType_MS:status = HardCodeMS(node, param_value); + case FmkType_MS: + status = HardCodeMS(node, param_value); break; - default:MS_LOG(ERROR) << "Unsupported fmkType: " << fmk_type << ", node: " << node->fullname_with_scope(); + default: + MS_LOG(ERROR) << "Unsupported fmkType: " << fmk_type << ", node: " << node->fullname_with_scope(); return false; } if (status != lite::RET_OK) {