Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into seq_expand_op
commit
9f32b61c27
@ -0,0 +1,3 @@
|
||||
if(WITH_GPU)
|
||||
nv_library(nccl_common SRCS nccl_gpu_common.cc DEPS device_context operator )
|
||||
endif()
|
@ -0,0 +1,17 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/nccl/nccl_gpu_common.h"
|
||||
#include "paddle/platform/gpu_info.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace platform {} // namespace platform
|
||||
} // namespace paddle
|
@ -0,0 +1,63 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <condition_variable>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/platform/device_context.h"
|
||||
#include "paddle/platform/dynload/nccl.h"
|
||||
#include "paddle/platform/enforce.h"
|
||||
#include "paddle/platform/macros.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace platform {
|
||||
|
||||
constexpr int kInvalidGPUId = -1;
|
||||
|
||||
struct Communicator {
|
||||
std::vector<ncclComm_t> comms_;
|
||||
std::unordered_map<int, int> comm_id_map_;
|
||||
|
||||
Communicator() {}
|
||||
|
||||
int GetCommId(int device_id) const { return comm_id_map_.at(device_id); }
|
||||
|
||||
void InitAll(const std::vector<int>& gpus) {
|
||||
comms_.resize(gpus.size());
|
||||
for (size_t i = 0; i < gpus.size(); ++i) {
|
||||
comm_id_map_[gpus[i]] = i;
|
||||
}
|
||||
PADDLE_ENFORCE(
|
||||
dynload::ncclCommInitAll(comms_.data(), gpus.size(), gpus.data()));
|
||||
}
|
||||
|
||||
~Communicator() {
|
||||
for (size_t i = 0; i < comms_.size(); ++i) {
|
||||
// FIXME(dzh) : PADDLE_ENFORCE return void
|
||||
dynload::ncclCommDestroy(comms_[i]);
|
||||
}
|
||||
}
|
||||
|
||||
DISABLE_COPY_AND_ASSIGN(Communicator);
|
||||
};
|
||||
|
||||
} // namespace platform
|
||||
} // namespace paddle
|
@ -0,0 +1,206 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/operators/nccl/nccl_gpu_common.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
// NCCLinitOp
|
||||
class NCCLInitOp : public framework::OperatorBase {
|
||||
public:
|
||||
NCCLInitOp(const std::string &type, const framework::VariableNameMap &inputs,
|
||||
const framework::VariableNameMap &outputs,
|
||||
const framework::AttributeMap &attrs)
|
||||
: OperatorBase(type, inputs, outputs, attrs) {}
|
||||
|
||||
void Run(const framework::Scope &scope,
|
||||
const platform::DeviceContext &dev_ctx) const override {
|
||||
const auto &name = Output("Communicator");
|
||||
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(name),
|
||||
"Can not find variable '%s' in the scope.", name);
|
||||
std::vector<int> gpus = Attr<std::vector<int>>("gpus");
|
||||
PADDLE_ENFORCE(!gpus.empty(), "Attr(gpus) should not be empty.");
|
||||
|
||||
if (scope.FindVar(name) == nullptr) {
|
||||
PADDLE_THROW("Output(Communicator) is needed for ncclInit operator.");
|
||||
}
|
||||
|
||||
platform::Communicator *comm =
|
||||
scope.FindVar(name)->GetMutable<platform::Communicator>();
|
||||
comm->InitAll(gpus);
|
||||
}
|
||||
};
|
||||
|
||||
class NCCLInitOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
NCCLInitOpMaker(framework::OpProto *proto,
|
||||
framework::OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddOutput("Communicator",
|
||||
"Create Communicator for communicating between gpus");
|
||||
AddAttr<std::vector<int>>("gpus", "gpu id lists");
|
||||
AddAttr<int>("data_type", "output data type")
|
||||
.SetDefault(framework::DataType::FP32);
|
||||
AddComment(R"DOC(
|
||||
create communicator.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
// AllReduceOp
|
||||
class NCCLAllReduceOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||
" Input(X) of AllReduce op input should not be NULL");
|
||||
PADDLE_ENFORCE(
|
||||
ctx->HasInput("Communicator"),
|
||||
" Input(Communicator) of AllReduce op input should not be NULL");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
" Input(X) of AllReduce op input should not be NULL");
|
||||
|
||||
auto x_dims = ctx->GetInputsDim("X");
|
||||
|
||||
std::string reduction = ctx->Attrs().Get<std::string>("reduction");
|
||||
PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" ||
|
||||
reduction == "ncclMin" || reduction == "ncclMax"),
|
||||
"invalid reduction.");
|
||||
|
||||
ctx->SetOutputsDim("Out", x_dims);
|
||||
ctx->ShareLoD("X", /*->*/ "Out");
|
||||
}
|
||||
};
|
||||
|
||||
// ReduceOp
|
||||
class NCCLReduceOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||
" Input(X) of Reduce op input should not be NULL");
|
||||
PADDLE_ENFORCE(
|
||||
ctx->HasInput("Communicator"),
|
||||
" Input(Communicator) of Reduce op input should not be NULL");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
" Input(X) of Reduce op input should not be NULL");
|
||||
|
||||
std::string reduction = ctx->Attrs().Get<std::string>("reduction");
|
||||
PADDLE_ENFORCE((reduction == "ncclSum" || reduction == "ncclProd" ||
|
||||
reduction == "ncclMin" || reduction == "ncclMax"),
|
||||
"invalid reduction.");
|
||||
|
||||
auto x_dims = ctx->GetInputsDim("X");
|
||||
ctx->SetOutputsDim("Out", x_dims);
|
||||
ctx->ShareLoD("X", /*->*/ "Out");
|
||||
}
|
||||
};
|
||||
|
||||
// BcastOp
|
||||
class NCCLBcastOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||
" Input(X) of Bcast op input should not be NULL");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Communicator"),
|
||||
" Input(Communicator) of Bcast op input should not be NULL");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
" Output(Out) of Bcast op output should not be NULL");
|
||||
|
||||
int root = ctx->Attrs().Get<int>("root");
|
||||
PADDLE_ENFORCE(root != platform::kInvalidGPUId, "Bcast root must be set.");
|
||||
|
||||
auto x_dims = ctx->GetInputsDim("X");
|
||||
ctx->SetOutputsDim("Out", x_dims);
|
||||
ctx->ShareLoD("X", /*->*/ "Out");
|
||||
}
|
||||
};
|
||||
|
||||
// AllreduceOp
|
||||
class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
NCCLAllReduceOpMaker(framework::OpProto *proto,
|
||||
framework::OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("X", "The input of AllReduce op");
|
||||
AddInput("Communicator", "Communicator for communicating between gpus");
|
||||
AddOutput("Out", "The output of AllReduce op");
|
||||
AddAttr<std::string>("reduction",
|
||||
"{'ncclMin', 'ncclMax', 'ncclProd', 'ncclSum'}.")
|
||||
.SetDefault("ncclSum");
|
||||
AddComment(R"DOC(
|
||||
AllReduce the input tensors.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
// ReduceOp
|
||||
class NCCLReduceOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
NCCLReduceOpMaker(framework::OpProto *proto,
|
||||
framework::OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("X", "The input of Reduce op");
|
||||
AddInput("Communicator", "Communicator for communicating between gpus");
|
||||
AddOutput("Out", "The output of Reduce op");
|
||||
AddAttr<std::string>("reduction",
|
||||
"{'ncclMin', 'ncclMax', 'ncclProd', 'ncclSum'}.")
|
||||
.SetDefault("ncclSum");
|
||||
AddAttr<int>("root",
|
||||
"root gpu of the parameter. if not "
|
||||
"set(platform::kInvalidGPUId). hashed by name.")
|
||||
.SetDefault(platform::kInvalidGPUId);
|
||||
AddComment(R"DOC(
|
||||
Reduce the tensors)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
// BcastOp
|
||||
class NCCLBcastOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
NCCLBcastOpMaker(framework::OpProto *proto,
|
||||
framework::OpAttrChecker *op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("X", "The input of BcastSend op");
|
||||
AddInput("Communicator", "Communicator for communicating between gpus");
|
||||
AddOutput("Out", "The output of Bcast");
|
||||
AddAttr<int>("root",
|
||||
"root gpu of the parameter. if not "
|
||||
"set(platform::kInvalidGPUId). hashed by name.")
|
||||
.SetDefault(platform::kInvalidGPUId);
|
||||
AddComment(R"DOC(
|
||||
Bcast the tensors.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OPERATOR(ncclInit, ops::NCCLInitOp,
|
||||
paddle::framework::EmptyGradOpMaker, ops::NCCLInitOpMaker);
|
||||
|
||||
REGISTER_OP_WITHOUT_GRADIENT(ncclAllReduce, ops::NCCLAllReduceOp,
|
||||
ops::NCCLAllReduceOpMaker);
|
||||
REGISTER_OP_WITHOUT_GRADIENT(ncclBcast, ops::NCCLBcastOp,
|
||||
ops::NCCLBcastOpMaker);
|
||||
REGISTER_OP_WITHOUT_GRADIENT(ncclReduce, ops::NCCLReduceOp,
|
||||
ops::NCCLReduceOpMaker);
|
@ -0,0 +1,211 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenseshashernless required by applicable law or agreed
|
||||
to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "paddle/framework/lod_tensor.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/operators/nccl/nccl_gpu_common.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::Tensor;
|
||||
using platform::Communicator;
|
||||
using framework::LoDTensor;
|
||||
|
||||
template <typename Type>
|
||||
class NCCLTypeWrapper;
|
||||
|
||||
template <>
|
||||
class NCCLTypeWrapper<float> {
|
||||
public:
|
||||
static const ncclDataType_t type = ncclFloat;
|
||||
};
|
||||
|
||||
template <>
|
||||
class NCCLTypeWrapper<double> {
|
||||
public:
|
||||
static const ncclDataType_t type = ncclDouble;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class NCCLAllReduceKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
||||
"This kernel only runs on GPU device.");
|
||||
|
||||
auto ins = ctx.MultiInput<LoDTensor>("X");
|
||||
auto outs = ctx.MultiOutput<LoDTensor>("Out");
|
||||
|
||||
std::string reduction = ctx.Attr<std::string>("reduction");
|
||||
ncclRedOp_t reduction_op_ = ncclSum;
|
||||
|
||||
if (reduction == "ncclMin") {
|
||||
reduction_op_ = ncclMin;
|
||||
} else if (reduction == "ncclMax") {
|
||||
reduction_op_ = ncclMax;
|
||||
} else if (reduction == "ncclSum") {
|
||||
reduction_op_ = ncclSum;
|
||||
} else if (reduction == "ncclProd") {
|
||||
reduction_op_ = ncclProd;
|
||||
} else {
|
||||
PADDLE_THROW("Invalid reduction. default ncclSum.");
|
||||
}
|
||||
|
||||
auto* comm = ctx.Input<Communicator>("Communicator");
|
||||
|
||||
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
|
||||
ctx.device_context())
|
||||
.stream();
|
||||
|
||||
// device id
|
||||
int gpu_id = boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId();
|
||||
int idx = comm->GetCommId(gpu_id);
|
||||
|
||||
for (size_t i = 0; i < ins.size(); ++i) {
|
||||
VLOG(1) << "gpu : "
|
||||
<< " invoke allreduce. send " << ins[i]->numel() << " recv "
|
||||
<< outs[i]->numel();
|
||||
|
||||
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
|
||||
ins[i]->data<T>(), outs[i]->mutable_data<T>(ctx.GetPlace()),
|
||||
outs[i]->numel(), NCCLTypeWrapper<T>::type, reduction_op_,
|
||||
comm->comms_[idx], stream));
|
||||
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
|
||||
|
||||
VLOG(1) << "gpu : "
|
||||
<< " finished allreduce. send " << ins[i]->numel() << " recv "
|
||||
<< outs[i]->numel();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class NCCLReduceKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
||||
"This kernel only runs on GPU device.");
|
||||
|
||||
auto ins = ctx.MultiInput<LoDTensor>("X"); // x0, x1, x2
|
||||
auto outs = ctx.MultiOutput<LoDTensor>("Out");
|
||||
|
||||
std::string reduction = ctx.Attr<std::string>("reduction");
|
||||
ncclRedOp_t reduction_op_ = ncclSum;
|
||||
|
||||
if (reduction == "ncclMin") {
|
||||
reduction_op_ = ncclMin;
|
||||
} else if (reduction == "ncclMax") {
|
||||
reduction_op_ = ncclMax;
|
||||
} else if (reduction == "ncclSum") {
|
||||
reduction_op_ = ncclSum;
|
||||
} else if (reduction == "ncclProd") {
|
||||
reduction_op_ = ncclProd;
|
||||
} else {
|
||||
PADDLE_THROW("Invalid reduction. default ncclSum.");
|
||||
}
|
||||
|
||||
int root = ctx.Attr<int>("root");
|
||||
auto* comm = ctx.Input<Communicator>("Communicator");
|
||||
|
||||
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
|
||||
ctx.device_context())
|
||||
.stream();
|
||||
// device id
|
||||
int gpu_id = boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId();
|
||||
int idx = comm->GetCommId(gpu_id);
|
||||
|
||||
auto ins_names = ctx.Inputs("X");
|
||||
std::hash<std::string> hasher;
|
||||
for (size_t i = 0; i < ins.size(); ++i) {
|
||||
if (root == platform::kInvalidGPUId) {
|
||||
root = hasher(ins_names[i]) % comm->comms_.size();
|
||||
}
|
||||
T* recvbuffer = nullptr;
|
||||
if (root == gpu_id) {
|
||||
recvbuffer = outs[i]->mutable_data<T>(ctx.GetPlace());
|
||||
}
|
||||
|
||||
VLOG(1) << "gpu : " << gpu_id << " invoke reduce. send "
|
||||
<< ins[i]->numel() << " recv " << outs[i]->numel();
|
||||
|
||||
PADDLE_ENFORCE(platform::dynload::ncclReduce(
|
||||
ins[i]->data<T>(), recvbuffer, ins[i]->numel(),
|
||||
NCCLTypeWrapper<T>::type, reduction_op_, root, comm->comms_[idx],
|
||||
stream));
|
||||
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
|
||||
|
||||
VLOG(1) << "gpu : " << gpu_id << " finished reduce. send "
|
||||
<< ins[i]->numel() << " recv " << outs[i]->numel();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class NCCLBcastKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
|
||||
"This kernel only runs on GPU device.");
|
||||
|
||||
int root = ctx.Attr<int>("root");
|
||||
|
||||
auto* comm = ctx.Input<Communicator>("Communicator");
|
||||
|
||||
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>(
|
||||
ctx.device_context())
|
||||
.stream();
|
||||
// device id
|
||||
int gpu_id = boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId();
|
||||
int idx = comm->GetCommId(gpu_id);
|
||||
|
||||
if (idx == root) {
|
||||
auto ins = ctx.MultiInput<LoDTensor>("X");
|
||||
for (size_t i = 0; i < ins.size(); ++i) {
|
||||
VLOG(1) << "gpu : " << gpu_id << " invoke Bcast. send "
|
||||
<< ins[i]->numel();
|
||||
|
||||
VLOG(1) << " before ncclBcast";
|
||||
PADDLE_ENFORCE(platform::dynload::ncclBcast(
|
||||
(void*)ins[i]->data<T>(), ins[i]->numel(), NCCLTypeWrapper<T>::type,
|
||||
root, comm->comms_[idx], stream));
|
||||
VLOG(1) << " after ncclBcast";
|
||||
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
|
||||
|
||||
VLOG(1) << "gpu : " << gpu_id << " finished Bcast.";
|
||||
}
|
||||
} else {
|
||||
auto outs = ctx.MultiOutput<LoDTensor>("Out");
|
||||
for (size_t i = 0; i < outs.size(); ++i) {
|
||||
VLOG(1) << "gpu : " << gpu_id << " invoke Bcast. recv buffer "
|
||||
<< framework::product(outs[i]->dims());
|
||||
|
||||
PADDLE_ENFORCE(platform::dynload::ncclBcast(
|
||||
outs[i]->mutable_data<T>(ctx.GetPlace()), outs[i]->numel(),
|
||||
NCCLTypeWrapper<T>::type, root, comm->comms_[idx], stream));
|
||||
PADDLE_ENFORCE(cudaStreamSynchronize(stream));
|
||||
|
||||
VLOG(1) << "gpu : " << gpu_id << " finished Bcast. recv "
|
||||
<< outs[i]->numel();
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_GPU_KERNEL(ncclAllReduce, ops::NCCLAllReduceKernel<float>);
|
||||
REGISTER_OP_GPU_KERNEL(ncclBcast, ops::NCCLBcastKernel<float>);
|
||||
REGISTER_OP_GPU_KERNEL(ncclReduce, ops::NCCLReduceKernel<float>);
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,75 @@
|
||||
import unittest
|
||||
|
||||
import paddle.v2.framework.layers as layers
|
||||
import paddle.v2.framework.nets as nets
|
||||
from paddle.v2.framework.framework import Program
|
||||
|
||||
|
||||
def conv_block(input,
|
||||
num_filter,
|
||||
groups,
|
||||
dropouts,
|
||||
program=None,
|
||||
init_program=None):
|
||||
return nets.img_conv_group(
|
||||
input=input,
|
||||
pool_size=2,
|
||||
pool_stride=2,
|
||||
conv_num_filter=[num_filter] * groups,
|
||||
conv_filter_size=3,
|
||||
conv_act='relu',
|
||||
conv_with_batchnorm=True,
|
||||
conv_batchnorm_drop_rate=dropouts,
|
||||
pool_type='max',
|
||||
program=program,
|
||||
init_program=init_program)
|
||||
|
||||
|
||||
class TestLayer(unittest.TestCase):
|
||||
def test_batch_norm_layer(self):
|
||||
program = Program()
|
||||
init_program = Program()
|
||||
images = layers.data(
|
||||
name='pixel',
|
||||
shape=[3, 48, 48],
|
||||
data_type='float32',
|
||||
program=program)
|
||||
layers.batch_norm(
|
||||
input=images, program=program, init_program=init_program)
|
||||
|
||||
#print str(program)
|
||||
|
||||
def test_dropout_layer(self):
|
||||
program = Program()
|
||||
init_program = Program()
|
||||
images = layers.data(
|
||||
name='pixel',
|
||||
shape=[3, 48, 48],
|
||||
data_type='float32',
|
||||
program=program)
|
||||
layers.dropout(
|
||||
x=images,
|
||||
dropout_prob=0.5,
|
||||
program=program,
|
||||
init_program=init_program)
|
||||
|
||||
#print str(program)
|
||||
|
||||
def test_img_conv_group(self):
|
||||
program = Program()
|
||||
init_program = Program()
|
||||
|
||||
images = layers.data(
|
||||
name='pixel',
|
||||
shape=[3, 48, 48],
|
||||
data_type='float32',
|
||||
program=program,
|
||||
init_program=init_program)
|
||||
conv1 = conv_block(images, 64, 2, [0.3, 0], program, init_program)
|
||||
conv2 = conv_block(conv1, 256, 3, [0.4, 0.4, 0], program, init_program)
|
||||
|
||||
# print str(program)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
@ -0,0 +1,133 @@
|
||||
import paddle.v2 as paddle
|
||||
import paddle.v2.framework.layers as layers
|
||||
import paddle.v2.framework.nets as nets
|
||||
import paddle.v2.framework.core as core
|
||||
import paddle.v2.framework.optimizer as optimizer
|
||||
|
||||
from paddle.v2.framework.framework import Program, g_program
|
||||
from paddle.v2.framework.executor import Executor
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def vgg16_bn_drop(input, program, init_program):
|
||||
def conv_block(input,
|
||||
num_filter,
|
||||
groups,
|
||||
dropouts,
|
||||
program=None,
|
||||
init_program=None):
|
||||
return nets.img_conv_group(
|
||||
input=input,
|
||||
pool_size=2,
|
||||
pool_stride=2,
|
||||
conv_num_filter=[num_filter] * groups,
|
||||
conv_filter_size=3,
|
||||
conv_act='relu',
|
||||
conv_with_batchnorm=True,
|
||||
conv_batchnorm_drop_rate=dropouts,
|
||||
pool_type='max',
|
||||
program=program,
|
||||
init_program=init_program)
|
||||
|
||||
conv1 = conv_block(input, 64, 2, [0.3, 0], program, init_program)
|
||||
conv2 = conv_block(conv1, 128, 2, [0.4, 0], program, init_program)
|
||||
conv3 = conv_block(conv2, 256, 3, [0.4, 0.4, 0], program, init_program)
|
||||
conv4 = conv_block(conv3, 512, 3, [0.4, 0.4, 0], program, init_program)
|
||||
conv5 = conv_block(conv4, 512, 3, [0.4, 0.4, 0], program, init_program)
|
||||
|
||||
drop = layers.dropout(
|
||||
x=conv5, dropout_prob=0.5, program=program, init_program=init_program)
|
||||
fc1 = layers.fc(input=drop,
|
||||
size=512,
|
||||
act=None,
|
||||
program=program,
|
||||
init_program=init_program)
|
||||
reshape1 = layers.reshape(
|
||||
x=fc1,
|
||||
shape=list(fc1.shape + (1, 1)),
|
||||
program=program,
|
||||
init_program=init_program)
|
||||
bn = layers.batch_norm(
|
||||
input=reshape1, act='relu', program=program, init_program=init_program)
|
||||
drop2 = layers.dropout(
|
||||
x=bn, dropout_prob=0.5, program=program, init_program=init_program)
|
||||
fc2 = layers.fc(input=drop2,
|
||||
size=512,
|
||||
act=None,
|
||||
program=program,
|
||||
init_program=init_program)
|
||||
return fc2
|
||||
|
||||
|
||||
init_program = Program()
|
||||
program = Program()
|
||||
|
||||
classdim = 10
|
||||
data_shape = [3, 32, 32]
|
||||
|
||||
images = layers.data(
|
||||
name='pixel', shape=data_shape, data_type='float32', program=program)
|
||||
|
||||
label = layers.data(
|
||||
name='label',
|
||||
shape=[1],
|
||||
data_type='int64',
|
||||
program=program,
|
||||
init_program=init_program)
|
||||
vgg_net = vgg16_bn_drop(images, program, init_program)
|
||||
predict = layers.fc(input=vgg_net,
|
||||
size=classdim,
|
||||
act='softmax',
|
||||
program=program,
|
||||
init_program=init_program)
|
||||
cost = layers.cross_entropy(
|
||||
input=predict, label=label, program=program, init_program=init_program)
|
||||
avg_cost = layers.mean(x=cost, program=program, init_program=init_program)
|
||||
|
||||
sgd_optimizer = optimizer.SGDOptimizer(learning_rate=0.001)
|
||||
opts = sgd_optimizer.minimize(avg_cost)
|
||||
|
||||
BATCH_SIZE = 128
|
||||
PASS_NUM = 1
|
||||
|
||||
train_reader = paddle.batch(
|
||||
paddle.reader.shuffle(
|
||||
paddle.dataset.cifar.train10(), buf_size=128 * 10),
|
||||
batch_size=BATCH_SIZE)
|
||||
|
||||
place = core.CPUPlace()
|
||||
exe = Executor(place)
|
||||
|
||||
exe.run(init_program, feed={}, fetch_list=[])
|
||||
|
||||
for pass_id in range(PASS_NUM):
|
||||
batch_id = 0
|
||||
for data in train_reader():
|
||||
img_data = np.array(map(lambda x: x[0].reshape(data_shape),
|
||||
data)).astype("float32")
|
||||
y_data = np.array(map(lambda x: x[1], data)).astype("int64")
|
||||
batch_size = 1
|
||||
for i in y_data.shape:
|
||||
batch_size = batch_size * i
|
||||
y_data = y_data.reshape([batch_size, 1])
|
||||
|
||||
tensor_img = core.LoDTensor()
|
||||
tensor_y = core.LoDTensor()
|
||||
tensor_img.set(img_data, place)
|
||||
tensor_y.set(y_data, place)
|
||||
|
||||
outs = exe.run(program,
|
||||
feed={"pixel": tensor_img,
|
||||
"label": tensor_y},
|
||||
fetch_list=[avg_cost])
|
||||
|
||||
loss = np.array(outs[0])
|
||||
# print("pass_id:" + str(pass_id) + " batch_id:" + str(batch_id) +
|
||||
# " loss:" + str(loss))
|
||||
batch_id = batch_id + 1
|
||||
|
||||
if batch_id > 1:
|
||||
# this model is slow, so if we can train two mini batch, we think it works properly.
|
||||
exit(0)
|
||||
exit(1)
|
@ -0,0 +1,39 @@
|
||||
import unittest, os
|
||||
import numpy as np
|
||||
import paddle.v2 as paddle
|
||||
from paddle.v2.framework.op import Operator
|
||||
import paddle.v2.framework.core as core
|
||||
from op_test import OpTest, create_op, set_input
|
||||
|
||||
if not core.is_compile_gpu():
|
||||
exit(0)
|
||||
|
||||
gpu_count = core.get_cuda_device_count()
|
||||
|
||||
if gpu_count <= 1:
|
||||
exit(0)
|
||||
|
||||
g_scope = core.Scope()
|
||||
g_ctx = core.DeviceContext.create(core.CPUPlace())
|
||||
|
||||
|
||||
class TestNCCLInit(unittest.TestCase):
|
||||
def test_init(self):
|
||||
self.op_type = "ncclInit"
|
||||
self.gpus = range(gpu_count)
|
||||
|
||||
self.inputs = {}
|
||||
self.attrs = {"gpus": self.gpus}
|
||||
g_scope.var("Communicator").get_communicator()
|
||||
self.outputs = {"Communicator": g_scope.find_var("Communicator")}
|
||||
nccl_init = create_op(
|
||||
g_scope,
|
||||
op_type=self.op_type,
|
||||
inputs=self.inputs,
|
||||
outputs=self.outputs,
|
||||
attrs=self.attrs)
|
||||
nccl_init.run(g_scope, g_ctx)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue