Add collective ops (reduce) (#26340)
parent
bdb805505e
commit
e92f770c42
@ -0,0 +1,39 @@
|
|||||||
|
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/collective/c_reduce_op.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
class CReduceMaxOpMaker : public CReduceOpMaker {
|
||||||
|
protected:
|
||||||
|
std::string GetName() const override { return "Max"; }
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
namespace plat = paddle::platform;
|
||||||
|
|
||||||
|
REGISTER_OP_WITHOUT_GRADIENT(c_reduce_max, ops::CReduceOp,
|
||||||
|
ops::CReduceMaxOpMaker);
|
||||||
|
|
||||||
|
REGISTER_OP_CPU_KERNEL(c_reduce_max,
|
||||||
|
ops::CReduceOpCPUKernel<ops::kRedMax, float>,
|
||||||
|
ops::CReduceOpCPUKernel<ops::kRedMax, double>,
|
||||||
|
ops::CReduceOpCPUKernel<ops::kRedMax, int>,
|
||||||
|
ops::CReduceOpCPUKernel<ops::kRedMax, int64_t>,
|
||||||
|
ops::CReduceOpCPUKernel<ops::kRedMax, plat::float16>);
|
@ -0,0 +1,25 @@
|
|||||||
|
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/collective/c_reduce_op.h"
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
namespace plat = paddle::platform;
|
||||||
|
|
||||||
|
REGISTER_OP_CUDA_KERNEL(c_reduce_max,
|
||||||
|
ops::CReduceOpCUDAKernel<ops::kRedMax, float>,
|
||||||
|
ops::CReduceOpCUDAKernel<ops::kRedMax, double>,
|
||||||
|
ops::CReduceOpCUDAKernel<ops::kRedMax, int>,
|
||||||
|
ops::CReduceOpCUDAKernel<ops::kRedMax, int64_t>,
|
||||||
|
ops::CReduceOpCUDAKernel<ops::kRedMax, plat::float16>)
|
@ -0,0 +1,39 @@
|
|||||||
|
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/collective/c_reduce_op.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
class CReduceMinOpMaker : public CReduceOpMaker {
|
||||||
|
protected:
|
||||||
|
std::string GetName() const override { return "Min"; }
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
namespace plat = paddle::platform;
|
||||||
|
|
||||||
|
REGISTER_OP_WITHOUT_GRADIENT(c_reduce_min, ops::CReduceOp,
|
||||||
|
ops::CReduceMinOpMaker);
|
||||||
|
|
||||||
|
REGISTER_OP_CPU_KERNEL(c_reduce_min,
|
||||||
|
ops::CReduceOpCPUKernel<ops::kRedMin, float>,
|
||||||
|
ops::CReduceOpCPUKernel<ops::kRedMin, double>,
|
||||||
|
ops::CReduceOpCPUKernel<ops::kRedMin, int>,
|
||||||
|
ops::CReduceOpCPUKernel<ops::kRedMin, int64_t>,
|
||||||
|
ops::CReduceOpCPUKernel<ops::kRedMin, plat::float16>);
|
@ -0,0 +1,25 @@
|
|||||||
|
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/collective/c_reduce_op.h"
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
namespace plat = paddle::platform;
|
||||||
|
|
||||||
|
REGISTER_OP_CUDA_KERNEL(c_reduce_min,
|
||||||
|
ops::CReduceOpCUDAKernel<ops::kRedMin, float>,
|
||||||
|
ops::CReduceOpCUDAKernel<ops::kRedMin, double>,
|
||||||
|
ops::CReduceOpCUDAKernel<ops::kRedMin, int>,
|
||||||
|
ops::CReduceOpCUDAKernel<ops::kRedMin, int64_t>,
|
||||||
|
ops::CReduceOpCUDAKernel<ops::kRedMin, plat::float16>)
|
@ -0,0 +1,151 @@
|
|||||||
|
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/data_type.h"
|
||||||
|
#include "paddle/fluid/framework/ddim.h"
|
||||||
|
#include "paddle/fluid/framework/lod_tensor.h"
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
|
||||||
|
#if defined(PADDLE_WITH_NCCL)
|
||||||
|
#include "paddle/fluid/platform/collective_helper.h"
|
||||||
|
#include "paddle/fluid/platform/nccl_helper.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
enum ReduceType { kRedSum, kRedMax, kRedMin, kRedProd };
|
||||||
|
|
||||||
|
class CReduceOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||||
|
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
framework::OpKernelType GetExpectedKernelType(
|
||||||
|
const framework::ExecutionContext& ctx) const override {
|
||||||
|
return framework::OpKernelType(
|
||||||
|
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <ReduceType red_type, typename T>
|
||||||
|
class CReduceOpCPUKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
true, false,
|
||||||
|
platform::errors::Unavailable("Unimplemented CReduceOpCPUKernel now."));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <ReduceType red_type, typename T>
|
||||||
|
class CReduceOpCUDAKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
#if defined(PADDLE_WITH_NCCL)
|
||||||
|
auto in = ctx.Input<framework::Tensor>("X");
|
||||||
|
auto out = ctx.Output<framework::Tensor>("Out");
|
||||||
|
|
||||||
|
auto place = ctx.GetPlace();
|
||||||
|
ncclDataType_t dtype = platform::ToNCCLDataType(in->type());
|
||||||
|
int64_t numel = in->numel();
|
||||||
|
const void* sendbuff = in->data<void>();
|
||||||
|
out->Resize(in->dims());
|
||||||
|
void* recvbuff = out->mutable_data<T>(place);
|
||||||
|
|
||||||
|
int rid = ctx.Attr<int>("ring_id");
|
||||||
|
int root = ctx.Attr<int>("root_id");
|
||||||
|
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
|
||||||
|
|
||||||
|
cudaStream_t stream = nullptr;
|
||||||
|
if (ctx.Attr<bool>("use_calc_stream")) {
|
||||||
|
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
|
||||||
|
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
|
||||||
|
} else {
|
||||||
|
stream = comm->stream();
|
||||||
|
}
|
||||||
|
|
||||||
|
ncclRedOp_t nccl_red_type = ncclSum;
|
||||||
|
switch (red_type) {
|
||||||
|
case kRedSum:
|
||||||
|
nccl_red_type = ncclSum;
|
||||||
|
break;
|
||||||
|
|
||||||
|
case kRedMax:
|
||||||
|
nccl_red_type = ncclMax;
|
||||||
|
break;
|
||||||
|
|
||||||
|
case kRedMin:
|
||||||
|
nccl_red_type = ncclMin;
|
||||||
|
break;
|
||||||
|
|
||||||
|
case kRedProd:
|
||||||
|
nccl_red_type = ncclProd;
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
PADDLE_ENFORCE_EQ(true, false, platform::errors::InvalidArgument(
|
||||||
|
"red_type must be one of kRedSum, "
|
||||||
|
"kRedMax, kRedMin, kRedProd."));
|
||||||
|
}
|
||||||
|
|
||||||
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclReduce(
|
||||||
|
sendbuff, recvbuff, numel, dtype, nccl_red_type, root, comm->comm(),
|
||||||
|
stream));
|
||||||
|
#else
|
||||||
|
PADDLE_ENFORCE_EQ(true, false,
|
||||||
|
platform::errors::Unavailable(
|
||||||
|
"PaddlePaddle should compile with GPU.."));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class CReduceOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
void Make() {
|
||||||
|
AddInput("X", "(Tensor), tensor to be reduced.");
|
||||||
|
AddOutput("Out", "(Tensor) the reduced result.");
|
||||||
|
AddAttr<int>("ring_id", "(int default 0) communication ring id.")
|
||||||
|
.SetDefault(0);
|
||||||
|
AddAttr<int>("root_id", "(int default 0) root id.").SetDefault(0);
|
||||||
|
AddAttr<bool>(
|
||||||
|
"use_calc_stream",
|
||||||
|
"(bool default false) eject CUDA operations to calculation stream.")
|
||||||
|
.SetDefault(false);
|
||||||
|
AddComment(string::Sprintf(R"DOC(
|
||||||
|
CReduce %s Operator
|
||||||
|
|
||||||
|
Call collective Reduce with reduce type %s. If input and output are
|
||||||
|
the same variable, in-place reduce will be used.
|
||||||
|
)DOC",
|
||||||
|
GetName(), GetName()));
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
virtual std::string GetName() const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,39 @@
|
|||||||
|
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/collective/c_reduce_op.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
class CReduceProdOpMaker : public CReduceOpMaker {
|
||||||
|
protected:
|
||||||
|
std::string GetName() const override { return "Prod"; }
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
namespace plat = paddle::platform;
|
||||||
|
|
||||||
|
REGISTER_OP_WITHOUT_GRADIENT(c_reduce_prod, ops::CReduceOp,
|
||||||
|
ops::CReduceProdOpMaker);
|
||||||
|
|
||||||
|
REGISTER_OP_CPU_KERNEL(c_reduce_prod,
|
||||||
|
ops::CReduceOpCPUKernel<ops::kRedProd, float>,
|
||||||
|
ops::CReduceOpCPUKernel<ops::kRedProd, double>,
|
||||||
|
ops::CReduceOpCPUKernel<ops::kRedProd, int>,
|
||||||
|
ops::CReduceOpCPUKernel<ops::kRedProd, int64_t>,
|
||||||
|
ops::CReduceOpCPUKernel<ops::kRedProd, plat::float16>)
|
@ -0,0 +1,25 @@
|
|||||||
|
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/collective/c_reduce_op.h"
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
namespace plat = paddle::platform;
|
||||||
|
|
||||||
|
REGISTER_OP_CUDA_KERNEL(c_reduce_prod,
|
||||||
|
ops::CReduceOpCUDAKernel<ops::kRedProd, float>,
|
||||||
|
ops::CReduceOpCUDAKernel<ops::kRedProd, double>,
|
||||||
|
ops::CReduceOpCUDAKernel<ops::kRedProd, int>,
|
||||||
|
ops::CReduceOpCUDAKernel<ops::kRedProd, int64_t>,
|
||||||
|
ops::CReduceOpCUDAKernel<ops::kRedProd, plat::float16>)
|
@ -0,0 +1,39 @@
|
|||||||
|
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/collective/c_reduce_op.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
class CReduceSumOpMaker : public CReduceOpMaker {
|
||||||
|
protected:
|
||||||
|
std::string GetName() const override { return "Sum"; }
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
namespace plat = paddle::platform;
|
||||||
|
|
||||||
|
REGISTER_OP_WITHOUT_GRADIENT(c_reduce_sum, ops::CReduceOp,
|
||||||
|
ops::CReduceSumOpMaker);
|
||||||
|
|
||||||
|
REGISTER_OP_CPU_KERNEL(c_reduce_sum,
|
||||||
|
ops::CReduceOpCPUKernel<ops::kRedSum, float>,
|
||||||
|
ops::CReduceOpCPUKernel<ops::kRedSum, double>,
|
||||||
|
ops::CReduceOpCPUKernel<ops::kRedSum, int>,
|
||||||
|
ops::CReduceOpCPUKernel<ops::kRedSum, int64_t>,
|
||||||
|
ops::CReduceOpCPUKernel<ops::kRedSum, plat::float16>)
|
@ -0,0 +1,25 @@
|
|||||||
|
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/collective/c_reduce_op.h"
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
namespace plat = paddle::platform;
|
||||||
|
|
||||||
|
REGISTER_OP_CUDA_KERNEL(c_reduce_sum,
|
||||||
|
ops::CReduceOpCUDAKernel<ops::kRedSum, float>,
|
||||||
|
ops::CReduceOpCUDAKernel<ops::kRedSum, double>,
|
||||||
|
ops::CReduceOpCUDAKernel<ops::kRedSum, int>,
|
||||||
|
ops::CReduceOpCUDAKernel<ops::kRedSum, int64_t>,
|
||||||
|
ops::CReduceOpCUDAKernel<ops::kRedSum, plat::float16>)
|
@ -0,0 +1,92 @@
|
|||||||
|
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/collective/c_scatter_op.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
class CScatterOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||||
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "CScatter");
|
||||||
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "CScatter");
|
||||||
|
int root_id = ctx->Attrs().Get<int>("root");
|
||||||
|
int ring_id = ctx->Attrs().Get<int>("ring_id");
|
||||||
|
int nranks = ctx->Attrs().Get<int>("nranks");
|
||||||
|
PADDLE_ENFORCE_GE(nranks, 2,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"The number of ranks (%d) must be greater than 1 "
|
||||||
|
"to use collective op (c_scatter op).",
|
||||||
|
nranks));
|
||||||
|
PADDLE_ENFORCE_GE(
|
||||||
|
root_id, 0,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"The root_id (%d) for c_scatter_op must be non-negative.",
|
||||||
|
root_id));
|
||||||
|
PADDLE_ENFORCE_GE(
|
||||||
|
ring_id, 0,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"The ring_id (%d) for c_scatter_op must be non-negative.",
|
||||||
|
root_id));
|
||||||
|
framework::DDim dim = ctx->GetInputDim("X");
|
||||||
|
dim[0] = dim[0] / nranks;
|
||||||
|
if (dim[0] < 0) dim[0] = -1;
|
||||||
|
ctx->SetOutputDim("Out", dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
framework::OpKernelType GetExpectedKernelType(
|
||||||
|
const framework::ExecutionContext& ctx) const override {
|
||||||
|
return framework::OpKernelType(
|
||||||
|
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class CScatterOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
void Make() {
|
||||||
|
AddInput("X", "(Tensor) tensor to be broadcasted.");
|
||||||
|
AddOutput("Out", "(Tensor) the result of broadcast.");
|
||||||
|
AddAttr<int>("ring_id", "(int default 0) nccl communication ring id.")
|
||||||
|
.SetDefault(0);
|
||||||
|
AddAttr<int>("root", "(int default 0) root id for broadcasting.")
|
||||||
|
.SetDefault(0);
|
||||||
|
AddAttr<int>("nranks", "(int default 1) number of ranks.").SetDefault(0);
|
||||||
|
AddAttr<bool>(
|
||||||
|
"use_calc_stream",
|
||||||
|
"(bool default false) eject CUDA operations to calculation stream.")
|
||||||
|
.SetDefault(false);
|
||||||
|
AddComment(R"DOC(
|
||||||
|
CScatter Operator
|
||||||
|
Scatter the source to all participators.
|
||||||
|
)DOC");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
namespace plat = paddle::platform;
|
||||||
|
|
||||||
|
REGISTER_OP_WITHOUT_GRADIENT(c_scatter, ops::CScatterOp, ops::CScatterOpMaker);
|
||||||
|
|
||||||
|
REGISTER_OP_CPU_KERNEL(c_scatter, ops::CScatterOpCPUKernel<float>,
|
||||||
|
ops::CScatterOpCPUKernel<double>,
|
||||||
|
ops::CScatterOpCPUKernel<int>,
|
||||||
|
ops::CScatterOpCPUKernel<int64_t>,
|
||||||
|
ops::CScatterOpCPUKernel<plat::float16>);
|
@ -0,0 +1,101 @@
|
|||||||
|
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/collective/c_scatter_op.h"
|
||||||
|
|
||||||
|
#if defined(PADDLE_WITH_NCCL)
|
||||||
|
#include "paddle/fluid/platform/collective_helper.h"
|
||||||
|
#include "paddle/fluid/platform/nccl_helper.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class CScatterOpCUDAKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
#if defined(PADDLE_WITH_NCCL)
|
||||||
|
auto x = ctx.Input<framework::LoDTensor>("X");
|
||||||
|
auto out = ctx.Output<framework::LoDTensor>("Out");
|
||||||
|
int numel = x->numel();
|
||||||
|
ncclDataType_t dtype = platform::ToNCCLDataType(x->type());
|
||||||
|
|
||||||
|
int nranks = ctx.Attr<int>("nranks");
|
||||||
|
int root_id = ctx.Attr<int>("root");
|
||||||
|
int ring_id = ctx.Attr<int>("ring_id");
|
||||||
|
auto place = ctx.GetPlace();
|
||||||
|
auto comm = platform::NCCLCommContext::Instance().Get(ring_id, place);
|
||||||
|
PADDLE_ENFORCE_EQ(nranks, comm->nranks(),
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"The number of ranks (%d) you set of must "
|
||||||
|
"be equal to comm->nranks (%d).",
|
||||||
|
nranks, comm->nranks()));
|
||||||
|
PADDLE_ENFORCE_GE(
|
||||||
|
root_id, 0,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"The root_id (%d) for c_scatter_op must be non-negative.",
|
||||||
|
root_id));
|
||||||
|
PADDLE_ENFORCE_GE(
|
||||||
|
ring_id, 0,
|
||||||
|
platform::errors::InvalidArgument(
|
||||||
|
"The ring_id (%d) for c_scatter_op must be non-negative.",
|
||||||
|
ring_id));
|
||||||
|
|
||||||
|
cudaStream_t stream = nullptr;
|
||||||
|
if (ctx.Attr<bool>("use_calc_stream")) {
|
||||||
|
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
|
||||||
|
stream = static_cast<platform::CUDADeviceContext*>(dev_ctx)->stream();
|
||||||
|
} else {
|
||||||
|
stream = comm->stream();
|
||||||
|
}
|
||||||
|
|
||||||
|
framework::DDim x_dims = x->dims();
|
||||||
|
framework::DDim out_dims(x_dims);
|
||||||
|
framework::Tensor temp;
|
||||||
|
auto in_data_ptr = x->data<T>();
|
||||||
|
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclBroadcast(
|
||||||
|
reinterpret_cast<const void*>(in_data_ptr),
|
||||||
|
temp.mutable_data<T>(out_dims, place), numel, dtype, root_id,
|
||||||
|
comm->comm(), stream));
|
||||||
|
VLOG(3) << "rank " << comm->rank() << " invoke Scatter.";
|
||||||
|
|
||||||
|
out_dims[0] = out_dims[0] / nranks;
|
||||||
|
auto start_index = out_dims[0] * comm->rank();
|
||||||
|
auto end_index = start_index + out_dims[0];
|
||||||
|
temp = temp.Slice(start_index, end_index);
|
||||||
|
temp.Resize(out_dims);
|
||||||
|
out->mutable_data<T>(out_dims, place);
|
||||||
|
framework::TensorCopySync(*static_cast<const framework::Tensor*>(&temp),
|
||||||
|
place, static_cast<framework::Tensor*>(out));
|
||||||
|
out->Resize(out_dims);
|
||||||
|
#else
|
||||||
|
PADDLE_ENFORCE_EQ(
|
||||||
|
true, false,
|
||||||
|
platform::errors::Unavailable("PaddlePaddle should compile with GPU."));
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
|
|
||||||
|
namespace ops = paddle::operators;
|
||||||
|
namespace plat = paddle::platform;
|
||||||
|
|
||||||
|
REGISTER_OP_CUDA_KERNEL(c_scatter, ops::CScatterOpCUDAKernel<float>,
|
||||||
|
ops::CScatterOpCUDAKernel<double>,
|
||||||
|
ops::CScatterOpCUDAKernel<int>,
|
||||||
|
ops::CScatterOpCUDAKernel<int64_t>,
|
||||||
|
ops::CScatterOpCUDAKernel<plat::float16>);
|
@ -0,0 +1,39 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/data_type.h"
|
||||||
|
#include "paddle/fluid/framework/lod_tensor.h"
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
class CScatterOpCPUKernel : public framework::OpKernel<T> {
|
||||||
|
public:
|
||||||
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||||
|
PADDLE_ENFORCE_EQ(true, false,
|
||||||
|
platform::errors::Unavailable(
|
||||||
|
"Unimplemented cpu kernel for CScatterOp."));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
@ -0,0 +1,70 @@
|
|||||||
|
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import signal
|
||||||
|
import time
|
||||||
|
import socket
|
||||||
|
from contextlib import closing
|
||||||
|
from six import string_types
|
||||||
|
import math
|
||||||
|
import paddle
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
import paddle.fluid.profiler as profiler
|
||||||
|
import paddle.fluid.unique_name as nameGen
|
||||||
|
from paddle.fluid import core
|
||||||
|
import unittest
|
||||||
|
from multiprocessing import Process
|
||||||
|
import paddle.fluid.layers as layers
|
||||||
|
from functools import reduce
|
||||||
|
from test_collective_base import TestCollectiveRunnerBase, runtime_main
|
||||||
|
|
||||||
|
|
||||||
|
class TestCollectiveReduce(TestCollectiveRunnerBase):
|
||||||
|
def __init__(self):
|
||||||
|
self.global_ring_id = 0
|
||||||
|
|
||||||
|
def get_model(self, main_prog, startup_program):
|
||||||
|
ring_id = 0
|
||||||
|
rootid = 1
|
||||||
|
with fluid.program_guard(main_prog, startup_program):
|
||||||
|
tindata = layers.data(
|
||||||
|
name="tindata", shape=[10, 1000], dtype='float32')
|
||||||
|
toutdata = main_prog.current_block().create_var(
|
||||||
|
name="outofreduce",
|
||||||
|
dtype='float32',
|
||||||
|
type=core.VarDesc.VarType.LOD_TENSOR,
|
||||||
|
persistable=False,
|
||||||
|
stop_gradient=False)
|
||||||
|
main_prog.global_block().append_op(
|
||||||
|
type="c_reduce_sum",
|
||||||
|
inputs={'X': tindata},
|
||||||
|
attrs={'ring_id': ring_id,
|
||||||
|
'root_id': rootid},
|
||||||
|
outputs={'Out': toutdata})
|
||||||
|
main_prog.global_block().append_op(
|
||||||
|
type="c_sync_comm_stream",
|
||||||
|
inputs={'X': toutdata},
|
||||||
|
outputs={'Out': toutdata},
|
||||||
|
attrs={'ring_id': ring_id})
|
||||||
|
return toutdata
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
runtime_main(TestCollectiveReduce, "reduce", 0)
|
@ -0,0 +1,73 @@
|
|||||||
|
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import signal
|
||||||
|
import time
|
||||||
|
import socket
|
||||||
|
from contextlib import closing
|
||||||
|
from six import string_types
|
||||||
|
import math
|
||||||
|
import paddle
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
import paddle.fluid.profiler as profiler
|
||||||
|
import paddle.fluid.unique_name as nameGen
|
||||||
|
from paddle.fluid import core
|
||||||
|
import unittest
|
||||||
|
from multiprocessing import Process
|
||||||
|
import paddle.fluid.layers as layers
|
||||||
|
from functools import reduce
|
||||||
|
from test_collective_base import TestCollectiveRunnerBase, runtime_main
|
||||||
|
|
||||||
|
|
||||||
|
class TestCollectiveReduce(TestCollectiveRunnerBase):
|
||||||
|
def __init__(self):
|
||||||
|
self.global_ring_id = 0
|
||||||
|
|
||||||
|
def get_model(self, main_prog, startup_program):
|
||||||
|
ring_id = 0
|
||||||
|
rootid = 1
|
||||||
|
with fluid.program_guard(main_prog, startup_program):
|
||||||
|
tindata = layers.data(
|
||||||
|
name="tindata", shape=[10, 1000], dtype='float32')
|
||||||
|
toutdata = main_prog.current_block().create_var(
|
||||||
|
name="outofreduce",
|
||||||
|
dtype='float32',
|
||||||
|
type=core.VarDesc.VarType.LOD_TENSOR,
|
||||||
|
persistable=False,
|
||||||
|
stop_gradient=False)
|
||||||
|
main_prog.global_block().append_op(
|
||||||
|
type="c_reduce_sum",
|
||||||
|
inputs={'X': tindata},
|
||||||
|
attrs={
|
||||||
|
'ring_id': ring_id,
|
||||||
|
'use_calc_stream': True,
|
||||||
|
'root_id': rootid
|
||||||
|
},
|
||||||
|
outputs={'Out': toutdata})
|
||||||
|
main_prog.global_block().append_op(
|
||||||
|
type="c_sync_comm_stream",
|
||||||
|
inputs={'X': toutdata},
|
||||||
|
outputs={'Out': toutdata},
|
||||||
|
attrs={'ring_id': ring_id})
|
||||||
|
return toutdata
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
runtime_main(TestCollectiveReduce, "reduce", 0)
|
@ -0,0 +1,71 @@
|
|||||||
|
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import signal
|
||||||
|
import time
|
||||||
|
import socket
|
||||||
|
from contextlib import closing
|
||||||
|
from six import string_types
|
||||||
|
import math
|
||||||
|
import paddle
|
||||||
|
import paddle.fluid as fluid
|
||||||
|
import paddle.fluid.profiler as profiler
|
||||||
|
import paddle.fluid.unique_name as nameGen
|
||||||
|
from paddle.fluid import core
|
||||||
|
import unittest
|
||||||
|
from multiprocessing import Process
|
||||||
|
import paddle.fluid.layers as layers
|
||||||
|
from functools import reduce
|
||||||
|
from test_collective_base import TestCollectiveRunnerBase, runtime_main
|
||||||
|
|
||||||
|
|
||||||
|
class TestCollectiveScatter(TestCollectiveRunnerBase):
|
||||||
|
def __init__(self):
|
||||||
|
self.global_ring_id = 0
|
||||||
|
|
||||||
|
def get_model(self, main_prog, startup_program):
|
||||||
|
ring_id = 0
|
||||||
|
rootid = 1
|
||||||
|
with fluid.program_guard(main_prog, startup_program):
|
||||||
|
tindata = layers.data(
|
||||||
|
name="tindata", shape=[10, 1000], dtype='float32')
|
||||||
|
toutdata = main_prog.current_block().create_var(
|
||||||
|
name="outofreduce",
|
||||||
|
dtype='float32',
|
||||||
|
type=core.VarDesc.VarType.LOD_TENSOR,
|
||||||
|
persistable=False,
|
||||||
|
stop_gradient=False)
|
||||||
|
main_prog.global_block().append_op(
|
||||||
|
type="c_scatter",
|
||||||
|
inputs={'X': tindata},
|
||||||
|
attrs={'ring_id': ring_id,
|
||||||
|
'root': rootid,
|
||||||
|
'nranks': 2},
|
||||||
|
outputs={'Out': toutdata})
|
||||||
|
main_prog.global_block().append_op(
|
||||||
|
type="c_sync_comm_stream",
|
||||||
|
inputs={'X': toutdata},
|
||||||
|
outputs={'Out': toutdata},
|
||||||
|
attrs={'ring_id': ring_id})
|
||||||
|
return toutdata
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
runtime_main(TestCollectiveScatter, "scatter", 0)
|
@ -0,0 +1,34 @@
|
|||||||
|
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from test_collective_base import TestDistBase
|
||||||
|
|
||||||
|
|
||||||
|
class TestCReduceOp(TestDistBase):
|
||||||
|
def _setup_config(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_reduce(self):
|
||||||
|
self.check_with_place("collective_reduce_op.py", "reduce")
|
||||||
|
|
||||||
|
def test_reduce_calc_stream(self):
|
||||||
|
self.check_with_place("collective_reduce_op_calc_stream.py", "reduce")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
@ -0,0 +1,31 @@
|
|||||||
|
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import print_function
|
||||||
|
import unittest
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from test_collective_base import TestDistBase
|
||||||
|
|
||||||
|
|
||||||
|
class TestCScatterOp(TestDistBase):
|
||||||
|
def _setup_config(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_scatter(self):
|
||||||
|
self.check_with_place("collective_scatter_op.py", "scatter")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
Reference in new issue