parent
821534efd3
commit
21d95be0db
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,208 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/inplace_abn_op.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/framework/framework.pb.h"
|
||||
#include "paddle/fluid/operators/batch_norm_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class InplaceABNOp : public paddle::operators::BatchNormOp {
|
||||
public:
|
||||
using paddle::operators::BatchNormOp::BatchNormOp;
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
|
||||
// By default, the type of the scale, bias, mean,
|
||||
// and var tensors should both be float. (For float or float16 input tensor)
|
||||
// or double (For double input tensor).
|
||||
auto bn_param_type = framework::proto::VarType::FP32;
|
||||
if (input_data_type == framework::proto::VarType::FP64) {
|
||||
bn_param_type = framework::proto::VarType::FP64;
|
||||
}
|
||||
PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Scale")->type(),
|
||||
platform::errors::InvalidArgument(
|
||||
"Scale input should be of float type"));
|
||||
PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Bias")->type(),
|
||||
platform::errors::InvalidArgument(
|
||||
"Bias input should be of float type"));
|
||||
PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Mean")->type(),
|
||||
platform::errors::InvalidArgument(
|
||||
"Mean input should be of float type"));
|
||||
PADDLE_ENFORCE_EQ(bn_param_type, ctx.Input<Tensor>("Variance")->type(),
|
||||
platform::errors::InvalidArgument(
|
||||
"Variance input should be of float type"));
|
||||
|
||||
framework::LibraryType library = framework::LibraryType::kPlain;
|
||||
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
|
||||
|
||||
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
|
||||
library);
|
||||
}
|
||||
};
|
||||
|
||||
class InplaceABNGradOp : public paddle::operators::BatchNormGradOp {
|
||||
public:
|
||||
using paddle::operators::BatchNormGradOp::BatchNormGradOp;
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
const auto* var = ctx.InputVar(framework::GradVarName("Y"));
|
||||
auto input_data_type = ctx.Input<Tensor>("Y")->type();
|
||||
if (var == nullptr) {
|
||||
PADDLE_THROW(platform::errors::InvalidArgument(
|
||||
"can't find gradient variable of Y"));
|
||||
}
|
||||
const Tensor* t = nullptr;
|
||||
if (var->IsType<Tensor>()) {
|
||||
t = &var->Get<Tensor>();
|
||||
} else if (var->IsType<LoDTensor>()) {
|
||||
t = &var->Get<LoDTensor>();
|
||||
}
|
||||
if (t == nullptr) {
|
||||
PADDLE_THROW(
|
||||
platform::errors::InvalidArgument("gradient variable of Y is empty"));
|
||||
}
|
||||
framework::LibraryType library = framework::LibraryType::kPlain;
|
||||
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
|
||||
|
||||
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
|
||||
library);
|
||||
}
|
||||
};
|
||||
|
||||
class InplaceABNOpMaker : public paddle::operators::BatchNormOpMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
BatchNormOpMaker::Make();
|
||||
AddAttr<std::string>(
|
||||
"activation",
|
||||
"(enum string, default identity, can be identity|elu|leaky-relu) "
|
||||
"The activation type used for output candidate {h}_t.")
|
||||
.SetDefault("");
|
||||
AddAttr<float>("alpha",
|
||||
"(float, default 1.0) Only used in inplace-abn kernel,"
|
||||
"the activation type(identity|elu|leakyrelu) would be fused "
|
||||
"with batch_norm, "
|
||||
"this is the alpha value for elu|leakyrelu.")
|
||||
.SetDefault(0.1f);
|
||||
AddAttr<bool>("use_sync_bn",
|
||||
"(bool, default false) Whether use synchronize batch "
|
||||
"normalization.")
|
||||
.SetDefault(false);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class InplaceABNOpGradMaker : public framework::SingleGradOpMaker<T> {
|
||||
public:
|
||||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
||||
|
||||
protected:
|
||||
void Apply(GradOpPtr<T> op) const override {
|
||||
op->SetType(this->ForwardOpType() + "_grad");
|
||||
op->SetInput("Y", this->Output("Y"));
|
||||
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
|
||||
|
||||
op->SetInput("Scale", this->Input("Scale"));
|
||||
op->SetInput("Bias", this->Input("Bias"));
|
||||
op->SetInput("SavedMean", this->Output("SavedMean"));
|
||||
op->SetInput("SavedVariance", this->Output("SavedVariance"));
|
||||
|
||||
// used when setting use_global_stats True during training
|
||||
if (boost::get<bool>(this->GetAttr("use_global_stats"))) {
|
||||
op->SetInput("Mean", this->Output("MeanOut"));
|
||||
op->SetInput("Variance", this->Output("VarianceOut"));
|
||||
}
|
||||
|
||||
op->SetAttrMap(this->Attrs());
|
||||
|
||||
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
|
||||
op->SetOutput(framework::GradVarName("Scale"), this->InputGrad("Scale"));
|
||||
op->SetOutput(framework::GradVarName("Bias"), this->InputGrad("Bias"));
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class InplaceABNKernel
|
||||
: public paddle::operators::BatchNormKernel<DeviceContext, T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* x = ctx.Input<Tensor>("X");
|
||||
auto* y = ctx.Output<Tensor>("Y");
|
||||
PADDLE_ENFORCE_EQ(x, y, platform::errors::InvalidArgument(
|
||||
"X and Y not inplaced in inplace mode"));
|
||||
auto activation =
|
||||
GetInplaceABNActivationType(ctx.Attr<std::string>("activation"));
|
||||
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
|
||||
BatchNormKernel<DeviceContext, T>::Compute(ctx);
|
||||
|
||||
auto cur_y = EigenVector<T>::Flatten(*y);
|
||||
InplaceABNActivation<DeviceContext, T> functor;
|
||||
functor.Compute(ctx, activation, place, cur_y, cur_y);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class InplaceABNGradKernel
|
||||
: public paddle::operators::BatchNormGradKernel<DeviceContext, T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* y = ctx.Input<Tensor>("Y");
|
||||
auto* d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
|
||||
auto* d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||
PADDLE_ENFORCE_EQ(d_x, d_y,
|
||||
platform::errors::InvalidArgument(
|
||||
"X@GRAD and Y@GRAD not inplaced in inplace mode"));
|
||||
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
|
||||
auto activation =
|
||||
GetInplaceABNActivationType(ctx.Attr<std::string>("activation"));
|
||||
|
||||
auto py = *y;
|
||||
auto pd_y = *d_y;
|
||||
auto cur_y = EigenVector<T>::Flatten(py);
|
||||
auto cur_dy = EigenVector<T>::Flatten(pd_y);
|
||||
|
||||
InplaceABNActivation<DeviceContext, T> functor;
|
||||
functor.GradCompute(ctx, activation, place, cur_y, cur_y, cur_dy, cur_dy);
|
||||
|
||||
BatchNormGradKernel<DeviceContext, T>::Compute(ctx);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(inplace_abn, ops::InplaceABNOp, ops::InplaceABNOpMaker,
|
||||
ops::BatchNormOpInferVarType,
|
||||
ops::InplaceABNOpGradMaker<paddle::framework::OpDesc>,
|
||||
ops::InplaceABNOpGradMaker<paddle::imperative::OpBase>)
|
||||
REGISTER_OPERATOR(inplace_abn_grad, ops::InplaceABNGradOp)
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
inplace_abn,
|
||||
ops::InplaceABNKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::InplaceABNKernel<paddle::platform::CPUDeviceContext, double>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
inplace_abn_grad,
|
||||
ops::InplaceABNGradKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::InplaceABNGradKernel<paddle::platform::CPUDeviceContext, double>);
|
@ -0,0 +1,92 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/batch_norm_op.h"
|
||||
#include "paddle/fluid/operators/inplace_abn_op.h"
|
||||
#include "paddle/fluid/operators/sync_batch_norm_op.cu.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class InplaceABNKernel
|
||||
: public paddle::operators::SyncBatchNormKernel<DeviceContext, T>,
|
||||
public paddle::operators::BatchNormKernel<DeviceContext, T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
auto* y = ctx.Output<Tensor>("Y");
|
||||
auto* x = ctx.Input<Tensor>("X");
|
||||
PADDLE_ENFORCE_EQ(x, y, platform::errors::InvalidArgument(
|
||||
"X and Y not inplaced in inplace mode"));
|
||||
auto activation =
|
||||
GetInplaceABNActivationType(ctx.Attr<std::string>("activation"));
|
||||
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
|
||||
|
||||
if (ctx.Attr<bool>("use_sync_bn")) {
|
||||
SyncBatchNormKernel<DeviceContext, T>::Compute(ctx);
|
||||
} else {
|
||||
BatchNormKernel<DeviceContext, T>::Compute(ctx);
|
||||
}
|
||||
|
||||
auto cur_y = EigenVector<T>::Flatten(*y);
|
||||
InplaceABNActivation<DeviceContext, T> functor;
|
||||
functor.Compute(ctx, activation, place, cur_y, cur_y);
|
||||
}
|
||||
};
|
||||
|
||||
// Deriving the Gradient for the Backward Pass of Batch Normalization
|
||||
// https://kevinzakka.github.io/2016/09/14/batch_normalization/
|
||||
template <typename DeviceContext, typename T>
|
||||
class InplaceABNGradKernel
|
||||
: public paddle::operators::SyncBatchNormGradKernel<DeviceContext, T>,
|
||||
public paddle::operators::BatchNormGradKernel<DeviceContext, T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
const auto* y = ctx.Input<Tensor>("Y");
|
||||
auto* d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
|
||||
auto* d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||
PADDLE_ENFORCE_EQ(d_x, d_y,
|
||||
platform::errors::InvalidArgument(
|
||||
"X@GRAD and Y@GRAD not inplaced in inplace mode"));
|
||||
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
|
||||
auto activation =
|
||||
GetInplaceABNActivationType(ctx.Attr<std::string>("activation"));
|
||||
|
||||
auto py = *y;
|
||||
auto pd_y = *d_y;
|
||||
auto cur_y = EigenVector<T>::Flatten(py);
|
||||
auto cur_dy = EigenVector<T>::Flatten(pd_y);
|
||||
|
||||
InplaceABNActivation<DeviceContext, T> functor;
|
||||
functor.GradCompute(ctx, activation, place, cur_y, cur_y, cur_dy, cur_dy);
|
||||
|
||||
if (ctx.Attr<bool>("use_sync_bn")) {
|
||||
SyncBatchNormGradKernel<DeviceContext, T>::Compute(ctx);
|
||||
} else {
|
||||
BatchNormGradKernel<DeviceContext, T>::Compute(ctx);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
namespace plat = paddle::platform;
|
||||
REGISTER_OP_CUDA_KERNEL(inplace_abn,
|
||||
ops::InplaceABNKernel<plat::CUDADeviceContext, float>,
|
||||
ops::InplaceABNKernel<plat::CUDADeviceContext, double>);
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
inplace_abn_grad, ops::InplaceABNGradKernel<plat::CUDADeviceContext, float>,
|
||||
ops::InplaceABNGradKernel<plat::CUDADeviceContext, double>);
|
@ -0,0 +1,117 @@
|
||||
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
#include <string>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/activation_op.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
template <typename T, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
|
||||
|
||||
template <typename T, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
|
||||
|
||||
enum InplaceABNActivationType { identity = 0, leakyrelu = 1, elu = 2 };
|
||||
|
||||
inline InplaceABNActivationType GetInplaceABNActivationType(
|
||||
const std::string& type) {
|
||||
if (type == "leaky_relu") {
|
||||
return InplaceABNActivationType::leakyrelu;
|
||||
} else if (type == "elu") {
|
||||
return InplaceABNActivationType::elu;
|
||||
} else if (type == "identity" || type == "") {
|
||||
return InplaceABNActivationType::identity;
|
||||
} else {
|
||||
PADDLE_THROW(platform::errors::InvalidArgument(
|
||||
"unsupported activation type %s for Op(inplace_abn)", type));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class InplaceABNActivation {
|
||||
private:
|
||||
template <typename Functor>
|
||||
void setAttrs(const framework::ExecutionContext& ctx, Functor* functor) {
|
||||
auto attrs = functor->GetAttrs();
|
||||
for (auto& attr : attrs) {
|
||||
*attr.second = ctx.Attr<float>(attr.first);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Functor, typename... Args>
|
||||
void compute(const framework::ExecutionContext& ctx, Functor* functor,
|
||||
Args... args) {
|
||||
setAttrs(ctx, functor);
|
||||
(*functor)(args...);
|
||||
}
|
||||
|
||||
public:
|
||||
template <typename Device, typename X, typename Y>
|
||||
void Compute(const framework::ExecutionContext& ctx, const int act_type,
|
||||
const Device& d, X x, Y y) {
|
||||
if (act_type == InplaceABNActivationType::identity) {
|
||||
y.device(d) = x;
|
||||
} else if (act_type == InplaceABNActivationType::leakyrelu) {
|
||||
LeakyReluFunctor<T> functor;
|
||||
compute(ctx, &functor, d, x, y);
|
||||
} else if (act_type == InplaceABNActivationType::elu) {
|
||||
ELUFunctor<T> functor;
|
||||
compute(ctx, &functor, d, x, y);
|
||||
} else {
|
||||
PADDLE_THROW(
|
||||
platform::errors::InvalidArgument("unsupported activation type"));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Device, typename X, typename Y, typename DX, typename DY>
|
||||
void GradCompute(const framework::ExecutionContext& ctx, const int act_type,
|
||||
const Device& d, X x, Y y, DX dx, DY dy) {
|
||||
const float alpha = ctx.Attr<float>("alpha");
|
||||
|
||||
if (act_type == InplaceABNActivationType::identity) {
|
||||
x.device(d) = y;
|
||||
dx.device(d) = dy;
|
||||
} else if (act_type == InplaceABNActivationType::leakyrelu) {
|
||||
auto temp1 = (y < static_cast<T>(0)).template cast<T>().eval() /
|
||||
static_cast<T>(alpha);
|
||||
auto temp2 = (y >= static_cast<T>(0)).template cast<T>().eval();
|
||||
x.device(d) = y * (temp1 + temp2).template cast<T>();
|
||||
|
||||
LeakyReluGradFunctor<T> functor;
|
||||
compute(ctx, &functor, d, x, y, dy, dx);
|
||||
} else if (act_type == InplaceABNActivationType::elu) {
|
||||
auto temp1 = (y >= static_cast<T>(0)).template cast<T>().eval();
|
||||
auto temp = (y < static_cast<T>(0)).template cast<T>().eval();
|
||||
auto temp2 = (y * temp / static_cast<T>(alpha) + static_cast<T>(1)).log();
|
||||
x.device(d) = (y * temp1 + temp2).template cast<T>();
|
||||
|
||||
ELUGradFunctor<T> functor;
|
||||
compute(ctx, &functor, d, x, y, dy, dx);
|
||||
} else {
|
||||
PADDLE_THROW(
|
||||
platform::errors::InvalidArgument("unsupported activation type"));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,189 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import os
|
||||
import six
|
||||
import paddle.fluid.core as core
|
||||
import paddle.fluid as fluid
|
||||
from paddle.fluid.layer_helper import LayerHelper
|
||||
from paddle.fluid import compiler
|
||||
import paddle.fluid.unique_name as unique_name
|
||||
|
||||
|
||||
class TestInplaceANBOpTraining(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.dtype = np.float64
|
||||
self.N = 4
|
||||
self.C = 5
|
||||
self.H = 7
|
||||
self.W = 9
|
||||
self.dshape = [self.N, self.C, self.H, self.W]
|
||||
|
||||
def build_program(self,
|
||||
place,
|
||||
layout,
|
||||
seed,
|
||||
only_forward=False,
|
||||
activation="identity",
|
||||
alpha=1.0,
|
||||
use_cuda=False,
|
||||
inplace=False):
|
||||
main = fluid.Program()
|
||||
startup = fluid.Program()
|
||||
main.random_seed = seed
|
||||
startup.random_seed = seed
|
||||
with fluid.unique_name.guard():
|
||||
with fluid.program_guard(main, startup):
|
||||
data = fluid.layers.data(
|
||||
name='input',
|
||||
shape=self.dshape,
|
||||
dtype=self.dtype,
|
||||
append_batch_size=False,
|
||||
stop_gradient=False)
|
||||
if inplace:
|
||||
bn = fluid.layers.inplace_abn(
|
||||
data,
|
||||
act=activation,
|
||||
param_attr=fluid.ParamAttr(name='bn_scale'),
|
||||
bias_attr=fluid.ParamAttr(name='bn_bias'),
|
||||
moving_mean_name='bn_moving_mean',
|
||||
moving_variance_name='bn_moving_variance',
|
||||
data_layout=layout,
|
||||
is_test=only_forward,
|
||||
act_alpha=alpha)
|
||||
else:
|
||||
bn = fluid.layers.batch_norm(
|
||||
data,
|
||||
param_attr=fluid.ParamAttr(name='bn_scale'),
|
||||
bias_attr=fluid.ParamAttr(name='bn_bias'),
|
||||
moving_mean_name='bn_moving_mean',
|
||||
moving_variance_name='bn_moving_variance',
|
||||
data_layout=layout,
|
||||
is_test=only_forward,
|
||||
in_place=inplace)
|
||||
if activation == 'leaky_relu':
|
||||
bn = fluid.layers.leaky_relu(bn, alpha)
|
||||
if activation == 'elu':
|
||||
bn = fluid.layers.elu(bn, alpha)
|
||||
|
||||
# NOTE: in inplace mode input and output of bn
|
||||
# may have same name, multiply 1. to generate
|
||||
# a new Variable for fetch
|
||||
bn = bn * 1.
|
||||
|
||||
sigmoid = fluid.layers.sigmoid(bn)
|
||||
out = fluid.layers.reduce_sum(sigmoid)
|
||||
if not only_forward:
|
||||
sgd_opt = fluid.optimizer.SGD(learning_rate=0.0)
|
||||
sgd_opt.backward(out)
|
||||
return main, startup, [out, bn]
|
||||
|
||||
def compare(self, place, layout, only_forward, activation, alpha, use_cuda):
|
||||
seed = 10
|
||||
os.environ['FLAGS_cudnn_deterministic'] = "1"
|
||||
data = np.random.random(size=self.dshape).astype(self.dtype) * 4. - 2
|
||||
|
||||
fetch_outs = []
|
||||
fetch_names = []
|
||||
for inplace in [False, True]:
|
||||
main, startup, outs = self.build_program(
|
||||
place,
|
||||
layout,
|
||||
seed,
|
||||
only_forward,
|
||||
activation,
|
||||
alpha,
|
||||
inplace=inplace)
|
||||
exe = fluid.Executor(place)
|
||||
exe.run(startup)
|
||||
|
||||
fetch_name = [v.name for v in outs] + [
|
||||
'bn_moving_mean', 'bn_moving_variance', 'bn_scale', 'bn_bias'
|
||||
]
|
||||
if not only_forward:
|
||||
others = [
|
||||
'inplace_abn_0.tmp_0' if inplace else 'batch_norm_0.tmp_0',
|
||||
'inplace_abn_0.tmp_1' if inplace else 'batch_norm_0.tmp_1',
|
||||
'bn_scale@GRAD',
|
||||
'bn_bias@GRAD',
|
||||
'input@GRAD',
|
||||
]
|
||||
fetch_name += others
|
||||
for nm in fetch_name:
|
||||
fv = fluid.framework._get_var(str(nm), program=main)
|
||||
fv.persistable = True
|
||||
|
||||
build_strategy = fluid.BuildStrategy()
|
||||
build_strategy.sync_batch_norm = use_cuda and \
|
||||
fluid.core.get_cuda_device_count() > 1
|
||||
build_strategy.enable_inplace = inplace
|
||||
exec_strategy = fluid.ExecutionStrategy()
|
||||
exec_strategy.num_threads = 1 if os.name == 'nt' else 0
|
||||
comp_prog1 = compiler.CompiledProgram(main).with_data_parallel(
|
||||
outs[0].name if not only_forward else None,
|
||||
build_strategy=build_strategy,
|
||||
exec_strategy=exec_strategy)
|
||||
bn_fetches = exe.run(program=comp_prog1,
|
||||
feed={'input': data},
|
||||
fetch_list=fetch_name)
|
||||
fetch_outs.append(bn_fetches)
|
||||
fetch_names.append(fetch_name)
|
||||
|
||||
for bn_val, inplace_abn_val, name1, name2 in zip(*(fetch_outs +
|
||||
fetch_names)):
|
||||
self.assertTrue(
|
||||
np.allclose(
|
||||
bn_val, inplace_abn_val, atol=1e-2),
|
||||
"Output (" + name1 + ":" + name2 +
|
||||
") has diff on {} with {} layout and {} activation. \n".format(
|
||||
place, layout, activation) + "\nBN " + str(bn_val) +
|
||||
"\n" + "Inplace ABN " + str(inplace_abn_val))
|
||||
|
||||
def test_op(self):
|
||||
use_cudas = [False, True] if core.is_compiled_with_cuda() else [False]
|
||||
for use_cuda in use_cudas:
|
||||
place = core.CUDAPlace(0) if use_cuda else core.CPUPlace()
|
||||
layouts = ["NCHW", "NHWC"]
|
||||
for layout in layouts:
|
||||
for activation, alpha in zip([None, 'elu', 'leaky_relu'],
|
||||
[0., 1., 0.02]):
|
||||
for infer_only in [True, False]:
|
||||
self.compare(place, layout, infer_only, activation,
|
||||
alpha, use_cuda)
|
||||
|
||||
def test_all_branches(self):
|
||||
seed = 10
|
||||
os.environ['FLAGS_cudnn_deterministic'] = "1"
|
||||
data = np.random.random(size=self.dshape).astype(self.dtype) * 4. - 2
|
||||
use_cudas = [False, True] if core.is_compiled_with_cuda() else [False]
|
||||
alpha = 0.1
|
||||
layouts = ["NCHW", "NHWC"]
|
||||
for use_cuda in use_cudas:
|
||||
place = core.CUDAPlace(0) if use_cuda else core.CPUPlace()
|
||||
for layout in layouts:
|
||||
for activation in ['identity', 'leaky_relu']:
|
||||
main, startup, outs = self.build_program(
|
||||
place, layout, seed, False, activation, alpha, use_cuda,
|
||||
True)
|
||||
exe = fluid.Executor(place)
|
||||
exe.run(startup)
|
||||
exe.run(program=main, feed={'input': data})
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue