dedepthwise_conv2d adapter

pull/6899/head
guohongzilong 4 years ago
parent a95ed7e8e0
commit a0e15592e1

@ -17,6 +17,13 @@
#include "src/ops/deconv2d.h"
#include <memory>
#include <string>
#include "include/errorcode.h"
#include "src/common/log_adapter.h"
#ifdef PRIMITIVE_WRITEABLE
#include <float.h>
#include "tools/converter/quantizer/quantize_util.h"
#endif
namespace mindspore {
namespace lite {
@ -58,6 +65,121 @@ void DeConv2D::SetHasBias(bool has_bias) { this->primitive_->value.AsDeConv2D()-
void DeConv2D::SetActivationType(int activation_type) {
this->primitive_->value.AsDeConv2D()->activationType = (schema::ActivationType)activation_type;
}
template <typename T>
void ConvertConvWeight(const ParameterPtr &param_node) {
MS_ASSERT(param_node != nullptr);
auto param = param_node->default_param();
auto weight = std::dynamic_pointer_cast<ParamValueLite>(param);
MS_ASSERT(weight != nullptr);
std::unique_ptr<T> buf(new (std::nothrow) T[weight->tensor_shape_size()]);
if (buf == nullptr) {
MS_LOG(ERROR) << "new buf failed";
return;
}
size_t filter_k = weight->tensor_shape()[0];
size_t filter_c = weight->tensor_shape()[1];
size_t filter_h = weight->tensor_shape()[2];
size_t filter_w = weight->tensor_shape()[3];
T *p1Buff = nullptr;
T *p2Buff = nullptr;
for (size_t k = 0; k < filter_k; ++k) {
for (size_t c = 0; c < filter_c; ++c) {
for (size_t h = 0; h < filter_h; ++h) {
for (size_t w = 0; w < filter_w; ++w) {
p1Buff = reinterpret_cast<float *>(weight->tensor_addr()) +
((k * filter_c * filter_h * filter_w) + (c * filter_h * filter_w) + (h * filter_w) + (w));
p2Buff =
buf.get() + ((c * filter_k * filter_h * filter_w) + (k * filter_h * filter_w) + (h * filter_w) + (w));
*p2Buff = *p1Buff;
}
}
}
}
auto ret = ::memcpy_s(weight->tensor_addr(), weight->tensor_shape_size() * sizeof(T), buf.get(),
weight->tensor_shape_size() * sizeof(T));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy_s failed: " << ret;
return;
}
auto abstract_base = param_node->abstract();
MS_ASSERT(abstract_base != nullptr);
if (utils::isa<abstract::AbstractTensorPtr>(abstract_base)) {
auto abstract_tensor = utils::cast<abstract::AbstractTensorPtr>(abstract_base);
utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape()[0] = filter_c;
utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape()[1] = filter_k;
utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape()[2] = filter_h;
utils::cast<abstract::ShapePtr>(abstract_tensor->BuildShape())->shape()[3] = filter_w;
}
return;
}
void DeConv2D::PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group,
const std::vector<AnfNodePtr> &inputs) {
auto attr = std::make_unique<schema::DeDepthwiseConv2DT>();
auto format = GetValue<std::string>(prim.GetAttr("data_format"));
if (format == "NCHW") {
attr->format = schema::Format::Format_NCHW;
} else if (format == "NHWC") {
attr->format = schema::Format::Format_NHWC;
} else {
attr->format = schema::Format::Format_NUM_OF_FORMAT;
}
auto pad_list = GetValue<std::vector<int>>(prim.GetAttr("pad_list"));
attr->padUp = pad_list[0];
attr->padDown = pad_list[1];
attr->padLeft = pad_list[2];
attr->padRight = pad_list[3];
auto dilation = GetValue<std::vector<int>>(prim.GetAttr("dilation"));
attr->dilateH = dilation[0];
attr->dilateW = dilation[1];
auto kernel_size = GetValue<std::vector<int>>(prim.GetAttr("kernel_size"));
attr->kernelH = kernel_size[0];
attr->kernelW = kernel_size[1];
auto stride = GetValue<std::vector<int>>(prim.GetAttr("stride"));
attr->strideH = stride[0];
attr->strideW = stride[1];
auto pad_mode = GetValue<std::string>(prim.GetAttr("pad_mode"));
if (pad_mode == "valid") {
attr->padMode = schema::PadMode_VALID;
} else if (pad_mode == "same") {
attr->padMode = schema::PadMode_SAME_UPPER;
} else {
attr->padMode = schema::PadMode_NOTSET;
}
if (prim.GetAttr("activation_name") != nullptr) {
std::string activate_name = GetValue<std::string>(prim.GetAttr("activation_name"));
attr->activationType = kActivationTypeMap[activate_name];
} else {
attr->activationType = schema::ActivationType_NO_ACTIVATION;
}
int channel_mutiplier = 1;
if (prim.GetAttr("channel_mutiplier") != nullptr) {
channel_mutiplier = GetValue<int>(prim.GetAttr("channel_multiplier"));
}
attr->channelMultiplier = channel_mutiplier;
MS_ASSERT(inputs.size() == kAnfPopulaterTwo);
auto input_node = inputs[kAnfPopulaterOne];
MS_ASSERT(input_node != nullptr);
if (input_node->isa<Parameter>()) {
auto param_node = input_node->cast<ParameterPtr>();
ConvertConvWeight<float>(param_node);
}
primitive->value.type = schema::PrimitiveType_DeDepthwiseConv2D;
primitive->value.value = attr.release();
}
void DeConv2D::PopulaterDeConv2DSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group) {
auto attr = std::make_unique<schema::DeConv2DT>();
attr->group = group;
@ -125,6 +247,8 @@ int DeConv2D::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &i
int group = GetValue<int>(prim.GetAttr("group"));
if (group == 1) {
PopulaterDeConv2DSingleGroup(prim, this->primitive_, group);
} else if (group > 1) {
PopulaterConv2DMultiGroup(prim, this->primitive_, group, inputs);
}
if (GetQuantType() == schema::QuantType_AwareTraining) {

@ -48,6 +48,8 @@ class DeConv2D : public PrimitiveC {
void SetHasBias(bool has_bias);
void SetActivationType(int activation_type);
void PopulaterDeConv2DSingleGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group);
void PopulaterConv2DMultiGroup(const Primitive &prim, schema::PrimitiveT *primitive, const int &group,
const std::vector<AnfNodePtr> &inputs);
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else

@ -153,7 +153,7 @@ int DeDepthwiseConv2D::InferShape(std::vector<lite::Tensor *> inputs_, std::vect
out_shape.at(1) = output_h;
out_shape.at(2) = output_w;
if (GetChannelMultiplier() * input_channel != weight->shape()[0]) {
MS_LOG(ERROR) << "Conv depthwise only support group equals output channel.";
MS_LOG(ERROR) << "Conv dedepthwise only support group equals output channel.";
return RET_ERROR;
}
out_shape.at(3) = weight->shape()[0] * weight->shape()[3]; // in_channel * out_channel

@ -14,11 +14,45 @@
* limitations under the License.
*/
#include "include/errorcode.h"
#include "src/ops/maximum.h"
#include "src/common/log_adapter.h"
#ifdef PRIMITIVE_WRITEABLE
#include <float.h>
#include "tools/converter/quantizer/quantize_util.h"
#endif
namespace mindspore {
namespace lite {
#ifdef PRIMITIVE_WRITEABLE
int Maximum::UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) {
if (this->primitive_ == nullptr) {
this->primitive_ = new (std::nothrow) schema::PrimitiveT;
if (this->primitive_ == nullptr) {
MS_LOG(ERROR) << "new primitiveT failed";
return RET_ERROR;
}
this->primitive_->value.type = schema::PrimitiveType_Maximum;
}
if (this->primitive_->value.type != schema::PrimitiveType_Maximum) {
MS_LOG(ERROR) << "Primitive type is error :" << this->primitive_->value.type;
return RET_ERROR;
}
if (this->primitive_->value.value == nullptr) {
auto attr = new (std::nothrow) schema::MaximumT();
if (attr == nullptr) {
MS_LOG(ERROR) << "new primitiveT value failed";
return RET_ERROR;
}
this->primitive_->value.value = attr;
if (this->primitive_->value.value == nullptr) {
MS_LOG(ERROR) << "primitive value is nullptr";
return RET_ERROR;
}
}
return RET_OK;
}
#else
int Maximum::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) {
MS_ASSERT(nullptr != primitive);

@ -22,6 +22,7 @@
#include <cmath>
#include "src/ops/arithmetic.h"
#include "src/ops/primitive_c.h"
namespace mindspore {
namespace lite {
@ -31,6 +32,7 @@ class Maximum : public Arithmetic {
MS_DECLARE_PARENT(Arithmetic, Arithmetic);
Maximum() = default;
explicit Maximum(schema::PrimitiveT *primitive) : Arithmetic(primitive) {}
int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override;
#else
Maximum() = default;

@ -423,6 +423,8 @@ std::shared_ptr<PrimitiveC> PrimitiveC::Create(const Primitive &prim, const std:
return NewPrimitiveC<StridedSlice>(prim, inputs, quantType);
} else if (op_type == "Cast") {
return NewPrimitiveC<Cast>(prim, inputs, quantType);
} else if (op_type == "Maximum") {
return NewPrimitiveC<Maximum>(prim, inputs, quantType);
} else if (op_type == "Split") {
return NewPrimitiveC<Split>(prim, inputs, quantType);

@ -18,23 +18,19 @@
#include "tools/optimizer/common/gllo_utils.h"
using mindspore::lite::converter::FmkType_CAFFE;
using mindspore::lite::converter::FmkType_TFLITE;
using mindspore::lite::converter::FmkType_ONNX;
using mindspore::lite::converter::FmkType_MS;
using mindspore::schema::QuantType_WeightQuant;
using mindspore::schema::QuantType_QUANT_NONE;
using mindspore::lite::converter::FmkType_ONNX;
using mindspore::lite::converter::FmkType_TFLITE;
using mindspore::schema::QuantType_AwareTraining;
using mindspore::schema::QuantType_PostTraining;
using mindspore::schema::QuantType_QUANT_NONE;
using mindspore::schema::QuantType_WeightQuant;
namespace mindspore::opt {
namespace {
constexpr size_t kConvWeightIndex = 2;
} // namespace
void WeightFormatHardCodePass::SetQuantType(QuantType type) {
this->quant_type = type;
}
void WeightFormatHardCodePass::SetFmkType(FmkType type) {
this->fmk_type = type;
}
void WeightFormatHardCodePass::SetQuantType(QuantType type) { this->quant_type = type; }
void WeightFormatHardCodePass::SetFmkType(FmkType type) { this->fmk_type = type; }
lite::STATUS WeightFormatHardCodePass::HardCodeCAFFE(const AnfNodePtr &conv_node,
const ParamValueLitePtr &param_value) const {
MS_ASSERT(conv_cnode != nullptr);
@ -42,11 +38,12 @@ lite::STATUS WeightFormatHardCodePass::HardCodeCAFFE(const AnfNodePtr &conv_node
switch (quant_type) {
case schema::QuantType_PostTraining:
case QuantType_WeightQuant:
case QuantType_QUANT_NONE:param_value->set_format(schema::Format::Format_KCHW);
case QuantType_QUANT_NONE:
param_value->set_format(schema::Format::Format_KCHW);
break;
default: {
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) << ", node: "
<< conv_node->fullname_with_scope();
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type)
<< ", node: " << conv_node->fullname_with_scope();
return lite::RET_ERROR;
}
}
@ -68,12 +65,11 @@ lite::STATUS WeightFormatHardCodePass::HardCodeONNX(const AnfNodePtr &conv_node,
} else if (op_type == schema::PrimitiveType_DeConv2D) {
param_value->set_format(schema::Format::Format_KCHW);
} else {
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: "
<< conv_node->fullname_with_scope();
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type)
<< ", node: " << conv_node->fullname_with_scope();
return lite::RET_ERROR;
}
}
break;
} break;
case QuantType_PostTraining:
case QuantType_WeightQuant:
case QuantType_QUANT_NONE: {
@ -81,19 +77,18 @@ lite::STATUS WeightFormatHardCodePass::HardCodeONNX(const AnfNodePtr &conv_node,
// depth (K x C/group x kH x kW) group = channelOut ==> (K, multiplier, H, W)
// deconv (C x K/group x kH x kW) group = 1
// dedepth (C x K/group x kH x kW) group = channelIn ==> (C, multiplier, H, W)
if (op_type == schema::PrimitiveType_Conv2D || op_type == schema::PrimitiveType_DepthwiseConv2D
|| op_type == schema::PrimitiveType_DeConv2D) {
if (op_type == schema::PrimitiveType_Conv2D || op_type == schema::PrimitiveType_DepthwiseConv2D ||
op_type == schema::PrimitiveType_DeConv2D) {
param_value->set_format(schema::Format::Format_KCHW);
} else {
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: "
<< conv_node->fullname_with_scope();
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type)
<< ", node: " << conv_node->fullname_with_scope();
return lite::RET_ERROR;
}
}
break;
} break;
default: {
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) << ", node: "
<< conv_node->fullname_with_scope();
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type)
<< ", node: " << conv_node->fullname_with_scope();
return lite::RET_ERROR;
}
}
@ -114,8 +109,7 @@ lite::STATUS WeightFormatHardCodePass::HardCodeMS(const AnfNodePtr &conv_node,
} else {
param_value->set_format(schema::Format::Format_KCHW);
}
}
break;
} break;
case QuantType_PostTraining:
case QuantType_WeightQuant:
case QuantType_QUANT_NONE: {
@ -124,18 +118,19 @@ lite::STATUS WeightFormatHardCodePass::HardCodeMS(const AnfNodePtr &conv_node,
param_value->set_format(schema::Format::Format_KCHW);
} else if (op_type == schema::PrimitiveType_DepthwiseConv2D) {
param_value->set_format(schema::Format::Format_CKHW);
} else if (op_type == schema::PrimitiveType_DeDepthwiseConv2D) {
param_value->set_format(schema::Format::Format_CKHW);
} else if (op_type == schema::PrimitiveType_DeConv2D) {
param_value->set_format(schema::Format::Format_KCHW);
} else {
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: "
<< conv_node->fullname_with_scope();
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type)
<< ", node: " << conv_node->fullname_with_scope();
return lite::RET_ERROR;
}
}
break;
} break;
default: {
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type) << ", node: "
<< conv_node->fullname_with_scope();
MS_LOG(ERROR) << "Unsupported quantType: " << EnumNameQuantType(quant_type)
<< ", node: " << conv_node->fullname_with_scope();
return lite::RET_ERROR;
}
}
@ -159,15 +154,14 @@ lite::STATUS WeightFormatHardCodePass::HardCodeTFLITE(const AnfNodePtr &conv_nod
} else if (op_type == schema::PrimitiveType_DeConv2D) {
param_value->set_format(schema::Format::Format_CHWK);
} else {
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: "
<< conv_node->fullname_with_scope();
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type)
<< ", node: " << conv_node->fullname_with_scope();
return lite::RET_ERROR;
}
}
break;
} break;
default: {
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type) << ", node: "
<< conv_node->fullname_with_scope();
MS_LOG(ERROR) << "Unsupported opType: " << EnumNamePrimitiveType(op_type)
<< ", node: " << conv_node->fullname_with_scope();
return lite::RET_ERROR;
}
}
@ -183,8 +177,8 @@ bool WeightFormatHardCodePass::Run(const FuncGraphPtr &graph) {
}
auto conv_cnode = node->cast<CNodePtr>();
auto type = opt::GetCNodeType(node);
if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D
&& type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) {
if (type != schema::PrimitiveType_Conv2D && type != schema::PrimitiveType_DepthwiseConv2D &&
type != schema::PrimitiveType_DeConv2D && type != schema::PrimitiveType_DeDepthwiseConv2D) {
continue;
}
MS_ASSERT(conv_cnode->inputs().size() > kConvWeightIndex);
@ -197,15 +191,20 @@ bool WeightFormatHardCodePass::Run(const FuncGraphPtr &graph) {
}
lite::STATUS status;
switch (fmk_type) {
case FmkType_CAFFE:status = HardCodeCAFFE(node, param_value);
case FmkType_CAFFE:
status = HardCodeCAFFE(node, param_value);
break;
case FmkType_TFLITE:status = HardCodeTFLITE(node, param_value);
case FmkType_TFLITE:
status = HardCodeTFLITE(node, param_value);
break;
case FmkType_ONNX:status = HardCodeONNX(node, param_value);
case FmkType_ONNX:
status = HardCodeONNX(node, param_value);
break;
case FmkType_MS:status = HardCodeMS(node, param_value);
case FmkType_MS:
status = HardCodeMS(node, param_value);
break;
default:MS_LOG(ERROR) << "Unsupported fmkType: " << fmk_type << ", node: " << node->fullname_with_scope();
default:
MS_LOG(ERROR) << "Unsupported fmkType: " << fmk_type << ", node: " << node->fullname_with_scope();
return false;
}
if (status != lite::RET_OK) {

Loading…
Cancel
Save