Recompute Offload (#30233)
parent
2e80857760
commit
75936d838f
@ -0,0 +1,146 @@
|
|||||||
|
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/memcpy_op.h"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
class OpDesc;
|
||||||
|
class Variable;
|
||||||
|
} // namespace framework
|
||||||
|
namespace imperative {
|
||||||
|
class OpBase;
|
||||||
|
} // namespace imperative
|
||||||
|
namespace platform {
|
||||||
|
struct CPUPlace;
|
||||||
|
struct CUDAPlace;
|
||||||
|
struct float16;
|
||||||
|
} // namespace platform
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
class MemcpyOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
MemcpyOp(const std::string &type, const framework::VariableNameMap &inputs,
|
||||||
|
const framework::VariableNameMap &outputs,
|
||||||
|
const framework::AttributeMap &attrs)
|
||||||
|
: OperatorWithKernel(type, inputs, outputs, attrs) {}
|
||||||
|
|
||||||
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||||
|
auto type = ctx->GetInputsVarType("X")[0];
|
||||||
|
if (type == framework::proto::VarType::SELECTED_ROWS ||
|
||||||
|
type == framework::proto::VarType::LOD_TENSOR) {
|
||||||
|
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
|
||||||
|
if (type == framework::proto::VarType::LOD_TENSOR) {
|
||||||
|
ctx->ShareLoD("X", /*->*/ "Out");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
framework::OpKernelType GetKernelTypeForVar(
|
||||||
|
const std::string &var_name, const framework::Tensor &tensor,
|
||||||
|
const framework::OpKernelType &expected_kernel_type) const override {
|
||||||
|
return framework::OpKernelType(expected_kernel_type.data_type_,
|
||||||
|
expected_kernel_type.place_,
|
||||||
|
tensor.layout());
|
||||||
|
}
|
||||||
|
|
||||||
|
framework::OpKernelType GetExpectedKernelType(
|
||||||
|
const framework::ExecutionContext &ctx) const override {
|
||||||
|
return framework::OpKernelType(
|
||||||
|
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
|
||||||
|
ctx.device_context());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class MemcpyInferVarType : public framework::VarTypeInference {
|
||||||
|
public:
|
||||||
|
void operator()(framework::InferVarTypeContext *ctx) const override {
|
||||||
|
ctx->SyncTypeAndDataType("X", "Out");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class MemcpyKernel {
|
||||||
|
public:
|
||||||
|
void operator()(const framework::ExecutionContext &ctx) const {
|
||||||
|
auto *x = ctx.InputVar("X");
|
||||||
|
if (x == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
ctx.HasOutput("Out"), true,
|
||||||
|
platform::errors::NotFound("Output(Out) of memcpy_op is not found."));
|
||||||
|
auto *out = ctx.OutputVar("Out");
|
||||||
|
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
|
||||||
|
auto &dev_ctx = *pool.Get(ctx.GetPlace());
|
||||||
|
auto dst_place_type = ctx.Attr<int>("dst_place_type");
|
||||||
|
framework::VisitVarType(*x, MemcpyFunctor(out, dev_ctx, dst_place_type));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class MemcpyOpProtoMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
void Make() override {
|
||||||
|
AddInput("X", "(LoDTensor) The input variable ");
|
||||||
|
AddOutput("Out",
|
||||||
|
"(LoDTensor) The type of output "
|
||||||
|
"is the same as input X.");
|
||||||
|
AddAttr<int>("dst_place_type",
|
||||||
|
"Determine the dst place of tensor copy. "
|
||||||
|
"By Now it ONLY support CUDAPlace and CUDAPinnedPlace. Other "
|
||||||
|
"place type is Unimplemented and will cause ERROR."
|
||||||
|
"0: dst is on CPUPlace. "
|
||||||
|
"1: dst is on CUDAPlace. "
|
||||||
|
"2: dst is on CUDAPinnedPlace. "
|
||||||
|
"3: dst is on XPUPlace. ");
|
||||||
|
AddComment(R"DOC(
|
||||||
|
Memcpy Operator.
|
||||||
|
By now, it ONLY supports the memcopy between CUDAPinnedPlace and CUDAPlace,
|
||||||
|
and used as an internal op by Recompute-Offload.
|
||||||
|
You would have to update it if you want other more capacities.
|
||||||
|
|
||||||
|
Out = X, when type in [LoDTensor]
|
||||||
|
raise error if the type is not listed above.
|
||||||
|
)DOC");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
namespace plat = paddle::platform;
|
||||||
|
REGISTER_OPERATOR(
|
||||||
|
memcpy, ops::MemcpyOp, ops::MemcpyOpProtoMaker, ops::MemcpyInferVarType,
|
||||||
|
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
|
||||||
|
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
|
||||||
|
|
||||||
|
REGISTER_OP_CPU_KERNEL_FUNCTOR(memcpy, float, ops::MemcpyKernel, double,
|
||||||
|
ops::MemcpyKernel, int, ops::MemcpyKernel,
|
||||||
|
int64_t, ops::MemcpyKernel, bool,
|
||||||
|
ops::MemcpyKernel, plat::float16,
|
||||||
|
ops::MemcpyKernel);
|
||||||
|
|
||||||
|
#ifdef PADDLE_WITH_CUDA
|
||||||
|
REGISTER_OP_CUDA_KERNEL_FUNCTOR(memcpy, float, ops::MemcpyKernel, double,
|
||||||
|
ops::MemcpyKernel, int, ops::MemcpyKernel,
|
||||||
|
int64_t, ops::MemcpyKernel, bool,
|
||||||
|
ops::MemcpyKernel, plat::float16,
|
||||||
|
ops::MemcpyKernel);
|
||||||
|
#endif
|
@ -0,0 +1,75 @@
|
|||||||
|
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/data_type.h"
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
#include "paddle/fluid/framework/var_type.h"
|
||||||
|
#include "paddle/fluid/platform/device_context.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace framework {
|
||||||
|
class LoDTensor;
|
||||||
|
class Variable;
|
||||||
|
} // namespace framework
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
class MemcpyFunctor {
|
||||||
|
public:
|
||||||
|
MemcpyFunctor(framework::Variable *out,
|
||||||
|
const platform::DeviceContext &dev_ctx,
|
||||||
|
const int dst_place_type)
|
||||||
|
: out_(out), dev_ctx_(dev_ctx), dst_place_type_(dst_place_type) {}
|
||||||
|
|
||||||
|
void operator()(const framework::LoDTensor &lod_tensor) const {
|
||||||
|
auto &out_tensor = *out_->GetMutable<framework::LoDTensor>();
|
||||||
|
|
||||||
|
if (dst_place_type_ == 3) {
|
||||||
|
framework::TensorCopy(lod_tensor, platform::CUDAPinnedPlace(), dev_ctx_,
|
||||||
|
&out_tensor);
|
||||||
|
} else if (dst_place_type_ == 2) {
|
||||||
|
framework::TensorCopy(lod_tensor, dev_ctx_.GetPlace(), dev_ctx_,
|
||||||
|
&out_tensor);
|
||||||
|
} else {
|
||||||
|
PADDLE_THROW(platform::errors::Unimplemented(
|
||||||
|
"memcpy dst_place_type: %d is not supported yet.", dst_place_type_));
|
||||||
|
}
|
||||||
|
out_tensor.set_lod(lod_tensor.lod());
|
||||||
|
}
|
||||||
|
|
||||||
|
void operator()(const framework::SelectedRows &rows) const {
|
||||||
|
// (JZ-LIANG) to support SelectedRows
|
||||||
|
PADDLE_THROW(platform::errors::Unimplemented(
|
||||||
|
"Memcpy for SelectedRows is NOT support yet."));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void operator()(const T &v) const {
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
true, false,
|
||||||
|
platform::errors::PermissionDenied(
|
||||||
|
"Not support type for Memcpy op with type %s", typeid(T).name()));
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
framework::Variable *out_;
|
||||||
|
const platform::DeviceContext &dev_ctx_;
|
||||||
|
const int dst_place_type_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,176 @@
|
|||||||
|
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import op_test
|
||||||
|
import numpy as np
|
||||||
|
import unittest
|
||||||
|
import paddle
|
||||||
|
import paddle.fluid.core as core
|
||||||
|
from paddle.fluid.op import Operator
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
from paddle.fluid import compiler, Program, program_guard
|
||||||
|
from paddle.fluid.backward import append_backward
|
||||||
|
|
||||||
|
|
||||||
|
class TestMemcpy_FillConstant(unittest.TestCase):
|
||||||
|
def get_prog(self):
|
||||||
|
paddle.enable_static()
|
||||||
|
main_program = Program()
|
||||||
|
with program_guard(main_program):
|
||||||
|
pinned_var_name = "tensor@Pinned"
|
||||||
|
gpu_var_name = "tensor@GPU"
|
||||||
|
pinned_var = main_program.global_block().create_var(
|
||||||
|
name=pinned_var_name,
|
||||||
|
shape=[10, 10],
|
||||||
|
dtype='float32',
|
||||||
|
persistable=False,
|
||||||
|
stop_gradient=True)
|
||||||
|
gpu_var = main_program.global_block().create_var(
|
||||||
|
name=gpu_var_name,
|
||||||
|
shape=[10, 10],
|
||||||
|
dtype='float32',
|
||||||
|
persistable=False,
|
||||||
|
stop_gradient=True)
|
||||||
|
main_program.global_block().append_op(
|
||||||
|
type="fill_constant",
|
||||||
|
outputs={"Out": gpu_var_name},
|
||||||
|
attrs={
|
||||||
|
"shape": [10, 10],
|
||||||
|
"dtype": gpu_var.dtype,
|
||||||
|
"value": 1.0,
|
||||||
|
"place_type": 1
|
||||||
|
})
|
||||||
|
main_program.global_block().append_op(
|
||||||
|
type="fill_constant",
|
||||||
|
outputs={"Out": pinned_var_name},
|
||||||
|
attrs={
|
||||||
|
"shape": [10, 10],
|
||||||
|
"dtype": gpu_var.dtype,
|
||||||
|
"value": 0.0,
|
||||||
|
"place_type": 2
|
||||||
|
})
|
||||||
|
return main_program, gpu_var, pinned_var
|
||||||
|
|
||||||
|
def test_gpu_cpoy_to_pinned(self):
|
||||||
|
main_program, gpu_var, pinned_var = self.get_prog()
|
||||||
|
main_program.global_block().append_op(
|
||||||
|
type='memcpy',
|
||||||
|
inputs={'X': gpu_var},
|
||||||
|
outputs={'Out': pinned_var},
|
||||||
|
attrs={'dst_place_type': 3})
|
||||||
|
place = fluid.CUDAPlace(0)
|
||||||
|
exe = fluid.Executor(place)
|
||||||
|
gpu_, pinned_ = exe.run(main_program,
|
||||||
|
feed={},
|
||||||
|
fetch_list=[gpu_var.name, pinned_var.name])
|
||||||
|
self.assertTrue(np.allclose(gpu_, pinned_))
|
||||||
|
self.assertTrue(np.allclose(pinned_, np.ones((10, 10))))
|
||||||
|
|
||||||
|
def test_pinned_cpoy_gpu(self):
|
||||||
|
main_program, gpu_var, pinned_var = self.get_prog()
|
||||||
|
main_program.global_block().append_op(
|
||||||
|
type='memcpy',
|
||||||
|
inputs={'X': pinned_var},
|
||||||
|
outputs={'Out': gpu_var},
|
||||||
|
attrs={'dst_place_type': 2})
|
||||||
|
place = fluid.CUDAPlace(0)
|
||||||
|
exe = fluid.Executor(place)
|
||||||
|
gpu_, pinned_ = exe.run(main_program,
|
||||||
|
feed={},
|
||||||
|
fetch_list=[gpu_var.name, pinned_var.name])
|
||||||
|
self.assertTrue(np.allclose(gpu_, pinned_))
|
||||||
|
self.assertTrue(np.allclose(gpu_, np.zeros((10, 10))))
|
||||||
|
|
||||||
|
|
||||||
|
class TestMemcpyOPError(unittest.TestCase):
|
||||||
|
def get_prog(self):
|
||||||
|
paddle.enable_static()
|
||||||
|
main_program = Program()
|
||||||
|
with program_guard(main_program):
|
||||||
|
pinned_var = main_program.global_block().create_var(
|
||||||
|
name="tensor@Pinned_0",
|
||||||
|
shape=[10, 10],
|
||||||
|
dtype='float32',
|
||||||
|
persistable=False,
|
||||||
|
stop_gradient=True)
|
||||||
|
main_program.global_block().append_op(
|
||||||
|
type="fill_constant",
|
||||||
|
outputs={"Out": "tensor@Pinned_0"},
|
||||||
|
attrs={
|
||||||
|
"shape": [10, 10],
|
||||||
|
"dtype": pinned_var.dtype,
|
||||||
|
"value": 0.0,
|
||||||
|
"place_type": 2
|
||||||
|
})
|
||||||
|
return main_program, pinned_var
|
||||||
|
|
||||||
|
def test_SELECTED_ROWS(self):
|
||||||
|
main_program, pinned_var = self.get_prog()
|
||||||
|
selected_row_var = main_program.global_block().create_var( \
|
||||||
|
name="selected_row_0", dtype="float32", persistable=False, \
|
||||||
|
type=fluid.core.VarDesc.VarType.SELECTED_ROWS, stop_gradient=True)
|
||||||
|
main_program.global_block().append_op(
|
||||||
|
type="fill_constant",
|
||||||
|
outputs={"Out": selected_row_var},
|
||||||
|
attrs={
|
||||||
|
"shape": selected_row_var.shape,
|
||||||
|
"dtype": selected_row_var.dtype,
|
||||||
|
"value": 1.0,
|
||||||
|
"place_type": 1
|
||||||
|
})
|
||||||
|
main_program.global_block().append_op(
|
||||||
|
type='memcpy',
|
||||||
|
inputs={'X': selected_row_var},
|
||||||
|
outputs={'Out': pinned_var},
|
||||||
|
attrs={'dst_place_type': 3})
|
||||||
|
with self.assertRaises(NotImplementedError):
|
||||||
|
place = fluid.CUDAPlace(0)
|
||||||
|
exe = fluid.Executor(place)
|
||||||
|
selected_row_var_, pinned_ = exe.run(
|
||||||
|
main_program,
|
||||||
|
feed={},
|
||||||
|
fetch_list=[selected_row_var.name, pinned_var.name])
|
||||||
|
|
||||||
|
def test_OTHER_PLACE_NotImplementedError(self):
|
||||||
|
main_program, pinned_var = self.get_prog()
|
||||||
|
lod_tensor_var = main_program.global_block().create_var( \
|
||||||
|
name="lod_tensor_0", dtype="float32", persistable=False, stop_gradient=True)
|
||||||
|
main_program.global_block().append_op(
|
||||||
|
type="fill_constant",
|
||||||
|
outputs={"Out": lod_tensor_var},
|
||||||
|
attrs={
|
||||||
|
"shape": lod_tensor_var.shape,
|
||||||
|
"dtype": lod_tensor_var.dtype,
|
||||||
|
"value": 1.0,
|
||||||
|
"place_type": 0
|
||||||
|
})
|
||||||
|
main_program.global_block().append_op(
|
||||||
|
type='memcpy',
|
||||||
|
inputs={'X': pinned_var},
|
||||||
|
outputs={'Out': lod_tensor_var},
|
||||||
|
attrs={'dst_place_type': 0, })
|
||||||
|
with self.assertRaises(NotImplementedError):
|
||||||
|
place = fluid.CUDAPlace(0)
|
||||||
|
exe = fluid.Executor(place)
|
||||||
|
lod_tensor_var_, pinned_ = exe.run(
|
||||||
|
main_program,
|
||||||
|
feed={},
|
||||||
|
fetch_list=[lod_tensor_var.name, pinned_var.name])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
paddle.enable_static()
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue