|
|
|
@ -15,6 +15,7 @@
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
#include "src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
|
|
|
|
|
#include <set>
|
|
|
|
|
#include "src/runtime/opencl/opencl_executor.h"
|
|
|
|
|
#include "src/runtime/opencl/opencl_runtime.h"
|
|
|
|
|
#include "src/runtime/kernel/opencl/utils.h"
|
|
|
|
@ -181,11 +182,31 @@ int SubGraphOpenCLKernel::Init() {
|
|
|
|
|
}
|
|
|
|
|
nodes_.insert(nodes_.end(), out_convert_ops_.begin(), out_convert_ops_.end());
|
|
|
|
|
|
|
|
|
|
UpdateTensorDataType();
|
|
|
|
|
|
|
|
|
|
MallocTensorWithReuse();
|
|
|
|
|
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int SubGraphOpenCLKernel::UpdateTensorDataType() {
|
|
|
|
|
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
|
|
|
|
bool is_fp16 = ocl_runtime->GetFp16Enable();
|
|
|
|
|
if (is_fp16 && (in_tensors_[0]->data_type() == kNumberTypeFloat32)) {
|
|
|
|
|
std::set<lite::tensor::Tensor *> out_set;
|
|
|
|
|
out_set.insert(in_tensors_.begin(), in_tensors_.end());
|
|
|
|
|
out_set.insert(out_tensors_.begin(), out_tensors_.end());
|
|
|
|
|
for (auto iv : nodes_) {
|
|
|
|
|
auto cur_outs = iv->out_tensors();
|
|
|
|
|
for (auto jv : cur_outs) {
|
|
|
|
|
if (out_set.count(jv) == 0) {
|
|
|
|
|
jv->set_data_type(kNumberTypeFloat16);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
int SubGraphOpenCLKernel::MallocTensorWithReuse() {
|
|
|
|
|
kernel::LiteKernelUtil::InitTensorRefCount(nodes_);
|
|
|
|
|
for (auto *kernel : nodes_) {
|
|
|
|
|