run transfermer decoder success

pull/5864/head
cjh9368 5 years ago
parent e0c7ad784b
commit 9d8fd9252d

@ -91,7 +91,7 @@ int Gather::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor
std::vector<int> out_shape{in_shape}; std::vector<int> out_shape{in_shape};
out_shape.erase(out_shape.begin() + axis); out_shape.erase(out_shape.begin() + axis);
for (int i = 0; i < indices_rank; i++) { for (int i = 0; i < indices_rank; i++) {
out_shape.insert(out_shape.begin() + axis, indices_shape[i]); out_shape.insert(out_shape.begin() + axis + i, indices_shape[i]);
} }
output->set_shape(out_shape); output->set_shape(out_shape);
return RET_OK; return RET_OK;

@ -55,7 +55,7 @@ kernel::LiteKernel *CpuFullConnectionInt8KernelCreator(const std::vector<lite::t
} }
int RestoreFullconnectWeight(lite::tensor::Tensor *input_tensor) { int RestoreFullconnectWeight(lite::tensor::Tensor *input_tensor) {
MS_ASSERT(input_tensor != nullptr); MS_ASSERT(input_tensor != nullptr);
if (input_tensor->data_type() != kNumberTypeUInt8) { if (input_tensor->data_type() != kNumberTypeInt8) {
MS_LOG(ERROR) << "full connect input type error" << input_tensor->data_type(); MS_LOG(ERROR) << "full connect input type error" << input_tensor->data_type();
return RET_ERROR; return RET_ERROR;
} }
@ -63,7 +63,7 @@ int RestoreFullconnectWeight(lite::tensor::Tensor *input_tensor) {
MS_LOG(ERROR) << "no quant param"; MS_LOG(ERROR) << "no quant param";
return RET_ERROR; return RET_ERROR;
} }
const auto* quant_data = static_cast<const uint8_t*>(input_tensor->Data()); const auto* quant_data = static_cast<const int8_t*>(input_tensor->Data());
auto* dequant_data = static_cast<float *>(malloc(input_tensor->DataSize() * sizeof(float))); auto* dequant_data = static_cast<float *>(malloc(input_tensor->DataSize() * sizeof(float)));
if (dequant_data == nullptr) { if (dequant_data == nullptr) {
MS_LOG(ERROR) << "malloc faile"; MS_LOG(ERROR) << "malloc faile";
@ -108,7 +108,7 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::t
MS_ASSERT(desc.type == schema::PrimitiveType_Concat); MS_ASSERT(desc.type == schema::PrimitiveType_Concat);
auto *weight_tensor = inputs.at(kWeightIndex); auto *weight_tensor = inputs.at(kWeightIndex);
auto *restore_data = weight_tensor->Data(); auto *restore_data = weight_tensor->Data();
if (primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (!weight_tensor->GetQuantParams().empty()) {
RestoreFullconnectWeight(inputs.at(kWeightIndex)); RestoreFullconnectWeight(inputs.at(kWeightIndex));
} }
auto kernel = new (std::nothrow) FullconnectionCPUKernel(opParameter, inputs, outputs, ctx, primitive); auto kernel = new (std::nothrow) FullconnectionCPUKernel(opParameter, inputs, outputs, ctx, primitive);
@ -123,7 +123,7 @@ kernel::LiteKernel *CpuFullConnectionFp32KernelCreator(const std::vector<lite::t
<< schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_)); << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(opParameter->type_));
return nullptr; return nullptr;
} }
if (primitive->GetQuantType() == schema::QuantType_WeightQuant) { if (!weight_tensor->GetQuantParams().empty()) {
weight_tensor->FreeData(); weight_tensor->FreeData();
weight_tensor->SetData(restore_data); weight_tensor->SetData(restore_data);
} }

