|
|
@ -38,12 +38,8 @@ int ConvolutionOpenCLKernel::Init() {
|
|
|
|
MS_LOG(ERROR) << "ConvolutionOpenCLKernel only support Batch=1!";
|
|
|
|
MS_LOG(ERROR) << "ConvolutionOpenCLKernel only support Batch=1!";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
auto io_NHWC = inputs_[0]->GetFormat() == schema::Format_NHWC && outputs_[0]->GetFormat() == schema::Format_NHWC;
|
|
|
|
outputs_[0]->SetFormat(schema::Format_NHWC4);
|
|
|
|
auto io_NHWC4 = inputs_[0]->GetFormat() == schema::Format_NHWC4 && outputs_[0]->GetFormat() == schema::Format_NHWC4;
|
|
|
|
io_dataformat_ = outputs_[0]->GetFormat();
|
|
|
|
if (!io_NHWC && !io_NHWC4) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "input and output data_format is invalid!";
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
io_dataformat_ = inputs_[0]->GetFormat();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (inputs_[1]->GetFormat() != schema::Format_KHWC) {
|
|
|
|
if (inputs_[1]->GetFormat() != schema::Format_KHWC) {
|
|
|
|
MS_LOG(ERROR) << "weight data_format is invalid!";
|
|
|
|
MS_LOG(ERROR) << "weight data_format is invalid!";
|
|
|
@ -62,6 +58,7 @@ int ConvolutionOpenCLKernel::Init() {
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
|
|
this->InitBuffer();
|
|
|
|
this->InitBuffer();
|
|
|
|
|
|
|
|
MS_LOG(DEBUG) << kernel_name << " Init Done!";
|
|
|
|
return 0;
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
int ConvolutionOpenCLKernel::InitBuffer() {
|
|
|
|
int ConvolutionOpenCLKernel::InitBuffer() {
|
|
|
@ -123,7 +120,7 @@ int ConvolutionOpenCLKernel::InitBuffer() {
|
|
|
|
int ConvolutionOpenCLKernel::ReSize() { return 0; }
|
|
|
|
int ConvolutionOpenCLKernel::ReSize() { return 0; }
|
|
|
|
|
|
|
|
|
|
|
|
int ConvolutionOpenCLKernel::Run() {
|
|
|
|
int ConvolutionOpenCLKernel::Run() {
|
|
|
|
MS_LOG(INFO) << "ConvolutionOpenCLKernel::Run()";
|
|
|
|
MS_LOG(DEBUG) << this->Name() << " Running!";
|
|
|
|
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
|
|
|
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
|
|
|
|
|
|
|
|
|
|
|
auto param = reinterpret_cast<ConvParameter *>(opParameter);
|
|
|
|
auto param = reinterpret_cast<ConvParameter *>(opParameter);
|
|
|
|