parent
ec11135d54
commit
8e9ebebcef
@ -0,0 +1,84 @@
|
||||
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/linspace_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class LinspaceOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("Start"),
|
||||
"Input(Start) of LinspaceOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Stop"),
|
||||
"Input(Stop) of LinspaceOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasInput("Num"),
|
||||
"Input(Num) of LinspaceOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"Output(OUt) of LinspaceOp should not be null.");
|
||||
|
||||
auto s_dims = ctx->GetInputDim("Start");
|
||||
PADDLE_ENFORCE((s_dims.size() == 1) && (s_dims[0] == 1),
|
||||
"The shape of Input(Start) should be [1].");
|
||||
|
||||
auto e_dims = ctx->GetInputDim("Stop");
|
||||
PADDLE_ENFORCE((e_dims.size() == 1) && (e_dims[0] == 1),
|
||||
"The shape of Input(Stop) should be [1].");
|
||||
|
||||
auto step_dims = ctx->GetInputDim("Num");
|
||||
PADDLE_ENFORCE((step_dims.size() == 1) && (step_dims[0] == 1),
|
||||
"The shape of Input(Num) should be [1].");
|
||||
|
||||
ctx->SetOutputDim("Out", {-1});
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
framework::LibraryType library_{framework::LibraryType::kPlain};
|
||||
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
|
||||
return framework::OpKernelType(
|
||||
ctx.Input<framework::Tensor>("Start")->type(), ctx.device_context(),
|
||||
layout_, library_);
|
||||
}
|
||||
};
|
||||
|
||||
class LinspaceOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("Start",
|
||||
"First entry in the sequence. It is a tensor of shape [1], should "
|
||||
"be of type float32 or float64.");
|
||||
AddInput("Stop",
|
||||
"Last entry in the sequence. It is a tensor of shape [1], should "
|
||||
"be of type float32 or float64.");
|
||||
AddInput("Num",
|
||||
"Number of entry in the sequence. It is a tensor of shape [1], "
|
||||
"should be of type int32.");
|
||||
AddOutput("Out", "A sequence of numbers.");
|
||||
AddComment(R"DOC(
|
||||
Return fixed number of evenly spaced values within a given interval. First entry is start, and last entry is stop. In the case when Num is 1, only Start is returned. Like linspace function of numpy.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_WITHOUT_GRADIENT(linspace, ops::LinspaceOp, ops::LinspaceOpMaker);
|
||||
REGISTER_OP_CPU_KERNEL(linspace, ops::CPULinspaceKernel<float>,
|
||||
ops::CPULinspaceKernel<double>);
|
@ -0,0 +1,75 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/linspace_op.h"
|
||||
#include "paddle/fluid/platform/cuda_primitives.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
#define CUDA_1D_KERNEL_LOOP(i, n) \
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
|
||||
i += blockDim.x * gridDim.x)
|
||||
|
||||
template <typename T>
|
||||
__global__ void LinspaceKernel(T start, T step, int64_t size, T* out) {
|
||||
CUDA_1D_KERNEL_LOOP(index, size) { out[index] = start + step * index; }
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void LinspaceSpecialKernel(T start, T* out) {
|
||||
out[0] = start;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class CUDALinspaceKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
auto* start_t = context.Input<framework::Tensor>("Start");
|
||||
auto* stop_t = context.Input<framework::Tensor>("Stop");
|
||||
auto* num_t = context.Input<framework::Tensor>("Num");
|
||||
auto* out = context.Output<framework::Tensor>("Out");
|
||||
|
||||
framework::Tensor n;
|
||||
framework::TensorCopy(*start_t, platform::CPUPlace(), &n);
|
||||
T start = n.data<T>()[0];
|
||||
framework::TensorCopy(*stop_t, platform::CPUPlace(), &n);
|
||||
T stop = n.data<T>()[0];
|
||||
framework::TensorCopy(*num_t, platform::CPUPlace(), &n);
|
||||
int32_t num = n.data<int32_t>()[0];
|
||||
|
||||
PADDLE_ENFORCE(num > 0, "The num of linspace op should be larger than 0.");
|
||||
|
||||
out->Resize(framework::make_ddim({num}));
|
||||
T* out_data = out->mutable_data<T>(context.GetPlace());
|
||||
|
||||
T step = 0;
|
||||
if (num != 1) {
|
||||
step = (stop - start) / (num - 1);
|
||||
}
|
||||
|
||||
auto stream = context.cuda_device_context().stream();
|
||||
int block = 512;
|
||||
int grid = (num + block - 1) / block;
|
||||
LinspaceKernel<T><<<grid, block, 0, stream>>>(start, step, num, out_data);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(linspace, ops::CUDALinspaceKernel<float>,
|
||||
ops::CUDALinspaceKernel<double>);
|
@ -0,0 +1,51 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include <functional>
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T>
|
||||
class CPULinspaceKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
T start = context.Input<framework::Tensor>("Start")->data<T>()[0];
|
||||
T stop = context.Input<framework::Tensor>("Stop")->data<T>()[0];
|
||||
int32_t num = context.Input<framework::Tensor>("Num")->data<int32_t>()[0];
|
||||
auto* out = context.Output<framework::Tensor>("Out");
|
||||
PADDLE_ENFORCE(num > 0, "The num of linspace op should be larger than 0.");
|
||||
|
||||
out->Resize(framework::make_ddim({num}));
|
||||
|
||||
T* out_data = out->mutable_data<T>(context.GetPlace());
|
||||
|
||||
if (num > 1) {
|
||||
T step = (stop - start) / (num - 1);
|
||||
T value = start;
|
||||
for (int i = 0; i < num; ++i) {
|
||||
out_data[i] = value;
|
||||
value += step;
|
||||
}
|
||||
} else {
|
||||
out_data[0] = start;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,71 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class TestLinspaceOpCommonCase(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "linspace"
|
||||
dtype = 'float32'
|
||||
self.inputs = {
|
||||
'Start': np.array([0]).astype(dtype),
|
||||
'Stop': np.array([10]).astype(dtype),
|
||||
'Num': np.array([11]).astype('int32')
|
||||
}
|
||||
|
||||
self.outputs = {'Out': np.arange(0, 11).astype(dtype)}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
class TestLinspaceOpReverseCase(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "linspace"
|
||||
dtype = 'float32'
|
||||
self.inputs = {
|
||||
'Start': np.array([10]).astype(dtype),
|
||||
'Stop': np.array([0]).astype(dtype),
|
||||
'Num': np.array([11]).astype('int32')
|
||||
}
|
||||
|
||||
self.outputs = {'Out': np.arange(10, -1, -1).astype(dtype)}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
class TestLinspaceOpNumOneCase(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "linspace"
|
||||
dtype = 'float32'
|
||||
self.inputs = {
|
||||
'Start': np.array([10]).astype(dtype),
|
||||
'Stop': np.array([0]).astype(dtype),
|
||||
'Num': np.array([1]).astype('int32')
|
||||
}
|
||||
|
||||
self.outputs = {'Out': np.array(10, dtype=dtype)}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue