!11521 fix packedOp implements && gather infershape

From: @xutianchun
Reviewed-by: @hangangqiang,@HilbertDavid
Signed-off-by: @hangangqiang
pull/11521/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 794ab0dfcf

@ -81,5 +81,13 @@ std::vector<size_t> GetLinkedPostNodeIdx(const lite::Model *model, const size_t
}
return post_node_idxes;
}
bool IsPackedOp(schema::PrimitiveType op_type) {
static std::vector<schema::PrimitiveType> packed_ops = {
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D,
schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeDepthwiseConv2D,
schema::PrimitiveType_MatMul, schema::PrimitiveType_Lstm};
return IsContain(packed_ops, op_type);
}
} // namespace lite
} // namespace mindspore

@ -36,6 +36,8 @@ std::vector<size_t> GetGraphInputNodes(const lite::Model *model);
std::vector<size_t> GetGraphOutputNodes(const lite::Model *model);
std::vector<size_t> GetLinkedPostNodeIdx(const lite::Model *model, size_t tensor_idx);
bool IsPackedOp(schema::PrimitiveType op_type);
} // namespace lite
} // namespace mindspore

@ -49,7 +49,7 @@ static bool WeightTensorNeedCopy(const lite::Model *model, const uint32_t tensor
return std::none_of(post_node_idxes.begin(), post_node_idxes.end(), [&](const size_t &post_node_idx) {
auto node = model->all_nodes_[post_node_idx];
MS_ASSERT(node != nullptr);
return IsContain(packed_op, static_cast<schema::PrimitiveType>(node->primitive_->Type()));
return IsPackedOp(static_cast<schema::PrimitiveType>(node->primitive_->Type()));
});
}

@ -112,6 +112,9 @@ int Gather::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp
auto output = outputs_.front();
MS_ASSERT(input != nullptr);
output->set_data_type(input->data_type());
if (this->quant_type() == schema::QuantType_WeightQuant) {
output->set_data_type(kNumberTypeFloat32);
}
output->set_format(input->format());
if (!infer_flag()) {
return RET_INFER_INVALID;

@ -188,7 +188,7 @@ kernel::LiteKernel *Scheduler::FindBackendKernel(const std::vector<Tensor *> &in
if (primitive->quant_type() == schema::QuantType_WeightQuant) {
data_type = kNumberTypeFloat32;
}
if (!IsContain(packed_op, (schema::PrimitiveType)primitive->Type())) {
if (!IsPackedOp((schema::PrimitiveType)primitive->Type())) {
need_restore = false;
}
kernel::KernelKey desc{kCPU, data_type, static_cast<schema::PrimitiveType>(primitive->Type())};

@ -26,12 +26,6 @@
#include "src/ops/primitive_c.h"
namespace mindspore::lite {
static std::vector<schema::PrimitiveType> packed_op = {
schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D,
schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeDepthwiseConv2D,
schema::PrimitiveType_MatMul, schema::PrimitiveType_Lstm};
class Scheduler {
public:
Scheduler(const InnerContext *ctx, Model *src_model, std::vector<Tensor *> *src_tensors)

Loading…
Cancel
Save