Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into rewrite_allocation
	
		
	
				
					
				
			test=developpanyx0718-patch-1
						commit
						98bbfc17be
					
				@ -0,0 +1,219 @@
 | 
				
			||||
set(PART_CUDA_KERNEL_FILES)
 | 
				
			||||
function(op_library TARGET)
 | 
				
			||||
    # op_library is a function to create op library. The interface is same as
 | 
				
			||||
    # cc_library. But it handle split GPU/CPU code and link some common library
 | 
				
			||||
    # for ops.
 | 
				
			||||
    set(cc_srcs)
 | 
				
			||||
    set(cu_srcs)
 | 
				
			||||
    set(hip_cu_srcs)
 | 
				
			||||
    set(miopen_hip_cc_srcs)
 | 
				
			||||
    set(cu_cc_srcs)
 | 
				
			||||
    set(cudnn_cu_cc_srcs)
 | 
				
			||||
    set(CUDNN_FILE)
 | 
				
			||||
    set(mkldnn_cc_srcs)
 | 
				
			||||
    set(MKLDNN_FILE)
 | 
				
			||||
    set(op_common_deps operator op_registry math_function)
 | 
				
			||||
    set(options "")
 | 
				
			||||
    set(oneValueArgs "")
 | 
				
			||||
    set(multiValueArgs SRCS DEPS)
 | 
				
			||||
    set(pybind_flag 0)
 | 
				
			||||
    cmake_parse_arguments(op_library "${options}" "${oneValueArgs}"
 | 
				
			||||
            "${multiValueArgs}" ${ARGN})
 | 
				
			||||
 | 
				
			||||
    list(LENGTH op_library_SRCS op_library_SRCS_len)
 | 
				
			||||
    if (${op_library_SRCS_len} EQUAL 0)
 | 
				
			||||
        if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cc)
 | 
				
			||||
            list(APPEND cc_srcs ${TARGET}.cc)
 | 
				
			||||
        endif()
 | 
				
			||||
        if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu.cc)
 | 
				
			||||
            list(APPEND cu_cc_srcs ${TARGET}.cu.cc)
 | 
				
			||||
        endif()
 | 
				
			||||
        if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu)
 | 
				
			||||
            list(APPEND cu_srcs ${TARGET}.cu)
 | 
				
			||||
        endif()
 | 
				
			||||
        if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu)
 | 
				
			||||
            set(PART_CUDA_KERNEL_FILES ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu
 | 
				
			||||
                    ${PART_CUDA_KERNEL_FILES} PARENT_SCOPE)
 | 
				
			||||
            list(APPEND cu_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu)
 | 
				
			||||
        endif()
 | 
				
			||||
 | 
				
			||||
        if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.hip.cu)
 | 
				
			||||
            list(APPEND hip_cu_srcs ${TARGET}.hip.cu)
 | 
				
			||||
        endif()
 | 
				
			||||
        string(REPLACE "_op" "_cudnn_op" CUDNN_FILE "${TARGET}")
 | 
				
			||||
        if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${CUDNN_FILE}.cu.cc)
 | 
				
			||||
            list(APPEND cudnn_cu_cc_srcs ${CUDNN_FILE}.cu.cc)
 | 
				
			||||
        endif()
 | 
				
			||||
        if(WITH_AMD_GPU)
 | 
				
			||||
            string(REPLACE "_op" "_miopen_op" MIOPEN_FILE "${TARGET}")
 | 
				
			||||
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${MIOPEN_FILE}.hip.cc)
 | 
				
			||||
                list(APPEND miopen_hip_cc_srcs ${MIOPEN_FILE}.hip.cc)
 | 
				
			||||
            endif()
 | 
				
			||||
        endif()
 | 
				
			||||
        if(WITH_MKLDNN)
 | 
				
			||||
            string(REPLACE "_op" "_mkldnn_op" MKLDNN_FILE "${TARGET}")
 | 
				
			||||
            if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${MKLDNN_FILE}.cc)
 | 
				
			||||
                list(APPEND mkldnn_cc_srcs ${MKLDNN_FILE}.cc)
 | 
				
			||||
            endif()
 | 
				
			||||
        endif()
 | 
				
			||||
    else()
 | 
				
			||||
        foreach(src ${op_library_SRCS})
 | 
				
			||||
            if (${src} MATCHES ".*\\.hip.cu$")
 | 
				
			||||
                list(APPEND hip_cu_srcs ${src})
 | 
				
			||||
            elseif (${src} MATCHES ".*\\.cu$")
 | 
				
			||||
                list(APPEND cu_srcs ${src})
 | 
				
			||||
            elseif(${src} MATCHES ".*_cudnn_op.cu.cc$")
 | 
				
			||||
                list(APPEND cudnn_cu_cc_srcs ${src})
 | 
				
			||||
            elseif(WITH_AMD_GPU AND ${src} MATCHES ".*_miopen_op.hip.cc$")
 | 
				
			||||
                list(APPEND miopen_hip_cc_srcs ${src})
 | 
				
			||||
            elseif(WITH_MKLDNN AND ${src} MATCHES ".*_mkldnn_op.cc$")
 | 
				
			||||
                list(APPEND mkldnn_cc_srcs ${src})
 | 
				
			||||
            elseif(${src} MATCHES ".*\\.cu.cc$")
 | 
				
			||||
                list(APPEND cu_cc_srcs ${src})
 | 
				
			||||
            elseif(${src} MATCHES ".*\\.cc$")
 | 
				
			||||
                list(APPEND cc_srcs ${src})
 | 
				
			||||
            else()
 | 
				
			||||
                message(FATAL_ERROR "${TARGET} Source file ${src} should only be .cc or .cu")
 | 
				
			||||
            endif()
 | 
				
			||||
        endforeach()
 | 
				
			||||
    endif()
 | 
				
			||||
 | 
				
			||||
    list(LENGTH cc_srcs cc_srcs_len)
 | 
				
			||||
    if (${cc_srcs_len} EQUAL 0)
 | 
				
			||||
        message(FATAL_ERROR "The op library ${TARGET} should contains at least one .cc file")
 | 
				
			||||
    endif()
 | 
				
			||||
    if (WIN32)
 | 
				
			||||
    # remove windows unsupported op, because windows has no nccl, no warpctc such ops.
 | 
				
			||||
    foreach(windows_unsupport_op "nccl_op" "gen_nccl_id_op" "warpctc_op" "hierarchical_sigmoid_op"
 | 
				
			||||
     "crf_decoding_op" "select_op" "lstmp_op" "gru_op" "fusion_gru_op" "lstm_op" "fusion_lstm_op" "cumsum_op"
 | 
				
			||||
      "fusion_seqconv_eltadd_relu_op" "channel_send_op" "channel_create_op" "channel_close_op" "channel_recv_op")
 | 
				
			||||
        if ("${TARGET}" STREQUAL "${windows_unsupport_op}")
 | 
				
			||||
          return()
 | 
				
			||||
        endif()
 | 
				
			||||
    endforeach()
 | 
				
			||||
    endif(WIN32)
 | 
				
			||||
    set(OP_LIBRARY ${TARGET} ${OP_LIBRARY} CACHE INTERNAL "op libs")
 | 
				
			||||
 | 
				
			||||
    list(LENGTH op_library_DEPS op_library_DEPS_len)
 | 
				
			||||
    if (${op_library_DEPS_len} GREATER 0)
 | 
				
			||||
        set(DEPS_OPS ${TARGET} ${DEPS_OPS} PARENT_SCOPE)
 | 
				
			||||
    endif()
 | 
				
			||||
    if (WITH_GPU)
 | 
				
			||||
        nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${mkldnn_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS}
 | 
				
			||||
                ${op_common_deps})
 | 
				
			||||
    elseif (WITH_AMD_GPU)
 | 
				
			||||
        hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cu_srcs} ${miopen_hip_cc_srcs} ${mkldnn_cc_srcs} DEPS ${op_library_DEPS}
 | 
				
			||||
                ${op_common_deps})
 | 
				
			||||
    else()
 | 
				
			||||
        cc_library(${TARGET} SRCS ${cc_srcs} ${mkldnn_cc_srcs} DEPS ${op_library_DEPS}
 | 
				
			||||
            ${op_common_deps})
 | 
				
			||||
    endif()
 | 
				
			||||
 | 
				
			||||
    # Define operators that don't need pybind here.
 | 
				
			||||
    foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op"
 | 
				
			||||