@ -116,11 +116,6 @@ int RestoreMulWeight(lite::tensor::Tensor *input_tensor) {
return RET_OK; return RET_OK;
} }
int ArithmeticSelfCPUKernel::Run() { int ArithmeticSelfCPUKernel::Run() {
void *restore_data = nullptr;
if (primitive_->GetQuantType() == schema::QuantType_WeightQuant) {
restore_data = in_tensors_[1]->Data();
RestoreMulWeight(in_tensors_[1]);
}
auto ret = Prepare(); auto ret = Prepare();
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail!ret: " << ret; MS_LOG(ERROR) << "Prepare fail!ret: " << ret;
@ -135,10 +130,6 @@ int ArithmeticSelfCPUKernel::Run() {
MS_LOG(ERROR) << "ArithmeticSelfRun error error_code[" << ret << "]"; MS_LOG(ERROR) << "ArithmeticSelfRun error error_code[" << ret << "]";
return ret; return ret;
} }
if (primitive_->GetQuantType() == schema::QuantType_WeightQuant) {
in_tensors_[1]->FreeData();
in_tensors_[1]->SetData(restore_data);
}
return RET_OK; return RET_OK;
} }

@ -35,29 +35,11 @@ int GatherInt8CPUKernel::Init() {
axis_ = (reinterpret_cast<GatherParameter *>(op_parameter_))->axis_; axis_ = (reinterpret_cast<GatherParameter *>(op_parameter_))->axis_;
batchDims_ = (reinterpret_cast<GatherParameter *>(op_parameter_))->batchDims_; batchDims_ = (reinterpret_cast<GatherParameter *>(op_parameter_))->batchDims_;
auto in_quant_args = in_tensors_.at(0)->GetQuantParams(); auto in_quant_args = in_tensors_.at(0)->GetQuantParams();
auto ind_quant_args = in_tensors_.at(1)->GetQuantParams();
auto out_quant_args = out_tensors_.at(0)->GetQuantParams(); auto out_quant_args = out_tensors_.at(0)->GetQuantParams();
param_.alpha_ = in_quant_args.front().scale / out_quant_args.front().scale; param_.alpha_ = in_quant_args.front().scale / out_quant_args.front().scale;
param_.zp_in_ = in_quant_args.front().zeroPoint; param_.zp_in_ = in_quant_args.front().zeroPoint;
param_.zp_out_ = out_quant_args.front().zeroPoint; param_.zp_out_ = out_quant_args.front().zeroPoint;
auto indices_ptr = reinterpret_cast<int8_t *>(in_tensors_.at(1)->Data());
if (indices_ != nullptr) {
free(indices_);
indices_ = nullptr;
}
int count = in_tensors_.at(1)->ElementsNum();
indices_ = reinterpret_cast<int *>(malloc(count * sizeof(int)));
if (indices_ == nullptr) {
MS_LOG(ERROR) << "Gather Malloc indices_ error!";
return RET_ERROR;
}
(void)memset(indices_, 0, count * sizeof(int));
for (int i = 0; i < count; ++i) {
indices_[i] =
static_cast<int>(round((indices_ptr[i] - ind_quant_args.front().zeroPoint) * ind_quant_args.front().scale));
}
if (!InferShapeDone()) { if (!InferShapeDone()) {
return RET_OK; return RET_OK;
} }
@ -73,6 +55,7 @@ int GatherInt8CPUKernel::DoGather(int task_id) {
auto input_ptr = reinterpret_cast<int8_t *>(input_tensor->Data()); auto input_ptr = reinterpret_cast<int8_t *>(input_tensor->Data());
auto output_ptr = reinterpret_cast<int8_t *>(out_tensor->Data()); auto output_ptr = reinterpret_cast<int8_t *>(out_tensor->Data());
auto indices_ptr = reinterpret_cast<int32_t *>(out_tensor->Data());
auto in_shape = input_tensor->shape(); auto in_shape = input_tensor->shape();
int in_rank = in_shape.size(); int in_rank = in_shape.size();
@ -80,8 +63,8 @@ int GatherInt8CPUKernel::DoGather(int task_id) {
const int limit = in_shape[axis_]; const int limit = in_shape[axis_];
for (int i = 0; i < indices_element_size; ++i) { for (int i = 0; i < indices_element_size; ++i) {
if (indices_[i] >= limit) { if (indices_ptr[i] >= limit) {
MS_LOG(ERROR) << " indice data: " << indices_[i] << " is not in [ 0, " << limit - 1 << " ]"; MS_LOG(ERROR) << " indice data: " << indices_ptr[i] << " is not in [ 0, " << limit - 1 << " ]";
return RET_ERROR; return RET_ERROR;
} }
} }
@ -103,7 +86,7 @@ int GatherInt8CPUKernel::DoGather(int task_id) {
int error_code; int error_code;
input_ptr += thread_stride * limit; input_ptr += thread_stride * limit;
output_ptr += thread_stride * indices_element_size; output_ptr += thread_stride * indices_element_size;
error_code = GatherInt8(input_ptr, output_ptr, count, inner_size, limit, indices_, indices_element_size, param_); error_code = GatherInt8(input_ptr, output_ptr, count, inner_size, limit, indices_ptr, indices_element_size, param_);
if (error_code != RET_OK) { if (error_code != RET_OK) {
return RET_ERROR; return RET_ERROR;
@ -127,6 +110,7 @@ int GatherInt8CPUKernel::Run() {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret; MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret; return prepare_ret;
} }
int error_code = ParallelLaunch(THREAD_POOL_DEFAULT, GatherInt8Run, this, thread_count_); int error_code = ParallelLaunch(THREAD_POOL_DEFAULT, GatherInt8Run, this, thread_count_);
if (error_code != RET_OK) { if (error_code != RET_OK) {
MS_LOG(ERROR) << "Gather function error error_code[" << error_code << "]"; MS_LOG(ERROR) << "Gather function error error_code[" << error_code << "]";

@ -30,8 +30,6 @@ class GatherInt8CPUKernel : public LiteKernel {
const mindspore::lite::PrimitiveC *primitive) const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {} : LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {}
~GatherInt8CPUKernel() { ~GatherInt8CPUKernel() {
free(indices_);
indices_ = nullptr;
} }
int Init() override; int Init() override;
@ -40,7 +38,6 @@ class GatherInt8CPUKernel : public LiteKernel {
int DoGather(int task_id); int DoGather(int task_id);
private: private:
int *indices_ = nullptr;
int thread_count_; int thread_count_;
int batchDims_; int batchDims_;
int axis_; int axis_;

@ -129,7 +129,7 @@ void AnfExporter::SetGraphInputIndex(const std::unique_ptr<schema::MetaGraphT> &
for (auto node : graph_input_nodes_) { for (auto node : graph_input_nodes_) {
for (auto input : node->inputIndex) { for (auto input : node->inputIndex) {
auto tensor = meta_graphT->allTensors[input].get(); auto tensor = meta_graphT->allTensors[input].get();
if (tensor->data.empty()) { if (tensor->nodeType != schema::NodeType_CNode && tensor->data.empty()) {
tensor->nodeType = schema::NodeType_ValueNode; tensor->nodeType = schema::NodeType_ValueNode;
tensor->format = schema::Format_NHWC; tensor->format = schema::Format_NHWC;
if (!IsContain(meta_graphT->inputIndex, input)) { if (!IsContain(meta_graphT->inputIndex, input)) {
@ -261,7 +261,6 @@ int AnfExporter::ConvertInputParameter(const std::shared_ptr<AnfNode> input_anod
return RET_OK; return RET_OK;
} }
auto paramTensor = std::make_unique<schema::TensorT>(); auto paramTensor = std::make_unique<schema::TensorT>();
paramTensor->nodeType = schema::NodeType_ValueNode;
paramTensor->format = schema::Format_NHWC; paramTensor->format = schema::Format_NHWC;
auto abstractBase = paramNode->abstract(); auto abstractBase = paramNode->abstract();
if (abstractBase == nullptr) { if (abstractBase == nullptr) {
@ -341,11 +340,10 @@ int AnfExporter::SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<sch
if (cnode->inputs().size() <= 1) { if (cnode->inputs().size() <= 1) {
return RET_OK; return RET_OK;
} }
bool is_graph_input = true; bool is_graph_input = false;
for (size_t i = 1; i < cnode->inputs().size(); i++) { for (size_t i = 1; i < cnode->inputs().size(); i++) {
auto input_node = cnode->input(i); auto input_node = cnode->input(i);
if (input_node->isa<CNode>()) { if (input_node->isa<CNode>()) {
is_graph_input = false;
auto ret = ConvertInputCNode(input_node, fb_node); auto ret = ConvertInputCNode(input_node, fb_node);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvertInputCNode failed"; MS_LOG(ERROR) << "ConvertInputCNode failed";
@ -357,6 +355,9 @@ int AnfExporter::SetOpInputNode(const CNodePtr &cnode, const std::unique_ptr<sch
MS_LOG(ERROR) << "ConvertInputParameter failed"; MS_LOG(ERROR) << "ConvertInputParameter failed";
return RET_ERROR; return RET_ERROR;
} }
if (!input_node->cast<ParameterPtr>()->has_default()) {
is_graph_input = true;
}
} else if (input_node->isa<ValueNode>()) { } else if (input_node->isa<ValueNode>()) {
auto ret = ConvertInputValueNode(input_node, meta_graphT, fb_node); auto ret = ConvertInputValueNode(input_node, meta_graphT, fb_node);
if (ret != RET_OK) { if (ret != RET_OK) {
@ -382,7 +383,7 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s
auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(cnode->abstract()); auto tuple = std::reinterpret_pointer_cast<abstract::AbstractTuple>(cnode->abstract());
for (size_t i = 0; i < tuple->size(); i++) { for (size_t i = 0; i < tuple->size(); i++) {
auto msTensor = new schema::TensorT(); auto msTensor = new schema::TensorT();
msTensor->nodeType = schema::NodeType_Parameter; msTensor->nodeType = schema::NodeType_CNode;
fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size()); fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size());
if (tuple->size() == 1) { if (tuple->size() == 1) {
node_id_map_[cnode_name] = meta_graphT->allTensors.size(); node_id_map_[cnode_name] = meta_graphT->allTensors.size();
@ -399,7 +400,7 @@ void AnfExporter::SetOpOutputNode(const CNodePtr &cnode, const std::unique_ptr<s
} }
} else { } else {
auto ms_tensor = new schema::TensorT(); auto ms_tensor = new schema::TensorT();
ms_tensor->nodeType = schema::NodeType_Parameter; ms_tensor->nodeType = schema::NodeType_CNode;
fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size()); fb_node->outputIndex.emplace_back(meta_graphT->allTensors.size());
node_id_map_[cnode_name] = meta_graphT->allTensors.size(); node_id_map_[cnode_name] = meta_graphT->allTensors.size();
meta_graphT->allTensors.emplace_back(ms_tensor); meta_graphT->allTensors.emplace_back(ms_tensor);

@ -59,8 +59,8 @@ int AnfImporterFromMetaGraphT::ConverterConstTensor() {
std::memcpy(tensor_data, tensor->data.data(), size); std::memcpy(tensor_data, tensor->data.data(), size);
param_value->set_tensor_addr(tensor_data); param_value->set_tensor_addr(tensor_data);
param_value->set_tensor_size(size); param_value->set_tensor_size(size);
parameter->set_default_param(param_value);
} }
parameter->set_default_param(param_value);
AddNode(i, parameter); AddNode(i, parameter);
} }
return RET_OK; return RET_OK;

Loading…
Cancel
Save