|
|
|
@ -29,7 +29,7 @@
|
|
|
|
|
namespace mindspore::lite {
|
|
|
|
|
void AnfConvPopulater::PopulaterConv2DMultiGroup(const PrimitivePtr &prim,
|
|
|
|
|
const std::unique_ptr<schema::PrimitiveT> &primitive,
|
|
|
|
|
const int &group) {
|
|
|
|
|
const int &group, const std::vector<AnfNodePtr> &inputs) {
|
|
|
|
|
auto attr = std::make_unique<schema::DepthwiseConv2DT>();
|
|
|
|
|
auto format = GetValue<std::string>(prim->GetAttr("data_format"));
|
|
|
|
|
if (format == "NCHW") {
|
|
|
|
@ -66,6 +66,28 @@ void AnfConvPopulater::PopulaterConv2DMultiGroup(const PrimitivePtr &prim,
|
|
|
|
|
attr->padMode = schema::PadMode_NOTSET;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int channel_mutiplier = 1;
|
|
|
|
|
if (prim->GetAttr("channel_mutiplier") != nullptr) {
|
|
|
|
|
channel_mutiplier = GetValue<int>(prim->GetAttr("channel_multiplier"));
|
|
|
|
|
}
|
|
|
|
|
attr->channelMultiplier = channel_mutiplier;
|
|
|
|
|
|
|
|
|
|
MS_ASSERT(inputs.size() == kAnfPopulaterTwo);
|
|
|
|
|
auto inputNode = inputs[kAnfPopulaterOne];
|
|
|
|
|
MS_ASSERT(inputNode != nullptr);
|
|
|
|
|
if (inputNode->isa<Parameter>()) {
|
|
|
|
|
auto paramNode = inputNode->cast<ParameterPtr>();
|
|
|
|
|
auto abstractBase = paramNode->abstract();
|
|
|
|
|
MS_ASSERT(abstractBase != nullptr);
|
|
|
|
|
if (utils::isa<abstract::AbstractTensorPtr>(abstractBase)) {
|
|
|
|
|
auto abstractTensor = utils::cast<abstract::AbstractTensorPtr>(abstractBase);
|
|
|
|
|
MS_ASSERT(abstractTensor != nullptr);
|
|
|
|
|
if (abstractTensor->format() == schema::Format_NCHW) {
|
|
|
|
|
abstractTensor->set_format(schema::Format_KCHW);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
primitive->value.type = schema::PrimitiveType_DepthwiseConv2D;
|
|
|
|
|
primitive->value.value = attr.release();
|
|
|
|
|
}
|
|
|
|
@ -214,7 +236,7 @@ int AnfConvPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primit
|
|
|
|
|
|
|
|
|
|
int group = GetValue<int>(prim->GetAttr("group"));
|
|
|
|
|
if (group > 1) {
|
|
|
|
|
PopulaterConv2DMultiGroup(prim, primitive, group);
|
|
|
|
|
PopulaterConv2DMultiGroup(prim, primitive, group, inputs);
|
|
|
|
|
} else {
|
|
|
|
|
PopulaterConv2DSingleGroup(prim, primitive, group);
|
|
|
|
|
}
|
|
|
|
|