"tensor_array_read_write_op" "tensorrt_engine_op")
 | 
				
			||||
        if ("${TARGET}" STREQUAL "${manual_pybind_op}")
 | 
				
			||||
            set(pybind_flag 1)
 | 
				
			||||
        endif()
 | 
				
			||||
    endforeach()
 | 
				
			||||
 | 
				
			||||
    # The registration of USE_OP, please refer to paddle/fluid/framework/op_registry.h.
 | 
				
			||||
    # Note that it's enough to just adding one operator to pybind in a *_op.cc file.
 | 
				
			||||
    # And for detail pybind information, please see generated paddle/pybind/pybind.h.
 | 
				
			||||
    file(READ ${TARGET}.cc TARGET_CONTENT)
 | 
				
			||||
    string(REGEX MATCH "REGISTER_OPERATOR\\(.*REGISTER_OPERATOR\\(" multi_register "${TARGET_CONTENT}")
 | 
				
			||||
    string(REGEX MATCH "REGISTER_OPERATOR\\([a-z0-9_]*," one_register "${multi_register}")
 | 
				
			||||
    if (one_register STREQUAL "")
 | 
				
			||||
        string(REPLACE "_op" "" TARGET "${TARGET}")
 | 
				
			||||
    else ()
 | 
				
			||||
        string(REPLACE "REGISTER_OPERATOR(" "" TARGET "${one_register}")
 | 
				
			||||
        string(REPLACE "," "" TARGET "${TARGET}")
 | 
				
			||||
    endif()
 | 
				
			||||
 | 
				
			||||
    # pybind USE_NO_KERNEL_OP
 | 
				
			||||
    # HACK: if REGISTER_OP_CPU_KERNEL presents the operator must have kernel
 | 
				
			||||
    string(REGEX MATCH "REGISTER_OP_CPU_KERNEL" regex_result "${TARGET_CONTENT}")
 | 
				
			||||
    string(REPLACE "_op" "" TARGET "${TARGET}")
 | 
				
			||||
    if (${pybind_flag} EQUAL 0 AND regex_result STREQUAL "")
 | 
				
			||||
        file(APPEND ${pybind_file} "USE_NO_KERNEL_OP(${TARGET});\n")
 | 
				
			||||
        set(pybind_flag 1)
 | 
				
			||||
    endif()
 | 
				
			||||
 | 
				
			||||
    # pybind USE_CPU_ONLY_OP
 | 
				
			||||
    list(LENGTH cu_srcs cu_srcs_len)
 | 
				
			||||
    list(LENGTH cu_cc_srcs cu_cc_srcs_len)
 | 
				
			||||
    list(LENGTH mkldnn_cc_srcs mkldnn_cc_srcs_len)
 | 
				
			||||
    list(LENGTH hip_cu_srcs hip_cu_srcs_len)
 | 
				
			||||
    list(LENGTH miopen_hip_cc_srcs miopen_hip_cc_srcs_len)
 | 
				
			||||
    if (${pybind_flag} EQUAL 0 AND ${mkldnn_cc_srcs_len} EQUAL 0 AND ${cu_srcs_len} EQUAL 0 AND ${cu_cc_srcs_len} EQUAL 0 AND
 | 
				
			||||
        ${hip_cu_srcs_len} EQUAL 0 AND ${miopen_hip_cc_srcs_len} EQUAL 0)
 | 
				
			||||
        file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n")
 | 
				
			||||
        set(pybind_flag 1)
 | 
				
			||||
    endif()
 | 
				
			||||
 | 
				
			||||
    # pybind USE_OP_DEVICE_KERNEL for CUDNN
 | 
				
			||||
    list(LENGTH cudnn_cu_cc_srcs cudnn_cu_cc_srcs_len)
 | 
				
			||||
    if (WITH_GPU AND ${cudnn_cu_cc_srcs_len} GREATER 0)
 | 
				
			||||
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n")
 | 
				
			||||
    endif()
 | 
				
			||||
 | 
				
			||||
    # pybind USE_OP_DEVICE_KERNEL for MIOPEN
 | 
				
			||||
    if (WITH_AMD_GPU AND ${miopen_hip_cc_srcs_len} GREATER 0)
 | 
				
			||||
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MIOPEN);\n")
 | 
				
			||||
    endif()
 | 
				
			||||
 | 
				
			||||
    # pybind USE_OP_DEVICE_KERNEL for MKLDNN
 | 
				
			||||
    if (WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0)
 | 
				
			||||
      # Append first implemented MKLDNN activation operator
 | 
				
			||||
      if (${MKLDNN_FILE} STREQUAL "activation_mkldnn_op")
 | 
				
			||||
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(relu, MKLDNN);\n")
 | 
				
			||||
      else()
 | 
				
			||||
        file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n")
 | 
				
			||||
      endif()
 | 
				
			||||
    endif()
 | 
				
			||||
 | 
				
			||||
    # pybind USE_OP
 | 
				
			||||
    if (${pybind_flag} EQUAL 0)
 | 
				
			||||
      # NOTE(*): activation use macro to regist the kernels, set use_op manually.
 | 
				
			||||
      if(${TARGET} STREQUAL "activation")
 | 
				
			||||
        file(APPEND ${pybind_file} "USE_OP(relu);\n")
 | 
				
			||||
      elseif(${TARGET} STREQUAL "fake_dequantize")
 | 
				
			||||
        file(APPEND ${pybind_file} "USE_OP(fake_dequantize_max_abs);\n")
 | 
				
			||||
      elseif(${TARGET} STREQUAL "fake_quantize")
 | 
				
			||||
        file(APPEND ${pybind_file} "USE_OP(fake_quantize_abs_max);\n")
 | 
				
			||||
      elseif(${TARGET} STREQUAL "tensorrt_engine_op")
 | 
				
			||||
          message(STATUS "Pybind skips [tensorrt_engine_op], for this OP is only used in inference")
 | 
				
			||||
      elseif(${TARGET} STREQUAL "fc")
 | 
				
			||||
        # HACK: fc only have mkldnn and cpu, which would mismatch the cpu only condition
 | 
				
			||||
        file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n")
 | 
				
			||||
      else()
 | 
				
			||||
        file(APPEND ${pybind_file} "USE_OP(${TARGET});\n")
 | 
				
			||||
      endif()
 | 
				
			||||
    endif()
 | 
				
			||||
endfunction()
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
function(register_operators)
 | 
				
			||||
    set(options "")
 | 
				
			||||
    set(oneValueArgs "")
 | 
				
			||||
    set(multiValueArgs EXCLUDES DEPS)
 | 
				
			||||
    cmake_parse_arguments(register_operators "${options}" "${oneValueArgs}"
 | 
				
			||||
            "${multiValueArgs}" ${ARGN})
 | 
				
			||||
 | 
				
			||||
    file(GLOB OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "*_op.cc")
 | 
				
			||||
    string(REPLACE "_mkldnn" "" OPS "${OPS}")
 | 
				
			||||
    string(REPLACE ".cc" "" OPS "${OPS}")
 | 
				
			||||
    list(REMOVE_DUPLICATES OPS)
 | 
				
			||||
    list(LENGTH register_operators_DEPS register_operators_DEPS_len)
 | 
				
			||||
 | 
				
			||||
    foreach(src ${OPS})
 | 
				
			||||
        list(FIND register_operators_EXCLUDES ${src} _index)
 | 
				
			||||
        if (${_index} EQUAL -1)
 | 
				
			||||
            if (${register_operators_DEPS_len} GREATER 0)
 | 
				
			||||
                op_library(${src} DEPS ${register_operators_DEPS})
 | 
				
			||||
            else()
 | 
				
			||||
                op_library(${src})
 | 
				
			||||
            endif()
 | 
				
			||||
        endif()
 | 
				
			||||
    endforeach()
 | 
				
			||||
endfunction()
 | 
				
			||||
@ -0,0 +1,57 @@
 | 
				
			||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#include "paddle/fluid/framework/ir/is_test_pass.h"
 | 
				
			||||
#include <string>
 | 
				
			||||
#include <utility>
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace framework {
 | 
				
			||||
namespace ir {
 | 
				
			||||
 | 
				
			||||
std::unique_ptr<ir::Graph> IsTestPass::ApplyImpl(
 | 
				
			||||
    std::unique_ptr<ir::Graph> graph) const {
 | 
				
			||||
  VLOG(3) << "Sets is_test attrbiute to true and if it is missing, inserts it "
 | 
				
			||||
             "for activations and pooling.";
 | 
				
			||||
  auto op_list = {"pool2d",      "sigmoid",      "logsigmoid",
 | 
				
			||||
                  "softshrink",  "exp",          "brelu",
 | 
				
			||||
                  "pow",         "leaky_relu",   "stanh",
 | 
				
			||||
                  "relu",        "tanh",         "tanh_shrink",
 | 
				
			||||
                  "sqrt",        "abs",          "ceil",
 | 
				
			||||
                  "elu",         "floor",        "cos",
 | 
				
			||||
                  "sin",         "round",        "reciprocal",
 | 
				
			||||
                  "hard_shrink", "hard_sigmoid", "relu6",
 | 
				
			||||
                  "soft_relu",   "swish",        "thresholded_relu",
 | 
				
			||||
                  "log",         "square",       "softplus",
 | 
				
			||||
                  "softsign"};
 | 
				
			||||
  for (const Node* n : graph->Nodes()) {
 | 
				
			||||
    if (n->IsOp()) {
 | 
				
			||||
      auto* op = n->Op();
 | 
				
			||||
      if (op->HasAttr("is_test")) {
 | 
				
			||||
        op->SetAttr("is_test", true);
 | 
				
			||||
      } else if (std::find(begin(op_list), end(op_list), op->Type()) !=
 | 
				
			||||
                 end(op_list)) {
 | 
				
			||||
        op->MutableAttrMap()->insert(
 | 
				
			||||
            std::pair<std::string, Attribute>("is_test", true));
 | 
				
			||||
      }
 | 
				
			||||
    }
 | 
				
			||||
  }
 | 
				
			||||
  return graph;
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
}  // namespace ir
 | 
				
			||||
}  // namespace framework
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
 | 
				
			||||
REGISTER_PASS(is_test_pass, paddle::framework::ir::IsTestPass);
 | 
				
			||||
@ -0,0 +1,31 @@
 | 
				
			||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#pragma once
 | 
				
			||||
 | 
				
			||||
#include "paddle/fluid/framework/ir/pass.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace framework {
 | 
				
			||||
namespace ir {
 | 
				
			||||
 | 
				
			||||
class IsTestPass : public Pass {
 | 
				
			||||
 protected:
 | 
				
			||||
  std::unique_ptr<ir::Graph> ApplyImpl(
 | 
				
			||||
      std::unique_ptr<ir::Graph> graph) const override;
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
}  // namespace ir
 | 
				
			||||
}  // namespace framework
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
@ -0,0 +1,117 @@
 | 
				
			||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
//
 | 
				
			||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
// you may not use this file except in compliance with the License.
 | 
				
			||||
// You may obtain a copy of the License at
 | 
				
			||||
//
 | 
				
			||||
//     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
//
 | 
				
			||||
// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
// See the License for the specific language governing permissions and
 | 
				
			||||
// limitations under the License.
 | 
				
			||||
 | 
				
			||||
#include "paddle/fluid/framework/ir/is_test_pass.h"
 | 
				
			||||
 | 
				
			||||
#include <gtest/gtest.h>
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace framework {
 | 
				
			||||
namespace ir {
 | 
				
			||||
 | 
				
			||||
enum class ISTEST_STATE { FALSE, TRUE, UNSET };
 | 
				
			||||
 | 
				
			||||
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
 | 
				
			||||
           const std::vector<std::string>& inputs,
 | 
				
			||||
           const std::vector<std::string>& outputs, bool use_mkldnn = false,
 | 
				
			||||
           ISTEST_STATE is_test = ISTEST_STATE::UNSET) {
 | 
				
			||||
  auto* op = prog->MutableBlock(0)->AppendOp();
 | 
				
			||||
  op->SetType(type);
 | 
				
			||||
  op->SetAttr("name", name);
 | 
				
			||||
  op->SetInput("X", inputs);
 | 
				
			||||
  op->SetOutput("Out", outputs);
 | 
				
			||||
  op->SetAttr("use_mkldnn", use_mkldnn);
 | 
				
			||||
  if (is_test == ISTEST_STATE::UNSET)
 | 
				
			||||
    op->MutableAttrMap()->erase("is_test");
 | 
				
			||||
  else if (is_test == ISTEST_STATE::FALSE)
 | 
				
			||||
    op->SetAttr("is_test", false);
 | 
				
			||||
  else
 | 
				
			||||
    op->SetAttr("is_test", true);
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
// a->pool2d->b
 | 
				
			||||
// b->relu->c
 | 
				
			||||
// c,weights1)->conv2d->d
 | 
				
			||||
//
 | 
				
			||||
// d->pool2d->e
 | 
				
			||||
// e->hard_sigmoid->f
 | 
				
			||||
// (f,weights2)->conv2d->g
 | 
				
			||||
//
 | 
				
			||||
// g->pool2d->h
 | 
				
			||||
// h->tanh->i
 | 
				
			||||
// (i,weights3)->conv2d->j
 | 
				
			||||
ProgramDesc BuildProgramDesc() {
 | 
				
			||||
  ProgramDesc prog;
 | 
				
			||||
  for (auto& v :
 | 
				
			||||
       std::vector<std::string>({"a", "b", "c", "d", "e", "f", "g", "h", "i",
 | 
				
			||||
                                 "j", "weights1", "weights2", "weights3"})) {
 | 
				
			||||
    auto* var = prog.MutableBlock(0)->Var(v);
 | 
				
			||||
    var->SetType(proto::VarType::SELECTED_ROWS);
 | 
				
			||||
    if (v == "weights1" || v == "weights2" || v == "weights3") {
 | 
				
			||||
      var->SetPersistable(true);
 | 
				
			||||
    }
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  SetOp(&prog, "pool2d", "pooling1", std::vector<std::string>({"a"}),
 | 
				
			||||
        std::vector<std::string>({"b"}), true, ISTEST_STATE::TRUE);
 | 
				
			||||
  SetOp(&prog, "relu", "activation1", std::vector<std::string>({"b"}),
 | 
				
			||||
        std::vector<std::string>({"c"}), true, ISTEST_STATE::TRUE);
 | 
				
			||||
  SetOp(&prog, "conv2d", "conv1", std::vector<std::string>({"c", "weights1"}),
 | 
				
			||||
        std::vector<std::string>({"d"}), true, ISTEST_STATE::TRUE);
 | 
				
			||||
 | 
				
			||||
  SetOp(&prog, "pool2d", "pooling2", std::vector<std::string>({"d"}),
 | 
				
			||||
        std::vector<std::string>({"e"}), false, ISTEST_STATE::FALSE);
 | 
				
			||||
  SetOp(&prog, "hard_sigmoid", "activation2", std::vector<std::string>({"e"}),
 | 
				
			||||
        std::vector<std::string>({"f"}), false, ISTEST_STATE::FALSE);
 | 
				
			||||
  SetOp(&prog, "conv2d", "conv2", std::vector<std::string>({"f", "weights2"}),
 | 
				
			||||
        std::vector<std::string>({"g"}), false, ISTEST_STATE::FALSE);
 | 
				
			||||
 | 
				
			||||
  SetOp(&prog, "pool2d", "pooling3", std::vector<std::string>({"g"}),
 | 
				
			||||
        std::vector<std::string>({"h"}), false, ISTEST_STATE::UNSET);
 | 
				
			||||
  SetOp(&prog, "tanh", "activation3", std::vector<std::string>({"h"}),
 | 
				
			||||
        std::vector<std::string>({"i"}), true, ISTEST_STATE::UNSET);
 | 
				
			||||
  SetOp(&prog, "conv2d", "conv3", std::vector<std::string>({"i", "weights3"}),
 | 
				
			||||
        std::vector<std::string>({"j"}), false, ISTEST_STATE::UNSET);
 | 
				
			||||
 | 
				
			||||
  return prog;
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
TEST(IsTestPass, basic) {
 | 
				
			||||
  auto prog = BuildProgramDesc();
 | 
				
			||||
 | 
				
			||||
  std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
 | 
				
			||||
 | 
				
			||||
  auto pass = PassRegistry::Instance().Get("is_test_pass");
 | 
				
			||||
 | 
				
			||||
  graph = pass->Apply(std::move(graph));
 | 
				
			||||
 | 
				
			||||
  for (auto* node : graph->Nodes()) {
 | 
				
			||||
    if (node->IsOp()) {
 | 
				
			||||
      auto* op = node->Op();
 | 
				
			||||
      auto op_name = boost::get<std::string>(op->GetAttr("name"));
 | 
				
			||||
      if (op_name == "conv3") {
 | 
				
			||||
        ASSERT_FALSE(op->HasAttr("is_test"));
 | 
				
			||||
      } else {
 | 
				
			||||
        ASSERT_TRUE(op->HasAttr("is_test"));
 | 
				
			||||
        EXPECT_TRUE(boost::get<bool>(op->GetAttr("is_test")));
 | 
				
			||||
      }
 | 
				
			||||
    }
 | 
				
			||||
  }
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
}  // namespace ir
 | 
				
			||||
}  // namespace framework
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
 | 
				
			||||
USE_PASS(is_test_pass);
 | 
				
			||||
@ -0,0 +1,80 @@
 | 
				
			||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
 | 
				
			||||
#include "paddle/fluid/inference/tensorrt/plugin/prelu_op_plugin.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace inference {
 | 
				
			||||
namespace tensorrt {
 | 
				
			||||
 | 
				
			||||
/*
 | 
				
			||||
 * PRelu converter from fluid to tensorRT.
 | 
				
			||||
 */
 | 
				
			||||
class PReluOpConverter : public OpConverter {
 | 
				
			||||
 public:
 | 
				
			||||
  void operator()(const framework::proto::OpDesc& op,
 | 
				
			||||
                  const framework::Scope& scope, bool test_mode) override {
 | 
				
			||||
    VLOG(4) << "convert fluid prelu op to tensorrt prelu layer";
 | 
				
			||||
 | 
				
			||||
    framework::OpDesc op_desc(op, nullptr);
 | 
				
			||||
    // Declare inputs
 | 
				
			||||
    int input_num = op_desc.Input("X").size();
 | 
				
			||||
    PADDLE_ENFORCE(input_num == 1);
 | 
				
			||||
    auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
 | 
				
			||||
    // Get output
 | 
				
			||||
    size_t output_num = op_desc.Output("Out").size();
 | 
				
			||||
    PADDLE_ENFORCE(output_num == 1);
 | 
				
			||||
    // Get attrs
 | 
				
			||||
    std::string mode = boost::get<std::string>(op_desc.GetAttr("mode"));
 | 
				
			||||
    //
 | 
				
			||||
    auto* alpha_var = scope.FindVar(op_desc.Input("Alpha")[0]);
 | 
				
			||||
    PADDLE_ENFORCE_NOT_NULL(alpha_var);
 | 
				
			||||
    auto* alpha_tensor = alpha_var->GetMutable<framework::LoDTensor>();
 | 
				
			||||
 | 
				
			||||
    platform::CUDAPlace place;
 | 
				
			||||
    std::unique_ptr<framework::LoDTensor> alpha_tensor_device(
 | 
				
			||||
        new framework::LoDTensor());
 | 
				
			||||
    alpha_tensor_device->Resize(alpha_tensor->dims());
 | 
				
			||||
    TensorCopySync(*alpha_tensor, place, alpha_tensor_device.get());
 | 
				
			||||
    float* alpha_data = alpha_tensor_device->mutable_data<float>(place);
 | 
				
			||||
 | 
				
			||||
    // Transform alpha to TensorRTEngine::Weight
 | 
				
			||||
    TensorRTEngine::Weight alpha_rt(nvinfer1::DataType::kFLOAT,
 | 
				
			||||
                                    static_cast<void*>(alpha_data),
 | 
				
			||||
                                    alpha_tensor_device->numel());
 | 
				
			||||
    PReluPlugin* plugin = new PReluPlugin(alpha_rt, mode);
 | 
				
			||||
    nvinfer1::IPluginLayer* layer =
 | 
				
			||||
        engine_->AddPlugin(&input, input_num, plugin);
 | 
				
			||||
    // keep alpha tensor to avoid release it's memory
 | 
				
			||||
    engine_->weight_map[op_desc.Input("Alpha")[0]] =
 | 
				
			||||
        std::move(alpha_tensor_device);
 | 
				
			||||
 | 
				
			||||
    std::string layer_name = "prelu (Output: ";
 | 
				
			||||
    auto output_name = op_desc.Output("Out")[0];
 | 
				
			||||
    layer->getOutput(0)->setName(output_name.c_str());
 | 
				
			||||
    engine_->SetITensor(output_name, layer->getOutput(0));
 | 
				
			||||
    layer_name += output_name;
 | 
				
			||||
    if (test_mode) {
 | 
				
			||||
      engine_->DeclareOutput(output_name);
 | 
				
			||||
    }
 | 
				
			||||
    layer->setName((layer_name + ")").c_str());
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
}  // namespace tensorrt
 | 
				
			||||
}  // namespace inference
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
 | 
				
			||||
REGISTER_TRT_OP_CONVERTER(prelu, PReluOpConverter);
 | 
				
			||||
@ -0,0 +1,94 @@
 | 
				
			||||
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#include <gtest/gtest.h>
 | 
				
			||||
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
 | 
				
			||||
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace inference {
 | 
				
			||||
namespace tensorrt {
 | 
				
			||||
 | 
				
			||||
TEST(prelu_op, test_channel_wise) {
 | 
				
			||||
  std::unordered_set<std::string> parameters({"prelu_alpha"});
 | 
				
			||||
  framework::Scope scope;
 | 
				
			||||
  TRTConvertValidation validator(10, parameters, scope, 1000);
 | 
				
			||||
  validator.DeclInputVar("prelu_input", nvinfer1::DimsCHW(3, 2, 2));
 | 
				
			||||
  validator.DeclParamVar("prelu_alpha", nvinfer1::Dims3(3, 1, 1));
 | 
				
			||||
  validator.DeclOutputVar("prelu_out", nvinfer1::DimsCHW(3, 2, 2));
 | 
				
			||||
 | 
				
			||||
  // Prepare Op description
 | 
				
			||||
  framework::OpDesc desc;
 | 
				
			||||
  desc.SetType("prelu");
 | 
				
			||||
  desc.SetInput("X", {"prelu_input"});
 | 
				
			||||
  desc.SetInput("Alpha", {"prelu_alpha"});
 | 
				
			||||
  desc.SetOutput("Out", {"prelu_out"});
 | 
				
			||||
 | 
				
			||||
  desc.SetAttr("mode", std::string("channel"));
 | 
				
			||||
 | 
				
			||||
  validator.SetOp(*desc.Proto());
 | 
				
			||||
 | 
				
			||||
  validator.Execute(1);
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
TEST(prelu_op, test_element_wise) {
 | 
				
			||||
  std::unordered_set<std::string> parameters({"prelu_alpha"});
 | 
				
			||||
  framework::Scope scope;
 | 
				
			||||
  TRTConvertValidation validator(10, parameters, scope, 1000);
 | 
				
			||||
  validator.DeclInputVar("prelu_input", nvinfer1::DimsCHW(3, 2, 2));
 | 
				
			||||
  validator.DeclParamVar("prelu_alpha", nvinfer1::Dims4(10, 3, 2, 2));
 | 
				
			||||
  validator.DeclOutputVar("prelu_out", nvinfer1::DimsCHW(3, 2, 2));
 | 
				
			||||
 | 
				
			||||
  // Prepare Op description
 | 
				
			||||
  framework::OpDesc desc;
 | 
				
			||||
  desc.SetType("prelu");
 | 
				
			||||
  desc.SetInput("X", {"prelu_input"});
 | 
				
			||||
  desc.SetInput("Alpha", {"prelu_alpha"});
 | 
				
			||||
  desc.SetOutput("Out", {"prelu_out"});
 | 
				
			||||
 | 
				
			||||
  desc.SetAttr("mode", std::string("element"));
 | 
				
			||||
 | 
				
			||||
  validator.SetOp(*desc.Proto());
 | 
				
			||||
 | 
				
			||||
  validator.Execute(1);
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
TEST(prelu_op, test_scalar) {
 | 
				
			||||
  std::unordered_set<std::string> parameters({"prelu_alpha"});
 | 
				
			||||
  framework::Scope scope;
 | 
				
			||||
  TRTConvertValidation validator(10, parameters, scope, 1000);
 | 
				
			||||
  validator.DeclInputVar("prelu_input", nvinfer1::DimsCHW(3, 2, 2));
 | 
				
			||||
  validator.DeclParamVar("prelu_alpha", nvinfer1::Dims3(1, 1, 1));
 | 
				
			||||
  validator.DeclOutputVar("prelu_out", nvinfer1::DimsCHW(3, 2, 2));
 | 
				
			||||
 | 
				
			||||
  // Prepare Op description
 | 
				
			||||
  framework::OpDesc desc;
 | 
				
			||||
  desc.SetType("prelu");
 | 
				
			||||
  desc.SetInput("X", {"prelu_input"});
 | 
				
			||||
  desc.SetInput("Alpha", {"prelu_alpha"});
 | 
				
			||||
  desc.SetOutput("Out", {"prelu_out"});
 | 
				
			||||
 | 
				
			||||
  desc.SetAttr("mode", std::string("all"));
 | 
				
			||||
 | 
				
			||||
  validator.SetOp(*desc.Proto());
 | 
				
			||||
 | 
				
			||||
  validator.Execute(1);
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
}  // namespace tensorrt
 | 
				
			||||
}  // namespace inference
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
 | 
				
			||||
// USE_OP(prelu);
 | 
				
			||||
USE_CPU_ONLY_OP(prelu);
 | 
				
			||||
Some files were not shown because too many files have changed in this diff Show More
					Loading…
					
					
				
		Reference in new issue