You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
121 lines
4.2 KiB
121 lines
4.2 KiB
diff -Npur tvm/src/pass/make_api.cc tvm_new/src/pass/make_api.cc
|
|
--- tvm/src/pass/make_api.cc 2019-12-14 15:11:37.626419432 +0800
|
|
+++ tvm_new/src/pass/make_api.cc 2019-12-14 14:58:46.562493287 +0800
|
|
@@ -20,6 +20,11 @@
|
|
/*!
|
|
* \file make_api.cc Build API function.
|
|
*/
|
|
+
|
|
+/*
|
|
+ * 2019.12.30 - Define new function to push buffer node from api_args to args_real.
|
|
+ */
|
|
+
|
|
#include <tvm/ir_pass.h>
|
|
#include <tvm/ir.h>
|
|
#include <tvm/ir_visitor.h>
|
|
@@ -40,6 +45,17 @@ inline Stmt MakeAssertEQ(Expr lhs, Expr
|
|
return AssertStmt::make(lhs == rhs, msg, Evaluate::make(0));
|
|
}
|
|
|
|
+Array<Var> Param ( Array<NodeRef> api_args,Array<Var> args_real) {
|
|
+ int num_args = static_cast<int>(api_args.size());
|
|
+ for (int i = 0; i < num_args; i++) {
|
|
+ const BufferNode *v = api_args[i].as<BufferNode>();
|
|
+ if(v) {
|
|
+ args_real.push_back(v->data);
|
|
+ }
|
|
+ }
|
|
+ return args_real;
|
|
+}
|
|
+
|
|
LoweredFunc MakeAPI(Stmt body,
|
|
std::string name,
|
|
Array<NodeRef> api_args,
|
|
@@ -47,6 +63,8 @@ LoweredFunc MakeAPI(Stmt body,
|
|
bool is_restricted) {
|
|
const Stmt nop = Evaluate::make(0);
|
|
int num_args = static_cast<int>(api_args.size());
|
|
+ Array<Var> args_real;
|
|
+ args_real = Param (api_args, args_real);
|
|
CHECK_LE(num_unpacked_args, num_args);
|
|
int num_packed_args = num_args - num_unpacked_args;
|
|
// Data field definitions
|
|
@@ -170,6 +188,7 @@ LoweredFunc MakeAPI(Stmt body,
|
|
NodePtr<LoweredFuncNode> n = make_node<LoweredFuncNode>();
|
|
n->name = name;
|
|
n->args = args;
|
|
+ n->args_real = args_real;
|
|
n->handle_data_type = binder.def_handle_dtype();
|
|
n->is_packed_func = num_unpacked_args == 0;
|
|
n->is_restricted = is_restricted;
|
|
diff -Npur tvm/src/pass/split_host_device.cc tvm_new/src/pass/split_host_device.cc
|
|
--- tvm/src/pass/split_host_device.cc 2019-12-14 15:11:37.626419432 +0800
|
|
+++ tvm_new/src/pass/split_host_device.cc 2019-12-14 11:28:49.293979656 +0800
|
|
@@ -21,6 +21,11 @@
|
|
* \file split_host_device.cc
|
|
* \brief Split device function from host.
|
|
*/
|
|
+
|
|
+/*
|
|
+ * 2019.12.30 - Add new implements for host device splitter.
|
|
+ */
|
|
+
|
|
#include <tvm/ir.h>
|
|
#include <tvm/lowered_func.h>
|
|
#include <tvm/channel.h>
|
|
@@ -38,6 +43,7 @@ class IRUseDefAnalysis : public IRMutato
|
|
Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
|
|
if (op->attr_key == attr::thread_extent) {
|
|
IterVar iv = Downcast<IterVar>(op->node);
|
|
+ iv = IterVarNode::make(Range(0, op->value), iv->var, iv->iter_type, iv->thread_tag);
|
|
CHECK_NE(iv->thread_tag.length(), 0U);
|
|
// thread_extent can appear multiple times
|
|
// use the first appearance as def.
|
|
@@ -186,6 +192,7 @@ class HostDeviceSplitter : public IRMuta
|
|
name_ = f->name;
|
|
NodePtr<LoweredFuncNode> n =
|
|
make_node<LoweredFuncNode>(*f.operator->());
|
|
+ args_real = n->args_real;
|
|
n->body = this->Mutate(f->body);
|
|
n->func_type = kHostFunc;
|
|
Array<LoweredFunc> ret{LoweredFunc(n)};
|
|
@@ -196,6 +203,7 @@ class HostDeviceSplitter : public IRMuta
|
|
}
|
|
|
|
private:
|
|
+ Array<Var> args_real;
|
|
Stmt SplitDeviceFunc(Stmt body) {
|
|
std::ostringstream os;
|
|
os << name_ << "_kernel" << device_funcs_.size();
|
|
@@ -223,6 +231,30 @@ class HostDeviceSplitter : public IRMuta
|
|
n->args.push_back(v);
|
|
}
|
|
}
|
|
+std::shared_ptr<LoweredFuncNode> na = std::make_shared<LoweredFuncNode>();
|
|
+ for (unsigned i = 0; i < (unsigned)args_real.size(); i++) {
|
|
+ bool match = false;
|
|
+ for (unsigned j = 0; j < (unsigned)n->args.size(); j++) {
|
|
+ if (strcmp(args_real[i].get()->name_hint.c_str(), n->args[j].get()->name_hint.c_str()) == 0) {
|
|
+ na->args.push_back(n->args[j]);
|
|
+ match = true;
|
|
+ break;
|
|
+ } else {
|
|
+ continue;
|
|
+ }
|
|
+ }
|
|
+
|
|
+ if (!match) {
|
|
+ na->args.push_back(args_real[i]);
|
|
+ // mark handle data type.
|
|
+ for (auto kv : handle_data_type_) {
|
|
+ if (strcmp(args_real[i].get()->name_hint.c_str(), kv.first->name_hint.c_str()) == 0) {
|
|
+ n->handle_data_type.Set(args_real[i], kv.second);
|
|
+ }
|
|
+ }
|
|
+ }
|
|
+ }
|
|
+ n->args = na->args;
|
|
LoweredFunc f_device(n);
|
|
Array<Expr> call_args;
|
|
call_args.push_back(StringImm::make(f_device->name));
|