Add Inplace strategy (Output reuse Input Varbase) in dygraph (#30103)

* add view strategy on squeeze,unsqueeze,reshape,flatten

* add squeeze unittest

* add unittests

* use View strategy as name rather than Reuse Allacation

* fix view api doc

* fix format

* use core.ops when input of reshape2 is Tensor

* fix test_cross_entropy_loss error because of reshape2

* fix test_cross_entropy_loss error because of reshape2

* add inplace strategy

* add elementwise_add sub

* let backward op not use inplace

* grad op do not use inplace

* fix memory increase error and add leaf error message

* delete selected_rows

* change op_function

* little change

* solve HandleViewBetweenInputAndOutput

* add unittest and leaf error message

* merge view error

* optimize op_function_generator format and support sum inplace op

* fix format of basic_engine

* fix format for framework

* little change of variable wrapper

* add reshape, squeeze, unsqueeze, scatter api

* add relu elu tanh softmax inplace api

* fix test_squeeze_op unittest

* fix test_relu_op unittest

* fix comment problems

* delete sample code of inplace api

* add reference of grad_pending_nodes in basic_engine

* fix unittest name

* add inplace apis into wlist

* fix error message

* add PADDLE_ENFORCE for set grad op twice

* fix head file error
revert-31068-fix_conv3d_windows
pangyoki 4 years ago committed by GitHub
parent 008b0a8b56
commit 13d757362c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <map>
#include <memory>
#include <string>
#include <tuple>
@ -247,8 +248,9 @@ struct OpInfoFiller<T, kGradOpBaseMaker> {
const std::string& type,
const imperative::NameVarBaseMap& var_base_map_in,
const imperative::NameVarBaseMap& var_base_map_out,
const framework::AttributeMap& attrs) {
T maker(type, var_base_map_in, var_base_map_out, attrs);
const framework::AttributeMap& attrs,
const std::map<std::string, std::string>& inplace_map) {
T maker(type, var_base_map_in, var_base_map_out, attrs, inplace_map);
return maker();
};
}

@ -221,6 +221,10 @@ class SingleGradOpMaker<imperative::OpBase>
std::shared_ptr<imperative::GradOpNode> operator()() const final {
auto node = this->NewGradNode();
auto& inplace_map = this->GetInplaceMap();
if (!inplace_map.empty()) {
node->SetInplaceGradNameMap(inplace_map);
}
{
imperative::TracedGradOp traced_grad_op(node);
try {

@ -59,7 +59,8 @@ using DygraphGradOpMakerFN =
const std::string& /*op_type*/,
const imperative::NameVarBaseMap& /*var_base_map_in*/,
const imperative::NameVarBaseMap& /*var_base_map_out*/,
const framework::AttributeMap& /*attributes*/)>;
const framework::AttributeMap& /*attributes*/,
const std::map<std::string, std::string>& /*inplace_map*/)>;
using InferVarTypeFN =
std::function<void(framework::InferVarTypeContext* /*context*/)>;

File diff suppressed because it is too large Load Diff

@ -39,15 +39,33 @@ class BasicEngine : public Engine {
void CheckBackwardInputs(const OpBase& op);
void PrepareGradAccumulators(const OpBase& op);
void PrepareGradAccumulators(
const OpBase& op,
const std::vector<std::shared_ptr<GradOpNode>>& grad_pending_nodes);
void Clear();
private:
std::shared_ptr<GradOpNode> init_node_;
std::unordered_map<GradOpNode*, size_t> node_deps_;
// The input and output of Inplace op are the same. If only `var` is used
// as the key, then the input and output of inplace op must be gradient
// accumulated. Therefore, add the `grad_node` as the key to prevent the
// problem of gradient accumulation in inplace op.
std::unordered_map<std::shared_ptr<GradOpNode>,
std::unordered_map<VariableWrapper*,
std::unique_ptr<GradientAccumulator>>>
accumulators_with_grad_node_;
// Leaf var doesn't have grad_node, and leaf var with `stop_gradient=False`
// can't use Inplace strategy. If a var doesn't have grad_node, only use
// `var` as the key.
std::unordered_map<VariableWrapper*, std::unique_ptr<GradientAccumulator>>
accumulators_;
// The output grad var of Inplace grad op. Because Inplace grad op does not
// use the Inplace strategy, a new output grad var needs to be created.
std::vector<std::pair<std::shared_ptr<VariableWrapper>,
std::shared_ptr<VariableWrapper>>>
inplace_output_grad_var_list_;
std::vector<std::pair<GradientAccumulator*, std::shared_ptr<VariableWrapper>>>
need_accu_var_list_;
// leaf_accumulators_ is only for leaf tensor(hooks/accumulate grad)

@ -14,6 +14,7 @@
#pragma once
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
@ -43,14 +44,16 @@ class TracedVarList : public std::vector<std::shared_ptr<T>> {
class GradOpBaseMakerBase {
public:
explicit GradOpBaseMakerBase(const std::string& type,
const NameVarBaseMap& var_base_map_in,
const NameVarBaseMap& var_base_map_out,
const framework::AttributeMap& attrs)
explicit GradOpBaseMakerBase(
const std::string& type, const NameVarBaseMap& var_base_map_in,
const NameVarBaseMap& var_base_map_out,
const framework::AttributeMap& attrs,
const std::map<std::string, std::string>& inplace_map)
: type_(type),
var_base_map_in_(var_base_map_in),
var_base_map_out_(var_base_map_out),
attrs_(attrs) {}
attrs_(attrs),
inplace_map_(inplace_map) {}
virtual ~GradOpBaseMakerBase() = default;
@ -141,6 +144,10 @@ class GradOpBaseMakerBase {
return std::make_shared<GradOpNode>();
}
const std::map<std::string, std::string>& GetInplaceMap() const {
return inplace_map_;
}
private:
template <TracedVarRole kRole>
TracedVarList<VarBase, kRole> GetVarBaseList(const std::string& name,
@ -192,6 +199,7 @@ class GradOpBaseMakerBase {
const NameVarBaseMap& var_base_map_in_;
const NameVarBaseMap& var_base_map_out_;
const framework::AttributeMap& attrs_;
const std::map<std::string, std::string>& inplace_map_;
};
class TracedGradOp {
@ -220,6 +228,10 @@ class TracedGradOp {
for (auto& var : vars) {
if (var && !var->OverridedStopGradient()) {
var->SetGraphIsFreed(false);
auto dirty_grad_node = var->GradNode();
if (dirty_grad_node) {
map_dirty_grad_node_[var] = dirty_grad_node;
}
var->SetGradNode(node_);
}
}
@ -246,7 +258,11 @@ class TracedGradOp {
} else {
for (auto& var : vars) {
if (var && !var->OverridedStopGradient() && var->GradNode()) {
node_->InsertGradPendingNode(var->GradNode());
if (map_dirty_grad_node_.find(var) != map_dirty_grad_node_.end()) {
node_->InsertGradPendingNode(map_dirty_grad_node_[var]);
} else {
node_->InsertGradPendingNode(var->GradNode());
}
}
}
}
@ -329,6 +345,12 @@ class TracedGradOp {
private:
const std::shared_ptr<GradOpNode>& node_;
OpBase* op_;
// Inplace op has recursion problems when performing grad calculation.
// Because the input and output of inplace op are the same, the grad
// node of inplace var will be overwritten.
// This map is used to store the grad node of inplace var in temporary.
std::unordered_map<std::shared_ptr<VarBase>, std::shared_ptr<GradOpNode>>
map_dirty_grad_node_;
};
} // namespace imperative

@ -451,13 +451,15 @@ static void ClearNoNeedBufferInputs(OpBase* op) {
std::shared_ptr<GradOpNode> CreateGradOpNode(
const framework::OperatorBase& op, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, const framework::AttributeMap& attrs,
const platform::Place& place) {
const platform::Place& place,
const std::map<std::string, std::string>& inplace_map) {
const auto& info = op.Info();
if (!info.dygraph_grad_op_maker_) {
return nullptr;
}
auto grad_node = info.dygraph_grad_op_maker_(op.Type(), ins, outs, attrs);
auto grad_node =
info.dygraph_grad_op_maker_(op.Type(), ins, outs, attrs, inplace_map);
if (grad_node && !grad_node->empty()) {
for (auto& grad_op : *grad_node) {
grad_op.SetId(OpBase::GenerateUniqueId());

@ -256,7 +256,8 @@ class Layer {
std::shared_ptr<GradOpNode> CreateGradOpNode(
const framework::OperatorBase& op, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, const framework::AttributeMap& attrs,
const platform::Place& place);
const platform::Place& place,
const std::map<std::string, std::string>& inplace_map);
} // namespace imperative
} // namespace paddle

@ -15,11 +15,13 @@
#pragma once
#include <atomic>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/imperative/saved_variable_wrapper_list.h"
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/imperative/variable_wrapper.h"
#include "paddle/fluid/platform/place.h"
@ -227,6 +229,22 @@ class GradOpNode {
}
}
void SetInplaceGradNameMap(
const std::map<std::string, std::string>& inplace_input_map) {
for (auto& pair : inplace_input_map) {
VLOG(10) << "Set mapping relationship ("
<< framework::GradVarName(pair.first) << ", "
<< framework::GradVarName(pair.second)
<< ") for Inplace grad node.";
inplace_grad_name_map_[framework::GradVarName(pair.first)] =
framework::GradVarName(pair.second);
}
}
const std::map<std::string, std::string>& InplaceGradNameMap() const {
return inplace_grad_name_map_;
}
const std::vector<std::shared_ptr<GradOpNode>>& GradPendingNodes() const {
return grad_pending_nodes_;
}
@ -237,6 +255,9 @@ class GradOpNode {
private:
std::vector<OpBase> ops_;
std::vector<std::shared_ptr<GradOpNode>> grad_pending_nodes_;
// Mapping relationship between grad output and grad input of the grad node of
// Inplace op.
std::map<std::string, std::string> inplace_grad_name_map_;
};
} // namespace imperative

@ -884,7 +884,7 @@ void PartialGradTask::RunEachOp(OpBase *op) {
if (create_graph_) {
auto double_grad_node = CreateGradOpNode(op->InnerOp(), tmp_ins, tmp_outs,
op->Attrs(), op->place());
op->Attrs(), op->place(), {});
PADDLE_ENFORCE_NOT_NULL(
double_grad_node,
platform::errors::NotFound("The Op %s doesn't have any grad op. If you "

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/imperative/tracer.h"
#include <map>
#include <set>
#include <unordered_set>
#include <utility>
@ -130,7 +131,8 @@ paddle::framework::GarbageCollector* Tracer::MutableGarbageCollectorIfNotExists(
void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, framework::AttributeMap attrs,
const platform::Place& place, bool trace_backward) {
const platform::Place& place, bool trace_backward,
const std::map<std::string, std::string>& inplace_map) {
VLOG(1) << "Trace Op: " << type;
if (FLAGS_use_mkldnn) {
// if both lists are empty all ops are enabled (default for
@ -182,16 +184,17 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
}
if (ComputeRequiredGrad(new_ins, outs, trace_backward)) {
CreateGradOpNode(*op, new_ins, outs, attrs, place);
CreateGradOpNode(*op, new_ins, outs, attrs, place, inplace_map);
} else {
VLOG(3) << "No Grad to track for Op: " << type;
}
}
void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins,
const NameVarBaseMap& outs,
framework::AttributeMap attrs) {
TraceOp(type, ins, outs, std::move(attrs), expected_place_, has_grad_);
const NameVarBaseMap& outs, framework::AttributeMap attrs,
const std::map<std::string, std::string>& inplace_map) {
TraceOp(type, ins, outs, std::move(attrs), expected_place_, has_grad_,
inplace_map);
}
bool Tracer::ComputeRequiredGrad(const NameVarBaseMap& ins,

@ -21,7 +21,6 @@
#include <string>
#include <unordered_map>
#include <vector>
#include "ThreadPool.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/imperative/basic_engine.h"
@ -63,10 +62,12 @@ class Tracer {
void TraceOp(const std::string& type, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, framework::AttributeMap attrs,
const platform::Place& place, bool trace_bacward);
const platform::Place& place, bool trace_bacward,
const std::map<std::string, std::string>& inplace_map = {});
void TraceOp(const std::string& type, const NameVarBaseMap& ins,
const NameVarBaseMap& outs, framework::AttributeMap attrs);
const NameVarBaseMap& outs, framework::AttributeMap attrs,
const std::map<std::string, std::string>& inplace_map = {});
bool ComputeRequiredGrad(const NameVarBaseMap& ins,
const NameVarBaseMap& outs, bool trace_backward);

@ -20,6 +20,7 @@
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/imperative/hooks.h"
#include "paddle/fluid/imperative/op_base.h"
namespace paddle {
namespace imperative {
@ -258,8 +259,13 @@ class VariableWrapper {
auto shared_node = grad_node_.lock();
if (shared_node != grad_node) {
PADDLE_ENFORCE_EQ(
shared_node, nullptr,
platform::errors::PermissionDenied("Cannot set gradient op twice"));
!shared_node || !grad_node->InplaceGradNameMap().empty(), true,
platform::errors::PermissionDenied(
"Cannot set gradient op twice unless using Inplace Strategy."));
if (shared_node) {
VLOG(3) << "The gradient op of Var (" << Name()
<< ") has been set twice. Because Inplace Strategy is used.";
}
grad_node_ = grad_node;
}
}

File diff suppressed because it is too large Load Diff

@ -113,19 +113,23 @@ from .tensor.manipulation import flatten #DEFINE_ALIAS
from .tensor.manipulation import gather #DEFINE_ALIAS
from .tensor.manipulation import gather_nd #DEFINE_ALIAS
from .tensor.manipulation import reshape #DEFINE_ALIAS
from .tensor.manipulation import reshape_ #DEFINE_ALIAS
from .tensor.manipulation import flip as reverse #DEFINE_ALIAS
from .tensor.manipulation import scatter #DEFINE_ALIAS
from .tensor.manipulation import scatter_ #DEFINE_ALIAS
from .tensor.manipulation import scatter_nd_add #DEFINE_ALIAS
from .tensor.manipulation import scatter_nd #DEFINE_ALIAS
from .tensor.manipulation import shard_index #DEFINE_ALIAS
from .tensor.manipulation import slice #DEFINE_ALIAS
from .tensor.manipulation import split #DEFINE_ALIAS
from .tensor.manipulation import squeeze #DEFINE_ALIAS
from .tensor.manipulation import squeeze_ #DEFINE_ALIAS
from .tensor.manipulation import stack #DEFINE_ALIAS
from .tensor.manipulation import strided_slice #DEFINE_ALIAS
from .tensor.manipulation import transpose #DEFINE_ALIAS
from .tensor.manipulation import unique #DEFINE_ALIAS
from .tensor.manipulation import unsqueeze #DEFINE_ALIAS
from .tensor.manipulation import unsqueeze_ #DEFINE_ALIAS
from .tensor.manipulation import unstack #DEFINE_ALIAS
from .tensor.manipulation import flip #DEFINE_ALIAS
from .tensor.manipulation import unbind #DEFINE_ALIAS
@ -172,6 +176,7 @@ from .tensor.math import square #DEFINE_ALIAS
from .tensor.math import stanh #DEFINE_ALIAS
from .tensor.math import sum #DEFINE_ALIAS
from .tensor.math import tanh #DEFINE_ALIAS
from .tensor.math import tanh_ #DEFINE_ALIAS
from .tensor.math import add_n #DEFINE_ALIAS
from .tensor.math import max #DEFINE_ALIAS
from .tensor.math import maximum #DEFINE_ALIAS

@ -221,12 +221,16 @@ class TestTanhAPI(unittest.TestCase):
self.x_np = np.random.uniform(-1, 1, [10, 12]).astype(self.dtype)
self.place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
else paddle.CPUPlace()
self.executed_api()
def executed_api(self):
self.tanh = F.tanh
def test_static_api(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data('X', [10, 12], self.dtype)
out1 = F.tanh(x)
out1 = self.tanh(x)
th = paddle.nn.Tanh()
out2 = th(x)
exe = paddle.static.Executor(self.place)
@ -261,15 +265,21 @@ class TestTanhAPI(unittest.TestCase):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable.
self.assertRaises(TypeError, F.tanh, 1)
self.assertRaises(TypeError, self.tanh, 1)
# The input dtype must be float16, float32.
x_int32 = paddle.fluid.data(
name='x_int32', shape=[12, 10], dtype='int32')
self.assertRaises(TypeError, F.tanh, x_int32)
self.assertRaises(TypeError, self.tanh, x_int32)
# support the input dtype is float16
x_fp16 = paddle.fluid.data(
name='x_fp16', shape=[12, 10], dtype='float16')
F.tanh(x_fp16)
self.tanh(x_fp16)
class TestTanhInplaceAPI(TestTanhAPI):
# test paddle.tanh_
def executed_api(self):
self.tanh = paddle.tanh_
class TestAtan(TestActivation, TestParameter):
@ -1044,12 +1054,16 @@ class TestReluAPI(unittest.TestCase):
self.x_np = np.random.uniform(-1, 1, [10, 12]).astype('float32')
self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
else paddle.CPUPlace()
self.executed_api()
def executed_api(self):
self.relu = F.relu
def test_static_api(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data('X', [10, 12])
out1 = F.relu(x)
out1 = self.relu(x)
m = paddle.nn.ReLU()
out2 = m(x)
exe = paddle.static.Executor(self.place)
@ -1061,9 +1075,9 @@ class TestReluAPI(unittest.TestCase):
def test_dygraph_api(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x_np)
out1 = F.relu(x)
m = paddle.nn.ReLU()
out2 = m(x)
out1 = m(x)
out2 = self.relu(x)
out_ref = np.maximum(self.x_np, 0)
for r in [out1, out2]:
self.assertEqual(np.allclose(out_ref, r.numpy()), True)
@ -1073,15 +1087,21 @@ class TestReluAPI(unittest.TestCase):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable.
self.assertRaises(TypeError, F.relu, 1)
self.assertRaises(TypeError, self.relu, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.fluid.data(
name='x_int32', shape=[10, 12], dtype='int32')
self.assertRaises(TypeError, F.relu, x_int32)
self.assertRaises(TypeError, self.relu, x_int32)
# support the input dtype is float16
x_fp16 = paddle.fluid.data(
name='x_fp16', shape=[10, 12], dtype='float16')
F.relu(x_fp16)
self.relu(x_fp16)
class TestReluInplaceAPI(TestReluAPI):
# test paddle.nn.functional.relu_
def executed_api(self):
self.relu = F.relu_
def ref_leaky_relu(x, alpha=0.01):
@ -1609,12 +1629,16 @@ class TestELUAPI(unittest.TestCase):
self.x_np = np.random.uniform(-3, 3, [10, 12]).astype('float32')
self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
else paddle.CPUPlace()
self.executed_api()
def executed_api(self):
self.elu = F.elu
def test_static_api(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data('X', [10, 12])
out1 = F.elu(x)
out1 = self.elu(x)
m = paddle.nn.ELU()
out2 = m(x)
exe = paddle.static.Executor(self.place)
@ -1626,14 +1650,16 @@ class TestELUAPI(unittest.TestCase):
def test_dygraph_api(self):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x_np)
out1 = F.elu(x)
out1 = self.elu(x)
x = paddle.to_tensor(self.x_np)
m = paddle.nn.ELU()
out2 = m(x)
out_ref = elu(self.x_np, 1.0)
for r in [out1, out2]:
self.assertEqual(np.allclose(out_ref, r.numpy()), True)
out1 = F.elu(x, 0.2)
out1 = self.elu(x, 0.2)
x = paddle.to_tensor(self.x_np)
m = paddle.nn.ELU(0.2)
out2 = m(x)
out_ref = elu(self.x_np, 0.2)
@ -1645,15 +1671,21 @@ class TestELUAPI(unittest.TestCase):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable.
self.assertRaises(TypeError, F.elu, 1)
self.assertRaises(TypeError, self.elu, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.fluid.data(
name='x_int32', shape=[10, 12], dtype='int32')
self.assertRaises(TypeError, F.elu, x_int32)
self.assertRaises(TypeError, self.elu, x_int32)
# support the input dtype is float16
x_fp16 = paddle.fluid.data(
name='x_fp16', shape=[10, 12], dtype='float16')
F.elu(x_fp16)
self.elu(x_fp16)
class TestELUInplaceAPI(TestELUAPI):
# test paddle.nn.functional.elu_
def executed_api(self):
self.elu = F.elu_
class TestReciprocal(TestActivation):

@ -95,5 +95,206 @@ class TestInplace(unittest.TestCase):
loss.backward()
class TestDygraphInplace(unittest.TestCase):
def setUp(self):
self.init_data()
def init_data(self):
self.input_var_numpy = np.random.rand(2, 3, 1)
self.dtype = "float32"
def non_inplace_api_processing(self, var):
return paddle.squeeze(var)
def inplace_api_processing(self, var):
return paddle.squeeze_(var)
def test_inplace_api(self):
var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
inplace_var = self.inplace_api_processing(var)
self.assertTrue(id(var) == id(inplace_var))
inplace_var[0] = 2.
self.assertTrue(np.array_equal(var.numpy(), inplace_var.numpy()))
def test_forward_version(self):
with paddle.fluid.dygraph.guard():
var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
self.assertEqual(var.inplace_version, 0)
inplace_var = self.inplace_api_processing(var)
self.assertEqual(var.inplace_version, 1)
inplace_var[0] = 2.
self.assertEqual(var.inplace_version, 2)
inplace_var = self.inplace_api_processing(inplace_var)
self.assertEqual(var.inplace_version, 3)
def test_leaf_inplace_var_error(self):
with paddle.fluid.dygraph.guard():
var = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
var.stop_gradient = False
def leaf_inplace_error():
self.inplace_api_processing(var)
self.assertRaises(ValueError, leaf_inplace_error)
def test_backward_error(self):
# It raises an error because the inplace operator will result
# in incorrect gradient computation.
with paddle.fluid.dygraph.guard():
var_a = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
var_a.stop_gradient = False
var_b = var_a**2
# Here, the gradient computation will use the value of var_b
var_c = var_b**2
self.inplace_api_processing(var_b)
loss = paddle.nn.functional.relu(var_c)
with self.assertRaisesRegexp(
RuntimeError,
"received tensor_version:{} != wrapper_version_snapshot:{}".
format(1, 0)):
loss.backward()
def test_backward_success_1(self):
# var_b is modified inplace before using it, the inplace operator doesn't result
# in incorrect gradient computation.
grad_var_a, grad_var_a_inplace = 0, 1
with paddle.fluid.dygraph.guard():
var_a = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
var_a.stop_gradient = False
var_b = var_a**2
var_c = self.inplace_api_processing(
var_b) # var_b is modified inplace before using it
# Here, the gradient computation will use the value of var_b
var_d = var_c**2
loss = var_d.sum()
loss.backward()
grad_var_a_inplace = var_a.grad
with paddle.fluid.dygraph.guard():
var_a = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
var_a.stop_gradient = False
var_b = var_a**2
var_c = self.non_inplace_api_processing(var_b)
var_d = var_c**2
loss = var_d.sum()
loss.backward()
grad_var_a = var_a.grad
self.assertTrue(np.array_equal(grad_var_a_inplace, grad_var_a))
def test_backward_success_2(self):
# Although var_b is modified inplace after using it, it does not used in gradient computation.
# The inplace operator doesn't result in incorrect gradient computation.
grad_var_a, grad_var_a_inplace = 0, 1
with paddle.fluid.dygraph.guard():
var_a = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
var_a.stop_gradient = False
var_b = var_a**2
var_c = self.inplace_api_processing(
var_b) # var_b is modified inplace before using it
var_d = var_c + var_c # Here, the grad op of sum doesn't use the value of var_b
loss = var_d.sum()
loss.backward()
grad_var_a_inplace = var_a.grad
with paddle.fluid.dygraph.guard():
var_a = paddle.to_tensor(self.input_var_numpy).astype(self.dtype)
var_a.stop_gradient = False
var_b = var_a**2
var_c = self.non_inplace_api_processing(
var_b) # var_b is modified inplace before using it
var_d = var_c + var_c # Here, the grad op of sum doesn't use the value of var_b
loss = var_d.sum()
loss.backward()
grad_var_a = var_a.grad
self.assertTrue(np.array_equal(grad_var_a_inplace, grad_var_a))
class TestDygraphInplaceUnsqueeze(TestDygraphInplace):
def non_inplace_api_processing(self, var):
return paddle.unsqueeze(var, -1)
def inplace_api_processing(self, var):
return paddle.unsqueeze_(var, -1)
class TestDygraphInplaceReshape(TestDygraphInplace):
def non_inplace_api_processing(self, var):
return paddle.reshape(var, [-1])
def inplace_api_processing(self, var):
return paddle.reshape_(var, [-1])
class TestDygraphInplaceScatter(TestDygraphInplace):
def init_data(self):
self.input_var_numpy = np.array([[1, 1], [2, 2], [3, 3]])
self.dtype = "float32"
def non_inplace_api_processing(self, var):
index = paddle.to_tensor([2, 1, 0, 1], dtype='int64')
updates = paddle.to_tensor(
[[1, 1], [2, 2], [3, 3], [4, 4]], dtype='float32')
return paddle.scatter(var, index, updates, overwrite=False)
def inplace_api_processing(self, var):
index = paddle.to_tensor([2, 1, 0, 1], dtype='int64')
updates = paddle.to_tensor(
[[1, 1], [2, 2], [3, 3], [4, 4]], dtype='float32')
return paddle.scatter_(var, index, updates, overwrite=False)
class TestDygraphInplaceElu(TestDygraphInplace):
def non_inplace_api_processing(self, var):
return paddle.nn.functional.elu(var)
def inplace_api_processing(self, var):
return paddle.nn.functional.elu_(var)
class TestDygraphInplaceRelu(TestDygraphInplace):
def non_inplace_api_processing(self, var):
return paddle.nn.functional.relu(var)
def inplace_api_processing(self, var):
return paddle.nn.functional.relu_(var)
class TestDygraphInplaceSoftmax(TestDygraphInplace):
def non_inplace_api_processing(self, var):
return paddle.nn.functional.softmax(var)
def inplace_api_processing(self, var):
return paddle.nn.functional.softmax_(var)
class TestDygraphInplaceTanh(TestDygraphInplace):
def non_inplace_api_processing(self, var):
return paddle.tanh(var)
def inplace_api_processing(self, var):
return paddle.tanh_(var)
if __name__ == '__main__':
unittest.main()

@ -250,8 +250,11 @@ class TestReshapeAPI(unittest.TestCase):
def _set_paddle_api(self):
self.fill_constant = paddle.fluid.layers.fill_constant
self.data = paddle.static.data
self.reshape = paddle.reshape
self.to_tensor = paddle.to_tensor
self._executed_api()
def _executed_api(self):
self.reshape = paddle.reshape
def _set_fluid_api(self):
self.fill_constant = fluid.layers.fill_constant
@ -322,6 +325,30 @@ class TestReshapeAPI(unittest.TestCase):
assert np.array_equal(out_3.numpy(), input.reshape(shape))
class TestStaticReshape_(TestReshapeAPI):
def _executed_api(self):
self.reshape = paddle.reshape_
def test_imperative(self):
self._set_paddle_api()
input = np.random.random([2, 25]).astype("float32")
shape = [2, 5, 5]
with fluid.dygraph.guard():
x = self.to_tensor(input)
positive_five = self.fill_constant([1], "int32", 5)
out_1 = self.reshape(x, shape)
out_2 = self.reshape(x, shape=[positive_five, 10])
shape_tensor = self.to_tensor(np.array([2, 5, 5]).astype("int32"))
out_3 = self.reshape(x, shape=shape_tensor)
assert np.array_equal(out_1.numpy(), input.reshape(shape))
assert np.array_equal(out_2.numpy(), input.reshape(shape))
assert np.array_equal(out_3.numpy(), input.reshape(shape))
# Test Input Error
class TestReshapeOpError(unittest.TestCase):
def _set_paddle_api(self):
@ -397,12 +424,18 @@ class TestReshapeOpError(unittest.TestCase):
self._test_errors()
class API_TestDygraphReshape(unittest.TestCase):
class TestDygraphReshapeAPI(unittest.TestCase):
def setUp(self):
self.executed_api()
def executed_api(self):
self.reshape = paddle.reshape
def test_out(self):
paddle.disable_static()
input_1 = np.random.random([5, 1, 10]).astype("int32")
input = paddle.to_tensor(input_1)
output = paddle.reshape(x=input, shape=[5, 10])
output = self.reshape(x=input, shape=[5, 10])
out_np = output.numpy()
expected_out = np.reshape(input_1, newshape=[5, 10])
self.assertTrue(np.allclose(expected_out, out_np))
@ -411,7 +444,7 @@ class API_TestDygraphReshape(unittest.TestCase):
paddle.disable_static()
input_1 = np.random.random([5, 1, 10]).astype("uint8")
input = paddle.to_tensor(input_1)
output = paddle.reshape(x=input, shape=[5, 10])
output = self.reshape(x=input, shape=[5, 10])
out_np = output.numpy()
expected_out = np.reshape(input_1, newshape=[5, 10])
self.assertTrue(np.allclose(expected_out, out_np))
@ -420,11 +453,16 @@ class API_TestDygraphReshape(unittest.TestCase):
paddle.disable_static()
input_1 = np.random.random([5, 1, 10]).astype("float32")
input = paddle.to_tensor(input_1)
output = paddle.reshape(x=input, shape=[5, 10])
output = self.reshape(x=input, shape=[5, 10])
out_np = output.numpy()
expected_out = np.reshape(input_1, newshape=[5, 10])
self.assertTrue(np.allclose(expected_out, out_np))
class TestDygraphReshapeInplaceAPI(TestDygraphReshapeAPI):
def executed_api(self):
self.reshape = paddle.reshape_
if __name__ == "__main__":
unittest.main()

@ -180,13 +180,17 @@ class TestScatterAPI(unittest.TestCase):
self.places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
self.places.append(fluid.CUDAPlace(0))
self.executed_api()
def executed_api(self):
self.scatter = paddle.scatter
def check_static_result(self, place):
with fluid.program_guard(fluid.Program(), fluid.Program()):
input = fluid.data(name="input", shape=[3, 2], dtype="float64")
index = fluid.data(name="index", shape=[4], dtype="int64")
updates = fluid.data(name="updates", shape=[4, 2], dtype="float64")
result = paddle.scatter(input, index, updates, False)
result = self.scatter(input, index, updates, False)
input_data = np.array([[1, 1], [2, 2], [3, 3]]).astype(np.float64)
index_data = np.array([2, 1, 0, 1]).astype(np.int64)
@ -220,10 +224,15 @@ class TestScatterAPI(unittest.TestCase):
index = fluid.dygraph.to_variable(index_data)
updates = fluid.dygraph.to_variable(updates_data)
output1 = paddle.scatter(x, index, updates, overwrite=False)
output1 = self.scatter(x, index, updates, overwrite=False)
self.assertEqual((output1.numpy() == \
np.array([[3., 3.],[6., 6.],[1., 1.]])).all(), True)
class TestScatterInplaceAPI(TestScatterAPI):
def executed_api(self):
self.scatter = paddle.scatter_
if __name__ == "__main__":
unittest.main()

@ -301,11 +301,15 @@ class TestSoftmaxAPI(unittest.TestCase):
) else paddle.CPUPlace()
self.x_np = np.random.uniform(-1., 1., [2, 3, 4, 5]).astype('float32')
self.out_ref = np.apply_along_axis(stable_softmax, -1, self.x_np)
self.executed_api()
def executed_api(self):
self.softmax = F.softmax
def test_static_check(self):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.fluid.data('X', self.x_np.shape, 'float32')
out1 = F.softmax(x)
out1 = self.softmax(x)
m = paddle.nn.Softmax()
out2 = m(x)
exe = paddle.static.Executor(self.place)
@ -318,21 +322,23 @@ class TestSoftmaxAPI(unittest.TestCase):
paddle.disable_static(self.place)
x = paddle.to_tensor(self.x_np)
out1 = F.softmax(x)
out1 = self.softmax(x)
x = paddle.to_tensor(self.x_np)
m = paddle.nn.Softmax()
out2 = m(x)
out_ref = ref_softmax(self.x_np, axis=-1, dtype=None)
for r in [out1, out2]:
self.assertEqual(np.allclose(out_ref, r.numpy()), True)
out1 = F.softmax(x, axis=0)
out1 = self.softmax(x, axis=0)
x = paddle.to_tensor(self.x_np)
m = paddle.nn.Softmax(axis=0)
out2 = m(x)
out_ref = ref_softmax(self.x_np, axis=0, dtype=None)
for r in [out1, out2]:
self.assertEqual(np.allclose(out_ref, r.numpy()), True)
out = F.softmax(x, dtype=np.float64)
out = self.softmax(x, dtype=np.float64)
out_ref = ref_softmax(self.x_np, axis=-1, dtype=np.float64)
self.assertEqual(np.allclose(out_ref, out.numpy()), True)
@ -341,15 +347,20 @@ class TestSoftmaxAPI(unittest.TestCase):
def test_error(self):
with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable.
self.assertRaises(TypeError, F.softmax, 1)
self.assertRaises(TypeError, self.softmax, 1)
# The input dtype must be float16, float32, float64.
x_int32 = paddle.fluid.data(
name='x_int32', shape=[2, 3], dtype='int32')
self.assertRaises(TypeError, F.softmax, x_int32)
self.assertRaises(TypeError, self.softmax, x_int32)
# support the input dtype is float16
x_fp16 = paddle.fluid.data(
name='x_fp16', shape=[2, 3], dtype='float16')
F.softmax(x_fp16)
self.softmax(x_fp16)
class TestSoftmaxInplaceAPI(TestSoftmaxAPI):
def executed_api(self):
self.softmax = F.softmax_
if __name__ == "__main__":

@ -98,13 +98,19 @@ class TestSqueezeOpError(unittest.TestCase):
class API_TestSqueeze(unittest.TestCase):
def setUp(self):
self.executed_api()
def executed_api(self):
self.squeeze = paddle.squeeze
def test_out(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
data1 = paddle.static.data(
'data1', shape=[-1, 1, 10], dtype='float64')
result_squeeze = paddle.squeeze(data1, axis=[1])
result_squeeze = self.squeeze(data1, axis=[1])
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
input1 = np.random.random([5, 1, 10]).astype('float64')
@ -114,12 +120,23 @@ class API_TestSqueeze(unittest.TestCase):
self.assertTrue(np.allclose(expected_result, result))
class API_TestStaticSqueeze_(API_TestSqueeze):
def executed_api(self):
self.squeeze = paddle.squeeze_
class API_TestDygraphSqueeze(unittest.TestCase):
def setUp(self):
self.executed_api()
def executed_api(self):
self.squeeze = paddle.squeeze
def test_out(self):
paddle.disable_static()
input_1 = np.random.random([5, 1, 10]).astype("int32")
input = paddle.to_tensor(input_1)
output = paddle.squeeze(input, axis=[1])
output = self.squeeze(input, axis=[1])
out_np = output.numpy()
expected_out = np.squeeze(input_1, axis=1)
self.assertTrue(np.allclose(expected_out, out_np))
@ -128,7 +145,7 @@ class API_TestDygraphSqueeze(unittest.TestCase):
paddle.disable_static()
input_1 = np.random.random([5, 1, 10]).astype("int8")
input = paddle.to_tensor(input_1)
output = paddle.squeeze(input, axis=[1])
output = self.squeeze(input, axis=[1])
out_np = output.numpy()
expected_out = np.squeeze(input_1, axis=1)
self.assertTrue(np.allclose(expected_out, out_np))
@ -137,7 +154,7 @@ class API_TestDygraphSqueeze(unittest.TestCase):
paddle.disable_static()
input_1 = np.random.random([5, 1, 10]).astype("uint8")
input = paddle.to_tensor(input_1)
output = paddle.squeeze(input, axis=[1])
output = self.squeeze(input, axis=[1])
out_np = output.numpy()
expected_out = np.squeeze(input_1, axis=1)
self.assertTrue(np.allclose(expected_out, out_np))
@ -146,7 +163,7 @@ class API_TestDygraphSqueeze(unittest.TestCase):
paddle.disable_static()
input_1 = np.random.random([5, 1, 10]).astype("int32")
input = paddle.to_tensor(input_1)
output = paddle.squeeze(input, axis=1)
output = self.squeeze(input, axis=1)
out_np = output.numpy()
expected_out = np.squeeze(input_1, axis=1)
self.assertTrue(np.allclose(expected_out, out_np))
@ -155,11 +172,16 @@ class API_TestDygraphSqueeze(unittest.TestCase):
paddle.disable_static()
input_1 = np.random.random([5, 1, 10]).astype("int32")
input = paddle.to_tensor(input_1)
output = paddle.squeeze(input, axis=(1, 2))
output = self.squeeze(input, axis=(1, 0))
out_np = output.numpy()
expected_out = np.squeeze(input_1, axis=1)
self.assertTrue(np.allclose(expected_out, out_np))
class API_TestDygraphSqueezeInplace(API_TestDygraphSqueeze):
def executed_api(self):
self.squeeze = paddle.squeeze_
if __name__ == "__main__":
unittest.main()

@ -208,6 +208,12 @@ class TestUnsqueezeOp4_AxesTensor(TestUnsqueezeOp_AxesTensor):
# test api
class TestUnsqueezeAPI(unittest.TestCase):
def setUp(self):
self.executed_api()
def executed_api(self):
self.unsqueeze = paddle.unsqueeze
def test_api(self):
input = np.random.random([3, 2, 5]).astype("float64")
x = paddle.static.data(name='x', shape=[3, 2, 5], dtype="float64")
@ -218,12 +224,11 @@ class TestUnsqueezeAPI(unittest.TestCase):
axes_tensor_int64 = paddle.static.data(
name='axes_tensor_int64', shape=[3], dtype="int64")
out_1 = paddle.unsqueeze(x, axis=[3, 1, 1])
out_2 = paddle.unsqueeze(
x, axis=[positive_3_int32, positive_1_int64, 1])
out_3 = paddle.unsqueeze(x, axis=axes_tensor_int32)
out_4 = paddle.unsqueeze(x, axis=3)
out_5 = paddle.unsqueeze(x, axis=axes_tensor_int64)
out_1 = self.unsqueeze(x, axis=[3, 1, 1])
out_2 = self.unsqueeze(x, axis=[positive_3_int32, positive_1_int64, 1])
out_3 = self.unsqueeze(x, axis=axes_tensor_int32)
out_4 = self.unsqueeze(x, axis=3)
out_5 = self.unsqueeze(x, axis=axes_tensor_int64)
exe = paddle.static.Executor(place=paddle.CPUPlace())
res_1, res_2, res_3, res_4, res_5 = exe.run(
@ -244,10 +249,15 @@ class TestUnsqueezeAPI(unittest.TestCase):
def test_error(self):
def test_axes_type():
x2 = paddle.static.data(name="x2", shape=[2, 25], dtype="int32")
paddle.unsqueeze(x2, axis=2.1)
self.unsqueeze(x2, axis=2.1)
self.assertRaises(TypeError, test_axes_type)
class TestUnsqueezeInplaceAPI(TestUnsqueezeAPI):
def executed_api(self):
self.unsqueeze = paddle.unsqueeze_
if __name__ == "__main__":
unittest.main()

@ -203,11 +203,17 @@ class API_TestDyUnsqueezeAxisTensorList(unittest.TestCase):
class API_TestDygraphUnSqueeze(unittest.TestCase):
def setUp(self):
self.executed_api()
def executed_api(self):
self.unsqueeze = paddle.unsqueeze
def test_out(self):
paddle.disable_static()
input_1 = np.random.random([5, 1, 10]).astype("int32")
input = paddle.to_tensor(input_1)
output = paddle.unsqueeze(input, axis=[1])
output = self.unsqueeze(input, axis=[1])
out_np = output.numpy()
expected_out = np.expand_dims(input_1, axis=1)
self.assertTrue(np.allclose(expected_out, out_np))
@ -216,7 +222,7 @@ class API_TestDygraphUnSqueeze(unittest.TestCase):
paddle.disable_static()
input_1 = np.random.random([5, 1, 10]).astype("int8")
input = paddle.to_tensor(input_1)
output = paddle.unsqueeze(input, axis=[1])
output = self.unsqueeze(input, axis=[1])
out_np = output.numpy()
expected_out = np.expand_dims(input_1, axis=1)
self.assertTrue(np.allclose(expected_out, out_np))
@ -225,7 +231,7 @@ class API_TestDygraphUnSqueeze(unittest.TestCase):
paddle.disable_static()
input_1 = np.random.random([5, 1, 10]).astype("uint8")
input = paddle.to_tensor(input_1)
output = paddle.unsqueeze(input, axis=1)
output = self.unsqueeze(input, axis=1)
out_np = output.numpy()
expected_out = np.expand_dims(input_1, axis=1)
self.assertTrue(np.allclose(expected_out, out_np))
@ -234,7 +240,7 @@ class API_TestDygraphUnSqueeze(unittest.TestCase):
paddle.disable_static()
input_1 = np.random.random([5, 1, 10]).astype("int32")
input = paddle.to_tensor(input_1)
output = paddle.unsqueeze(input, axis=1)
output = self.unsqueeze(input, axis=1)
out_np = output.numpy()
expected_out = np.expand_dims(input_1, axis=1)
self.assertTrue(np.allclose(expected_out, out_np))
@ -243,11 +249,16 @@ class API_TestDygraphUnSqueeze(unittest.TestCase):
paddle.disable_static()
input_1 = np.random.random([5, 1, 10]).astype("int32")
input = paddle.to_tensor(input_1)
output = paddle.unsqueeze(input, axis=(1, 2))
output = self.unsqueeze(input, axis=(1, 2))
out_np = output.numpy()
expected_out = np.expand_dims(input_1, axis=1)
self.assertTrue(np.allclose(expected_out, out_np))
class API_TestDygraphUnSqueezeInplace(API_TestDygraphUnSqueeze):
def executed_api(self):
self.unsqueeze = paddle.unsqueeze_
if __name__ == "__main__":
unittest.main()

@ -30,6 +30,7 @@ __all__ += pooling.__all__
from . import loss
__all__ += loss.__all__
from .activation import elu #DEFINE_ALIAS
from .activation import elu_ #DEFINE_ALIAS
# from .activation import erf #DEFINE_ALIAS
from .activation import gelu #DEFINE_ALIAS
from .activation import hardshrink #DEFINE_ALIAS
@ -41,16 +42,19 @@ from .activation import log_sigmoid #DEFINE_ALIAS
from .activation import maxout #DEFINE_ALIAS
from .activation import prelu #DEFINE_ALIAS
from .activation import relu #DEFINE_ALIAS
from .activation import relu_ #DEFINE_ALIAS
from .activation import relu6 #DEFINE_ALIAS
from .activation import selu #DEFINE_ALIAS
from .activation import sigmoid #DEFINE_ALIAS
# from .activation import soft_relu #DEFINE_ALIAS
from .activation import softmax #DEFINE_ALIAS
from .activation import softmax_ #DEFINE_ALIAS
from .activation import softplus #DEFINE_ALIAS
from .activation import softshrink #DEFINE_ALIAS
from .activation import softsign #DEFINE_ALIAS
from .activation import swish #DEFINE_ALIAS
from .activation import tanh #DEFINE_ALIAS
from .activation import tanh_ #DEFINE_ALIAS
from .activation import tanhshrink #DEFINE_ALIAS
from .activation import thresholded_relu #DEFINE_ALIAS
from .activation import log_softmax #DEFINE_ALIAS

@ -20,10 +20,14 @@ from ...fluid.layers import maxout #DEFINE_ALIAS
from ...fluid.layers import swish #DEFINE_ALIAS
from ...fluid.layers import sigmoid #DEFINE_ALIAS
from ...tensor.math import tanh #DEFINE_ALIAS
from ...tensor.math import tanh_ #DEFINE_ALIAS
from ...tensor.manipulation import _print_warning_in_static_mode
__all__ = [
'brelu',
'elu',
'elu_',
'gelu',
'hardshrink',
'hardtanh',
@ -34,15 +38,18 @@ __all__ = [
'maxout',
'prelu',
'relu',
'relu_',
'relu6',
'selu',
'softmax',
'softmax_',
'softplus',
'softshrink',
'softsign',
'sigmoid',
'swish',
'tanh',
'tanh_',
'tanhshrink',
'thresholded_relu',
'log_softmax',
@ -99,6 +106,19 @@ def elu(x, alpha=1.0, name=None):
return out
def elu_(x, alpha=1.0, name=None):
r"""
Inplace version of ``elu`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_nn_cn_elu`.
"""
if in_dygraph_mode():
return core.ops.elu_(x, 'alpha', alpha)
_print_warning_in_static_mode("elu")
return elu(x, alpha, name)
def gelu(x, approximate=False, name=None):
r"""
gelu activation.
@ -514,6 +534,19 @@ def relu(x, name=None):
return out
def relu_(x, name=None):
"""
Inplace version of ``relu`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_nn_cn_relu`.
"""
if in_dygraph_mode():
return core.ops.relu_(x)
_print_warning_in_static_mode("relu")
return relu(x, name)
def log_sigmoid(x, name=None):
r"""
log_sigmoid activation.
@ -879,6 +912,23 @@ def softmax(x, axis=-1, dtype=None, name=None):
return outs_softmax
def softmax_(x, axis=-1, dtype=None, name=None):
r"""
Inplace version of ``softmax`` API, the output Tensor will be inplaced with input ``x``.
Please refer to :ref:`api_nn_cn_softmax`.
"""
if (dtype is not None) and (not isinstance(dtype, core.VarDesc.VarType)):
dtype = convert_np_dtype_to_dtype_(dtype)
use_cudnn = True
if in_dygraph_mode():
return core.ops.softmax_(x, 'axis', axis, 'use_cudnn', use_cudnn)
_print_warning_in_static_mode("softmax")
return softmax(x, axis, dtype, name)
def softplus(x, beta=1, threshold=20, name=None):
r"""
softplus activation

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save