[CustomOp] Split test and add inference test (#31078)
* split test & add inference test * add timeout config * change to setup install * change to jit compile * add verbose for test * fix load setup name repeat * polish details * resolve conflict * fix code format errorrevert-31068-fix_conv3d_windows
parent
d3f09ad702
commit
e60fd1f6a8
@ -0,0 +1,76 @@
|
||||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/extension.h"
|
||||
|
||||
template <typename data_t>
|
||||
void assign_cpu_kernel(const data_t* x_data,
|
||||
data_t* out_data,
|
||||
int64_t x_numel) {
|
||||
for (int i = 0; i < x_numel; ++i) {
|
||||
out_data[i] = x_data[i];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename data_t>
|
||||
void fill_constant_cpu_kernel(data_t* out_data, int64_t x_numel, data_t value) {
|
||||
for (int i = 0; i < x_numel; ++i) {
|
||||
out_data[i] = value;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> MultiOutCPU(const paddle::Tensor& x) {
|
||||
auto out = paddle::Tensor(paddle::PlaceType::kCPU);
|
||||
out.reshape(x.shape());
|
||||
|
||||
PD_DISPATCH_FLOATING_TYPES(
|
||||
x.type(), "assign_cpu_kernel", ([&] {
|
||||
assign_cpu_kernel<data_t>(
|
||||
x.data<data_t>(), out.mutable_data<data_t>(x.place()), x.size());
|
||||
}));
|
||||
|
||||
// fake multi output: Fake_float64 with float64 dtype
|
||||
auto fake_float64 = paddle::Tensor(paddle::PlaceType::kCPU);
|
||||
fake_float64.reshape(x.shape());
|
||||
|
||||
fill_constant_cpu_kernel<double>(
|
||||
fake_float64.mutable_data<double>(x.place()), x.size(), 0.);
|
||||
|
||||
// fake multi output: ZFake_int32 with int32 dtype
|
||||
auto zfake_int32 = paddle::Tensor(paddle::PlaceType::kCPU);
|
||||
zfake_int32.reshape(x.shape());
|
||||
|
||||
fill_constant_cpu_kernel<int32_t>(
|
||||
zfake_int32.mutable_data<int32_t>(x.place()), x.size(), 1);
|
||||
|
||||
return {out, fake_float64, zfake_int32};
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> InferShape(std::vector<int64_t> x_shape) {
|
||||
return {x_shape, x_shape, x_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> InferDtype(paddle::DataType x_dtype) {
|
||||
return {x_dtype, paddle::DataType::FLOAT64, paddle::DataType::INT32};
|
||||
}
|
||||
|
||||
PD_BUILD_OP("multi_out")
|
||||
.Inputs({"X"})
|
||||
.Outputs({"Out", "Fake_float64", "ZFake_int32"})
|
||||
.SetKernelFn(PD_KERNEL(MultiOutCPU))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(InferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(InferDtype));
|
@ -0,0 +1,86 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import unittest
|
||||
import paddle
|
||||
import numpy as np
|
||||
from paddle.utils.cpp_extension import load, get_build_directory
|
||||
from paddle.utils.cpp_extension.extension_utils import run_cmd
|
||||
from utils import paddle_includes, extra_compile_args
|
||||
from test_custom_relu_op_setup import custom_relu_dynamic, custom_relu_static
|
||||
|
||||
# Because Windows don't use docker, the shared lib already exists in the
|
||||
# cache dir, it will not be compiled again unless the shared lib is removed.
|
||||
if os.name == 'nt':
|
||||
cmd = 'del {}\\custom_relu_module_jit.pyd'.format(get_build_directory())
|
||||
run_cmd(cmd, True)
|
||||
|
||||
# Compile and load custom op Just-In-Time.
|
||||
# custom_relu_op_dup.cc is only used for multi ops test,
|
||||
# not a new op, if you want to test only one op, remove this
|
||||
# source file
|
||||
custom_module = load(
|
||||
name='custom_relu_module_jit',
|
||||
sources=[
|
||||
'custom_relu_op.cc', 'custom_relu_op.cu', 'custom_relu_op_dup.cc'
|
||||
],
|
||||
extra_include_paths=paddle_includes, # add for Coverage CI
|
||||
extra_cflags=extra_compile_args, # add for Coverage CI
|
||||
verbose=True)
|
||||
|
||||
|
||||
class TestJITLoad(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.custom_ops = [
|
||||
custom_module.custom_relu, custom_module.custom_relu_dup
|
||||
]
|
||||
self.dtypes = ['float32', 'float64']
|
||||
self.devices = ['cpu', 'gpu']
|
||||
|
||||
def test_static(self):
|
||||
for device in self.devices:
|
||||
for dtype in self.dtypes:
|
||||
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
|
||||
for custom_op in self.custom_ops:
|
||||
out = custom_relu_static(custom_op, device, dtype, x)
|
||||
pd_out = custom_relu_static(custom_op, device, dtype, x,
|
||||
False)
|
||||
self.assertTrue(
|
||||
np.array_equal(out, pd_out),
|
||||
"custom op out: {},\n paddle api out: {}".format(
|
||||
out, pd_out))
|
||||
|
||||
def test_dynamic(self):
|
||||
for device in self.devices:
|
||||
for dtype in self.dtypes:
|
||||
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
|
||||
for custom_op in self.custom_ops:
|
||||
out, x_grad = custom_relu_dynamic(custom_op, device, dtype,
|
||||
x)
|
||||
pd_out, pd_x_grad = custom_relu_dynamic(custom_op, device,
|
||||
dtype, x, False)
|
||||
self.assertTrue(
|
||||
np.array_equal(out, pd_out),
|
||||
"custom op out: {},\n paddle api out: {}".format(
|
||||
out, pd_out))
|
||||
self.assertTrue(
|
||||
np.array_equal(x_grad, pd_x_grad),
|
||||
"custom op x grad: {},\n paddle api x grad: {}".format(
|
||||
x_grad, pd_x_grad))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue