!10919 debug tts encoder

From: @wangzhe128
Reviewed-by: 
Signed-off-by:
pull/10919/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 0e047cbbeb

@ -22,3 +22,10 @@ int Fill(float *output, int size, float data) {
}
return NNACL_OK;
}
int FillInt32(int *output, int size, int data) {
for (int i = 0; i < size; ++i) {
output[i] = data;
}
return NNACL_OK;
}

@ -35,6 +35,8 @@ typedef struct FillParameter {
extern "C" {
#endif
int Fill(float *output, int size, float data);
int FillInt32(int *output, int size, int data);
#ifdef __cplusplus
}
#endif

@ -56,26 +56,6 @@ PrimitiveC *FillCreator(const schema::Primitive *primitive) { return PrimitiveC:
Registry FillRegistry(schema::PrimitiveType_Fill, FillCreator);
#endif
template <typename T>
void CalShape(const T *data, const std::vector<Tensor *> &inputs, std::vector<int> *out_shape, int shape_size) {
int input_count = inputs[0]->ElementsNum();
int index = 0;
int size = 1;
for (int i = 0; i < shape_size; i++) {
if (static_cast<int>(data[i]) == -1) {
index = i;
} else if (static_cast<int>(data[i]) == 0) {
size *= inputs[0]->shape().at(i);
} else {
size *= data[i];
}
out_shape->push_back(data[i]);
}
if (static_cast<int>(data[index]) == -1) {
(*out_shape).at(index) = input_count / size;
}
}
int Fill::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) {
MS_ASSERT(this->primitive_ != nullptr);
auto input = inputs_.front();
@ -94,54 +74,23 @@ int Fill::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> output
return RET_INFER_INVALID;
}
std::vector<int> out_shape;
std::vector<int> output_shape;
auto param_dims = GetDims();
for (size_t i = 0; i < param_dims.size(); i++) {
output_shape.push_back(param_dims.at(i));
}
if (inputs_.size() == kDoubleNum) {
auto shape_tensor = inputs_.at(1);
if (shape_tensor->IsConst()) {
if (shape_tensor->data_c() == nullptr || (shape_tensor->shape().size() == 1 && shape_tensor->shape()[0] == 0)) {
MS_LOG(DEBUG) << "reshape to a scalar.";
output->set_shape(out_shape);
return RET_OK;
}
}
if (shape_tensor->data_c() == nullptr) {
MS_LOG(INFO) << "Do infer shape in runtime.";
auto input_dims = inputs_.at(1);
MS_ASSERT(input_dims != nullptr);
if (input_dims->data_c() == nullptr) {
return RET_INFER_INVALID;
}
size_t shape_size = shape_tensor->ElementsNum();
switch (shape_tensor->data_type()) {
case kNumberTypeInt8: {
auto data = reinterpret_cast<int8_t *>(shape_tensor->MutableData());
CalShape<int8_t>(data, inputs_, &out_shape, shape_size);
} break;
case kNumberTypeInt32: {
auto data = reinterpret_cast<int32_t *>(shape_tensor->MutableData());
CalShape<int32_t>(data, inputs_, &out_shape, shape_size);
} break;
case kNumberTypeInt64: {
auto data = reinterpret_cast<int64_t *>(shape_tensor->MutableData());
CalShape<int64_t>(data, inputs_, &out_shape, shape_size);
} break;
case kNumberTypeFloat: {
auto data = reinterpret_cast<float *>(shape_tensor->MutableData());
CalShape<float>(data, inputs_, &out_shape, shape_size);
} break;
case kNumberTypeUInt32: {
auto data = reinterpret_cast<uint32_t *>(shape_tensor->MutableData());
CalShape<uint32_t>(data, inputs_, &out_shape, shape_size);
} break;
default: {
MS_LOG(ERROR) << "Reshape weight tensor has unsupported dataType: " << shape_tensor->data_type();
return RET_INFER_ERR;
}
}
} else {
for (size_t i = 0; i < GetDims().size(); i++) {
out_shape.push_back(GetDims().at(i));
}
int *dims_data = reinterpret_cast<int *>(input_dims->data_c());
output_shape = std::vector<int>{dims_data, dims_data + input_dims->ElementsNum()};
}
output->set_shape(out_shape);
output->set_shape(output_shape);
return RET_OK;
}
} // namespace lite

@ -116,6 +116,15 @@ int Transpose::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> o
MS_ASSERT(output != nullptr);
std::vector<int> perm = GetPerm();
if (inputs_.size() == kDoubleNum) {
auto input_perm = inputs_.at(1);
MS_ASSERT(input_perm != nullptr);
if (input_perm->data_c() == nullptr) {
return RET_INFER_INVALID;
}
int *perm_data = reinterpret_cast<int *>(input_perm->data_c());
perm = std::vector<int>{perm_data, perm_data + input_perm->ElementsNum()};
}
std::vector<int> nchw2nhwc_perm = {0, 2, 3, 1};
std::vector<int> nhwc2nchw_perm = {0, 3, 1, 2};
std::vector<int> in_shape = input->shape();

@ -48,7 +48,15 @@ int FillCPUKernel::DoFill(int task_id) {
return RET_OK;
}
int offset = task_id * thread_sz_stride_;
int ret = Fill(out_ptr_ + offset, size, src_data_);
auto input_tensor = in_tensors_.at(0);
int ret = RET_OK;
if (input_tensor->data_type() == kNumberTypeFloat32 || input_tensor->data_type() == kNumberTypeFloat) {
ret = Fill(out_ptr_ + offset, size, src_data_);
} else if (input_tensor->data_type() == kNumberTypeInt32 || input_tensor->data_type() == kNumberTypeInt) {
ret = FillInt32(int32_out_ptr_ + offset, size, int32_src_data_);
} else {
return RET_ERROR;
}
if (ret != RET_OK) {
MS_LOG(ERROR) << "FillRun error task_id[" << task_id << "] error_code[" << ret << "]";
return ret;
@ -67,11 +75,20 @@ int FillRun(void *cdata, int task_id) {
}
int FillCPUKernel::Run() {
auto fillData = in_tensors_.at(in_tensors_.size() - 1);
auto fill_input = in_tensors_.front();
auto output = out_tensors_.front();
auto fill_data = reinterpret_cast<float *>(fillData->MutableData());
src_data_ = fill_data[0];
out_ptr_ = reinterpret_cast<float *>(output->MutableData());
if (fill_input->data_type() == kNumberTypeFloat32 || fill_input->data_type() == kNumberTypeFloat) {
auto fill_data = reinterpret_cast<float *>(fill_input->MutableData());
src_data_ = fill_data[0];
out_ptr_ = reinterpret_cast<float *>(output->MutableData());
} else if (fill_input->data_type() == kNumberTypeInt32 || fill_input->data_type() == kNumberTypeInt) {
auto fill_data = reinterpret_cast<int *>(fill_input->MutableData());
int32_src_data_ = fill_data[0];
int32_out_ptr_ = reinterpret_cast<int *>(output->MutableData());
} else {
MS_LOG(ERROR) << "unsupported fill data type " << fill_input->data_type();
return RET_ERROR;
}
auto ret = ParallelLaunch(this->context_->thread_pool_, FillRun, this, thread_sz_count_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "FillRun error error_code[" << ret << "]";
@ -80,5 +97,6 @@ int FillCPUKernel::Run() {
return RET_OK;
}
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Fill, LiteKernelCreator<FillCPUKernel>)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Fill, LiteKernelCreator<FillCPUKernel>)
} // namespace mindspore::kernel

@ -44,6 +44,8 @@ class FillCPUKernel : public LiteKernel {
int data_size_;
float src_data_;
float *out_ptr_;
int int32_src_data_;
int *int32_out_ptr_;
int thread_count_;
};
} // namespace mindspore::kernel

@ -138,10 +138,10 @@ int GruCPUKernel::Run() {
MS_ASSERT(output != nullptr);
auto input_ptr = reinterpret_cast<float *>(input->data_c());
MS_ASSERT(input_ptr);
auto output_ptr = reinterpret_cast<float *>(output->MutableData());
auto output_ptr = reinterpret_cast<float *>(output->data_c());
MS_ASSERT(output_ptr);
auto output_hidden_state = out_tensors_[1];
memcpy(output_hidden_state->MutableData(), hidden_state->data_c(), hidden_state->ElementsNum() * sizeof(float));
memcpy(output_hidden_state->data_c(), hidden_state->data_c(), hidden_state->ElementsNum() * sizeof(float));
int check_seq_len = gru_parm_->seq_len_;
if (in_tensors_.size() == 6) {
auto seq_len = reinterpret_cast<int *>(in_tensors_.at(5)->data_c());
@ -152,12 +152,12 @@ int GruCPUKernel::Run() {
check_seq_len = MSMIN(check_seq_len, MSMAX(0, seq_len[0]));
}
MS_ASSERT(weight_g_ptr_);
MS_ASSERT(weight_r_ptr_);
MS_ASSERT(bias_ptr_);
MS_ASSERT(gate_buffer_);
MS_ASSERT(weight_g_ptr_ != nullptr);
MS_ASSERT(weight_r_ptr_ != nullptr);
MS_ASSERT(bias_ptr_ != nullptr);
MS_ASSERT(gate_buffer_ != nullptr);
Gru(output_ptr, input_ptr, weight_g_ptr_, weight_r_ptr_, bias_ptr_,
reinterpret_cast<float *>(output_hidden_state->MutableData()), gate_buffer_, check_seq_len, gru_parm_);
reinterpret_cast<float *>(output_hidden_state->data_c()), gate_buffer_, check_seq_len, gru_parm_);
return RET_OK;
}

@ -39,7 +39,7 @@ int TransposeCPUKernel::Init() {
int TransposeCPUKernel::ReSize() {
TransposeParameter *param = reinterpret_cast<TransposeParameter *>(op_parameter_);
if (in_tensors_.at(kInputIndex)->shape().size() != static_cast<size_t>(param->num_axes_)) {
if (in_tensors_.at(kInputIndex)->shape().size() != static_cast<size_t>(param->num_axes_) && in_tensors_.size() != 2) {
return RET_OK;
}
auto &inTensor = in_tensors_.front();
@ -89,6 +89,20 @@ int TransposeCPUKernel::Run() {
MS_ASSERT(out_data_);
TransposeParameter *param = reinterpret_cast<TransposeParameter *>(this->op_parameter_);
if (in_tensors_.size() == 2) {
auto input_perm = in_tensors_.at(1);
MS_ASSERT(input_perm != nullptr);
MS_ASSERT(input_perm->data_c() != nullptr);
int *perm_data = reinterpret_cast<int *>(input_perm->data_c());
auto perm = std::vector<int>{perm_data, perm_data + input_perm->ElementsNum()};
for (int i = 0; i < input_perm->ElementsNum(); ++i) {
param->perm_[i] = perm[i];
}
for (int i = input_perm->ElementsNum(); i <= 8; ++i) {
param->perm_[i] = 0;
}
param->num_axes_ = input_perm->ElementsNum();
}
if (in_tensor->shape().size() != static_cast<size_t>(param->num_axes_)) {
memcpy(out_data_, in_data_, in_tensor->ElementsNum() * sizeof(float));
return RET_OK;

@ -162,7 +162,9 @@ FuncGraphPtr AnfTransform::TransformSingleFuncGraph(const FuncGraphPtr &old_grap
inne_context_ptr->Init();
const_fold_pm->AddPass(std::make_shared<opt::ConstFoldPass>(inne_context_ptr));
}
const_fold_pm->AddPass(std::make_shared<opt::UpdateConv2DParamPass>());
auto update_conv2d_param_pass = std::make_shared<opt::UpdateConv2DParamPass>();
update_conv2d_param_pass->SetFmkType(config->fmk);
const_fold_pm->AddPass(update_conv2d_param_pass);
fusion_pm->AddPass(std::make_shared<opt::ConvConvFusion>());
convert_pm->AddPass(std::make_shared<opt::ClipConvertActivationPass>());
if (config->fmk == lite::converter::FmkType_TFLITE) {

@ -280,6 +280,7 @@ STATUS SingleSwitchPass::InsertPartialAndMergeAfterSwitch() {
second_partial_node_->outputIndex.push_back(graph_->allTensors.size() - 1);
}
auto origin_switch_outputs = switch_node_->outputIndex;
switch_node_->outputIndex.clear();
for (size_t i = 3; i < switch_node_->inputIndex.size(); i++) {
auto &switch_in_tensor = graph_->allTensors.at(i);
@ -338,7 +339,7 @@ STATUS SingleSwitchPass::InsertPartialAndMergeAfterSwitch() {
merge_node->inputIndex.insert(merge_node->inputIndex.end(), second_partial_node_->outputIndex.begin(),
second_partial_node_->outputIndex.end());
}
merge_node->outputIndex = origin_switch_output_tensor_indices_;
merge_node->outputIndex = origin_switch_outputs;
graph_->nodes.push_back(std::move(merge_node));
return RET_OK;
}

@ -67,19 +67,23 @@ STATUS TFConvParser::Parse(const tensorflow::NodeDef &tf_op,
attr->strideW = strides[1];
auto weight_node = GetConstInputNode(tf_node_map, tf_op.input(1));
if (weight_node == nullptr) {
MS_LOG(ERROR) << "Find Conv2D input weights failed";
return RET_ERROR;
}
std::vector<int64_t> kernels(4);
status = ParseKernels(*weight_node, attr->format, &kernels);
if (status != RET_OK) {
return status;
if (weight_node != nullptr) {
std::vector<int64_t> kernels(4);
status = ParseKernels(*weight_node, attr->format, &kernels);
if (status != RET_OK) {
return status;
}
attr->kernelH = kernels[0];
attr->kernelW = kernels[1];
attr->channelIn = kernels[2];
attr->channelOut = kernels[3];
} else {
attr->kernelH = -1;
attr->kernelW = -1;
attr->channelIn = -1;
attr->channelOut = -1;
MS_LOG(WARNING) << "parsing of kernelH/W channelIn/Out is delayed";
}
attr->kernelH = kernels[0];
attr->kernelW = kernels[1];
attr->channelIn = kernels[2];
attr->channelOut = kernels[3];
status = ParsePadMode(tf_op, &attr->padMode);
if (status != RET_OK) {

@ -42,20 +42,15 @@ STATUS TFFillParser::Parse(const tensorflow::NodeDef &tf_op,
return RET_NULL_PTR;
}
primitive->value.type = schema::PrimitiveType_Fill;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
inputs->emplace_back(tf_op.input(1));
// parse dims
tensorflow::AttrValue attr_value;
auto dims_node = GetConstInputNode(tf_node_map, tf_op.input(0));
MS_ASSERT(dims_node != nullptr);
if (dims_node != nullptr && TensorFlowUtils::FindAttrValue(*dims_node, "value", &attr_value)) {
if (dims_node != nullptr) {
if (!TensorFlowUtils::FindAttrValue(*dims_node, "value", &attr_value)) {
MS_LOG(ERROR) << "fill dims input not have value attr";
return RET_ERROR;
}
if (attr_value.value_case() != tensorflow::AttrValue::kTensor) {
MS_LOG(ERROR) << "The attrValue of value should have tensor type, actual: " << attr_value.value_case()
<< ", node: " << tf_op.name().c_str();
@ -66,32 +61,44 @@ STATUS TFFillParser::Parse(const tensorflow::NodeDef &tf_op,
MS_LOG(ERROR) << "The dimsTensor dataType should be DT_INT32, actual : " << dims_tensor.dtype();
return RET_ERROR;
}
const tensorflow::TensorShapeProto &dimsTensorShape = dims_tensor.tensor_shape();
size_t shapeSize = 1;
for (int i = 0; i < dimsTensorShape.dim_size(); i++) {
shapeSize *= dimsTensorShape.dim(i).size();
const tensorflow::TensorShapeProto &dims_tensor_shape = dims_tensor.tensor_shape();
size_t shape_size = 1;
for (int i = 0; i < dims_tensor_shape.dim_size(); i++) {
shape_size *= dims_tensor_shape.dim(i).size();
}
size_t size = dims_tensor.int_val().size();
if (size > 0) {
for (size_t i = 0; i < shapeSize; i++) {
attr->dims.emplace_back(dims_tensor.int_val().Get(0));
for (size_t i = 0; i < shape_size; i++) {
attr->dims.emplace_back(dims_tensor.int_val().Get(i));
}
} else {
size = dims_tensor.tensor_content().length();
if (size == shapeSize * sizeof(int32_t)) {
attr->dims.resize(shapeSize);
if (size > 0) {
if (size != shape_size * sizeof(int32_t)) {
MS_LOG(ERROR) << "tensor size mismatch";
return RET_ERROR;
}
attr->dims.resize(shape_size);
if (EOK != ::memcpy_s(attr->dims.data(), size, dims_tensor.tensor_content().data(), size)) {
MS_LOG(ERROR) << "Memcpy_s from dimsTensor to attr failed";
return RET_ERROR;
}
} else {
MS_LOG(ERROR) << "Can not find weight data, node: " << dims_node->name().c_str();
return RET_ERROR;
MS_LOG(DEBUG) << "empty dims";
}
}
} else {
inputs->emplace_back(tf_op.input(0));
}
primitive->value.type = schema::PrimitiveType_Fill;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
return RET_OK;
}
TFNodeRegistrar g_tfFillParser("Fill", new TFFillParser());

@ -15,6 +15,7 @@
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FILL_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_FILL_PARSER_H_
#include <string>
#include <memory>
#include <map>

@ -46,7 +46,7 @@ const NodeDef *TFNodeParser::GetConstInputNode(const std::map<string, const tens
node = tf_node_map.at(flatten_input_name);
}
if (node->op() != "Const") {
MS_LOG(ERROR) << "Attr node is not Const";
MS_LOG(DEBUG) << "Attr node is not Const";
return nullptr;
}
return node;

@ -54,7 +54,7 @@ STATUS TFPoolParser::Parse(const tensorflow::NodeDef &tf_op,
if (attr_value.s() == "VALID") {
attr->padMode = schema::PadMode_VALID;
} else if (attr_value.s() == "SAME") {
attr->padMode = schema::PadMode_VALID;
attr->padMode = schema::PadMode_SAME_UPPER;
}
}

@ -0,0 +1,61 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/parser/tf/tf_rsqrt_parser.h"
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser_registry.h"
namespace mindspore {
namespace lite {
STATUS TFRsqrtParser::Parse(const tensorflow::NodeDef &tf_op,
const std::map<string, const tensorflow::NodeDef *> &tf_node_map, PrimitiveC **primitiveC,
std::vector<std::string> *inputs, int *output_size) {
MS_LOG(INFO) << "TF RsqrtParser";
if (primitiveC == nullptr || output_size == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_NULL_PTR;
}
auto primitive = std::make_unique<schema::PrimitiveT>();
if (primitive == nullptr) {
MS_LOG(ERROR) << "primitive is nullptr";
return RET_NULL_PTR;
}
auto attr = std::make_unique<schema::RsqrtT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
primitive->value.type = schema::PrimitiveType_Rsqrt;
primitive->value.value = attr.release();
*primitiveC = PrimitiveC::Create(primitive.release());
if (*primitiveC == nullptr) {
MS_LOG(ERROR) << "primitiveC is nullptr";
return RET_ERROR;
}
*output_size = 1;
for (int i = 0; i < tf_op.input_size(); i++) {
inputs->emplace_back(tf_op.input(i));
}
return RET_OK;
}
TFNodeRegistrar g_tfRsqrtParser("Rsqrt", new TFRsqrtParser());
} // namespace lite
} // namespace mindspore

@ -0,0 +1,38 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RSQRT_PARSER_H_
#define MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RSQRT_PARSER_H_
#include <string>
#include <memory>
#include <map>
#include <vector>
#include "tools/converter/parser/tf/tf_node_parser.h"
namespace mindspore {
namespace lite {
class TFRsqrtParser : public TFNodeParser {
public:
TFRsqrtParser() = default;
~TFRsqrtParser() override = default;
STATUS Parse(const tensorflow::NodeDef &tf_op, const std::map<string, const tensorflow::NodeDef *> &tf_node_map,
PrimitiveC **primitiveC, std::vector<std::string> *inputs, int *output_size) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_TF_TF_RSQRT_PARSER_H_

@ -41,28 +41,36 @@ STATUS TFTransposeParser::Parse(const tensorflow::NodeDef &tf_op,
MS_LOG(ERROR) << "new attr failed";
return RET_NULL_PTR;
}
attr->conjugate = false;
auto status = AddOpInput(tf_op, 0, inputs);
if (status != RET_OK) {
return status;
}
auto perm_node = GetConstInputNode(tf_node_map, tf_op.input(1));
if (perm_node == nullptr) {
MS_LOG(ERROR) << "Find Transpose input perm failed";
return RET_ERROR;
}
tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(*perm_node, "value", &attr_value)) {
MS_LOG(ERROR) << "The value attr should be specified";
return RET_ERROR;
}
auto tensor_proto = attr_value.tensor();
if (tensor_proto.int_val_size() > 0) {
for (int i = 0; i < tensor_proto.int_val_size(); ++i) {
attr->perm.push_back(tensor_proto.int_val(i));
status = AddOpInput(tf_op, 1, inputs);
if (status != RET_OK) {
return status;
}
} else {
auto data_num = tensor_proto.tensor_content().size() / sizeof(int32_t);
auto data = reinterpret_cast<const int32_t *>(tensor_proto.tensor_content().data());
for (size_t i = 0; i < data_num; ++i) {
attr->perm.push_back(data[i]);
tensorflow::AttrValue attr_value;
if (!TensorFlowUtils::FindAttrValue(*perm_node, "value", &attr_value)) {
MS_LOG(ERROR) << "The value attr should be specified";
return RET_ERROR;
}
auto tensor_proto = attr_value.tensor();
if (tensor_proto.int_val_size() > 0) {
for (int i = 0; i < tensor_proto.int_val_size(); ++i) {
attr->perm.push_back(tensor_proto.int_val(i));
}
} else {
auto data_num = tensor_proto.tensor_content().size() / sizeof(int32_t);
auto data = reinterpret_cast<const int32_t *>(tensor_proto.tensor_content().data());
for (size_t i = 0; i < data_num; ++i) {
attr->perm.push_back(data[i]);
}
}
}
@ -75,7 +83,6 @@ STATUS TFTransposeParser::Parse(const tensorflow::NodeDef &tf_op,
}
*output_size = 1;
auto status = AddOpInput(tf_op, 0, inputs);
return status;
}
TFNodeRegistrar g_tfTransposeParser("Transpose", new TFTransposeParser());

@ -693,7 +693,7 @@ STATUS SetFilterDim(const ParamValueLitePtr &tensor, kTransFilterType type, int3
tensor->set_tensor_shape({filterC, filterK, filterH, filterW});
} else if (type == kKHWC2CHWK) {
tensor->set_tensor_shape({filterC, filterH, filterW, filterK});
} else if (type == kKCHW2KHWC || type == kCKHW2KHWC || type == kCHWK2KHWC) {
} else if (type == kKCHW2KHWC || type == kCKHW2KHWC || type == kCHWK2KHWC || type == kHWCK2KHWC) {
tensor->set_tensor_shape({filterK, filterH, filterW, filterC});
} else {
MS_LOG(ERROR) << "Unsupported transFilterType: " << type;
@ -812,7 +812,8 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType
}
} break;
case kHWCK2KCHW:
case kHWCK2CKHW: {
case kHWCK2CKHW:
case kHWCK2KHWC: {
for (int h = 0; h < filterH; ++h) {
for (int w = 0; w < filterW; ++w) {
for (int c = 0; c < filterC; ++c) {
@ -821,9 +822,12 @@ static STATUS TransFilterData(const ParamValueLitePtr &tensor, kTransFilterType
if (type == kHWCK2KCHW) {
p2Buff =
buf.get() + ((k * filterC * filterH * filterW) + (c * filterH * filterW) + (h * filterW) + (w));
} else {
} else if (type == kHWCK2CKHW) {
p2Buff =
buf.get() + ((c * filterK * filterH * filterW) + (k * filterH * filterW) + (h * filterW) + (w));
} else {
p2Buff =
buf.get() + ((k * filterH * filterW * filterC) + (h * filterW * filterC) + (w * filterC) + (c));
}
*p2Buff = *p1Buff;
}

@ -25,7 +25,6 @@
namespace mindspore {
namespace opt {
namespace {
constexpr size_t kWhileCommonInputsLength = 2;
constexpr size_t kWhileUniqInputsLength = 6;
constexpr size_t kCondNodesNum = 12;
constexpr size_t kCondCNodesNum = 4;
@ -47,16 +46,11 @@ BiDirectionTfGruCellFusion::BiDirectionTfGruCellFusion(const std::string &name,
: PatternProcessPass(name, multigraph) {
/*
* vars for while input
* common:
* 0:const0 1:init_state
* fw_while_inputs:
* 0:cond 1:body 2:kernel_gate 3:bias_gate 4:cand_kernel 5:cand_bias
* bw_while_inputs:
* 0:cond 1:body 2:kernel_gate 3:bias_gate 4:cand_kernel 5:cand_bias
*/
for (size_t i = 0; i < kWhileCommonInputsLength; ++i) {
common_vars_.emplace_back(std::make_shared<Var>());
}
for (size_t i = 0; i < kWhileUniqInputsLength; ++i) {
fw_vars_.emplace_back(std::make_shared<Var>());
bw_vars_.emplace_back(std::make_shared<Var>());
@ -64,17 +58,16 @@ BiDirectionTfGruCellFusion::BiDirectionTfGruCellFusion(const std::string &name,
input_ = std::make_shared<Var>();
input_length_ = std::make_shared<Var>();
transpose_input_ = std::make_shared<Var>();
fw_init_state_ = std::make_shared<Var>();
bw_init_state_ = std::make_shared<Var>();
}
const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const {
auto const1 = std::make_shared<CondVar>(IsParameterNode);
auto ele_shape = std::make_shared<CondVar>(IsParameterNode);
// forward
auto fw_max1 =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Reduce)), input_length_});
auto fw_max2 =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Maximum)), const1, fw_max1});
auto fw_max2 = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Maximum)),
std::make_shared<CondVar>(IsParameterNode), fw_max1});
auto fw_shape =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Shape)), transpose_input_});
@ -84,32 +77,33 @@ const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const {
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Minimum)), fw_stride, fw_max2});
auto fw_reserve =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListReserve)), ele_shape,
fw_stride});
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListReserve)),
std::make_shared<CondVar>(IsParameterNode), fw_stride});
auto fw_from_tensor =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListFromTensor)),
transpose_input_, ele_shape});
transpose_input_, std::make_shared<CondVar>(IsParameterNode)});
auto is_fw_while = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_While));
auto fw_while = VectorRef({is_fw_while, fw_vars_[0], fw_vars_[1], common_vars_[0], fw_stride, common_vars_[0],
fw_reserve, common_vars_[1], fw_min, fw_from_tensor, input_length_});
auto fw_while = VectorRef({is_fw_while, fw_vars_[0], fw_vars_[1], std::make_shared<CondVar>(IsParameterNode),
fw_stride, std::make_shared<CondVar>(IsParameterNode), fw_reserve, fw_init_state_, fw_min,
fw_from_tensor, input_length_});
fw_while.insert(fw_while.end(), fw_vars_.begin() + 2, fw_vars_.end());
fw_while.emplace_back(common_vars_[1]);
fw_while.emplace_back(std::make_shared<Var>());
auto fw_get_item = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TupleGetItem)),
fw_while, std::make_shared<Var>()});
auto fw_stack = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListStack)),
fw_get_item, ele_shape});
auto fw_out_trans =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)), fw_stack});
fw_get_item, std::make_shared<CondVar>(IsParameterNode)});
auto fw_out_trans = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)),
fw_stack, std::make_shared<Var>()});
// backward
auto bw_reverse_seq = VectorRef(
{std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_ReverseSequence)), input_, input_length_});
auto bw_max1 =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Reduce)), input_length_});
auto bw_max2 =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Maximum)), const1, bw_max1});
auto bw_trans =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)), bw_reverse_seq});
auto bw_max2 = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Maximum)),
std::make_shared<CondVar>(IsParameterNode), bw_max1});
auto bw_trans = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)),
bw_reverse_seq, std::make_shared<Var>()});
auto bw_shape =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Shape)), bw_trans});
auto bw_stride =
@ -117,22 +111,23 @@ const BaseRef BiDirectionTfGruCellFusion::DefinePattern() const {
auto bw_min =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Minimum)), bw_stride, bw_max2});
auto bw_reserve =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListReserve)), ele_shape,
bw_stride});
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListReserve)),
std::make_shared<CondVar>(IsParameterNode), bw_stride});
auto bw_from_tensor =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListFromTensor)), bw_trans,
ele_shape});
std::make_shared<CondVar>(IsParameterNode)});
auto is_bw_while = std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_While));
auto bw_while = VectorRef({is_bw_while, bw_vars_[0], bw_vars_[1], common_vars_[0], bw_stride, common_vars_[0],
bw_reserve, common_vars_[1], bw_min, bw_from_tensor, input_length_});
auto bw_while = VectorRef({is_bw_while, bw_vars_[0], bw_vars_[1], std::make_shared<CondVar>(IsParameterNode),
bw_stride, std::make_shared<CondVar>(IsParameterNode), bw_reserve, bw_init_state_, bw_min,
bw_from_tensor, input_length_});
bw_while.insert(bw_while.end(), bw_vars_.begin() + 2, bw_vars_.end());
bw_while.emplace_back(common_vars_[1]);
bw_while.emplace_back(std::make_shared<Var>());
auto bw_get_item = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TupleGetItem)),
bw_while, std::make_shared<Var>()});
auto bw_stack = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_TensorListStack)),
bw_get_item, ele_shape});
auto bw_out_trans =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)), bw_stack});
bw_get_item, std::make_shared<CondVar>(IsParameterNode)});
auto bw_out_trans = VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_Transpose)),
bw_stack, std::make_shared<Var>()});
auto bw_reverse1 =
VectorRef({std::make_shared<CondVar>(std::bind(IsOpType, p1, schema::PrimitiveType_ReverseSequence)), bw_out_trans,
input_length_});
@ -416,10 +411,12 @@ STATUS BiDirectionTfGruCellFusion::ConvertBiasData(const AnfNodePtr &gate_bias,
}
CNodePtr BiDirectionTfGruCellFusion::GetStackedHiddenState(const FuncGraphPtr &func_graph,
const AnfNodePtr &hidden_state,
const AnfNodePtr &fw_init_state,
const AnfNodePtr &bw_init_state,
const std::string base_name) const {
MS_ASSERT(func_graph);
MS_ASSERT(hidden_state);
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(fw_init_state != nullptr);
MS_ASSERT(bw_init_state != nullptr);
auto stack_primitive = std::make_unique<schema::PrimitiveT>();
std::unique_ptr<schema::StackT> attr = std::make_unique<schema::StackT>();
attr->axis = 0;
@ -427,9 +424,9 @@ CNodePtr BiDirectionTfGruCellFusion::GetStackedHiddenState(const FuncGraphPtr &f
stack_primitive->value.value = attr.release();
auto stack_cvalue = lite::PrimitiveC::Create(stack_primitive.release());
auto value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(stack_cvalue));
std::vector<AnfNodePtr> new_node_inputs = {value_node, hidden_state, hidden_state};
std::vector<AnfNodePtr> new_node_inputs = {value_node, fw_init_state, bw_init_state};
auto new_node = func_graph->NewCNode(new_node_inputs);
new_node->set_abstract(hidden_state->abstract()->Clone());
new_node->set_abstract(fw_init_state->abstract()->Clone());
new_node->set_fullname_with_scope("stack_hidden_" + base_name);
return new_node;
}
@ -452,31 +449,33 @@ CNodePtr BiDirectionTfGruCellFusion::CreateBiDirectionGruNode(const FuncGraphPtr
auto value_node = NewValueNode(std::shared_ptr<lite::PrimitiveC>(gru_cvalue));
auto fw_gate_kernel = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[2]]);
MS_ASSERT(fw_gate_kernel);
MS_ASSERT(fw_gate_kernel != nullptr);
auto fw_gate_bias = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[3]]);
MS_ASSERT(fw_gate_bias);
MS_ASSERT(fw_gate_bias != nullptr);
auto fw_cand_kernel = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[4]]);
MS_ASSERT(fw_cand_kernel);
MS_ASSERT(fw_cand_kernel != nullptr);
auto fw_cand_bias = utils::cast<AnfNodePtr>((*equiv)[fw_vars_[5]]);
MS_ASSERT(fw_cand_bias);
MS_ASSERT(fw_cand_bias != nullptr);
auto bw_gate_kernel = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[2]]);
MS_ASSERT(bw_gate_kernel);
MS_ASSERT(bw_gate_kernel != nullptr);
auto bw_gate_bias = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[3]]);
MS_ASSERT(bw_gate_bias);
MS_ASSERT(bw_gate_bias != nullptr);
auto bw_cand_kernel = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[4]]);
MS_ASSERT(bw_cand_kernel);
MS_ASSERT(bw_cand_kernel != nullptr);
auto bw_cand_bias = utils::cast<AnfNodePtr>((*equiv)[bw_vars_[5]]);
MS_ASSERT(bw_cand_bias);
MS_ASSERT(bw_cand_bias != nullptr);
auto hidden = utils::cast<AnfNodePtr>((*equiv)[common_vars_[1]]);
MS_ASSERT(hidden);
auto stacked_hidden = GetStackedHiddenState(func_graph, hidden, base_name);
auto fw_init_state = utils::cast<AnfNodePtr>((*equiv)[fw_init_state_]);
MS_ASSERT(fw_init_state != nullptr);
auto bw_init_state = utils::cast<AnfNodePtr>((*equiv)[bw_init_state_]);
MS_ASSERT(bw_init_state != nullptr);
auto stacked_hidden = GetStackedHiddenState(func_graph, fw_init_state, bw_init_state, base_name);
if (stacked_hidden == nullptr) {
return nullptr;
}
auto input_length = utils::cast<AnfNodePtr>((*equiv)[input_length_]);
MS_ASSERT(hidden);
MS_ASSERT(hidden != nullptr);
int input_size = 0;
int hidden_size = 0;
@ -536,8 +535,8 @@ CNodePtr BiDirectionTfGruCellFusion::CreateBiDirectionGruNode(const FuncGraphPtr
CNodePtr BiDirectionTfGruCellFusion::GetPostProcessNode(const FuncGraphPtr &func_graph, const CNodePtr &gru_output,
const std::string base_name) const {
MS_ASSERT(func_graph);
MS_ASSERT(gru_output);
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(gru_output != nullptr);
auto split_primitive = std::make_unique<schema::PrimitiveT>();
std::unique_ptr<schema::SplitT> split_attr = std::make_unique<schema::SplitT>();
split_attr->numberSplit = 2;
@ -603,8 +602,8 @@ CNodePtr BiDirectionTfGruCellFusion::GetPostProcessNode(const FuncGraphPtr &func
const AnfNodePtr BiDirectionTfGruCellFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &concat_node,
const EquivPtr &equiv) const {
MS_ASSERT(func_graph);
MS_ASSERT(concat_node);
MS_ASSERT(func_graph != nullptr);
MS_ASSERT(concat_node != nullptr);
MS_LOG(DEBUG) << "bidirection tf gru fusion pass";
if (CheckIfFuncGraphIsNull(func_graph) != lite::RET_OK || CheckIfAnfNodeIsNull(concat_node) != lite::RET_OK) {
lite::ReturnCode::GetSingleReturnCode()->UpdateReturnCode(lite::RET_NULL_PTR);
@ -612,7 +611,7 @@ const AnfNodePtr BiDirectionTfGruCellFusion::Process(const FuncGraphPtr &func_gr
}
auto transpose_input = utils::cast<AnfNodePtr>((*equiv)[transpose_input_]);
MS_ASSERT(transpose_input);
MS_ASSERT(transpose_input != nullptr);
if (!utils::isa<CNodePtr>(transpose_input) || GetCNodeType(transpose_input) != schema::PrimitiveType_Transpose) {
return nullptr;
}

@ -54,18 +54,19 @@ class BiDirectionTfGruCellFusion : public PatternProcessPass {
float *tensor_data) const;
void CopyFlattenMatData(const float *mat, const int R, const int C, const int r0, const int r1, const int c0,
const int c1, float *data, bool t = false) const;
CNodePtr GetStackedHiddenState(const FuncGraphPtr &func_graph, const AnfNodePtr &hidden_state,
const std::string base_name) const;
CNodePtr GetStackedHiddenState(const FuncGraphPtr &func_graph, const AnfNodePtr &fw_init_state,
const AnfNodePtr &bw_init_state, const std::string base_name) const;
CNodePtr GetPostProcessNode(const FuncGraphPtr &func_graph, const CNodePtr &gru_output,
const std::string base_name) const;
private:
std::vector<VarPtr> common_vars_;
std::vector<VarPtr> fw_vars_;
std::vector<VarPtr> bw_vars_;
VarPtr input_;
VarPtr input_length_;
VarPtr transpose_input_;
VarPtr fw_init_state_;
VarPtr bw_init_state_;
};
} // namespace opt
} // namespace mindspore

@ -53,7 +53,44 @@ bool UpdateConv2DParamPass::Run(const FuncGraphPtr &func_graph) {
primT->value.AsDepthwiseConv2D()->channelIn = weight->tensor_shape().at(0);
}
}
} else if (type == schema::PrimitiveType_Conv2D) {
auto conv2d_cnode = node->cast<CNodePtr>();
auto primitive_c = GetValueNode<std::shared_ptr<lite::PrimitiveC>>(conv2d_cnode->input(0));
if (primitive_c == nullptr) {
MS_LOG(ERROR) << "Conv2D node has no primitiveC.";
continue;
}
auto primT = primitive_c->primitiveT();
if (primT == nullptr) {
MS_LOG(ERROR) << "Conv2D node has no primitiveT.";
continue;
}
auto conv2d_primt = primT->value.AsConv2D();
auto weight_node = conv2d_cnode->input(lite::kAnfPopulaterInputNumTwo);
if (weight_node == nullptr) {
MS_LOG(ERROR) << "Conv2D weight node is nullptr.";
continue;
}
if (!weight_node->isa<Parameter>()) {
MS_LOG(ERROR) << "Conv2D weight node is not parameter.";
continue;
}
auto weight_param = weight_node->cast<ParameterPtr>();
if (!weight_param->has_default()) {
MS_LOG(ERROR) << "Conv2D weight node is not parameter.";
continue;
}
auto default_param = weight_param->default_param();
auto weight_tensor = std::dynamic_pointer_cast<ParamValueLite>(default_param);
auto weight_shape = weight_tensor->tensor_shape();
if (fmk_type == lite::converter::FmkType_TF && conv2d_primt->format == schema::Format_NHWC) {
conv2d_primt->kernelH = weight_shape[0];
conv2d_primt->kernelW = weight_shape[1];
conv2d_primt->channelIn = weight_shape[2];
conv2d_primt->channelOut = weight_shape[3];
}
}
if (status != lite::RET_OK && status != lite::RET_NO_CHANGE) {
MS_LOG(ERROR) << "remove identity pass is failed.";
return false;

@ -19,13 +19,19 @@
#include "schema/inner/model_generated.h"
#include "backend/optimizer/common/pass.h"
#include "tools/optimizer/common/gllo_utils.h"
#include "tools/converter/converter_flags.h"
using mindspore::lite::converter::FmkType;
namespace mindspore::opt {
class UpdateConv2DParamPass : public Pass {
public:
UpdateConv2DParamPass() : Pass("update_conv2d_param_pass") {}
~UpdateConv2DParamPass() override = default;
bool Run(const FuncGraphPtr &graph) override;
void SetFmkType(FmkType fmk_type) { this->fmk_type = fmk_type; }
private:
FmkType fmk_type = lite::converter::FmkType_ONNX;
};
} // namespace mindspore::opt
#endif // MINDSPORE_LITE_SRC_PASS_UPDATE_CONV2D_PARAM_PASS_H_

Loading…
Cancel
Save