parent
f2f839af27
commit
48dea84bf0
@ -0,0 +1,39 @@
|
||||
#pragma once
|
||||
#include <nccl.h>
|
||||
|
||||
#include "paddle/platform/device_context.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace platform {
|
||||
|
||||
class NCCLManager {
|
||||
public:
|
||||
static NCCLManager* Get() {
|
||||
static NCCLManager m;
|
||||
return &m;
|
||||
}
|
||||
|
||||
NCCLManager() { _comms.resize(_gpu_worlds.size()); }
|
||||
~NCCLManager() {}
|
||||
|
||||
private:
|
||||
// clang-format off
|
||||
std::vector<ncclComm_t> _comms;
|
||||
std::vector<int> _gpu_worlds;
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
class NCCLContext : public DeviceContext {
|
||||
public:
|
||||
explicit NCCLContext(GPUPlace place);
|
||||
virtual ~NCCLContext();
|
||||
|
||||
private:
|
||||
// clang-format off
|
||||
std::vector<int> _gpu_ids;
|
||||
std::vector<cudaStream_t> _streams;
|
||||
int root_gpu;
|
||||
// clang-format on
|
||||
};
|
||||
}
|
||||
}
|
@ -0,0 +1,48 @@
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/operators/nccl/nccl_gpu_common.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
// AllreduceOp
|
||||
class NCCLAllreduceOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
// allreduce do nothing in infershape
|
||||
void InferShape(const framework::InferShapeContext &ctx) const override {}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class NCCLAllreduceOp : public framework::OpKernel {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext &context) const override {
|
||||
auto *ctx = static_cast<NCCLContext *>(context.device_context());
|
||||
// auto *comm = ;
|
||||
// auto *src = ;
|
||||
// ncclAllReduce(src, dest, )
|
||||
}
|
||||
};
|
||||
|
||||
// BcastSendOp
|
||||
template <typename T>
|
||||
class NCCLBroadcastSendOp final : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(const framework::InferShapeContext &ctx) const override {}
|
||||
};
|
||||
|
||||
// BcastRecvOp
|
||||
template <typename T>
|
||||
class NCCLBroadcastRecvOp final : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(const framework::InferShapeContext &ctx) const override {}
|
||||
};
|
||||
}
|
||||
}
|
@ -0,0 +1,7 @@
|
||||
#pragma once
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/operators/nccl/nccl_gpu_common.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {}
|
||||
}
|
Loading…
Reference in new issue