Recompute Offload (#30233)

revert-31562-mean
JZ-LIANG 4 years ago committed by GitHub
parent 2e80857760
commit 75936d838f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -22,7 +22,11 @@ enum Mode {
HETER = 4; // support XPU and GPU computing server
}
message RecomputeConfig { repeated string checkpoints = 1; }
message RecomputeConfig {
repeated string checkpoints = 1;
optional bool enable_offload = 2 [ default = false ];
repeated int32 checkpoint_shape = 3;
}
message ShardingConfig {
optional float fuse_broadcast_MB = 1 [ default = 32.0 ];

@ -394,5 +394,5 @@ REGISTER_PASS_CAPABILITY(squared_mat_sub_fuse_pass)
.EQ("square", 0)
.LE("elementwise_mul", 1)
.LE("elementwise_sub", 1)
.EQ("fill_constant", 1)
.LE("fill_constant", 2)
.EQ("fusion_squared_mat_sub", 0));

@ -116,6 +116,15 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker {
"memory. Otherwise, fill output variable to the running "
"device")
.SetDefault(false);
AddAttr<int>("place_type",
"(int, default -1) allow mamually setting place where the "
"variable should be hold. "
"-1: not set manually, determine the place by executor. "
"0: CPUPlace. "
"1: CUDAPlace. "
"2: CUDAPinnedPlace. "
"3: XPUPlace. ")
.SetDefault(-1);
AddOutput("Out",
"(Tensor) Tensor of specified shape will be filled "
"with the specified value");
@ -154,4 +163,11 @@ REGISTER_OP_VERSION(fill_constant)
)ROC",
paddle::framework::compatible::OpVersionDesc().NewInput(
"ValueTensor",
"In order to support new feature tensor support of Value"));
"In order to support new feature tensor support of Value"))
.AddCheckpoint(
R"ROC(
Upgrade fill_constant to add a new attribute [place_type].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"place_type",
"In order to support tensor in CUDAPinnedPlace and XPUPlace", -1));

@ -39,6 +39,7 @@ class FillConstantKernel : public framework::OpKernel<T> {
auto str_value = ctx.Attr<std::string>("str_value");
auto float_value = ctx.Attr<float>("value");
auto force_cpu = ctx.Attr<bool>("force_cpu");
auto place_type = ctx.Attr<int>("place_type");
framework::Tensor *tensor = nullptr;
framework::Variable *out_var = ctx.OutputVar("Out");
@ -101,29 +102,59 @@ class FillConstantKernel : public framework::OpKernel<T> {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(ctx.GetPlace());
bool cpu_place = force_cpu || ctx.GetPlace() == platform::CPUPlace();
if (cpu_place) {
int actual_place = place_type;
if (actual_place == -1) {
bool cpu_place = force_cpu || ctx.GetPlace() == platform::CPUPlace();
if (cpu_place) {
actual_place = 0;
} else if (platform::is_gpu_place(ctx.GetPlace())) {
actual_place = 1;
} else if (platform::is_xpu_place(ctx.GetPlace())) {
actual_place = 3;
}
}
if (actual_place == 0) {
tensor->mutable_data(platform::CPUPlace(), data_type);
math::SetConstant<platform::CPUDeviceContext, T> functor;
functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx),
tensor, static_cast<T>(value));
}
} else if (actual_place == 1) {
#ifdef PADDLE_WITH_CUDA
if (!cpu_place) {
tensor->mutable_data(ctx.GetPlace(), data_type);
math::SetConstant<platform::CUDADeviceContext, T> functor;
functor(reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx),
tensor, static_cast<T>(value));
}
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
#endif
} else if (actual_place == 2) {
#ifdef PADDLE_WITH_CUDA
tensor->mutable_data(platform::CUDAPinnedPlace(), data_type);
math::SetConstant<platform::CPUDeviceContext, T> functor;
functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx),
tensor, static_cast<T>(value));
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
#endif
} else if (actual_place == 3) {
#ifdef PADDLE_WITH_XPU
if (!cpu_place) {
tensor->mutable_data(ctx.GetPlace(), data_type);
math::SetConstant<platform::XPUDeviceContext, T> functor;
functor(reinterpret_cast<const platform::XPUDeviceContext &>(dev_ctx),
tensor, static_cast<T>(value));
}
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with XPU."));
#endif
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Could NOT determine the place of variable, place_type = %d .",
actual_place));
}
}
};
} // namespace operators

