Add unittests and OP version registry for tensorrt_subgraph_pass (#27544)

* add unittests and op version register for tensorrt_subgraph_pass

* rename to test_trt_subgraph_pass.py

* fix softmax converter diff when padding dim=1
revert-27356-init_low_level_gloo
Pei Yang 5 years ago committed by GitHub
parent 6822307745
commit ae6e40a7fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -18,6 +18,7 @@
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/subgraph_detector.h" #include "paddle/fluid/framework/ir/subgraph_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h" #include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
@ -358,3 +359,31 @@ REGISTER_PASS(tensorrt_subgraph_pass,
.RequirePassAttr("max_batch_size") .RequirePassAttr("max_batch_size")
.RequirePassAttr("workspace_size") .RequirePassAttr("workspace_size")
.RequirePassAttr("min_subgraph_size"); .RequirePassAttr("min_subgraph_size");
REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.EQ("pool2d", 0)
.EQ("relu", 0)
.EQ("softmax", 0)
.EQ("sigmoid", 0)
.EQ("hard_swish", 0)
.EQ("depthwise_conv2d", 0)
.EQ("batch_norm", 0)
.EQ("concat", 0)
.EQ("tanh", 0)
.EQ("pad", 0)
.EQ("elementwise_add", 0)
.EQ("elementwise_mul", 0)
.EQ("prelu", 0)
.LE("conv2d_transpose", 1)
.LE("leaky_relu", 1)
.EQ("fc", 0)
.EQ("shuffle_channel", 0)
.EQ("swish", 0)
.EQ("split", 0)
.EQ("instance_norm", 0)
.EQ("gelu", 0)
.EQ("layer_norm", 0)
.EQ("scale", 0));

@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <algorithm>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle { namespace paddle {
@ -39,9 +40,41 @@ class SoftMaxOpConverter : public OpConverter {
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
// Declare inputs // Declare inputs
auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]); auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]);
nvinfer1::Dims input_shape = input1->getDimensions();
int input_dims = input_shape.nbDims;
int axis = op_desc.HasAttr("axis")
? BOOST_GET_CONST(int, op_desc.GetAttr("axis"))
: -1;
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, SoftMax, auto* layer = TRT_ENGINE_ADD_LAYER(engine_, SoftMax,
*const_cast<nvinfer1::ITensor*>(input1)); *const_cast<nvinfer1::ITensor*>(input1));
uint32_t axes = std::max(0, input_dims - 3);
// TODO(cryoco): Poor workaround. Fix padded dims problem when TRT layers
// support Nd.
int padded_dims = 0;
int explicit_batch = 0;
if (engine_->with_dynamic_shape()) explicit_batch = 1;
for (int i = input_dims - 1; i > explicit_batch; i--) {
if (input_shape.d[i] == 1) {
padded_dims += 1;
} else {
break;
}
}
if (!engine_->with_dynamic_shape()) {
if (axis == -1) {
axes = input_dims - 1 - padded_dims;
} else {
axes = axis;
}
} else {
if (axis == -1) {
axes = input_dims - 1 - padded_dims;
} else {
axes = axis + 1;
}
}
layer->setAxes(1 << axes);
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "softmax", {output_name}, test_mode); RreplenishLayerAndOutput(layer, "softmax", {output_name}, test_mode);

@ -113,7 +113,14 @@ bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc,
op_type == "depthwise_conv2d" || op_type == "conv2d_transpose") { op_type == "depthwise_conv2d" || op_type == "conv2d_transpose") {
std::vector<int> paddings = std::vector<int> paddings =
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("paddings")); BOOST_GET_CONST(std::vector<int>, desc.GetAttr("paddings"));
if (paddings.size() > 2) return false;
std::string padding_algorithm = "EXPLICIT";
if (desc.HasAttr("padding_algorithm"))
padding_algorithm =
BOOST_GET_CONST(std::string, desc.GetAttr("padding_algorithm"));
if (paddings.size() > 2 ||
(padding_algorithm == "SAME" && op_type != "pool2d"))
return false;
} }
if ((*teller)(op_type, desc, use_no_calib_int8)) return true; if ((*teller)(op_type, desc, use_no_calib_int8)) return true;
} }

@ -50,10 +50,18 @@ int LayerNormPlugin::enqueue(int batch_size, const void *const *inputs,
float *output = reinterpret_cast<float **>(outputs)[0]; float *output = reinterpret_cast<float **>(outputs)[0];
int begin_norm_axis = begin_norm_axis_; int begin_norm_axis = begin_norm_axis_;
float eps = eps_; float eps = eps_;
int c = input_dims.d[begin_norm_axis - 1];
scale_t.Resize(framework::make_ddim({c})); std::vector<int> input_shape;
bias_t.Resize(framework::make_ddim({c})); input_shape.push_back(batch_size);
for (int i = 0; i < input_dims.nbDims; i++) {
input_shape.push_back(input_dims.d[i]);
}
const auto input_ddim = framework::make_ddim(input_shape);
auto matrix_dim = framework::flatten_to_2d(input_ddim, begin_norm_axis - 1);
int feature_size = static_cast<int>(matrix_dim[1]);
scale_t.Resize(framework::make_ddim({feature_size}));
bias_t.Resize(framework::make_ddim({feature_size}));
mean_t.Resize(framework::make_ddim(mean_shape_)); mean_t.Resize(framework::make_ddim(mean_shape_));
variance_t.Resize(framework::make_ddim(variance_shape_)); variance_t.Resize(framework::make_ddim(variance_shape_));
int device_id; int device_id;
@ -63,15 +71,11 @@ int LayerNormPlugin::enqueue(int batch_size, const void *const *inputs,
float *mean_d = mean_t.mutable_data<float>(platform::CUDAPlace(device_id)); float *mean_d = mean_t.mutable_data<float>(platform::CUDAPlace(device_id));
float *variance_d = float *variance_d =
variance_t.mutable_data<float>(platform::CUDAPlace(device_id)); variance_t.mutable_data<float>(platform::CUDAPlace(device_id));
cudaMemcpyAsync(scale_d, scale_.data(), sizeof(float) * c, cudaMemcpyAsync(scale_d, scale_.data(), sizeof(float) * feature_size,
cudaMemcpyHostToDevice, stream); cudaMemcpyHostToDevice, stream);
cudaMemcpyAsync(bias_d, bias_.data(), sizeof(float) * c, cudaMemcpyAsync(bias_d, bias_.data(), sizeof(float) * feature_size,
cudaMemcpyHostToDevice, stream); cudaMemcpyHostToDevice, stream);
std::vector<int> input_shape;
input_shape.push_back(batch_size);
for (int i = 0; i < input_dims.nbDims; i++) {
input_shape.push_back(input_dims.d[i]);
}
paddle::operators::LayerNormDirectCUDAFunctor<float> layer_norm; paddle::operators::LayerNormDirectCUDAFunctor<float> layer_norm;
layer_norm(stream, input, input_shape, bias_d, scale_d, output, mean_d, layer_norm(stream, input, input_shape, bias_d, scale_d, output, mean_d,
variance_d, begin_norm_axis, eps); variance_d, begin_norm_axis, eps);

@ -133,7 +133,7 @@ class InferencePassTest(unittest.TestCase):
for place_ in use_gpu: for place_ in use_gpu:
self.check_output_with_option(place_, atol) self.check_output_with_option(place_, atol)
def check_output_with_option(self, use_gpu, atol=1e-5): def check_output_with_option(self, use_gpu, atol=1e-5, flatten=False):
''' '''
Check whether calculating on CPU and GPU, enable TensorRT Check whether calculating on CPU and GPU, enable TensorRT
or disable TensorRT, enable MKLDNN or disable MKLDNN or disable TensorRT, enable MKLDNN or disable MKLDNN
@ -155,9 +155,13 @@ class InferencePassTest(unittest.TestCase):
format(device)) format(device))
for out, analysis_output in zip(outs, analysis_outputs): for out, analysis_output in zip(outs, analysis_outputs):
out = np.array(out)
if flatten:
out = out.flatten()
analysis_output = analysis_output.flatten()
self.assertTrue( self.assertTrue(
np.allclose( np.allclose(
np.array(out), analysis_output, atol=atol), out, analysis_output, atol=atol),
"Output has diff between inference and training forward at {} ". "Output has diff between inference and training forward at {} ".
format(device)) format(device))
@ -172,9 +176,13 @@ class InferencePassTest(unittest.TestCase):
"The number of outputs is different between GPU and TensorRT. ") "The number of outputs is different between GPU and TensorRT. ")
for out, tensorrt_output in zip(outs, tensorrt_outputs): for out, tensorrt_output in zip(outs, tensorrt_outputs):
out = np.array(out)
if flatten:
out = out.flatten()
tensorrt_output = tensorrt_output.flatten()
self.assertTrue( self.assertTrue(
np.allclose( np.allclose(
np.array(out), tensorrt_output, atol=atol), out, tensorrt_output, atol=atol),
"Output has diff between GPU and TensorRT. ") "Output has diff between GPU and TensorRT. ")
# Check whether the mkldnn results and the CPU results are the same. # Check whether the mkldnn results and the CPU results are the same.

Loading…
Cancel
Save