Layer normalization fuse pass. (#30721)
	
		
	
				
					
				
			
							parent
							
								
									b1026f64af
								
							
						
					
					
						commit
						4f066e316e
					
				@ -0,0 +1,231 @@
 | 
				
			||||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
//
 | 
				
			||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
// you may not use this file except in compliance with the License.
 | 
				
			||||
// You may obtain a copy of the License at
 | 
				
			||||
//
 | 
				
			||||
//     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
//
 | 
				
			||||
// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
// See the License for the specific language governing permissions and
 | 
				
			||||
// limitations under the License.
 | 
				
			||||
 | 
				
			||||
#include <string>
 | 
				
			||||
#include <vector>
 | 
				
			||||
 | 
				
			||||
#include "paddle/fluid/framework/framework.pb.h"
 | 
				
			||||
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
 | 
				
			||||
#include "paddle/fluid/framework/ir/layer_norm_fuse_pass.h"
 | 
				
			||||
#include "paddle/fluid/framework/op_version_registry.h"
 | 
				
			||||
#include "paddle/fluid/framework/var_desc.h"
 | 
				
			||||
#include "paddle/fluid/platform/enforce.h"
 | 
				
			||||
#include "paddle/fluid/string/pretty_log.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace framework {
 | 
				
			||||
namespace ir {
 | 
				
			||||
 | 
				
			||||
// cpplint complaints (wrong!) for not included <string> header in below line.
 | 
				
			||||
using string::PrettyLogDetail;  // NOLINT
 | 
				
			||||
 | 
				
			||||
namespace {
 | 
				
			||||
void validateReduceOpAttrs(const Node* node, const std::string& name) {
 | 
				
			||||
  const auto* op = node->Op();
 | 
				
			||||
  if (op->HasAttr("dim")) {
 | 
				
			||||
    auto dims = BOOST_GET_CONST(std::vector<int>, op->GetAttr("dim"));
 | 
				
			||||
    PADDLE_ENFORCE_EQ(dims.size(), 1, platform::errors::PreconditionNotMet(
 | 
				
			||||
                                          "The LayerNorm fusion ", name,
 | 
				
			||||
                                          " reduction must happen only over "
 | 
				
			||||
                                          "single dimension."));
 | 
				
			||||
    PADDLE_ENFORCE_EQ(dims.front(), -1, platform::errors::PreconditionNotMet(
 | 
				
			||||
                                            "The LayerNorm fusion ", name,
 | 
				
			||||
                                            " reduction must happen over last "
 | 
				
			||||
                                            "dimension."));
 | 
				
			||||
  }
 | 
				
			||||
  if (op->HasAttr("reduce_all")) {
 | 
				
			||||
    PADDLE_ENFORCE(!BOOST_GET_CONST(bool, op->GetAttr("reduce_all")),
 | 
				
			||||
                   platform::errors::PreconditionNotMet(
 | 
				
			||||
                       "The LayerNorm fusion ", name,
 | 
				
			||||
                       " reduction must have "
 | 
				
			||||
                       "\'reduce_all\' attribute set to false."));
 | 
				
			||||
  }
 | 
				
			||||
  if (op->HasAttr("keep_dim")) {
 | 
				
			||||
    PADDLE_ENFORCE(BOOST_GET_CONST(bool, op->GetAttr("keep_dim")),
 | 
				
			||||
                   platform::errors::PreconditionNotMet(
 | 
				
			||||
                       "The LayerNorm fusion ", name,
 | 
				
			||||
                       " reduction must have "
 | 
				
			||||
                       "\'keep_dim\' attribute set to true."));
 | 
				
			||||
  }
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
void setIntermediateOut(OpDesc* desc, const std::string& out_name,
 | 
				
			||||
                        const std::string& scope_name) {
 | 
				
			||||
  std::string new_name = scope_name + "/at." + out_name + ".new";
 | 
				
			||||
  desc->SetOutput(out_name, {new_name});
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
void addIntermediateOut(Node* op_node, const std::string& out_name,
 | 
				
			||||
                        const std::string& scope_name, Graph* graph) {
 | 
				
			||||
  std::string new_name = scope_name + "/at." + out_name + ".new";
 | 
				
			||||
  VarDesc out_var(new_name);
 | 
				
			||||
  out_var.SetPersistable(false);
 | 
				
			||||
  auto* node_var = graph->CreateVarNode(&out_var);
 | 
				
			||||
  IR_NODE_LINK_TO(op_node, node_var);
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
}  // namespace
 | 
				
			||||
 | 
				
			||||
void LayerNormFusePass::ApplyImpl(Graph* graph) const {
 | 
				
			||||
  PADDLE_ENFORCE_NOT_NULL(graph,
 | 
				
			||||
                          platform::errors::InvalidArgument(
 | 
				
			||||
                              "The input graph of "
 | 
				
			||||
                              "LayerNormFusePass should not be nullptr."));
 | 
				
			||||
  FusePassBase::Init(scope_name_, graph);
 | 
				
			||||
 | 
				
			||||
  auto* scope = param_scope();
 | 
				
			||||
  PADDLE_ENFORCE_NOT_NULL(
 | 
				
			||||
      scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
 | 
				
			||||
 | 
				
			||||
  GraphPatternDetector gpd;
 | 
				
			||||
  patterns::LayerNorm layer_norm_pattern(gpd.mutable_pattern(), scope_name_);
 | 
				
			||||
  layer_norm_pattern();
 | 
				
			||||
 | 
				
			||||
  int found_layer_norm_count = 0;
 | 
				
			||||
  auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
 | 
				
			||||
                     Graph* g) {
 | 
				
			||||
    VLOG(4) << "Fuse LayerNorm from subgraph.";
 | 
				
			||||
    GET_IR_NODE_FROM_SUBGRAPH(x, x, layer_norm_pattern);
 | 
				
			||||
    GET_IR_NODE_FROM_SUBGRAPH(x_mean, x_mean, layer_norm_pattern);
 | 
				
			||||
    GET_IR_NODE_FROM_SUBGRAPH(x_mean_out, x_mean_out, layer_norm_pattern);
 | 
				
			||||
    GET_IR_NODE_FROM_SUBGRAPH(x_sub_mean, x_sub_mean, layer_norm_pattern);
 | 
				
			||||
    GET_IR_NODE_FROM_SUBGRAPH(x_sub_mean_out, x_sub_mean_out,
 | 
				
			||||
                              layer_norm_pattern);
 | 
				
			||||
    GET_IR_NODE_FROM_SUBGRAPH(sqr_pow, sqr_pow, layer_norm_pattern);
 | 
				
			||||
    GET_IR_NODE_FROM_SUBGRAPH(x_sub_mean_sqr, x_sub_mean_sqr,
 | 
				
			||||
                              layer_norm_pattern);
 | 
				
			||||
    GET_IR_NODE_FROM_SUBGRAPH(x_sub_mean_sqr_out, x_sub_mean_sqr_out,
 | 
				
			||||
                              layer_norm_pattern);
 | 
				
			||||
    GET_IR_NODE_FROM_SUBGRAPH(std_dev, std_dev, layer_norm_pattern);
 | 
				
			||||
    GET_IR_NODE_FROM_SUBGRAPH(std_dev_out, std_dev_out, layer_norm_pattern);
 | 
				
			||||
    GET_IR_NODE_FROM_SUBGRAPH(eps, eps, layer_norm_pattern);
 | 
				
			||||
    GET_IR_NODE_FROM_SUBGRAPH(std_dev_eps, std_dev_eps, layer_norm_pattern);
 | 
				
			||||
    GET_IR_NODE_FROM_SUBGRAPH(std_dev_eps_out, std_dev_eps_out,
 | 
				
			||||
                              layer_norm_pattern);
 | 
				
			||||
    GET_IR_NODE_FROM_SUBGRAPH(std_dev_eps_sqrt, std_dev_eps_sqrt,
 | 
				
			||||
                              layer_norm_pattern);
 | 
				
			||||
    GET_IR_NODE_FROM_SUBGRAPH(std_dev_eps_sqrt_out, std_dev_eps_sqrt_out,
 | 
				
			||||
                              layer_norm_pattern);
 | 
				
			||||
    GET_IR_NODE_FROM_SUBGRAPH(division, division, layer_norm_pattern);
 | 
				
			||||
    GET_IR_NODE_FROM_SUBGRAPH(division_out, division_out, layer_norm_pattern);
 | 
				
			||||
    GET_IR_NODE_FROM_SUBGRAPH(gamma, gamma, layer_norm_pattern);
 | 
				
			||||
    GET_IR_NODE_FROM_SUBGRAPH(scale, scale, layer_norm_pattern);
 | 
				
			||||
    GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, layer_norm_pattern);
 | 
				
			||||
    GET_IR_NODE_FROM_SUBGRAPH(beta, beta, layer_norm_pattern);
 | 
				
			||||
    GET_IR_NODE_FROM_SUBGRAPH(shift, shift, layer_norm_pattern);
 | 
				
			||||
    GET_IR_NODE_FROM_SUBGRAPH(shift_out, shift_out, layer_norm_pattern);
 | 
				
			||||
 | 
				
			||||
    auto* eps_tensor = scope->FindVar(eps->Name())->GetMutable<LoDTensor>();
 | 
				
			||||
 | 
				
			||||
    // ------------------ subgraph node's validation ---------------------------
 | 
				
			||||
    PADDLE_ENFORCE_EQ(
 | 
				
			||||
        eps_tensor->numel(), 1,
 | 
				
			||||
        platform::errors::InvalidArgument(
 | 
				
			||||
            "The LayerNorm divisor "
 | 
				
			||||
            "epsilon value must be one-element tensor, but has %s "
 | 
				
			||||
            "elements.",
 | 
				
			||||
            eps_tensor->numel()));
 | 
				
			||||
    PADDLE_ENFORCE_EQ(eps_tensor->type(), proto::VarType::FP32,
 | 
				
			||||
                      platform::errors::InvalidArgument(
 | 
				
			||||
                          "The LayerNorm divisor "
 | 
				
			||||
                          "epsilon value must be of FP32 data type, but is %s.",
 | 
				
			||||
                          eps_tensor->type()));
 | 
				
			||||
 | 
				
			||||
    const auto& gamma_shape = gamma->Var()->GetShape();
 | 
				
			||||
    const auto& beta_shape = beta->Var()->GetShape();
 | 
				
			||||
    const auto& x_shape = x->Var()->GetShape();
 | 
				
			||||
    int64_t x_last_dim = x_shape.back();
 | 
				
			||||
 | 
				
			||||
    PADDLE_ENFORCE_EQ(gamma_shape.size(), 1,
 | 
				
			||||
                      platform::errors::InvalidArgument(
 | 
				
			||||
                          "The LayerNorm gamma "
 | 
				
			||||
                          "(scale) tensor shape must be one-dimensional, "
 | 
				
			||||
                          "but is %s.",
 | 
				
			||||
                          gamma_shape.size()));
 | 
				
			||||
    PADDLE_ENFORCE_EQ(beta_shape.size(), 1,
 | 
				
			||||
                      platform::errors::InvalidArgument(
 | 
				
			||||
                          "The LayerNorm beta "
 | 
				
			||||
                          "(shift) tensor shape must be one-dimensional, "
 | 
				
			||||
                          "but is %s.",
 | 
				
			||||
                          beta_shape.size()));
 | 
				
			||||
    PADDLE_ENFORCE_EQ(beta_shape, gamma_shape,
 | 
				
			||||
                      platform::errors::InvalidArgument(
 | 
				
			||||
                          "The LayerNorm beta "
 | 
				
			||||
                          "and gamma tensors shapes' must be equal."));
 | 
				
			||||
    PADDLE_ENFORCE_EQ(gamma_shape.front(), x_last_dim,
 | 
				
			||||
                      platform::errors::InvalidArgument(
 | 
				
			||||
                          "The LayerNorm beta "
 | 
				
			||||
                          "and gamma tensors shapes' must be equal to the last "
 | 
				
			||||
                          "input's dimension size."));
 | 
				
			||||
 | 
				
			||||
    validateReduceOpAttrs(x_mean, "input mean");
 | 
				
			||||
    validateReduceOpAttrs(std_dev, "std_dev mean");
 | 
				
			||||
 | 
				
			||||
    // ------------------ op creation and placement ---------------------------
 | 
				
			||||
 | 
				
			||||
    OpDesc ln_op_desc;
 | 
				
			||||
    ln_op_desc.SetType("layer_norm");
 | 
				
			||||
    ln_op_desc.SetInput("X", {x->Name()});
 | 
				
			||||
    ln_op_desc.SetInput("Scale", {gamma->Name()});
 | 
				
			||||
    ln_op_desc.SetInput("Bias", {beta->Name()});
 | 
				
			||||
    ln_op_desc.SetOutput("Y", {shift_out->Name()});
 | 
				
			||||
    setIntermediateOut(&ln_op_desc, "Mean", scope_name_);
 | 
				
			||||
    setIntermediateOut(&ln_op_desc, "Variance", scope_name_);
 | 
				
			||||
    ln_op_desc.SetAttr("begin_norm_axis", static_cast<int>(x_shape.size() - 1));
 | 
				
			||||
    ln_op_desc.SetAttr("epsilon", *(eps_tensor->data<float>()));
 | 
				
			||||
    ln_op_desc.SetAttr("is_test", true);
 | 
				
			||||
    Node* ln_op = g->CreateOpNode(&ln_op_desc);
 | 
				
			||||
 | 
				
			||||
    addIntermediateOut(ln_op, "Mean", scope_name_, g);
 | 
				
			||||
    addIntermediateOut(ln_op, "Variance", scope_name_, g);
 | 
				
			||||
 | 
				
			||||
    IR_NODE_LINK_TO(x, ln_op);
 | 
				
			||||
    IR_NODE_LINK_TO(gamma, ln_op);
 | 
				
			||||
    IR_NODE_LINK_TO(beta, ln_op);
 | 
				
			||||
    IR_OP_VAR_LINK(ln_op, shift_out);
 | 
				
			||||
    GraphSafeRemoveNodes(
 | 
				
			||||
        g,
 | 
				
			||||
        {x_mean, x_mean_out, x_sub_mean, x_sub_mean_out, sqr_pow,
 | 
				
			||||
         x_sub_mean_sqr, x_sub_mean_sqr_out, std_dev, std_dev_out, eps,
 | 
				
			||||
         std_dev_eps, std_dev_eps_out, std_dev_eps_sqrt, std_dev_eps_sqrt_out,
 | 
				
			||||
         division, division_out, scale, scale_out, shift});
 | 
				
			||||
    found_layer_norm_count++;
 | 
				
			||||
  };
 | 
				
			||||
 | 
				
			||||
  gpd(graph, handler);
 | 
				
			||||
  AddStatis(found_layer_norm_count);
 | 
				
			||||
  PrettyLogDetail("---    Fused %d subgraphs into layer_norm op.",
 | 
				
			||||
                  found_layer_norm_count);
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
}  // namespace ir
 | 
				
			||||
}  // namespace framework
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
 | 
				
			||||
REGISTER_PASS(layer_norm_fuse_pass, paddle::framework::ir::LayerNormFusePass);
 | 
				
			||||
REGISTER_PASS_CAPABILITY(layer_norm_fuse_pass)
 | 
				
			||||
    .AddCombination(
 | 
				
			||||
        paddle::framework::compatible::OpVersionComparatorCombination()
 | 
				
			||||
            .GE("elementwise_add", 0)
 | 
				
			||||
            .LE("elementwise_add", 1)
 | 
				
			||||
            .GE("elementwise_div", 0)
 | 
				
			||||
            .LE("elementwise_div", 1)
 | 
				
			||||
            .GE("elementwise_mul", 0)
 | 
				
			||||
            .LE("elementwise_mul", 1)
 | 
				
			||||
            .GE("elementwise_pow", 0)
 | 
				
			||||
            .LE("elementwise_pow", 1)
 | 
				
			||||
            .GE("elementwise_sub", 0)
 | 
				
			||||
            .LE("elementwise_sub", 1)
 | 
				
			||||
            .EQ("reduce_mean", 0)
 | 
				
			||||
            .EQ("sqrt", 0));
 | 
				
			||||
@ -0,0 +1,84 @@
 | 
				
			||||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
//
 | 
				
			||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
// you may not use this file except in compliance with the License.
 | 
				
			||||
// You may obtain a copy of the License at
 | 
				
			||||
//
 | 
				
			||||
//     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
//
 | 
				
			||||
// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
// See the License for the specific language governing permissions and
 | 
				
			||||
// limitations under the License.
 | 
				
			||||
 | 
				
			||||
#pragma once
 | 
				
			||||
 | 
				
			||||
#include <string>
 | 
				
			||||
 | 
				
			||||
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
 | 
				
			||||
#include "paddle/fluid/framework/ir/graph.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace framework {
 | 
				
			||||
namespace ir {
 | 
				
			||||
 | 
				
			||||
/*
 | 
				
			||||
 * \brief   Fuse the subgraph representing layer normalization into
 | 
				
			||||
 *          layer_norm op.
 | 
				
			||||
 *
 | 
				
			||||
 * \note    The following graph represents this equation:
 | 
				
			||||
 *
 | 
				
			||||
 *                       x - u(x)
 | 
				
			||||
 *          y(c) * -------------------  + b(c)
 | 
				
			||||
 *                 sqrt(sigma^2 + eps)
 | 
				
			||||
 *
 | 
				
			||||
 *          x        - input data
 | 
				
			||||
 *          u(x)     - mean
 | 
				
			||||
 *          sigma^2  - standard deviation
 | 
				
			||||
 *          eps      - epsilon
 | 
				
			||||
 *          y(c)     - gamma (scale) channelwise
 | 
				
			||||
 *          b(c)     - beta (shift) channelwise
 | 
				
			||||
 *
 | 
				
			||||
 *
 | 
				
			||||
 *            X
 | 
				
			||||
 *           / \
 | 
				
			||||
 *          /   reduce_mean "u(x)"
 | 
				
			||||
 *          \   /
 | 
				
			||||
 *      elementwise_sub     "x - u(x)"
 | 
				
			||||
 *      /           \    2
 | 
				
			||||
 *      |            \  /
 | 
				
			||||
 *      |      elementwise_pow  "(x - u(x))^2"
 | 
				
			||||
 *      |             |
 | 
				
			||||
 *      |       reduce_mean     "sigma^2 = 1/C*Sum{(x - u(x))^2}"
 | 
				
			||||
 *      |             |     eps
 | 
				
			||||
 *      |             |     /
 | 
				
			||||
 *      |       elementwise_add "sigma^2 + epsilon"
 | 
				
			||||
 *      \             |
 | 
				
			||||
 *       \           sqrt       "sqrt(sigma^2 + epsilon)"
 | 
				
			||||
 *        \          /
 | 
				
			||||
 *         \        /
 | 
				
			||||
 *       elementwise_div        "lnorm = {x-u(x)}/{sqrt(sigma^2 + epsilon)}"
 | 
				
			||||
 *              |
 | 
				
			||||
 *       gamma  |
 | 
				
			||||
 *          \   |
 | 
				
			||||
 *       elementwise_mul        "scale: gamma(C) * lnorm"
 | 
				
			||||
 *              |
 | 
				
			||||
 *        beta  |
 | 
				
			||||
 *          \   |
 | 
				
			||||
 *       elementwise_add        "shift: gamma(C) * lnorm + beta(C)"
 | 
				
			||||
 */
 | 
				
			||||
class LayerNormFusePass : public FusePassBase {
 | 
				
			||||
 public:
 | 
				
			||||
  virtual ~LayerNormFusePass() {}
 | 
				
			||||
 | 
				
			||||
 protected:
 | 
				
			||||
  void ApplyImpl(ir::Graph *graph) const override;
 | 
				
			||||
 | 
				
			||||
 private:
 | 
				
			||||
  const std::string scope_name_{"layer_norm_fuse"};
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
}  // namespace ir
 | 
				
			||||
}  // namespace framework
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
@ -0,0 +1,199 @@
 | 
				
			||||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
//
 | 
				
			||||
// Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
// you may not use this file except in compliance with the License.
 | 
				
			||||
// You may obtain a copy of the License at
 | 
				
			||||
//
 | 
				
			||||
//     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
//
 | 
				
			||||
// Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
// distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
// See the License for the specific language governing permissions and
 | 
				
			||||
// limitations under the License.
 | 
				
			||||
 | 
				
			||||
#include <gtest/gtest.h>
 | 
				
			||||
 | 
				
			||||
#include "paddle/fluid/framework/framework.pb.h"
 | 
				
			||||
#include "paddle/fluid/framework/ir/layer_norm_fuse_pass.h"
 | 
				
			||||
#include "paddle/fluid/framework/ir/pass_test_util.h"
 | 
				
			||||
#include "paddle/fluid/framework/naive_executor.h"
 | 
				
			||||
#include "paddle/fluid/framework/op_desc.h"
 | 
				
			||||
#include "paddle/fluid/framework/op_version_registry.h"
 | 
				
			||||
#include "paddle/fluid/framework/program_desc.h"
 | 
				
			||||
#include "paddle/fluid/framework/scope.h"
 | 
				
			||||
#include "paddle/fluid/platform/errors.h"
 | 
				
			||||
#include "paddle/fluid/platform/place.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace framework {
 | 
				
			||||
namespace ir {
 | 
				
			||||
 | 
				
			||||
namespace {
 | 
				
			||||
 | 
				
			||||
ProgramDesc BuildGraphProgram() {
 | 
				
			||||
  auto prog = test::BuildProgramDesc(
 | 
				
			||||
      {"x", "x_mean_out", "x_sub_mean_out", "x_sub_mean_sqr_out", "std_dev_out",
 | 
				
			||||
       "std_dev_eps_out", "std_dev_eps_sqrt_out", "division_out", "scale_out",
 | 
				
			||||
       "shift_out"},
 | 
				
			||||
      {"sqr_pow", "eps", "gamma", "beta"});
 | 
				
			||||
 | 
				
			||||
  const auto& block_desc = prog.Block(0);
 | 
				
			||||
  auto* x_var_desc = block_desc.FindVar("x");
 | 
				
			||||
  x_var_desc->SetDataType(proto::VarType::FP32);
 | 
				
			||||
  x_var_desc->SetShape({3, 32, 48});
 | 
				
			||||
 | 
				
			||||
  auto* eps_var_desc = block_desc.FindVar("eps");
 | 
				
			||||
  eps_var_desc->SetDataType(proto::VarType::FP32);
 | 
				
			||||
  eps_var_desc->SetShape({1});
 | 
				
			||||
 | 
				
			||||
  auto* gamma_var_desc = block_desc.FindVar("gamma");
 | 
				
			||||
  gamma_var_desc->SetDataType(proto::VarType::FP32);
 | 
				
			||||
  gamma_var_desc->SetShape({48});
 | 
				
			||||
 | 
				
			||||
  auto* beta_var_desc = block_desc.FindVar("beta");
 | 
				
			||||
  beta_var_desc->SetDataType(proto::VarType::FP32);
 | 
				
			||||
  beta_var_desc->SetShape({48});
 | 
				
			||||
 | 
				
			||||
  auto* x_mean = test::CreateOp(&prog, "reduce_mean", {{"X", "x"}},
 | 
				
			||||
                                {{"Out", "x_mean_out"}}, false);
 | 
				
			||||
  x_mean->SetAttr("dim", std::vector<int>{-1});
 | 
				
			||||
  x_mean->SetAttr("keep_dim", true);
 | 
				
			||||
  x_mean->SetAttr("reduce_all", false);
 | 
				
			||||
 | 
				
			||||
  test::CreateOp(&prog, "elementwise_sub", {{"X", "x"}, {"Y", "x_mean_out"}},
 | 
				
			||||
                 {{"Out", "x_sub_mean_out"}}, false);
 | 
				
			||||
  test::CreateOp(&prog, "elementwise_pow",
 | 
				
			||||
                 {{"X", "x_sub_mean_out"}, {"Y", "sqr_pow"}},
 | 
				
			||||
                 {{"Out", "x_sub_mean_sqr_out"}}, false);
 | 
				
			||||
  auto* std_dev =
 | 
				
			||||
      test::CreateOp(&prog, "reduce_mean", {{"X", "x_sub_mean_sqr_out"}},
 | 
				
			||||
                     {{"Out", "std_dev_out"}}, false);
 | 
				
			||||
  std_dev->SetAttr("dim", std::vector<int>{-1});
 | 
				
			||||
  std_dev->SetAttr("keep_dim", true);
 | 
				
			||||
  std_dev->SetAttr("reduce_all", false);
 | 
				
			||||
 | 
				
			||||
  test::CreateOp(&prog, "elementwise_add", {{"X", "std_dev_out"}, {"Y", "eps"}},
 | 
				
			||||
                 {{"Out", "std_dev_eps_out"}}, false);
 | 
				
			||||
  test::CreateOp(&prog, "sqrt", {{"X", "std_dev_eps_out"}},
 | 
				
			||||
                 {{"Out", "std_dev_eps_sqrt_out"}}, false);
 | 
				
			||||
  test::CreateOp(&prog, "elementwise_div",
 | 
				
			||||
                 {{"X", "x_sub_mean_out"}, {"Y", "std_dev_eps_sqrt_out"}},
 | 
				
			||||
                 {{"Out", "division_out"}}, false);
 | 
				
			||||
  test::CreateOp(&prog, "elementwise_mul",
 | 
				
			||||
                 {{"X", "division_out"}, {"Y", "gamma"}},
 | 
				
			||||
                 {{"Out", "scale_out"}}, false);
 | 
				
			||||
  test::CreateOp(&prog, "elementwise_add", {{"X", "scale_out"}, {"Y", "beta"}},
 | 
				
			||||
                 {{"Out", "shift_out"}}, false);
 | 
				
			||||
  return prog;
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
bool CheckFusedSubgraphOpsCount(const Graph& graph) {
 | 
				
			||||
  return test::AssertOpsCount(graph, {{"reduce_mean", 0},
 | 
				
			||||
                                      {"elementwise_sub", 0},
 | 
				
			||||
                                      {"elementwise_pow", 0},
 | 
				
			||||
                                      {"elementwise_add", 0},
 | 
				
			||||
                                      {"sqrt", 0},
 | 
				
			||||
                                      {"elementwise_div", 0},
 | 
				
			||||
                                      {"elementwise_mul", 0},
 | 
				
			||||
                                      {"layer_norm", 1}});
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
}  // namespace
 | 
				
			||||
 | 
				
			||||
// ------------------------------ Test cases -----------------------------------
 | 
				
			||||
 | 
				
			||||
TEST(FuseLayerNormPass, TestFuse) {
 | 
				
			||||
  ProgramDesc prog = BuildGraphProgram();
 | 
				
			||||
 | 
				
			||||
  Graph graph(prog);
 | 
				
			||||
  constexpr int removed_nodes = 19;
 | 
				
			||||
  // LayerNorm + outputs: {Mean, Variance}
 | 
				
			||||
  constexpr int added_nodes = 3;
 | 
				
			||||
 | 
				
			||||
  auto place = paddle::platform::CPUPlace();
 | 
				
			||||
  NaiveExecutor exe{place};
 | 
				
			||||
  Scope scope;
 | 
				
			||||
  float eps_value = 1e-5f;
 | 
				
			||||
  // Init scope, as it is used in pass
 | 
				
			||||
  exe.CreateVariables(prog, 0, true, &scope);
 | 
				
			||||
  test::InitLoDTensorHolder<float>(&scope, place, "eps", {1}, &eps_value);
 | 
				
			||||
 | 
				
			||||
  graph.SetNotOwned(kParamScopeAttr, &scope);
 | 
				
			||||
  EXPECT_TRUE(test::RunPassAndAssert(&graph, "layer_norm_fuse_pass", "x",
 | 
				
			||||
                                     "shift_out", removed_nodes, added_nodes));
 | 
				
			||||
  EXPECT_TRUE(CheckFusedSubgraphOpsCount(graph));
 | 
				
			||||
 | 
				
			||||
  for (const auto* node : graph.Nodes()) {
 | 
				
			||||
    if (node->IsOp() && node->Op()->Type() == "layer_norm") {
 | 
				
			||||
      const auto* op = node->Op();
 | 
				
			||||
      ASSERT_TRUE(op->HasAttr("is_test"));
 | 
				
			||||
      EXPECT_TRUE(BOOST_GET_CONST(bool, op->GetAttr("is_test")));
 | 
				
			||||
      ASSERT_TRUE(op->HasAttr("begin_norm_axis"));
 | 
				
			||||
      ASSERT_TRUE(op->HasAttr("epsilon"));
 | 
				
			||||
    }
 | 
				
			||||
  }
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
TEST(FuseLayerNormPass, TestInvalidEpsNumel) {
 | 
				
			||||
  ProgramDesc prog = BuildGraphProgram();
 | 
				
			||||
 | 
				
			||||
  auto* eps_var_desc = prog.Block(0).FindVar("eps");
 | 
				
			||||
  eps_var_desc->SetDataType(proto::VarType::FP32);
 | 
				
			||||
  eps_var_desc->SetShape({2});
 | 
				
			||||
 | 
				
			||||
  Graph graph(prog);
 | 
				
			||||
  constexpr int removed_nodes = 19;
 | 
				
			||||
  constexpr int added_nodes = 3;
 | 
				
			||||
 | 
				
			||||
  auto place = paddle::platform::CPUPlace();
 | 
				
			||||
  NaiveExecutor exe{place};
 | 
				
			||||
  Scope scope;
 | 
				
			||||
  auto eps_values = std::vector<float>{1e-5f, 1e-5f};
 | 
				
			||||
  // Init scope, as it is used in pass
 | 
				
			||||
  exe.CreateVariables(prog, 0, true, &scope);
 | 
				
			||||
  test::InitLoDTensorHolder<float>(&scope, place, "eps", {2},
 | 
				
			||||
                                   eps_values.data());
 | 
				
			||||
 | 
				
			||||
  graph.SetNotOwned(kParamScopeAttr, &scope);
 | 
				
			||||
  EXPECT_THROW(test::RunPassAndAssert(&graph, "layer_norm_fuse_pass", "x",
 | 
				
			||||
                                      "shift_out", removed_nodes, added_nodes),
 | 
				
			||||
               paddle::platform::EnforceNotMet);
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
TEST(FuseLayerNormPass, TestInvalidEpsDataType) {
 | 
				
			||||
  ProgramDesc prog = BuildGraphProgram();
 | 
				
			||||
 | 
				
			||||
  auto* eps_var_desc = prog.Block(0).FindVar("eps");
 | 
				
			||||
  eps_var_desc->SetDataType(proto::VarType::FP64);
 | 
				
			||||
  eps_var_desc->SetShape({1});
 | 
				
			||||
 | 
				
			||||
  Graph graph(prog);
 | 
				
			||||
  constexpr int removed_nodes = 19;
 | 
				
			||||
  constexpr int added_nodes = 3;
 | 
				
			||||
 | 
				
			||||
  auto place = paddle::platform::CPUPlace();
 | 
				
			||||
  NaiveExecutor exe{place};
 | 
				
			||||
  Scope scope;
 | 
				
			||||
  double eps_value = 1e-5;
 | 
				
			||||
  // Init scope, as it is used in pass
 | 
				
			||||
  exe.CreateVariables(prog, 0, true, &scope);
 | 
				
			||||
  test::InitLoDTensorHolder<double>(&scope, place, "eps", {1}, &eps_value);
 | 
				
			||||
 | 
				
			||||
  graph.SetNotOwned(kParamScopeAttr, &scope);
 | 
				
			||||
  EXPECT_THROW(test::RunPassAndAssert(&graph, "layer_norm_fuse_pass", "x",
 | 
				
			||||
                                      "shift_out", removed_nodes, added_nodes),
 | 
				
			||||
               paddle::platform::EnforceNotMet);
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
TEST(FuseLayerNormPass, pass_op_version_check) {
 | 
				
			||||
  ASSERT_TRUE(
 | 
				
			||||
      paddle::framework::compatible::PassVersionCheckerRegistrar::GetInstance()
 | 
				
			||||
          .IsPassCompatible("layer_norm_fuse_pass"));
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
}  // namespace ir
 | 
				
			||||
}  // namespace framework
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
 | 
				
			||||
USE_PASS(layer_norm_fuse_pass);
 | 
				
			||||
@ -0,0 +1,64 @@
 | 
				
			||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
#
 | 
				
			||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
# you may not use this file except in compliance with the License.
 | 
				
			||||
# You may obtain a copy of the License at
 | 
				
			||||
#
 | 
				
			||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
#
 | 
				
			||||
# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
# See the License for the specific language governing permissions and
 | 
				
			||||
# limitations under the License.
 | 
				
			||||
"""Test for fusion of subgraph expressing layer normalization."""
 | 
				
			||||
 | 
				
			||||
import unittest
 | 
				
			||||
import numpy as np
 | 
				
			||||
import paddle
 | 
				
			||||
import paddle.fluid as fluid
 | 
				
			||||
from inference_pass_test import InferencePassTest
 | 
				
			||||
from paddle import enable_static
 | 
				
			||||
from paddle.fluid.core import PassVersionChecker
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
class LayerNormFusePassTest(InferencePassTest):
 | 
				
			||||
    def setUp(self):
 | 
				
			||||
        with fluid.program_guard(self.main_program, self.startup_program):
 | 
				
			||||
            data = fluid.data(name="data", shape=[3, 64, 120], dtype="float32")
 | 
				
			||||
            sqr_pow = fluid.layers.fill_constant(
 | 
				
			||||
                shape=[1], value=2, dtype="float32")
 | 
				
			||||
            eps = fluid.layers.fill_constant(
 | 
				
			||||
                shape=[1], value=1e-5, dtype="float32")
 | 
				
			||||
            gamma = fluid.layers.create_parameter(
 | 
				
			||||
                shape=[120], dtype="float32", is_bias=True)
 | 
				
			||||
            beta = fluid.layers.create_parameter(
 | 
				
			||||
                shape=[120], dtype="float32", is_bias=True)
 | 
				
			||||
 | 
				
			||||
            x_mean_out = fluid.layers.reduce_mean(data, dim=-1, keep_dim=True)
 | 
				
			||||
            x_sub_mean_out = fluid.layers.elementwise_sub(data, x_mean_out)
 | 
				
			||||
            x_sub_mean_sqr_out = fluid.layers.elementwise_pow(x_sub_mean_out,
 | 
				
			||||
                                                              sqr_pow)
 | 
				
			||||
            std_dev_out = fluid.layers.reduce_mean(
 | 
				
			||||
                x_sub_mean_sqr_out, dim=-1, keep_dim=True)
 | 
				
			||||
            std_dev_eps_out = fluid.layers.elementwise_add(std_dev_out, eps)
 | 
				
			||||
            std_dev_eps_sqrt_out = fluid.layers.sqrt(std_dev_eps_out)
 | 
				
			||||
            division_out = fluid.layers.elementwise_div(x_sub_mean_out,
 | 
				
			||||
                                                        std_dev_eps_sqrt_out)
 | 
				
			||||
            scale_out = fluid.layers.elementwise_mul(division_out, gamma)
 | 
				
			||||
            shift_out = fluid.layers.elementwise_add(scale_out, beta)
 | 
				
			||||
 | 
				
			||||
        self.feeds = {
 | 
				
			||||
            "data": np.random.random((3, 64, 120)).astype("float32"),
 | 
				
			||||
        }
 | 
				
			||||
        self.fetch_list = [shift_out]
 | 
				
			||||
 | 
				
			||||
    def test_check_output(self):
 | 
				
			||||
        use_gpu = False
 | 
				
			||||
        self.check_output_with_option(use_gpu)
 | 
				
			||||
        self.assertTrue(PassVersionChecker.IsCompatible("layer_norm_fuse_pass"))
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
if __name__ == "__main__":
 | 
				
			||||
    enable_static()
 | 
				
			||||
    unittest.main()
 | 
				
			||||
					Loading…
					
					
				
		Reference in new issue