Add mkldnn nearest_interp and bilinear_interp op (#30016)
* Add mkldnn nearest_interp and bilinear_interp op * don't run mkldnn interpolate in default * add interpolate_mkldnn_passrevert-31562-mean
parent
65d4ff753b
commit
c3c064a8fc
@ -0,0 +1,67 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/framework/ir/mkldnn/interpolate_mkldnn_pass.h"
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/platform/enforce.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
class OpDesc;
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
class Graph;
|
||||
|
||||
void InterpolateMKLDNNPass::ApplyImpl(ir::Graph* graph) const {
|
||||
PADDLE_ENFORCE_NOT_NULL(graph,
|
||||
platform::errors::InvalidArgument(
|
||||
"Pointer to graph argument should not be NULL."));
|
||||
if (!(graph->Has("use_mkldnn") && graph->Get<bool>("use_mkldnn"))) {
|
||||
VLOG(3) << "Do not handle interpolate_mkldnn_pass";
|
||||
return;
|
||||
}
|
||||
VLOG(4) << "Handle interpolate_mkldnn_pass";
|
||||
|
||||
Init("interpolate_mkldnn_pass", graph);
|
||||
|
||||
int found_count = 0;
|
||||
const std::vector<std::string> interpolate_op_types = {
|
||||
"bilinear_interp", "nearest_interp", "trilinear_interp", "bicubic_interp",
|
||||
"linear_interp"};
|
||||
|
||||
for (const Node* node : graph->Nodes()) {
|
||||
if (node->IsOp() &&
|
||||
std::find(interpolate_op_types.begin(), interpolate_op_types.end(),
|
||||
node->Name()) != interpolate_op_types.end()) {
|
||||
auto* op_desc = node->Op();
|
||||
op_desc->SetAttr("use_mkldnn", true);
|
||||
++found_count;
|
||||
}
|
||||
}
|
||||
|
||||
AddStatis(found_count);
|
||||
}
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
||||
|
||||
REGISTER_PASS(interpolate_mkldnn_pass,
|
||||
paddle::framework::ir::InterpolateMKLDNNPass);
|
@ -0,0 +1,41 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
#pragma once
|
||||
#include <memory>
|
||||
|
||||
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
|
||||
#include "paddle/fluid/framework/ir/graph.h"
|
||||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
|
||||
#include "paddle/fluid/framework/ir/pass.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
namespace ir {
|
||||
|
||||
/*
|
||||
* Change the interpolate op to run MKLDNN.
|
||||
*/
|
||||
class Graph;
|
||||
|
||||
class InterpolateMKLDNNPass : public FusePassBase {
|
||||
public:
|
||||
virtual ~InterpolateMKLDNNPass() {}
|
||||
|
||||
protected:
|
||||
void ApplyImpl(ir::Graph* graph) const override;
|
||||
};
|
||||
|
||||
} // namespace ir
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,174 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/data_layout_transform.h"
|
||||
#include "paddle/fluid/operators/interpolate_op.h"
|
||||
#include "paddle/fluid/platform/mkldnn_reuse.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::DataLayout;
|
||||
using dnnl::memory;
|
||||
using dnnl::primitive;
|
||||
using dnnl::reorder;
|
||||
using dnnl::stream;
|
||||
using dnnl::resampling_forward;
|
||||
using platform::GetMKLDNNFormat;
|
||||
using platform::to_void_cast;
|
||||
|
||||
template <typename T = float>
|
||||
class InterpolateMKLDNNHandler
|
||||
: public platform::MKLDNNHandlerT<T, dnnl::resampling_forward> {
|
||||
public:
|
||||
InterpolateMKLDNNHandler(const dnnl::algorithm algo,
|
||||
const paddle::platform::MKLDNNDeviceContext& dev_ctx,
|
||||
const dnnl::engine engine, platform::Place cpu_place,
|
||||
const Tensor* x, Tensor* z,
|
||||
const std::string& uniq_name)
|
||||
: platform::MKLDNNHandlerT<T, dnnl::resampling_forward>(
|
||||
dev_ctx, engine, cpu_place,
|
||||
platform::CreateKey(dev_ctx, framework::vectorize(x->dims()),
|
||||
uniq_name)) {
|
||||
if (!this->isCached()) {
|
||||
const auto src_x_tz = framework::vectorize(x->dims());
|
||||
const auto dst_tz = framework::vectorize(z->dims());
|
||||
const auto src_md = dnnl::memory::desc(
|
||||
src_x_tz, platform::MKLDNNGetDataType<T>(), x->format());
|
||||
const auto dst_md = memory::desc(dst_tz, platform::MKLDNNGetDataType<T>(),
|
||||
MKLDNNMemoryFormat::any);
|
||||
this->AcquireForwardPrimitiveDescriptor(
|
||||
dnnl::prop_kind::forward_inference, algo, src_md, dst_md);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T = float>
|
||||
class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
|
||||
std::vector<int> ComputeOutputShape(
|
||||
const framework::ExecutionContext& ctx) const {
|
||||
const auto* x = ctx.Input<Tensor>("X");
|
||||
auto in_dims = x->dims();
|
||||
const bool is_channel_last = false; // In mkldnn kernel, always use NCHW
|
||||
|
||||
framework::DDim in_dhw_dims;
|
||||
if (is_channel_last) { // NDHWC, NHWC, NWC
|
||||
in_dhw_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
|
||||
} else { // NCDHW, NCHW, NCW
|
||||
in_dhw_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
|
||||
}
|
||||
|
||||
std::vector<int> out_dims;
|
||||
if (in_dhw_dims.size() == 1) {
|
||||
out_dims.push_back(ctx.Attr<int>("out_w"));
|
||||
} else if (in_dhw_dims.size() == 2) {
|
||||
out_dims.push_back(ctx.Attr<int>("out_h"));
|
||||
out_dims.push_back(ctx.Attr<int>("out_w"));
|
||||
} else if (in_dhw_dims.size() == 3) {
|
||||
out_dims.push_back(ctx.Attr<int>("out_d"));
|
||||
out_dims.push_back(ctx.Attr<int>("out_h"));
|
||||
out_dims.push_back(ctx.Attr<int>("out_w"));
|
||||
}
|
||||
|
||||
auto list_new_size_tensor = ctx.MultiInput<framework::Tensor>("SizeTensor");
|
||||
auto out_size = ctx.Input<Tensor>("OutSize");
|
||||
if (list_new_size_tensor.size() > 0) {
|
||||
auto new_size = get_new_shape(list_new_size_tensor);
|
||||
if (new_size.size() == out_dims.size()) {
|
||||
out_dims = new_size;
|
||||
}
|
||||
} else if (out_size != nullptr) {
|
||||
auto out_size_data = get_new_data_from_tensor<int>(out_size);
|
||||
if (out_size_data.size() == out_dims.size()) {
|
||||
out_dims = out_size_data;
|
||||
}
|
||||
} else {
|
||||
float scale;
|
||||
auto scale_tensor = ctx.Input<Tensor>("Scale");
|
||||
if (scale_tensor != nullptr) {
|
||||
auto scale_data = get_new_data_from_tensor<float>(scale_tensor);
|
||||
scale = scale_data[0];
|
||||
} else {
|
||||
scale = ctx.Attr<float>("scale");
|
||||
}
|
||||
if (scale > 0) {
|
||||
std::vector<int64_t> in_dhw_vec = framework::vectorize(in_dhw_dims);
|
||||
std::transform(
|
||||
in_dhw_vec.begin(), in_dhw_vec.end(), out_dims.begin(),
|
||||
[&](int64_t i) -> int { return static_cast<int>(i * scale); });
|
||||
}
|
||||
}
|
||||
|
||||
PADDLE_ENFORCE_GT(std::all_of(out_dims.begin(), out_dims.end(),
|
||||
[](int i) { return i > 0; }),
|
||||
0, platform::errors::InvalidArgument(
|
||||
"out_d, out_h, out_w of Op(interpolate) "
|
||||
"should be greater than 0."));
|
||||
|
||||
out_dims.insert(out_dims.begin(), in_dims[0]);
|
||||
if (is_channel_last) {
|
||||
out_dims.push_back(in_dims[in_dims.size() - 1]);
|
||||
} else {
|
||||
out_dims.insert(out_dims.begin() + 1, in_dims[1]);
|
||||
}
|
||||
return out_dims;
|
||||
}
|
||||
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
const auto& dev_ctx =
|
||||
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
|
||||
const auto& mkldnn_engine = dev_ctx.GetEngine();
|
||||
|
||||
const auto* x = ctx.Input<Tensor>("X");
|
||||
std::vector<float> scale_prior;
|
||||
auto* z = ctx.Output<Tensor>("Out");
|
||||
|
||||
auto interp_method = ctx.Attr<std::string>("interp_method");
|
||||
dnnl::algorithm algo = (interp_method == "nearest")
|
||||
? dnnl::algorithm::resampling_nearest
|
||||
: dnnl::algorithm::resampling_linear;
|
||||
|
||||
auto out_dims_vec = ComputeOutputShape(ctx);
|
||||
framework::DDim dim_out = framework::make_ddim(out_dims_vec);
|
||||
z->mutable_data<T>(dim_out, ctx.GetPlace());
|
||||
|
||||
InterpolateMKLDNNHandler<T> handler(algo, dev_ctx, mkldnn_engine,
|
||||
ctx.GetPlace(), x, z,
|
||||
ctx.OutputName("Out"));
|
||||
|
||||
auto src_memory_p = handler.AcquireSrcMemory(x);
|
||||
auto dst_memory_p = handler.AcquireDstMemory(z);
|
||||
|
||||
auto resampling_prim = handler.AcquireForwardPrimitive();
|
||||
const std::unordered_map<int, dnnl::memory> args = {
|
||||
{DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}};
|
||||
mkldnn::stream astream(mkldnn_engine);
|
||||
resampling_prim->execute(astream, args);
|
||||
astream.wait();
|
||||
|
||||
z->set_layout(DataLayout::kMKLDNN);
|
||||
z->set_format(platform::GetMKLDNNFormat(*dst_memory_p));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OP_KERNEL(nearest_interp, MKLDNN, ::paddle::platform::CPUPlace,
|
||||
ops::InterpolateMKLDNNKernel<float>);
|
||||
REGISTER_OP_KERNEL(bilinear_interp, MKLDNN, ::paddle::platform::CPUPlace,
|
||||
ops::InterpolateMKLDNNKernel<float>);
|
@ -0,0 +1,201 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import math
|
||||
import paddle
|
||||
import paddle.fluid.core as core
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.tests.unittests.op_test import OpTest
|
||||
from paddle.fluid.tests.unittests.op_test import skip_check_grad_ci
|
||||
|
||||
|
||||
def bilinear_interp_mkldnn_np(input,
|
||||
out_h,
|
||||
out_w,
|
||||
out_size=None,
|
||||
actual_shape=None,
|
||||
data_layout='NCHW'):
|
||||
"""bilinear interpolation implement in shape [N, C, H, W]"""
|
||||
if data_layout == "NHWC":
|
||||
input = np.transpose(input, (0, 3, 1, 2)) # NHWC => NCHW
|
||||
if out_size is not None:
|
||||
out_h = out_size[0]
|
||||
out_w = out_size[1]
|
||||
if actual_shape is not None:
|
||||
out_h = actual_shape[0]
|
||||
out_w = actual_shape[1]
|
||||
batch_size, channel, in_h, in_w = input.shape
|
||||
|
||||
out = np.zeros((batch_size, channel, out_h, out_w))
|
||||
|
||||
for oh in range(out_h):
|
||||
h0 = int(math.floor((oh + 0.5) * in_h / out_h - 0.5))
|
||||
h1 = int(math.ceil((oh + 0.5) * in_h / out_h - 0.5))
|
||||
h0 = max(h0, 0)
|
||||
h1 = min(h1, in_h - 1)
|
||||
Wh = (oh + 0.5) * in_h / out_h - 0.5 - h0
|
||||
for ow in range(out_w):
|
||||
w0 = int(math.floor((ow + 0.5) * in_w / out_w - 0.5))
|
||||
w1 = int(math.ceil((ow + 0.5) * in_w / out_w - 0.5))
|
||||
w0 = max(w0, 0)
|
||||
w1 = min(w1, in_w - 1)
|
||||
Ww = (ow + 0.5) * in_w / out_w - 0.5 - w0
|
||||
input_h0_w0 = input[:, :, h0, w0]
|
||||
input_h1_w0 = input[:, :, h1, w0]
|
||||
input_h0_w1 = input[:, :, h0, w1]
|
||||
input_h1_w1 = input[:, :, h1, w1]
|
||||
out[:, :, oh, ow] = input_h0_w0 * (1 - Wh) * (
|
||||
1 - Ww) + input_h1_w0 * Wh * (1 - Ww) + input_h0_w1 * (
|
||||
1 - Wh) * Ww + input_h1_w1 * Wh * Ww
|
||||
|
||||
if data_layout == "NHWC":
|
||||
out = np.transpose(out, (0, 2, 3, 1)) # NCHW => NHWC
|
||||
|
||||
return out.astype(input.dtype)
|
||||
|
||||
|
||||
@skip_check_grad_ci(reason="Haven not implement interpolate grad kernel.")
|
||||
class TestBilinearInterpMKLDNNOp(OpTest):
|
||||
def init_test_case(self):
|
||||
pass
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "bilinear_interp"
|
||||
self.interp_method = 'bilinear'
|
||||
self._cpu_only = True
|
||||
self.use_mkldnn = True
|
||||
self.input_shape = [1, 1, 2, 2]
|
||||
self.data_layout = 'NCHW'
|
||||
# priority: actual_shape > out_size > scale > out_h & out_w
|
||||
self.out_h = 1
|
||||
self.out_w = 1
|
||||
self.scale = 2.0
|
||||
self.out_size = None
|
||||
self.actual_shape = None
|
||||
|
||||
self.init_test_case()
|
||||
|
||||
input_np = np.random.random(self.input_shape).astype("float32")
|
||||
if self.data_layout == "NCHW":
|
||||
in_h = self.input_shape[2]
|
||||
in_w = self.input_shape[3]
|
||||
else:
|
||||
in_h = self.input_shape[1]
|
||||
in_w = self.input_shape[2]
|
||||
|
||||
if self.scale > 0:
|
||||
out_h = int(in_h * self.scale)
|
||||
out_w = int(in_w * self.scale)
|
||||
else:
|
||||
out_h = self.out_h
|
||||
out_w = self.out_w
|
||||
|
||||
output_np = bilinear_interp_mkldnn_np(input_np, out_h, out_w,
|
||||
self.out_size, self.actual_shape,
|
||||
self.data_layout)
|
||||
|
||||
self.inputs = {'X': input_np}
|
||||
if self.out_size is not None:
|
||||
self.inputs['OutSize'] = self.out_size
|
||||
if self.actual_shape is not None:
|
||||
self.inputs['OutSize'] = self.actual_shape
|
||||
self.attrs = {
|
||||
'interp_method': self.interp_method,
|
||||
'out_h': self.out_h,
|
||||
'out_w': self.out_w,
|
||||
'scale': self.scale,
|
||||
'data_layout': self.data_layout,
|
||||
'use_mkldnn': self.use_mkldnn
|
||||
}
|
||||
self.outputs = {'Out': output_np}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output(check_dygraph=False)
|
||||
|
||||
|
||||
class TestBilinearInterpOpMKLDNNNHWC(TestBilinearInterpMKLDNNOp):
|
||||
def init_test_case(self):
|
||||
self.input_shape = [3, 2, 32, 16]
|
||||
self.out_h = 27
|
||||
self.out_w = 49
|
||||
self.scale = 2.0
|
||||
self.data_layout = 'NHWC'
|
||||
|
||||
|
||||
class TestBilinearNeighborInterpMKLDNNCase2(TestBilinearInterpMKLDNNOp):
|
||||
def init_test_case(self):
|
||||
self.input_shape = [3, 3, 9, 6]
|
||||
self.out_h = 12
|
||||
self.out_w = 12
|
||||
self.scale = 1.
|
||||
|
||||
|
||||
class TestBilinearNeighborInterpDataLayout(TestBilinearInterpMKLDNNOp):
|
||||
def init_test_case(self):
|
||||
self.input_shape = [2, 4, 4, 5]
|
||||
self.out_h = 6
|
||||
self.out_w = 7
|
||||
self.scale = 0.
|
||||
self.data_layout = "NHWC"
|
||||
|
||||
|
||||
class TestBilinearNeighborInterpCase3(TestBilinearInterpMKLDNNOp):
|
||||
def init_test_case(self):
|
||||
self.input_shape = [1, 1, 32, 64]
|
||||
self.out_h = 64
|
||||
self.out_w = 128
|
||||
self.scale = 0.
|
||||
|
||||
|
||||
class TestBilinearNeighborInterpCase4(TestBilinearInterpMKLDNNOp):
|
||||
def init_test_case(self):
|
||||
self.input_shape = [4, 1, 7, 8]
|
||||
self.out_h = 1
|
||||
self.out_w = 1
|
||||
self.scale = 0.
|
||||
self.out_size = np.array([2, 2]).astype("int32")
|
||||
|
||||
|
||||
class TestBilinearNeighborInterpCase5(TestBilinearInterpMKLDNNOp):
|
||||
def init_test_case(self):
|
||||
self.input_shape = [1, 1, 9, 6]
|
||||
self.out_h = 12
|
||||
self.out_w = 12
|
||||
self.scale = 0.
|
||||
self.out_size = np.array([13, 13]).astype("int32")
|
||||
|
||||
|
||||
class TestBilinearNeighborInterpCase6(TestBilinearInterpMKLDNNOp):
|
||||
def init_test_case(self):
|
||||
self.input_shape = [1, 1, 32, 64]
|
||||
self.out_h = 64
|
||||
self.out_w = 32
|
||||
self.scale = 0.
|
||||
self.out_size = np.array([65, 129]).astype("int32")
|
||||
|
||||
|
||||
class TestBilinearNeighborInterpSame(TestBilinearInterpMKLDNNOp):
|
||||
def init_test_case(self):
|
||||
self.input_shape = [2, 3, 32, 64]
|
||||
self.out_h = 32
|
||||
self.out_w = 64
|
||||
self.scale = 0.
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -0,0 +1,166 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.fluid.core as core
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.tests.unittests.op_test import OpTest
|
||||
from paddle.fluid.tests.unittests.op_test import skip_check_grad_ci
|
||||
|
||||
|
||||
def nearest_neighbor_interp_mkldnn_np(X,
|
||||
out_h,
|
||||
out_w,
|
||||
out_size=None,
|
||||
actual_shape=None,
|
||||
data_layout='NCHW'):
|
||||
"""nearest neighbor interpolation implement in shape [N, C, H, W]"""
|
||||
if data_layout == "NHWC":
|
||||
X = np.transpose(X, (0, 3, 1, 2)) # NHWC => NCHW
|
||||
if out_size is not None:
|
||||
out_h = out_size[0]
|
||||
out_w = out_size[1]
|
||||
if actual_shape is not None:
|
||||
out_h = actual_shape[0]
|
||||
out_w = actual_shape[1]
|
||||
|
||||
n, c, in_h, in_w = X.shape
|
||||
|
||||
fh = fw = 0.0
|
||||
if (out_h > 1):
|
||||
fh = out_h * 1.0 / in_h
|
||||
if (out_w > 1):
|
||||
fw = out_w * 1.0 / in_w
|
||||
|
||||
out = np.zeros((n, c, out_h, out_w))
|
||||
|
||||
for oh in range(out_h):
|
||||
ih = int(round((oh + 0.5) / fh - 0.5))
|
||||
for ow in range(out_w):
|
||||
iw = int(round((ow + 0.5) / fw - 0.5))
|
||||
out[:, :, oh, ow] = X[:, :, ih, iw]
|
||||
|
||||
if data_layout == "NHWC":
|
||||
out = np.transpose(out, (0, 2, 3, 1)) # NCHW => NHWC
|
||||
|
||||
return out.astype(X.dtype)
|
||||
|
||||
|
||||
@skip_check_grad_ci(reason="Haven not implement interpolate grad kernel.")
|
||||
class TestNearestInterpMKLDNNOp(OpTest):
|
||||
def init_test_case(self):
|
||||
pass
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "nearest_interp"
|
||||
self.interp_method = 'nearest'
|
||||
self._cpu_only = True
|
||||
self.use_mkldnn = True
|
||||
self.input_shape = [1, 1, 2, 2]
|
||||
self.data_layout = 'NCHW'
|
||||
# priority: actual_shape > out_size > scale > out_h & out_w
|
||||
self.out_h = 1
|
||||
self.out_w = 1
|
||||
self.scale = 2.0
|
||||
self.out_size = None
|
||||
self.actual_shape = None
|
||||
|
||||
self.init_test_case()
|
||||
|
||||
input_np = np.random.random(self.input_shape).astype("float32")
|
||||
if self.data_layout == "NCHW":
|
||||
in_h = self.input_shape[2]
|
||||
in_w = self.input_shape[3]
|
||||
else:
|
||||
in_h = self.input_shape[1]
|
||||
in_w = self.input_shape[2]
|
||||
|
||||
if self.scale > 0:
|
||||
out_h = int(in_h * self.scale)
|
||||
out_w = int(in_w * self.scale)
|
||||
else:
|
||||
out_h = self.out_h
|
||||
out_w = self.out_w
|
||||
|
||||
output_np = nearest_neighbor_interp_mkldnn_np(
|
||||
input_np, out_h, out_w, self.out_size, self.actual_shape,
|
||||
self.data_layout)
|
||||
|
||||
self.inputs = {'X': input_np}
|
||||
if self.out_size is not None:
|
||||
self.inputs['OutSize'] = self.out_size
|
||||
if self.actual_shape is not None:
|
||||
self.inputs['OutSize'] = self.actual_shape
|
||||
self.attrs = {
|
||||
'interp_method': self.interp_method,
|
||||
'out_h': self.out_h,
|
||||
'out_w': self.out_w,
|
||||
'scale': self.scale,
|
||||
'data_layout': self.data_layout,
|
||||
'use_mkldnn': self.use_mkldnn
|
||||
}
|
||||
self.outputs = {'Out': output_np}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output(check_dygraph=False)
|
||||
|
||||
|
||||
class TestNearestInterpOpMKLDNNNHWC(TestNearestInterpMKLDNNOp):
|
||||
def init_test_case(self):
|
||||
self.input_shape = [3, 2, 32, 16]
|
||||
self.out_h = 27
|
||||
self.out_w = 49
|
||||
self.scale = 2.0
|
||||
self.data_layout = 'NHWC'
|
||||
|
||||
|
||||
class TestNearestNeighborInterpMKLDNNCase2(TestNearestInterpMKLDNNOp):
|
||||
def init_test_case(self):
|
||||
self.input_shape = [3, 3, 9, 6]
|
||||
self.out_h = 12
|
||||
self.out_w = 12
|
||||
self.scale = 1.
|
||||
|
||||
|
||||
class TestNearestNeighborInterpCase3(TestNearestInterpMKLDNNOp):
|
||||
def init_test_case(self):
|
||||
self.input_shape = [1, 1, 32, 64]
|
||||
self.out_h = 64
|
||||
self.out_w = 128
|
||||
self.scale = 0.
|
||||
|
||||
|
||||
class TestNearestNeighborInterpCase4(TestNearestInterpMKLDNNOp):
|
||||
def init_test_case(self):
|
||||
self.input_shape = [1, 1, 32, 64]
|
||||
self.out_h = 64
|
||||
self.out_w = 32
|
||||
self.scale = 0.
|
||||
self.out_size = np.array([65, 129]).astype("int32")
|
||||
|
||||
|
||||
class TestNearestNeighborInterpSame(TestNearestInterpMKLDNNOp):
|
||||
def init_test_case(self):
|
||||
self.input_shape = [2, 3, 32, 64]
|
||||
self.out_h = 32
|
||||
self.out_w = 64
|
||||
self.scale = 0.
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue