fix some problem of warpctc network in pynative mode

pull/9082/head
lvchangquan 4 years ago
parent f4c126ddeb
commit 5afd5d6934

@ -194,6 +194,7 @@ void RunOpAscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_g
data_layout_pm->AddPass(std::make_shared<GetitemTuple>()); data_layout_pm->AddPass(std::make_shared<GetitemTuple>());
data_layout_pm->AddPass(std::make_shared<CommonSubexpressionElimination>()); data_layout_pm->AddPass(std::make_shared<CommonSubexpressionElimination>());
data_layout_pm->AddPass(std::make_shared<EliminateRedundantOp>()); data_layout_pm->AddPass(std::make_shared<EliminateRedundantOp>());
data_layout_pm->AddPass(std::make_shared<InsertTransposeForDynamicGRUV2>());
data_layout_pm->AddPass(std::make_shared<OptimizeDependence>()); data_layout_pm->AddPass(std::make_shared<OptimizeDependence>());
data_layout_pm->AddPass(std::make_shared<TransDataSplit>()); data_layout_pm->AddPass(std::make_shared<TransDataSplit>());
data_layout_pm->AddPass(std::make_shared<EraseVisitAttr>()); data_layout_pm->AddPass(std::make_shared<EraseVisitAttr>());
@ -330,6 +331,9 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne
ir_fusion_pm->AddPass(std::make_shared<AddnFission>()); ir_fusion_pm->AddPass(std::make_shared<AddnFission>());
ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>()); ir_fusion_pm->AddPass(std::make_shared<InsertPadForNMSWithMask>());
ir_fusion_pm->AddPass(std::make_shared<TensorScatterUpdateFission>()); ir_fusion_pm->AddPass(std::make_shared<TensorScatterUpdateFission>());
ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicRNN>());
ir_fusion_pm->AddPass(std::make_shared<InsertPlaceholderForDynamicGRUV2>());
ir_fusion_pm->AddPass(std::make_shared<EraseVisitAttr>());
optimizer->AddPassManager(ir_fusion_pm); optimizer->AddPassManager(ir_fusion_pm);
(void)optimizer->Optimize(kernel_graph); (void)optimizer->Optimize(kernel_graph);

@ -57,7 +57,6 @@ const AnfNodePtr InsertPlaceholderForDynamicRNN::Process(const FuncGraphPtr &fun
auto value_node = NewValueNode(value); auto value_node = NewValueNode(value);
value_node->set_abstract(std::make_shared<abstract::AbstractNone>()); value_node->set_abstract(std::make_shared<abstract::AbstractNone>());
auto new_node = kernel_graph->NewValueNode(value_node); auto new_node = kernel_graph->NewValueNode(value_node);
kernel_graph->AddValueNodeToGraph(new_node);
new_inputs.push_back(new_node); new_inputs.push_back(new_node);
} }
new_inputs.push_back(input_node); new_inputs.push_back(input_node);

@ -363,9 +363,6 @@ void SetTensorDeviceInfo(const CNodePtr &kernel_node) {
if (real_input_node->isa<Parameter>() && !AnfAlgo::IsParameterWeight(real_input_node->cast<ParameterPtr>())) { if (real_input_node->isa<Parameter>() && !AnfAlgo::IsParameterWeight(real_input_node->cast<ParameterPtr>())) {
continue; continue;
} }
if (selected_kernel_info->GetInputFormat(input_index) == kOpFormat_FRACTAL_ZN_LSTM) {
continue;
}
// we set special device info of a input tensor. // we set special device info of a input tensor.
bool is_ref = false; bool is_ref = false;
auto op_info = kernel::tbe::TbeDynamicShapeUtil::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel_node); auto op_info = kernel::tbe::TbeDynamicShapeUtil::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel_node);
@ -376,9 +373,12 @@ void SetTensorDeviceInfo(const CNodePtr &kernel_node) {
continue; continue;
} }
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
std::vector<std::string> output_format = {AnfAlgo::GetOutputFormat(real_input_node, 0)};
if (IsValueNode<tensor::Tensor>(input_kernel_node) && if (IsValueNode<tensor::Tensor>(input_kernel_node) &&
AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) { AnfAlgo::GetOutputDeviceDataType(input_kernel_node, 0) == kTypeUnknown) {
std::vector<std::string> output_format = {selected_kernel_info->GetInputFormat(input_index)}; if (selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_ZN_LSTM) {
output_format = {selected_kernel_info->GetInputFormat(input_index)};
}
builder->SetOutputsFormat(output_format); builder->SetOutputsFormat(output_format);
std::vector<TypeId> output_type = {selected_kernel_info->GetInputDeviceType(input_index)}; std::vector<TypeId> output_type = {selected_kernel_info->GetInputDeviceType(input_index)};
builder->SetOutputsDeviceType(output_type); builder->SetOutputsDeviceType(output_type);
@ -386,7 +386,9 @@ void SetTensorDeviceInfo(const CNodePtr &kernel_node) {
continue; continue;
} }
if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) { if (AnfAlgo::GetOutputDeviceDataType(real_input_node, 0) == kTypeUnknown || is_ref) {
std::vector<std::string> output_format = {selected_kernel_info->GetInputFormat(input_index)}; if (selected_kernel_info->GetInputFormat(input_index) != kOpFormat_FRACTAL_ZN_LSTM) {
output_format = {selected_kernel_info->GetInputFormat(input_index)};
}
builder->SetOutputsFormat(output_format); builder->SetOutputsFormat(output_format);
std::vector<TypeId> output_type = {selected_kernel_info->GetInputDeviceType(input_index)}; std::vector<TypeId> output_type = {selected_kernel_info->GetInputDeviceType(input_index)};
builder->SetOutputsDeviceType(output_type); builder->SetOutputsDeviceType(output_type);

@ -751,6 +751,18 @@ void KernelRuntime::GenLaunchArgs(const mindspore::kernel::KernelMod &kernel_mod
return GenAddrCleanLaunchArgs(cnode, kernel_inputs); return GenAddrCleanLaunchArgs(cnode, kernel_inputs);
} }
for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) {
auto op_name = AnfAlgo::GetCNodeName(cnode);
constexpr auto none_placeholder_index = 3;
if (op_name == kDynamicRNNOpName && i == none_placeholder_index) {
continue;
}
if (op_name == kDynamicGRUV2OpName) {
auto none_index = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(cnode, "placeholder_index");
auto item = std::find(none_index.begin(), none_index.end(), i);
if (item != none_index.end()) {
continue;
}
}
auto real_input = AnfAlgo::GetRealInputIndex(kernel, i); auto real_input = AnfAlgo::GetRealInputIndex(kernel, i);
auto device_address = AnfAlgo::GetPrevNodeOutputAddr(kernel, real_input); auto device_address = AnfAlgo::GetPrevNodeOutputAddr(kernel, real_input);
MS_EXCEPTION_IF_NULL(device_address); MS_EXCEPTION_IF_NULL(device_address);

Loading…
Cancel
Save