add test_mode in trt/activation_op

wangkuiyi-patch-1
Luo Tao 7 years ago
parent c73977af03
commit f6fb51a164

@ -23,7 +23,7 @@ class ReluOpConverter : public OpConverter {
public: public:
ReluOpConverter() {} ReluOpConverter() {}
void operator()(const framework::proto::OpDesc& op, void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope) override { const framework::Scope& scope, bool test_mode) override {
// Here the two nullptr looks strange, that's because the // Here the two nullptr looks strange, that's because the
// framework::OpDesc's constructor is strange. // framework::OpDesc's constructor is strange.
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
@ -34,7 +34,12 @@ class ReluOpConverter : public OpConverter {
nvinfer1::IActivationLayer* layer = TRT_ENGINE_ADD_LAYER( nvinfer1::IActivationLayer* layer = TRT_ENGINE_ADD_LAYER(
engine_, Activation, *const_cast<nvinfer1::ITensor*>(input_tensor), engine_, Activation, *const_cast<nvinfer1::ITensor*>(input_tensor),
nvinfer1::ActivationType::kRELU); nvinfer1::ActivationType::kRELU);
engine_->DeclareOutput(layer, 0, op_desc.Output("Out")[0]); auto output_name = op_desc.Output("Out")[0];
engine_->SetITensor(output_name, layer->getOutput(0));
if (test_mode) { // the test framework can not determine which is the
// output, so place the declaration inside.
engine_->DeclareOutput(output_name);
}
} }
}; };

Loading…
Cancel
Save