@ -0,0 +1,146 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/memcpy_op.h"
#include <string>
namespace paddle {
namespace framework {
class OpDesc;
class Variable;
} // namespace framework
namespace imperative {
class OpBase;
} // namespace imperative
namespace platform {
struct CPUPlace;
struct CUDAPlace;
struct float16;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
class MemcpyOp : public framework::OperatorWithKernel {
public:
MemcpyOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
auto type = ctx->GetInputsVarType("X")[0];
if (type == framework::proto::VarType::SELECTED_ROWS ||
type == framework::proto::VarType::LOD_TENSOR) {
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
if (type == framework::proto::VarType::LOD_TENSOR) {
ctx->ShareLoD("X", /*->*/ "Out");
}
}
}
protected:
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
return framework::OpKernelType(expected_kernel_type.data_type_,
expected_kernel_type.place_,
tensor.layout());
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
class MemcpyInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
ctx->SyncTypeAndDataType("X", "Out");
}
};
class MemcpyKernel {
public:
void operator()(const framework::ExecutionContext &ctx) const {
auto *x = ctx.InputVar("X");
if (x == nullptr) {
return;
}
PADDLE_ENFORCE_EQ(
ctx.HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) of memcpy_op is not found."));
auto *out = ctx.OutputVar("Out");
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(ctx.GetPlace());
auto dst_place_type = ctx.Attr<int>("dst_place_type");
framework::VisitVarType(*x, MemcpyFunctor(out, dev_ctx, dst_place_type));
}
};
class MemcpyOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(LoDTensor) The input variable ");
AddOutput("Out",
"(LoDTensor) The type of output "
"is the same as input X.");
AddAttr<int>("dst_place_type",
"Determine the dst place of tensor copy. "
"By Now it ONLY support CUDAPlace and CUDAPinnedPlace. Other "
"place type is Unimplemented and will cause ERROR."
"0: dst is on CPUPlace. "
"1: dst is on CUDAPlace. "
"2: dst is on CUDAPinnedPlace. "
"3: dst is on XPUPlace. ");
AddComment(R"DOC(
Memcpy Operator.
By now, it ONLY supports the memcopy between CUDAPinnedPlace and CUDAPlace,
and used as an internal op by Recompute-Offload.
You would have to update it if you want other more capacities.
Out = X, when type in [LoDTensor]
raise error if the type is not listed above.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(
memcpy, ops::MemcpyOp, ops::MemcpyOpProtoMaker, ops::MemcpyInferVarType,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL_FUNCTOR(memcpy, float, ops::MemcpyKernel, double,
ops::MemcpyKernel, int, ops::MemcpyKernel,
int64_t, ops::MemcpyKernel, bool,
ops::MemcpyKernel, plat::float16,
ops::MemcpyKernel);
#ifdef PADDLE_WITH_CUDA
REGISTER_OP_CUDA_KERNEL_FUNCTOR(memcpy, float, ops::MemcpyKernel, double,
ops::MemcpyKernel, int, ops::MemcpyKernel,
int64_t, ops::MemcpyKernel, bool,
ops::MemcpyKernel, plat::float16,
ops::MemcpyKernel);
#endif

@ -0,0 +1,75 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace framework {
class LoDTensor;
class Variable;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
class MemcpyFunctor {
public:
MemcpyFunctor(framework::Variable *out,
const platform::DeviceContext &dev_ctx,
const int dst_place_type)
: out_(out), dev_ctx_(dev_ctx), dst_place_type_(dst_place_type) {}
void operator()(const framework::LoDTensor &lod_tensor) const {
auto &out_tensor = *out_->GetMutable<framework::LoDTensor>();
if (dst_place_type_ == 3) {
framework::TensorCopy(lod_tensor, platform::CUDAPinnedPlace(), dev_ctx_,
&out_tensor);
} else if (dst_place_type_ == 2) {
framework::TensorCopy(lod_tensor, dev_ctx_.GetPlace(), dev_ctx_,
&out_tensor);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"memcpy dst_place_type: %d is not supported yet.", dst_place_type_));
}
out_tensor.set_lod(lod_tensor.lod());
}
void operator()(const framework::SelectedRows &rows) const {
// (JZ-LIANG) to support SelectedRows
PADDLE_THROW(platform::errors::Unimplemented(
"Memcpy for SelectedRows is NOT support yet."));
}
template <typename T>
void operator()(const T &v) const {
PADDLE_ENFORCE_EQ(
true, false,
platform::errors::PermissionDenied(
"Not support type for Memcpy op with type %s", typeid(T).name()));
}
private:
framework::Variable *out_;
const platform::DeviceContext &dev_ctx_;
const int dst_place_type_;
};
} // namespace operators
} // namespace paddle

@ -632,8 +632,20 @@ class DistributedStrategy(object):
@property
def recompute_configs(self):
"""
Set recompute configurations. In general, the recompute strategy of current
implementation should have some manually assign checkpoints
Set recompute configurations.
**Note**:
checkpoints(list): list of string name of checkpoints. In general, the recompute
strategy of current implementation should have some manually assign checkpoints.
enable_offload(bool): enable recompute checkpoints offload feature. this feature
will offload the checkpoint to host memory to allow even larger batch size. since
the memcpy from host to device takes time, it is a trade off between larger batch
size and training speed.
checkpoint_shape(list): list of int that specific the shape of checkpoint. so far
recompute-offload requires that all checkpoint to be same shape, and every dimension
specific here should be determined ("-1" is not allowed).
Examples:
@ -642,7 +654,10 @@ class DistributedStrategy(object):
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.recompute = True
strategy.recompute_configs = {"checkpoints": ["x", "y"]}
strategy.recompute_configs = {
"checkpoints": ["x", "y"],
"enable_offload": True,
"checkpoint_shape": [100, 512, 1024] }
"""
return get_msg_dict(self.strategy.recompute_configs)
@ -692,6 +707,14 @@ class DistributedStrategy(object):
This configuration will affect the communication speed in sharding training,
and should be an empirical value decided by your model size and network topology.
hybrid_dp(bool): enable hybrid data parallelism above the sharding parallelism.
you are supposed to have at least double the number of gpu you have in normal sharding
training to enable this feature.
sharding_group_size(int): attribute of hybrid_dp. specific the the number of gpus within
each sharding group; and therefore, the number of hybrid data parallelism ways will be equal
to (global_size / sharding_group_size).
Examples:
.. code-block:: python
@ -699,7 +722,10 @@ class DistributedStrategy(object):
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.sharding = True
strategy.sharding_configs = {"fuse_broadcast_MB": 32}
strategy.sharding_configs = {
"fuse_broadcast_MB": 32,
"hybrid_dp": True,
"sharding_group_size": 8}
"""
return get_msg_dict(self.strategy.sharding_configs)

@ -39,9 +39,13 @@ class RecomputeOptimizer(MetaOptimizerBase):
return
configs = self.user_defined_strategy.recompute_configs
self.wrapped_opt = RO(self.inner_opt)
self.wrapped_opt._set_checkpoints(list(configs["checkpoints"]))
if configs["enable_offload"]:
self.wrapped_opt._enable_offload()
# TODO(JZ-LIANG) might found a way to infer the checkpoint shape automatically
checkpoint_shapes = list(configs["checkpoint_shape"])
self.wrapped_opt.checkpoint_shape = checkpoint_shapes
def _can_apply(self):
if not self.role_maker._is_collective:

@ -99,8 +99,32 @@ class ProgramStats(object):
max_op_idx = max(max_op_idx, idx)
if min_op_idx >= max_op_idx:
return False, min_op_idx, max_op_idx
return True, min_op_idx, max_op_idx
def _update_segment_start(self, min_idx, pre_segment_end_idx):
"""
persist vars of amp-related cast should be included in recompute segment
"""
def is_amp_cast(op):
return op.desc.type() == 'cast' and self.block.var(
op.desc.input_arg_names()[0]).persistable
idx_ = min_idx - 1
updated_min_idx = min_idx
while idx_ > pre_segment_end_idx:
if is_amp_cast(self.ops[idx_]):
_logger.debug("found amp-cast op: {}, : {}".format(self.ops[
idx_].desc.type(), self.ops[idx_].desc.input_arg_names()[
0]))
updated_min_idx = idx_
idx_ -= 1
else:
break
return updated_min_idx
def build_stats(self):
for i, op in enumerate(self.ops):
self.op_deps[i] = {"in_ops": [], "out_ops": []}
@ -751,20 +775,29 @@ def _append_backward_ops_with_checkpoints_(
if name not in program_stat.var_op_deps:
break
op_idx = program_stat.var_op_deps[name]["var_as_output_ops"]
# only count the last generate op
for idx in op_idx:
max_op_idx = max(max_op_idx, idx)
if max_op_idx > 0:
segments.append([0, max_op_idx + 1])
else:
start_idx = 0
pre_segment_end_idx = -1
while True:
_logger.debug("FW op range[0] - [{}]".format(len(ops)))
if start_idx >= len(checkpoints_name) - 1:
break
# min_idx: checkpoint_1' s input op
# max_idx: checkpoint_2' s output op
flag, min_idx, max_idx = program_stat.is_subgraph(
[checkpoints_name[start_idx]],
[checkpoints_name[start_idx + 1]])
if flag:
# max_idx + 1 since the exact and used segment end idx is max_idx
min_idx = program_stat._update_segment_start(
min_idx, pre_segment_end_idx)
segments.append([min_idx, max_idx + 1])
start_idx += 1
if segments != [] and segments[0][0] != 0:
@ -772,12 +805,31 @@ def _append_backward_ops_with_checkpoints_(
else:
recompute_segments = segments
for i, (idx1, idx2) in enumerate(recompute_segments):
_logger.debug("recompute segment[{}]".format(i))
_logger.debug("segment start op: [{}]: [{}]".format(ops[idx1].desc.type(
), ops[idx1].desc.input_arg_names()))
_logger.debug("segment end op: [{}]: [{}]".format(ops[
idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names()))
_logger.debug("recompute segment[{}]".format(i))
_logger.debug("segment start op: [{}]: [{}]".format(ops[idx1].desc.type(
), ops[idx1].desc.input_arg_names()))
_logger.debug("segment end op: [{}]: [{}]".format(ops[
idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names()))
# 2) go through all forward ops and induct all variables that will be hold in memory
vars_should_be_hold = []
# a. variables that are used across segments will be held in memory
for segment in recompute_segments:
vars_should_be_hold.extend(
program_stat.get_out_of_subgraph_vars(segment[0], segment[1]))
cross_vars = set(vars_should_be_hold) - set(checkpoints_name)
_logger.debug("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \
len(cross_vars), cross_vars))
_logger.debug("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \
len(cross_vars), cross_vars))
# b. output of seed op should be kept in memory
vars_should_be_hold.extend(program_stat.get_reserved_vars())
# c. input variables are checkpoints
@ -792,8 +844,6 @@ def _append_backward_ops_with_checkpoints_(
max_calculated_op_position = len(ops)
if recompute_segments == []:
# if there is no recompute segment, add backward ops like
# _append_backward_ops_ function
gap_ops = ops[0:max_calculated_op_position]
for op in reversed(gap_ops):
if op.has_attr("sub_block"):
@ -807,7 +857,6 @@ def _append_backward_ops_with_checkpoints_(
grad_to_var.update(op_grad_to_var)
for i, segment in enumerate(recompute_segments[::-1]):
# add grad op for ops not in any segments
gap_ops = ops[segment[1]:max_calculated_op_position]
max_calculated_op_position = segment[0]
for op in reversed(gap_ops):
@ -851,7 +900,7 @@ def _append_backward_ops_with_checkpoints_(
# added_descs should be in grad_op_descs because it is backward op desc
grad_op_descs.extend(buffer_descs)
# 3.c. add backward ops of current recomputation ops
# 3.c. add backward ops for all ops in current segment
for op_desc in reversed(added_descs):
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op_desc, cpt.to_text(no_grad_dict[block.idx]), [])
@ -1480,9 +1529,11 @@ def append_backward(loss,
# TODO: support _append_backward_ops_with_checkpoints_ in
# sub-block (control flow)
is_recompute = False
if checkpoints != None and \
isinstance(checkpoints, list) and \
len(checkpoints) > 0:
is_recompute = True
program_stat, checkpoint_names, \
vars_should_be_hold, \
recompute_segments = \
@ -1577,7 +1628,10 @@ def append_backward(loss,
attr_val.extend(g.op.attr(op_role_var_attr_name))
g.op._set_attr(op_role_var_attr_name, attr_val)
return params_and_grads
if is_recompute:
return params_and_grads, checkpoint_names
else:
return params_and_grads
def _as_list(x):

File diff suppressed because it is too large Load Diff

@ -83,6 +83,7 @@ if(NOT WITH_GPU OR WIN32)
LIST(REMOVE_ITEM TEST_OPS test_collective_allreduce_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_broadcast_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_allgather_api)
LIST(REMOVE_ITEM TEST_OPS test_memcpy_op)
endif()
if(WIN32)

@ -132,5 +132,12 @@ class TestFleetMetaOptimizer(unittest.TestCase):
elif name == "sharding":
strategy.sharding = True
strategy.sharding_configs = {"fuse_broadcast_MB": 0.2}
elif name == "recompute-offload":
strategy.recompute = True
strategy.recompute_configs = {
"checkpoints": ["fc_0.tmp_2", "fc_1.tmp_2"],
"enable_offload": True,
"checkpoint_shape": [256]
}
else:
raise NotImplementedError()

@ -153,6 +153,20 @@ class TestFleetRecomputeMetaOptimizer(TestFleetMetaOptimizer):
self.assertIn('subprog', ''.join(outs))
self.assertIn('lamb', ops)
def test_recompute_offload(self):
train_prog, startup_prog = fluid.Program(), fluid.Program()
avg_cost, strategy = self.net(train_prog, startup_prog)
self.set_strategy(strategy, 'recompute-offload')
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
ops = [op.type for op in avg_cost.block.ops]
outs = [
op.output('Out')[0] for op in avg_cost.block.ops
if op.type == 'memcpy'
]
self.assertIn('memcpy', ops)
self.assertIn('@Pinned', ''.join(outs))
self.assertIn('@Fetch', ''.join(outs))
if __name__ == "__main__":
unittest.main()

@ -170,19 +170,19 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
self.assertEqual(ops, [
'cast', 'cast', 'cast', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_sync_comm_stream', 'cast', 'cast',
'mul', 'cast', 'elementwise_add', 'cast', 'tanh', 'cast', 'mul',
'elementwise_add', 'cast', 'tanh', 'cast', 'mul', 'elementwise_add',
'softmax', 'cast', 'cross_entropy2', 'mean', 'elementwise_mul',
'fill_constant', 'scale', 'elementwise_mul_grad', 'mean_grad',
'cross_entropy_grad2', 'cast', 'softmax_grad',
'elementwise_add_grad', 'mul_grad', 'cast', 'cast', 'mul', 'cast',
'elementwise_add', 'cast', 'tanh_grad', 'cast',
'elementwise_add_grad', 'mul_grad', 'cast', 'cast', 'mul', 'cast',
'elementwise_add', 'cast', 'tanh_grad', 'cast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream',
'cast', 'cast', 'mul', 'cast', 'elementwise_add', 'cast', 'tanh',
'cast', 'cast', 'mul', 'elementwise_add', 'cast', 'tanh', 'cast',
'mul', 'elementwise_add', 'softmax', 'cast', 'cross_entropy2',
'mean', 'elementwise_mul', 'fill_constant', 'scale',
'elementwise_mul_grad', 'mean_grad', 'cross_entropy_grad2', 'cast',
'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'cast', 'cast',
'cast', 'mul', 'cast', 'elementwise_add', 'cast', 'tanh_grad',
'cast', 'elementwise_add_grad', 'mul_grad', 'cast', 'cast', 'mul',
'cast', 'elementwise_add', 'cast', 'tanh_grad', 'cast',
'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum',

@ -0,0 +1,176 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import op_test
import numpy as np
import unittest
import paddle
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
from paddle.fluid.backward import append_backward
class TestMemcpy_FillConstant(unittest.TestCase):
def get_prog(self):
paddle.enable_static()
main_program = Program()
with program_guard(main_program):
pinned_var_name = "tensor@Pinned"
gpu_var_name = "tensor@GPU"
pinned_var = main_program.global_block().create_var(
name=pinned_var_name,
shape=[10, 10],
dtype='float32',
persistable=False,
stop_gradient=True)
gpu_var = main_program.global_block().create_var(
name=gpu_var_name,
shape=[10, 10],
dtype='float32',
persistable=False,
stop_gradient=True)
main_program.global_block().append_op(
type="fill_constant",
outputs={"Out": gpu_var_name},
attrs={
"shape": [10, 10],
"dtype": gpu_var.dtype,
"value": 1.0,
"place_type": 1
})
main_program.global_block().append_op(
type="fill_constant",
outputs={"Out": pinned_var_name},
attrs={
"shape": [10, 10],
"dtype": gpu_var.dtype,
"value": 0.0,
"place_type": 2
})
return main_program, gpu_var, pinned_var
def test_gpu_cpoy_to_pinned(self):
main_program, gpu_var, pinned_var = self.get_prog()
main_program.global_block().append_op(
type='memcpy',
inputs={'X': gpu_var},
outputs={'Out': pinned_var},
attrs={'dst_place_type': 3})
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
gpu_, pinned_ = exe.run(main_program,
feed={},
fetch_list=[gpu_var.name, pinned_var.name])
self.assertTrue(np.allclose(gpu_, pinned_))
self.assertTrue(np.allclose(pinned_, np.ones((10, 10))))
def test_pinned_cpoy_gpu(self):
main_program, gpu_var, pinned_var = self.get_prog()
main_program.global_block().append_op(
type='memcpy',
inputs={'X': pinned_var},
outputs={'Out': gpu_var},
attrs={'dst_place_type': 2})
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
gpu_, pinned_ = exe.run(main_program,
feed={},
fetch_list=[gpu_var.name, pinned_var.name])
self.assertTrue(np.allclose(gpu_, pinned_))
self.assertTrue(np.allclose(gpu_, np.zeros((10, 10))))
class TestMemcpyOPError(unittest.TestCase):
def get_prog(self):
paddle.enable_static()
main_program = Program()
with program_guard(main_program):
pinned_var = main_program.global_block().create_var(
name="tensor@Pinned_0",
shape=[10, 10],
dtype='float32',
persistable=False,
stop_gradient=True)
main_program.global_block().append_op(
type="fill_constant",
outputs={"Out": "tensor@Pinned_0"},
attrs={
"shape": [10, 10],
"dtype": pinned_var.dtype,
"value": 0.0,
"place_type": 2
})
return main_program, pinned_var
def test_SELECTED_ROWS(self):
main_program, pinned_var = self.get_prog()
selected_row_var = main_program.global_block().create_var( \
name="selected_row_0", dtype="float32", persistable=False, \
type=fluid.core.VarDesc.VarType.SELECTED_ROWS, stop_gradient=True)
main_program.global_block().append_op(
type="fill_constant",
outputs={"Out": selected_row_var},
attrs={
"shape": selected_row_var.shape,
"dtype": selected_row_var.dtype,
"value": 1.0,
"place_type": 1
})
main_program.global_block().append_op(
type='memcpy',
inputs={'X': selected_row_var},
outputs={'Out': pinned_var},
attrs={'dst_place_type': 3})
with self.assertRaises(NotImplementedError):
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
selected_row_var_, pinned_ = exe.run(
main_program,
feed={},
fetch_list=[selected_row_var.name, pinned_var.name])
def test_OTHER_PLACE_NotImplementedError(self):
main_program, pinned_var = self.get_prog()
lod_tensor_var = main_program.global_block().create_var( \
name="lod_tensor_0", dtype="float32", persistable=False, stop_gradient=True)
main_program.global_block().append_op(
type="fill_constant",
outputs={"Out": lod_tensor_var},
attrs={
"shape": lod_tensor_var.shape,
"dtype": lod_tensor_var.dtype,
"value": 1.0,
"place_type": 0
})
main_program.global_block().append_op(
type='memcpy',
inputs={'X': pinned_var},
outputs={'Out': lod_tensor_var},
attrs={'dst_place_type': 0, })
with self.assertRaises(NotImplementedError):
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
lod_tensor_var_, pinned_ = exe.run(
main_program,
feed={},
fetch_list=[lod_tensor_var.name, pinned_var.name])
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Loading…
Cancel
Save