You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/third_party/patch/incubator-tvm/src_pass.patch

121 lines
4.2 KiB

diff -Npur tvm/src/pass/make_api.cc tvm_new/src/pass/make_api.cc
--- tvm/src/pass/make_api.cc 2019-12-14 15:11:37.626419432 +0800
+++ tvm_new/src/pass/make_api.cc 2019-12-14 14:58:46.562493287 +0800
@@ -20,6 +20,11 @@
/*!
* \file make_api.cc Build API function.
*/
+
+/*
+ * 2019.12.30 - Define new function to push buffer node from api_args to args_real.
+ */
+
#include <tvm/ir_pass.h>
#include <tvm/ir.h>
#include <tvm/ir_visitor.h>
@@ -40,6 +45,17 @@ inline Stmt MakeAssertEQ(Expr lhs, Expr
return AssertStmt::make(lhs == rhs, msg, Evaluate::make(0));
}
+Array<Var> Param ( Array<NodeRef> api_args,Array<Var> args_real) {
+ int num_args = static_cast<int>(api_args.size());
+ for (int i = 0; i < num_args; i++) {
+ const BufferNode *v = api_args[i].as<BufferNode>();
+ if(v) {
+ args_real.push_back(v->data);
+ }
+ }
+ return args_real;
+}
+
LoweredFunc MakeAPI(Stmt body,
std::string name,
Array<NodeRef> api_args,
@@ -47,6 +63,8 @@ LoweredFunc MakeAPI(Stmt body,
bool is_restricted) {
const Stmt nop = Evaluate::make(0);
int num_args = static_cast<int>(api_args.size());
+ Array<Var> args_real;
+ args_real = Param (api_args, args_real);
CHECK_LE(num_unpacked_args, num_args);
int num_packed_args = num_args - num_unpacked_args;
// Data field definitions
@@ -170,6 +188,7 @@ LoweredFunc MakeAPI(Stmt body,
NodePtr<LoweredFuncNode> n = make_node<LoweredFuncNode>();
n->name = name;
n->args = args;
+ n->args_real = args_real;
n->handle_data_type = binder.def_handle_dtype();
n->is_packed_func = num_unpacked_args == 0;
n->is_restricted = is_restricted;
diff -Npur tvm/src/pass/split_host_device.cc tvm_new/src/pass/split_host_device.cc
--- tvm/src/pass/split_host_device.cc 2019-12-14 15:11:37.626419432 +0800
+++ tvm_new/src/pass/split_host_device.cc 2019-12-14 11:28:49.293979656 +0800
@@ -21,6 +21,11 @@
* \file split_host_device.cc
* \brief Split device function from host.
*/
+
+/*
+ * 2019.12.30 - Add new implements for host device splitter.
+ */
+
#include <tvm/ir.h>
#include <tvm/lowered_func.h>
#include <tvm/channel.h>
@@ -38,6 +43,7 @@ class IRUseDefAnalysis : public IRMutato
Stmt Mutate_(const AttrStmt *op, const Stmt& s) final {
if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
+ iv = IterVarNode::make(Range(0, op->value), iv->var, iv->iter_type, iv->thread_tag);
CHECK_NE(iv->thread_tag.length(), 0U);
// thread_extent can appear multiple times
// use the first appearance as def.
@@ -186,6 +192,7 @@ class HostDeviceSplitter : public IRMuta
name_ = f->name;
NodePtr<LoweredFuncNode> n =
make_node<LoweredFuncNode>(*f.operator->());
+ args_real = n->args_real;
n->body = this->Mutate(f->body);
n->func_type = kHostFunc;
Array<LoweredFunc> ret{LoweredFunc(n)};
@@ -196,6 +203,7 @@ class HostDeviceSplitter : public IRMuta
}
private:
+ Array<Var> args_real;
Stmt SplitDeviceFunc(Stmt body) {
std::ostringstream os;
os << name_ << "_kernel" << device_funcs_.size();
@@ -223,6 +231,30 @@ class HostDeviceSplitter : public IRMuta
n->args.push_back(v);
}
}
+std::shared_ptr<LoweredFuncNode> na = std::make_shared<LoweredFuncNode>();
+ for (unsigned i = 0; i < (unsigned)args_real.size(); i++) {
+ bool match = false;
+ for (unsigned j = 0; j < (unsigned)n->args.size(); j++) {
+ if (strcmp(args_real[i].get()->name_hint.c_str(), n->args[j].get()->name_hint.c_str()) == 0) {
+ na->args.push_back(n->args[j]);
+ match = true;
+ break;
+ } else {
+ continue;
+ }
+ }
+
+ if (!match) {
+ na->args.push_back(args_real[i]);
+ // mark handle data type.
+ for (auto kv : handle_data_type_) {
+ if (strcmp(args_real[i].get()->name_hint.c_str(), kv.first->name_hint.c_str()) == 0) {
+ n->handle_data_type.Set(args_real[i], kv.second);
+ }
+ }
+ }
+ }
+ n->args = na->args;
LoweredFunc f_device(n);
Array<Expr> call_args;
call_args.push_back(StringImm::make(f_device->name));