add test_mode in trt/activation_op

wangkuiyi-patch-1
Luo Tao 7 years ago
parent c73977af03
commit f6fb51a164

@ -23,7 +23,7 @@ class ReluOpConverter : public OpConverter {
public:
ReluOpConverter() {}
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope) override {
const framework::Scope& scope, bool test_mode) override {
// Here the two nullptr looks strange, that's because the
// framework::OpDesc's constructor is strange.
framework::OpDesc op_desc(op, nullptr);
@ -34,7 +34,12 @@ class ReluOpConverter : public OpConverter {
nvinfer1::IActivationLayer* layer = TRT_ENGINE_ADD_LAYER(
engine_, Activation, *const_cast<nvinfer1::ITensor*>(input_tensor),
nvinfer1::ActivationType::kRELU);
engine_->DeclareOutput(layer, 0, op_desc.Output("Out")[0]);
auto output_name = op_desc.Output("Out")[0];
engine_->SetITensor(output_name, layer->getOutput(0));
if (test_mode) { // the test framework can not determine which is the
// output, so place the declaration inside.
engine_->DeclareOutput(output_name);
}
}
};

Loading…
Cancel
Save