diff --git a/include/tvm/runtime/registry.h b/include/tvm/runtime/registry.h index d668984..3676a61 100644 --- a/include/tvm/runtime/registry.h +++ b/include/tvm/runtime/registry.h @@ -319,6 +319,19 @@ class Registry { #define TVM_REGISTER_EXT_TYPE(T) \ TVM_STR_CONCAT(TVM_TYPE_REG_VAR_DEF, __COUNTER__) = \ ::tvm::runtime::ExtTypeVTable::Register_() +/* + * Macro transfer TVM runtime API to custom runtime API + */ +#define TVM_RT_FUNC_TRANS(OrigFuncStr) ({ \ + const runtime::PackedFunc* trans_func = runtime::Registry::Get("codegen.GetTransRTFunc");\ + const char* dst_func_str = nullptr; \ + if( trans_func != nullptr){ \ + dst_func_str = ((*trans_func)(OrigFuncStr)).ptr(); \ + }else{ \ + dst_func_str = OrigFuncStr; \ + } \ + dst_func_str; \ +}) } // namespace runtime } // namespace tvm diff --git a/src/codegen/llvm/codegen_cpu.cc b/src/codegen/llvm/codegen_cpu.cc index 0ba0c58..2850ad4 100644 --- a/src/codegen/llvm/codegen_cpu.cc +++ b/src/codegen/llvm/codegen_cpu.cc @@ -99,26 +99,26 @@ void CodeGenCPU::Init(const std::string& module_name, // We will need this in environment for backward registration. f_tvm_register_system_symbol_ = llvm::Function::Create( llvm::FunctionType::get(t_int_, {t_char_->getPointerTo(), t_void_p_}, false), - llvm::Function::ExternalLinkage, "TVMBackendRegisterSystemLibSymbol", module_.get()); + llvm::Function::ExternalLinkage, TVM_RT_FUNC_TRANS("TVMBackendRegisterSystemLibSymbol"), module_.get()); } else { f_tvm_register_system_symbol_ = nullptr; } if (dynamic_lookup || system_lib) { f_tvm_func_call_ = llvm::Function::Create( ftype_tvm_func_call_, - llvm::Function::ExternalLinkage, "TVMFuncCall", module_.get()); + llvm::Function::ExternalLinkage, TVM_RT_FUNC_TRANS("TVMFuncCall"), module_.get()); f_tvm_get_func_from_env_ = llvm::Function::Create( ftype_tvm_get_func_from_env_, llvm::Function::ExternalLinkage, "TVMBackendGetFuncFromEnv", module_.get()); f_tvm_api_set_last_error_ = llvm::Function::Create( ftype_tvm_api_set_last_error_, - llvm::Function::ExternalLinkage, "TVMAPISetLastError", module_.get()); + llvm::Function::ExternalLinkage, TVM_RT_FUNC_TRANS("TVMAPISetLastError"), module_.get()); f_tvm_parallel_launch_ = llvm::Function::Create( ftype_tvm_parallel_launch_, - llvm::Function::ExternalLinkage, "TVMBackendParallelLaunch", module_.get()); + llvm::Function::ExternalLinkage, TVM_RT_FUNC_TRANS("TVMBackendParallelLaunch"), module_.get()); f_tvm_parallel_barrier_ = llvm::Function::Create( ftype_tvm_parallel_barrier_, - llvm::Function::ExternalLinkage, "TVMBackendParallelBarrier", module_.get()); + llvm::Function::ExternalLinkage, TVM_RT_FUNC_TRANS("TVMBackendParallelBarrier"), module_.get()); } this->InitGlobalContext(dynamic_lookup); } @@ -461,11 +461,14 @@ void CodeGenCPU::CreateComputeScope(const AttrStmt* op) { } std::swap(function_, fcompute); std::swap(new_vmap, var_map_); + std::stack br_ret_flg; + std::swap(br_ret_flg, br_ret_flg_); BasicBlock *compute_entry = BasicBlock::Create(*ctx_, "entry", function_); builder_->SetInsertPoint(compute_entry); this->VisitStmt(op->body); builder_->CreateRet(ConstInt32(0)); // swap the var map back, now we are back on track. + std::swap(br_ret_flg, br_ret_flg_); std::swap(new_vmap, var_map_); std::swap(function_, fcompute); builder_->SetInsertPoint(compute_call_end); @@ -542,9 +545,12 @@ void CodeGenCPU::CreateParallelLaunch(const Stmt& body, int num_task) { std::swap(function_, f); std::swap(parallel_env_, par_env); std::swap(var_map_, new_vmap); + std::stack br_ret_flg; + std::swap(br_ret_flg, br_ret_flg_); this->VisitStmt(body); builder_->CreateRet(ConstInt32(0)); // swap the var map back, now we are back on track. + std::swap(br_ret_flg, br_ret_flg_); std::swap(var_map_, new_vmap); std::swap(parallel_env_, par_env); std::swap(function_, f); @@ -794,7 +800,9 @@ llvm::Value* CodeGenCPU::CreateIntrinsic(const Call* op) { } else if (op->is_intrinsic(intrinsic::tvm_static_handle)) { return CreateStaticHandle(); } else if (op->is_intrinsic(intrinsic::tvm_throw_last_error)) { - builder_->CreateRet(ConstInt32(-1)); + llvm::Value* pRetCode = (op->args.size() == 0) ? ConstInt32(-1) : MakeValue(op->args[0]); + builder_->CreateRet(pRetCode); + CodeGenLLVM::SetRetTrFlg(true); return ConstInt32(-1); } else if (op->is_intrinsic(intrinsic::tvm_struct_get)) { CHECK_EQ(op->args.size(), 3U); diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index 2cff88b..e26812d 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -1110,23 +1110,37 @@ void CodeGenLLVM::VisitStmt_(const IfThenElse* op) { *ctx_, "if_then", function_); BasicBlock* end_block = BasicBlock::Create( *ctx_, "if_end", function_); + // define ret terminitor exist flg for this Stmt + bool cur_br_ret_flg = false; + br_ret_flg_.push(&cur_br_ret_flg); if (op->else_case.defined()) { BasicBlock* else_block = BasicBlock::Create( *ctx_, "if_else", function_); builder_->CreateCondBr(cond, then_block, else_block); builder_->SetInsertPoint(then_block); + cur_br_ret_flg = false; this->VisitStmt(op->then_case); builder_->CreateBr(end_block); + if ( !cur_br_ret_flg ){ + builder_->CreateBr(end_block); + } builder_->SetInsertPoint(else_block); + cur_br_ret_flg = false; this->VisitStmt(op->else_case); - builder_->CreateBr(end_block); + if ( !cur_br_ret_flg ){ + builder_->CreateBr(end_block); + } } else { builder_->CreateCondBr(cond, then_block, end_block, md_very_likely_branch_); builder_->SetInsertPoint(then_block); + cur_br_ret_flg = false; this->VisitStmt(op->then_case); - builder_->CreateBr(end_block); + if ( !cur_br_ret_flg ){ + builder_->CreateBr(end_block); + } } builder_->SetInsertPoint(end_block); + br_ret_flg_.pop(); } diff --git a/src/codegen/llvm/codegen_llvm.h b/src/codegen/llvm/codegen_llvm.h index b7d091b..6fba863 100644 --- a/src/codegen/llvm/codegen_llvm.h +++ b/src/codegen/llvm/codegen_llvm.h @@ -143,6 +143,11 @@ class CodeGenLLVM : void VisitStmt_(const Block* op) override; void VisitStmt_(const Evaluate* op) override; void VisitStmt_(const ProducerConsumer* op) override; + //Set IfThelElse branch exist Return flg + void SetRetTrFlg(bool RetFlg){ + if( !br_ret_flg_.empty() ) + *(br_ret_flg_.top()) = RetFlg; + } protected: /*! \brief The storage information */ @@ -304,6 +309,12 @@ class CodeGenLLVM : * initializes file and compilation_unit_ to TVM defaults. */ static std::unique_ptr CreateDebugInfo(llvm::Module* module); + + /* + * IfThenElse stmt branch return flg store stack + * if a branch already return, can't add br terminator again + */ + std::stack br_ret_flg_; }; } // namespace codegen } // namespace tvm diff --git a/src/pass/lower_tvm_builtin.cc b/src/pass/lower_tvm_builtin.cc index e73956c..3a7b46c 100644 --- a/src/pass/lower_tvm_builtin.cc +++ b/src/pass/lower_tvm_builtin.cc @@ -104,7 +104,7 @@ class BuiltinLower : public IRMutator { CHECK(device_type_.defined()) << "Unknown device type in current IR"; CHECK(device_id_.defined()) << "Unknown device id in current IR"; Stmt throw_last_error = Evaluate::make(Call::make(Int(32), - intrinsic::tvm_throw_last_error, {}, + intrinsic::tvm_throw_last_error, {(Int(32), 1001)}, Call::Intrinsic)); Stmt body = Block::make( @@ -117,7 +117,7 @@ class BuiltinLower : public IRMutator { Stmt alloca = LetStmt::make( op->buffer_var, Call::make(op->buffer_var.type(), - "TVMBackendAllocWorkspace", + TVM_RT_FUNC_TRANS("TVMBackendAllocWorkspace"), {cast(Int(32), device_type_), cast(Int(32), device_id_), cast(UInt(64), total_bytes), @@ -127,7 +127,7 @@ class BuiltinLower : public IRMutator { body); Expr free_op = Call::make(Int(32), - "TVMBackendFreeWorkspace", + TVM_RT_FUNC_TRANS("TVMBackendFreeWorkspace"), {cast(Int(32), device_type_), cast(Int(32), device_id_), op->buffer_var},