New custom operator extension mechanism (#30690)
* initial commit: simple demo * polish copyright format * add grap op simple demo * adapt uncertain number of argument * change trait marco name * add place & dtype support for add kernel * add dispath and infershape func * poish code & add notes * add dynamic_loader dep for paddle_framework * add new custom op test dir * polish impl details * add unittest for new custom op * fix failed unittest * Costum op (#1) * fix compile error * wrap framework tensor with LoDTensor * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * add CustomTensor default constructor * add size() for CustomTensor * make size const for CustomTensor * refactor place related api to circle the concept * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * make place const * make Tensor copy * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * remove additional head of framework * use back to shared ptr for custom tensor * use back to shared ptr for custom tensor * use back to shared ptr for custom tensor * use back to shared ptr for custom tensor * use back to shared ptr for custom tensor * use back to shared ptr for custom tensor * add gpu test * merge latest cwh code in * adjust ut code of custom op * adjust ut code of custom op * adjust ut code of custom op * Remove ShareData from user && Change CustomTensor to Tensor && Support more data type (#2) * fix compile error * wrap framework tensor with LoDTensor * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * add CustomTensor default constructor * add size() for CustomTensor * make size const for CustomTensor * refactor place related api to circle the concept * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * make place const * make Tensor copy * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * remove additional head of framework * use back to shared ptr for custom tensor * use back to shared ptr for custom tensor * use back to shared ptr for custom tensor * use back to shared ptr for custom tensor * use back to shared ptr for custom tensor * use back to shared ptr for custom tensor * add gpu test * merge latest cwh code in * adjust ut code of custom op * adjust ut code of custom op * adjust ut code of custom op * adjust ut code of custom op * adjust ut code of custom op * hid share data from and to * rename CustomTensor to Tensor * refactor register design & add test * change op_funtion to op_meta_info * split op meta info into .h and .cc * move get methods into friend class * move OpMetaInfoHelper into framework space * move CustomTensorUtils into framework space * change pybind api name * move PD C API into op meta info * add register custom op api * remove inference cmake change * refactor copy to api && change Reshape to lowercase && support more dtype && add more test (#3) * fix compile error * wrap framework tensor with LoDTensor * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * add CustomTensor default constructor * add size() for CustomTensor * make size const for CustomTensor * refactor place related api to circle the concept * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * make place const * make Tensor copy * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * remove additional head of framework * use back to shared ptr for custom tensor * use back to shared ptr for custom tensor * use back to shared ptr for custom tensor * use back to shared ptr for custom tensor * use back to shared ptr for custom tensor * use back to shared ptr for custom tensor * add gpu test * merge latest cwh code in * adjust ut code of custom op * adjust ut code of custom op * adjust ut code of custom op * adjust ut code of custom op * adjust ut code of custom op * hid share data from and to * rename CustomTensor to Tensor * support multi dtype * remove lod, make reshape lowercase, add copy test and refactor copy api * remove lod, make reshape lowercase, add copy test and refactor copy api * remove lod, make reshape lowercase, add copy test and refactor copy api * remove lod, make reshape lowercase, add copy test and refactor copy api * fix copy to error * add more test * add more test * add more test * add more test * add more test * add more test * add more test * add more test * add more test * add more test * add more test * add more test * add more test * add more test * add more test * add more test * polish detail & error message * polish test details * Add cast api && Change copy related api to copy_to && add more test (#4) * fix compile error * wrap framework tensor with LoDTensor * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * add CustomTensor default constructor * add size() for CustomTensor * make size const for CustomTensor * refactor place related api to circle the concept * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * fix compile error * make place const * make Tensor copy * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * debug CustomTensor core * remove additional head of framework * use back to shared ptr for custom tensor * use back to shared ptr for custom tensor * use back to shared ptr for custom tensor * use back to shared ptr for custom tensor * use back to shared ptr for custom tensor * use back to shared ptr for custom tensor * add gpu test * merge latest cwh code in * adjust ut code of custom op * adjust ut code of custom op * adjust ut code of custom op * adjust ut code of custom op * adjust ut code of custom op * hid share data from and to * rename CustomTensor to Tensor * support multi dtype * remove lod, make reshape lowercase, add copy test and refactor copy api * remove lod, make reshape lowercase, add copy test and refactor copy api * remove lod, make reshape lowercase, add copy test and refactor copy api * remove lod, make reshape lowercase, add copy test and refactor copy api * fix copy to error * add more test * add more test * add more test * add more test * add more test * add more test * add more test * add more test * add more test * add more test * add more test * add more test * add more test * add more test * add more test * add more test * add type cast * add cast and make copy to api * add cast and make copy to api * add cast and make copy to api * add cast and make copy to api * merge cwh code * merge cwh code * merge cwh code * merge cwh code * merge cwh code * add more error log * add more error log * polish code * used for test * remove test comment * remove test comment * fix uint8 type error * fix lost uint8 type error * add test for coverage * polish details by reviewer comments * add prefix for DISABLE_COPY_AND_ASSIGN Co-authored-by: Jiabin Yang <360788950@qq.com>revert-31068-fix_conv3d_windows
parent
5c0332714f
commit
f649442ddd
@ -0,0 +1,18 @@
|
||||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
// All paddle apis in C++ frontend
|
||||
#include "paddle/fluid/extension/include/all.h"
|
@ -0,0 +1,25 @@
|
||||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#if !defined(_MSC_VER) && __cplusplus < 199711L
|
||||
#error C++11 or later compatible compiler is required to use Paddle.
|
||||
#endif
|
||||
|
||||
#include "paddle/fluid/extension/include/dispatch.h"
|
||||
#include "paddle/fluid/extension/include/dtype.h"
|
||||
#include "paddle/fluid/extension/include/op_meta_info.h"
|
||||
#include "paddle/fluid/extension/include/place.h"
|
||||
#include "paddle/fluid/extension/include/tensor.h"
|
@ -0,0 +1,46 @@
|
||||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/extension/include/dtype.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
#define PD_PRIVATE_CASE_TYPE_USING_HINT(NAME, enum_type, type, HINT, ...) \
|
||||
case enum_type: { \
|
||||
using HINT = type; \
|
||||
__VA_ARGS__(); \
|
||||
break; \
|
||||
}
|
||||
|
||||
#define PD_PRIVATE_CASE_TYPE(NAME, enum_type, type, ...) \
|
||||
PD_PRIVATE_CASE_TYPE_USING_HINT(NAME, enum_type, type, data_t, __VA_ARGS__)
|
||||
|
||||
#define PD_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
|
||||
[&] { \
|
||||
const auto& dtype = TYPE; \
|
||||
switch (dtype) { \
|
||||
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \
|
||||
__VA_ARGS__) \
|
||||
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \
|
||||
__VA_ARGS__) \
|
||||
default: \
|
||||
throw std::runtime_error("function not implemented for this type."); \
|
||||
} \
|
||||
}()
|
||||
|
||||
// TODD(chenweihang): implement other DISPATH macros in next PR
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,39 @@
|
||||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/fluid/platform/bfloat16.h"
|
||||
#include "paddle/fluid/platform/complex128.h"
|
||||
#include "paddle/fluid/platform/complex64.h"
|
||||
#include "paddle/fluid/platform/float16.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
enum DataType {
|
||||
FLOAT32,
|
||||
FLOAT64,
|
||||
BFLOAT16,
|
||||
COMPLEX128,
|
||||
COMPLEX64,
|
||||
FLOAT16,
|
||||
INT64,
|
||||
INT32,
|
||||
INT16,
|
||||
UINT8,
|
||||
INT8,
|
||||
BOOL,
|
||||
// TODO(JiabinYang) support more data types if needed.
|
||||
};
|
||||
|
||||
} // namespace paddle
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,22 @@
|
||||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace paddle {
|
||||
|
||||
// TODO(yangjiabin): Add other place support in next PR
|
||||
enum class PlaceType { kUNK = -1, kCPU, kGPU };
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,95 @@
|
||||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include "paddle/fluid/extension/include/dtype.h"
|
||||
#include "paddle/fluid/extension/include/place.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
class CustomTensorUtils;
|
||||
} // namespace framework
|
||||
class Tensor {
|
||||
public:
|
||||
/// \brief Construct a Tensor on None Place for CustomOp.
|
||||
/// Generally it's only used for user to create Tensor.
|
||||
explicit Tensor(const PlaceType& place);
|
||||
/// \brief Reset the shape of the tensor.
|
||||
/// Generally it's only used for the input tensor.
|
||||
/// Reshape must be called before calling
|
||||
/// mutable_data() or copy_from_cpu()
|
||||
/// \param shape The shape to set.
|
||||
void reshape(const std::vector<int>& shape);
|
||||
|
||||
/// \brief Get the memory pointer in CPU or GPU with
|
||||
/// specific data type.
|
||||
/// Please Reshape the tensor first before call this.
|
||||
/// It's usually used to get input data pointer.
|
||||
/// \param place The place of the tensor this will
|
||||
/// override the original place of current tensor.
|
||||
template <typename T>
|
||||
T* mutable_data(const PlaceType& place);
|
||||
|
||||
/// \brief Get the memory pointer in CPU or GPU with
|
||||
/// specific data type. Please Reshape the tensor
|
||||
/// first before call this.It's usually used to get
|
||||
/// input data pointer.
|
||||
template <typename T>
|
||||
T* mutable_data();
|
||||
|
||||
/// \brief Get the memory pointer directly.
|
||||
/// It's usually used to get the output data pointer.
|
||||
/// \return The tensor data buffer pointer.
|
||||
template <typename T>
|
||||
T* data() const;
|
||||
|
||||
/// \brief Copy the host memory to tensor data.
|
||||
/// It's usually used to set the input tensor data.
|
||||
/// \param PlaceType of target place, from which
|
||||
/// the tensor will copy.
|
||||
|
||||
template <typename T>
|
||||
Tensor copy_to(const PlaceType& place);
|
||||
|
||||
/// \brief Return the shape of the Tensor.
|
||||
std::vector<int> shape() const;
|
||||
|
||||
/// \brief Return the data type of the tensor.
|
||||
/// It's usually used to get the output tensor data type.
|
||||
/// \return The data type of the tensor.
|
||||
DataType type() const;
|
||||
|
||||
/// \brief Get the size of current tensor.
|
||||
/// Use this method to get the size of tensor
|
||||
/// \return int64_t.
|
||||
int64_t size() const;
|
||||
|
||||
/// \brief Get the place of current tensor.
|
||||
/// Use this method to get the place of tensor
|
||||
/// \return Place.
|
||||
const PlaceType& place() const;
|
||||
|
||||
/// \brief Cast datatype from one to another
|
||||
Tensor cast(const DataType& target_type);
|
||||
|
||||
private:
|
||||
friend class framework::CustomTensorUtils;
|
||||
mutable std::shared_ptr<void> tensor_;
|
||||
mutable PlaceType place_;
|
||||
};
|
||||
|
||||
} // namespace paddle
|
@ -0,0 +1,120 @@
|
||||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/extension/include/op_meta_info.h"
|
||||
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/framework/custom_operator.h"
|
||||
|
||||
namespace paddle {
|
||||
|
||||
////////////////////// Op Meta Info //////////////////////
|
||||
|
||||
OpMetaInfo& OpMetaInfo::Inputs(std::vector<std::string>&& inputs) {
|
||||
inputs_ = std::forward<std::vector<std::string>>(inputs);
|
||||
return *this;
|
||||
}
|
||||
OpMetaInfo& OpMetaInfo::Outputs(std::vector<std::string>&& outputs) {
|
||||
outputs_ = std::forward<std::vector<std::string>>(outputs);
|
||||
return *this;
|
||||
}
|
||||
OpMetaInfo& OpMetaInfo::SetKernelFn(KernelFunc&& func) {
|
||||
kernel_fn_ = std::forward<KernelFunc>(func);
|
||||
return *this;
|
||||
}
|
||||
OpMetaInfo& OpMetaInfo::SetInferShapeFn(InferShapeFunc&& func) {
|
||||
infer_shape_fn_ = std::forward<InferShapeFunc>(func);
|
||||
return *this;
|
||||
}
|
||||
OpMetaInfo& OpMetaInfo::SetInferDtypeFn(InferDtypeFunc&& func) {
|
||||
infer_dtype_fn_ = std::forward<InferDtypeFunc>(func);
|
||||
return *this;
|
||||
}
|
||||
|
||||
//////////////// Op Meta Info Map /////////////////
|
||||
|
||||
std::vector<OpMetaInfo>& OpMetaInfoMap::operator[](const std::string& name) {
|
||||
return map_[name];
|
||||
}
|
||||
|
||||
const std::unordered_map<std::string, std::vector<OpMetaInfo>>&
|
||||
OpMetaInfoMap::GetMap() const {
|
||||
return map_;
|
||||
}
|
||||
|
||||
//////////////// Op Meta Info Builder /////////////////
|
||||
|
||||
OpMetaInfoBuilder::OpMetaInfoBuilder(std::string&& name) {
|
||||
name_ = std::forward<std::string>(name);
|
||||
auto& info_vector = OpMetaInfoMap::Instance()[name_];
|
||||
auto op_meta = OpMetaInfo(name_);
|
||||
info_vector.emplace_back(std::move(op_meta));
|
||||
info_ptr_ = &(info_vector.back());
|
||||
}
|
||||
|
||||
OpMetaInfoBuilder& OpMetaInfoBuilder::Inputs(
|
||||
std::vector<std::string>&& inputs) {
|
||||
info_ptr_->Inputs(std::forward<std::vector<std::string>>(inputs));
|
||||
return *this;
|
||||
}
|
||||
|
||||
OpMetaInfoBuilder& OpMetaInfoBuilder::Outputs(
|
||||
std::vector<std::string>&& outputs) {
|
||||
info_ptr_->Outputs(std::forward<std::vector<std::string>>(outputs));
|
||||
return *this;
|
||||
}
|
||||
|
||||
OpMetaInfoBuilder& OpMetaInfoBuilder::SetKernelFn(KernelFunc&& func) {
|
||||
info_ptr_->SetKernelFn(std::forward<KernelFunc>(func));
|
||||
return *this;
|
||||
}
|
||||
|
||||
OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferShapeFn(InferShapeFunc&& func) {
|
||||
info_ptr_->SetInferShapeFn(std::forward<InferShapeFunc>(func));
|
||||
return *this;
|
||||
}
|
||||
|
||||
OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferDtypeFn(InferDtypeFunc&& func) {
|
||||
info_ptr_->SetInferDtypeFn(std::forward<InferDtypeFunc>(func));
|
||||
return *this;
|
||||
}
|
||||
|
||||
OpMetaInfoBuilder& OpMetaInfoBuilder::SetBackwardOp(
|
||||
const std::string& bwd_op_name) {
|
||||
auto& info_vector = OpMetaInfoMap::Instance()[name_];
|
||||
auto op_meta = OpMetaInfo(bwd_op_name);
|
||||
info_vector.emplace_back(std::move(op_meta));
|
||||
info_ptr_ = &(info_vector.back());
|
||||
return *this;
|
||||
}
|
||||
|
||||
/////////////////////// Op register API /////////////////////////
|
||||
|
||||
void RegisterAllCustomOperator() {
|
||||
auto& op_meta_info_map = OpMetaInfoMap::Instance();
|
||||
framework::RegisterOperatorWithMetaInfoMap(op_meta_info_map);
|
||||
}
|
||||
|
||||
} // namespace paddle
|
||||
|
||||
extern "C" {
|
||||
|
||||
paddle::OpMetaInfoMap& PD_GetOpMetaInfoMap() {
|
||||
return paddle::OpMetaInfoMap::Instance();
|
||||
}
|
||||
|
||||
} // end extern "C"
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,32 @@
|
||||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "paddle/fluid/extension/include/op_meta_info.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
// Load custom op api: register op after user compiled
|
||||
void LoadOpMetaInfoAndRegisterOp(const std::string& dso_name);
|
||||
|
||||
// Register custom op api: register op directly
|
||||
void RegisterOperatorWithMetaInfoMap(
|
||||
const paddle::OpMetaInfoMap& op_meta_info_map);
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,246 @@
|
||||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "glog/logging.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "paddle/fluid/extension/include/all.h"
|
||||
#include "paddle/fluid/framework/custom_tensor_utils.h"
|
||||
#include "paddle/fluid/framework/lod_tensor.h"
|
||||
|
||||
template <typename T>
|
||||
paddle::Tensor InitCPUTensorForTest() {
|
||||
std::vector<int> tensor_shape{5, 5};
|
||||
auto t1 = paddle::Tensor(paddle::PlaceType::kCPU);
|
||||
t1.reshape(tensor_shape);
|
||||
auto* p_data_ptr = t1.mutable_data<T>(paddle::PlaceType::kCPU);
|
||||
for (int64_t i = 0; i < t1.size(); i++) {
|
||||
p_data_ptr[i] = 5;
|
||||
}
|
||||
return t1;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TestCopyTensor() {
|
||||
auto t1 = InitCPUTensorForTest<T>();
|
||||
auto t1_cpu_cp = t1.template copy_to<T>(paddle::PlaceType::kCPU);
|
||||
CHECK((paddle::PlaceType::kCPU == t1_cpu_cp.place()));
|
||||
for (int64_t i = 0; i < t1.size(); i++) {
|
||||
CHECK_EQ(t1_cpu_cp.template data<T>()[i], 5);
|
||||
}
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
VLOG(2) << "Do GPU copy test";
|
||||
auto t1_gpu_cp = t1_cpu_cp.template copy_to<T>(paddle::PlaceType::kGPU);
|
||||
CHECK((paddle::PlaceType::kGPU == t1_gpu_cp.place()));
|
||||
auto t1_gpu_cp_cp = t1_gpu_cp.template copy_to<T>(paddle::PlaceType::kGPU);
|
||||
CHECK((paddle::PlaceType::kGPU == t1_gpu_cp_cp.place()));
|
||||
auto t1_gpu_cp_cp_cpu =
|
||||
t1_gpu_cp.template copy_to<T>(paddle::PlaceType::kCPU);
|
||||
CHECK((paddle::PlaceType::kCPU == t1_gpu_cp_cp_cpu.place()));
|
||||
for (int64_t i = 0; i < t1.size(); i++) {
|
||||
CHECK_EQ(t1_gpu_cp_cp_cpu.template data<T>()[i], 5);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void TestAPIPlace() {
|
||||
std::vector<int> tensor_shape = {5, 5};
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
auto t1 = paddle::Tensor(paddle::PlaceType::kGPU);
|
||||
t1.reshape(tensor_shape);
|
||||
t1.mutable_data<float>();
|
||||
CHECK((paddle::PlaceType::kGPU == t1.place()));
|
||||
#endif
|
||||
auto t2 = paddle::Tensor(paddle::PlaceType::kCPU);
|
||||
t2.reshape(tensor_shape);
|
||||
t2.mutable_data<float>();
|
||||
CHECK((paddle::PlaceType::kCPU == t2.place()));
|
||||
}
|
||||
|
||||
void TestAPISizeAndShape() {
|
||||
std::vector<int> tensor_shape = {5, 5};
|
||||
auto t1 = paddle::Tensor(paddle::PlaceType::kCPU);
|
||||
t1.reshape(tensor_shape);
|
||||
CHECK_EQ(t1.size(), 25);
|
||||
CHECK(t1.shape() == tensor_shape);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
paddle::DataType TestDtype() {
|
||||
std::vector<int> tensor_shape = {5, 5};
|
||||
auto t1 = paddle::Tensor(paddle::PlaceType::kCPU);
|
||||
t1.reshape(tensor_shape);
|
||||
t1.template mutable_data<T>();
|
||||
return t1.type();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TestCast(paddle::DataType data_type) {
|
||||
std::vector<int> tensor_shape = {5, 5};
|
||||
auto t1 = paddle::Tensor(paddle::PlaceType::kCPU);
|
||||
t1.reshape(tensor_shape);
|
||||
t1.template mutable_data<T>();
|
||||
auto t2 = t1.cast(data_type);
|
||||
CHECK_EQ(t2.type(), data_type);
|
||||
}
|
||||
|
||||
void GroupTestCopy() {
|
||||
VLOG(2) << "Float cpu-cpu-gpu-gpu-cpu";
|
||||
TestCopyTensor<float>();
|
||||
VLOG(2) << "Double cpu-cpu-gpu-gpu-cpu";
|
||||
TestCopyTensor<double>();
|
||||
// TODO(JiabinYang): Support these test later
|
||||
// VLOG(2) << "Fp16 cpu-cpu-gpu-gpu-cpu";
|
||||
// TestCopyTensor<paddle::platform::float16>();
|
||||
// VLOG(2) << "BF16 cpu-cpu-gpu-gpu-cpu";
|
||||
// TestCopyTensor<paddle::platform::bfloat16>();
|
||||
// VLOG(2) << "complex128 cpu-cpu-gpu-gpu-cpu";
|
||||
// TestCopyTensor<paddle::platform::complex128>();
|
||||
// VLOG(2) << "complex64 cpu-cpu-gpu-gpu-cpu";
|
||||
// TestCopyTensor<paddle::platform::complex64>();
|
||||
// VLOG(2) << "int cpu-cpu-gpu-gpu-cpu";
|
||||
TestCopyTensor<int>();
|
||||
VLOG(2) << "int64 cpu-cpu-gpu-gpu-cpu";
|
||||
TestCopyTensor<int64_t>();
|
||||
VLOG(2) << "int16 cpu-cpu-gpu-gpu-cpu";
|
||||
TestCopyTensor<int16_t>();
|
||||
VLOG(2) << "int8 cpu-cpu-gpu-gpu-cpu";
|
||||
TestCopyTensor<int8_t>();
|
||||
VLOG(2) << "uint8 cpu-cpu-gpu-gpu-cpu";
|
||||
TestCopyTensor<uint8_t>();
|
||||
}
|
||||
|
||||
void GroupTestCast() {
|
||||
VLOG(2) << "int cast";
|
||||
TestCast<int>(paddle::DataType::FLOAT32);
|
||||
VLOG(2) << "int32 cast";
|
||||
TestCast<int32_t>(paddle::DataType::FLOAT32);
|
||||
VLOG(2) << "int64 cast";
|
||||
TestCast<int64_t>(paddle::DataType::FLOAT32);
|
||||
VLOG(2) << "double cast";
|
||||
TestCast<double>(paddle::DataType::FLOAT32);
|
||||
VLOG(2) << "bfloat16 cast";
|
||||
TestCast<paddle::platform::bfloat16>(paddle::DataType::FLOAT32);
|
||||
VLOG(2) << "float16 cast";
|
||||
TestCast<paddle::platform::float16>(paddle::DataType::FLOAT32);
|
||||
VLOG(2) << "bool cast";
|
||||
TestCast<bool>(paddle::DataType::FLOAT32);
|
||||
VLOG(2) << "uint8 cast";
|
||||
TestCast<uint8_t>(paddle::DataType::FLOAT32);
|
||||
VLOG(2) << "float cast";
|
||||
TestCast<float>(paddle::DataType::FLOAT32);
|
||||
}
|
||||
|
||||
void GroupTestDtype() {
|
||||
CHECK(TestDtype<float>() == paddle::DataType::FLOAT32);
|
||||
CHECK(TestDtype<double>() == paddle::DataType::FLOAT64);
|
||||
CHECK(TestDtype<paddle::platform::float16>() == paddle::DataType::FLOAT16);
|
||||
CHECK(TestDtype<paddle::platform::bfloat16>() == paddle::DataType::BFLOAT16);
|
||||
CHECK(TestDtype<paddle::platform::complex128>() ==
|
||||
paddle::DataType::COMPLEX128);
|
||||
CHECK(TestDtype<paddle::platform::complex64>() ==
|
||||
paddle::DataType::COMPLEX64);
|
||||
CHECK(TestDtype<int>() == paddle::DataType::INT32);
|
||||
CHECK(TestDtype<int64_t>() == paddle::DataType::INT64);
|
||||
CHECK(TestDtype<int16_t>() == paddle::DataType::INT16);
|
||||
CHECK(TestDtype<int8_t>() == paddle::DataType::INT8);
|
||||
CHECK(TestDtype<uint8_t>() == paddle::DataType::UINT8);
|
||||
}
|
||||
|
||||
void GroupTestDtypeConvert() {
|
||||
// enum -> proto
|
||||
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
|
||||
paddle::DataType::COMPLEX128) ==
|
||||
paddle::framework::proto::VarType::COMPLEX128);
|
||||
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
|
||||
paddle::DataType::COMPLEX64) ==
|
||||
paddle::framework::proto::VarType::COMPLEX64);
|
||||
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
|
||||
paddle::DataType::FLOAT64) ==
|
||||
paddle::framework::proto::VarType::FP64);
|
||||
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
|
||||
paddle::DataType::FLOAT32) ==
|
||||
paddle::framework::proto::VarType::FP32);
|
||||
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
|
||||
paddle::DataType::FLOAT16) ==
|
||||
paddle::framework::proto::VarType::FP16);
|
||||
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
|
||||
paddle::DataType::BFLOAT16) ==
|
||||
paddle::framework::proto::VarType::BF16);
|
||||
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
|
||||
paddle::DataType::UINT8) ==
|
||||
paddle::framework::proto::VarType::UINT8);
|
||||
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
|
||||
paddle::DataType::INT8) == paddle::framework::proto::VarType::INT8);
|
||||
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
|
||||
paddle::DataType::INT32) ==
|
||||
paddle::framework::proto::VarType::INT32);
|
||||
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
|
||||
paddle::DataType::INT64) ==
|
||||
paddle::framework::proto::VarType::INT64);
|
||||
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
|
||||
paddle::DataType::INT16) ==
|
||||
paddle::framework::proto::VarType::INT16);
|
||||
CHECK(paddle::framework::CustomTensorUtils::ConvertEnumDTypeToInnerDType(
|
||||
paddle::DataType::BOOL) == paddle::framework::proto::VarType::BOOL);
|
||||
// proto -> enum
|
||||
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
|
||||
paddle::framework::proto::VarType::COMPLEX128) ==
|
||||
paddle::DataType::COMPLEX128);
|
||||
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
|
||||
paddle::framework::proto::VarType::COMPLEX64) ==
|
||||
paddle::DataType::COMPLEX64);
|
||||
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
|
||||
paddle::framework::proto::VarType::FP64) ==
|
||||
paddle::DataType::FLOAT64);
|
||||
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
|
||||
paddle::framework::proto::VarType::FP32) ==
|
||||
paddle::DataType::FLOAT32);
|
||||
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
|
||||
paddle::framework::proto::VarType::FP16) ==
|
||||
paddle::DataType::FLOAT16);
|
||||
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
|
||||
paddle::framework::proto::VarType::BF16) ==
|
||||
paddle::DataType::BFLOAT16);
|
||||
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
|
||||
paddle::framework::proto::VarType::INT64) ==
|
||||
paddle::DataType::INT64);
|
||||
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
|
||||
paddle::framework::proto::VarType::INT32) ==
|
||||
paddle::DataType::INT32);
|
||||
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
|
||||
paddle::framework::proto::VarType::INT8) == paddle::DataType::INT8);
|
||||
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
|
||||
paddle::framework::proto::VarType::UINT8) ==
|
||||
paddle::DataType::UINT8);
|
||||
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
|
||||
paddle::framework::proto::VarType::INT16) ==
|
||||
paddle::DataType::INT16);
|
||||
CHECK(paddle::framework::CustomTensorUtils::ConvertInnerDTypeToEnumDType(
|
||||
paddle::framework::proto::VarType::BOOL) == paddle::DataType::BOOL);
|
||||
}
|
||||
|
||||
TEST(CustomTensor, copyTest) {
|
||||
VLOG(2) << "TestCopy";
|
||||
GroupTestCopy();
|
||||
VLOG(2) << "TestDtype";
|
||||
GroupTestDtype();
|
||||
VLOG(2) << "TestShape";
|
||||
TestAPISizeAndShape();
|
||||
VLOG(2) << "TestPlace";
|
||||
TestAPIPlace();
|
||||
VLOG(2) << "TestCast";
|
||||
GroupTestCast();
|
||||
VLOG(2) << "TestDtypeConvert";
|
||||
GroupTestDtypeConvert();
|
||||
}
|
@ -0,0 +1,145 @@
|
||||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "paddle/fluid/extension/include/tensor.h"
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/platform/gpu_info.h"
|
||||
#include "paddle/fluid/platform/place.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
class CustomTensorUtils {
|
||||
public:
|
||||
/// \brief Share data TO another tensor.
|
||||
/// Use this to pass tensor from op to op
|
||||
/// \return void.
|
||||
static void ShareDataTo(const paddle::Tensor& src, void* dst);
|
||||
|
||||
/// \brief Share data FROM another tensor.
|
||||
/// Use this to pass tensor from op to op
|
||||
/// \return void.
|
||||
static void ShareDataFrom(const void* src, const Tensor& dst);
|
||||
|
||||
static framework::proto::VarType::Type ConvertEnumDTypeToInnerDType(
|
||||
const paddle::DataType& dtype) {
|
||||
switch (dtype) {
|
||||
case paddle::DataType::COMPLEX128:
|
||||
return framework::proto::VarType::COMPLEX128;
|
||||
case paddle::DataType::COMPLEX64:
|
||||
return framework::proto::VarType::COMPLEX64;
|
||||
case paddle::DataType::FLOAT64:
|
||||
return framework::proto::VarType::FP64;
|
||||
case paddle::DataType::FLOAT32:
|
||||
return framework::proto::VarType::FP32;
|
||||
case paddle::DataType::FLOAT16:
|
||||
return framework::proto::VarType::FP16;
|
||||
case paddle::DataType::BFLOAT16:
|
||||
return framework::proto::VarType::BF16;
|
||||
case paddle::DataType::UINT8:
|
||||
return framework::proto::VarType::UINT8;
|
||||
case paddle::DataType::INT8:
|
||||
return framework::proto::VarType::INT8;
|
||||
case paddle::DataType::INT32:
|
||||
return framework::proto::VarType::INT32;
|
||||
case paddle::DataType::INT64:
|
||||
return framework::proto::VarType::INT64;
|
||||
case paddle::DataType::INT16:
|
||||
return framework::proto::VarType::INT16;
|
||||
case paddle::DataType::BOOL:
|
||||
return framework::proto::VarType::BOOL;
|
||||
default:
|
||||
PADDLE_THROW(platform::errors::Unimplemented(
|
||||
"Unsupported data type code(%d) when casting enum data type into "
|
||||
"paddle data type.",
|
||||
static_cast<int>(dtype)));
|
||||
}
|
||||
}
|
||||
|
||||
static paddle::DataType ConvertInnerDTypeToEnumDType(
|
||||
const framework::proto::VarType::Type& dtype) {
|
||||
switch (dtype) {
|
||||
case framework::proto::VarType::COMPLEX128:
|
||||
return paddle::DataType::COMPLEX128;
|
||||
case framework::proto::VarType::COMPLEX64:
|
||||
return paddle::DataType::COMPLEX64;
|
||||
case framework::proto::VarType::FP64:
|
||||
return paddle::DataType::FLOAT64;
|
||||
case framework::proto::VarType::FP32:
|
||||
return paddle::DataType::FLOAT32;
|
||||
case framework::proto::VarType::FP16:
|
||||
return paddle::DataType::FLOAT16;
|
||||
case framework::proto::VarType::BF16:
|
||||
return paddle::DataType::BFLOAT16;
|
||||
case framework::proto::VarType::INT64:
|
||||
return paddle::DataType::INT64;
|
||||
case framework::proto::VarType::INT32:
|
||||
return paddle::DataType::INT32;
|
||||
case framework::proto::VarType::INT8:
|
||||
return paddle::DataType::INT8;
|
||||
case framework::proto::VarType::UINT8:
|
||||
return paddle::DataType::UINT8;
|
||||
case framework::proto::VarType::INT16:
|
||||
return paddle::DataType::INT16;
|
||||
case framework::proto::VarType::BOOL:
|
||||
return paddle::DataType::BOOL;
|
||||
default:
|
||||
PADDLE_THROW(platform::errors::Unimplemented(
|
||||
"Unsupported data type `%s` when casting paddle data type into "
|
||||
"enum data type.",
|
||||
DataTypeToString(dtype)));
|
||||
}
|
||||
}
|
||||
|
||||
// PaddlePlace <-> platform::Place
|
||||
static platform::Place ConvertEnumPlaceToInnerPlace(const PlaceType& pc) {
|
||||
if (pc == PlaceType::kCPU) {
|
||||
return platform::Place(platform::CPUPlace());
|
||||
} else if (pc == PlaceType::kGPU) {
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
return platform::Place(
|
||||
platform::CUDAPlace(platform::GetCurrentDeviceId()));
|
||||
#endif
|
||||
} else {
|
||||
PADDLE_THROW(platform::errors::Unimplemented(
|
||||
"Unsupported place type code(%d) when "
|
||||
"casting enum place to paddle place.",
|
||||
static_cast<int>(pc)));
|
||||
}
|
||||
return platform::Place();
|
||||
}
|
||||
|
||||
static PlaceType ConvertInnerPlaceToEnumPlace(const platform::Place& pc) {
|
||||
if (platform::is_cpu_place(pc)) {
|
||||
return PlaceType::kCPU;
|
||||
} else if (platform::is_gpu_place(pc)) {
|
||||
#ifdef PADDLE_WITH_CUDA
|
||||
return PlaceType::kGPU;
|
||||
#endif
|
||||
} else {
|
||||
PADDLE_THROW(
|
||||
platform::errors::Unimplemented("Unsupported place type `%s` when "
|
||||
"casting paddle place to enum place.",
|
||||
pc));
|
||||
}
|
||||
return PlaceType::kUNK;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,54 @@
|
||||
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/fluid/extension/include/op_meta_info.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace framework {
|
||||
|
||||
class OpMetaInfoHelper {
|
||||
public:
|
||||
static const std::string& GetOpName(const paddle::OpMetaInfo& info) {
|
||||
return info.name_;
|
||||
}
|
||||
static const std::vector<std::string>& GetInputs(
|
||||
const paddle::OpMetaInfo& info) {
|
||||
return info.inputs_;
|
||||
}
|
||||
static const std::vector<std::string>& GetOutputs(
|
||||
const paddle::OpMetaInfo& info) {
|
||||
return info.outputs_;
|
||||
}
|
||||
static const std::vector<std::string>& GetAttrs(
|
||||
const paddle::OpMetaInfo& info) {
|
||||
return info.attrs_;
|
||||
}
|
||||
static const KernelFunc& GetKernelFn(const paddle::OpMetaInfo& info) {
|
||||
return info.kernel_fn_;
|
||||
}
|
||||
static const InferShapeFunc& GetInferShapeFn(const paddle::OpMetaInfo& info) {
|
||||
return info.infer_shape_fn_;
|
||||
}
|
||||
static const InferDtypeFunc& GetInferDtypeFn(const paddle::OpMetaInfo& info) {
|
||||
return info.infer_dtype_fn_;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace framework
|
||||
} // namespace paddle
|
@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
@ -0,0 +1,116 @@
|
||||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
|
||||
#include "paddle/extension.h"
|
||||
|
||||
template <typename data_t>
|
||||
void relu_cpu_forward_kernel(const data_t* x_data,
|
||||
data_t* out_data,
|
||||
int64_t x_numel) {
|
||||
for (int i = 0; i < x_numel; ++i) {
|
||||
out_data[i] = std::max(static_cast<data_t>(0.), x_data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename data_t>
|
||||
void relu_cpu_backward_kernel(const data_t* grad_out_data,
|
||||
const data_t* out_data,
|
||||
data_t* grad_x_data,
|
||||
int64_t out_numel) {
|
||||
for (int i = 0; i < out_numel; ++i) {
|
||||
grad_x_data[i] =
|
||||
grad_out_data[i] * (out_data[i] > static_cast<data_t>(0) ? 1. : 0.);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> relu_cpu_forward(const paddle::Tensor& x) {
|
||||
auto out = paddle::Tensor(paddle::PlaceType::kCPU);
|
||||
out.reshape(x.shape());
|
||||
|
||||
PD_DISPATCH_FLOATING_TYPES(
|
||||
x.type(), "relu_cpu_forward", ([&] {
|
||||
relu_cpu_forward_kernel<data_t>(
|
||||
x.data<data_t>(), out.mutable_data<data_t>(x.place()), x.size());
|
||||
}));
|
||||
|
||||
return {out};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> relu_cpu_backward(const paddle::Tensor& x,
|
||||
const paddle::Tensor& out,
|
||||
const paddle::Tensor& grad_out) {
|
||||
auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU);
|
||||
grad_x.reshape(x.shape());
|
||||
|
||||
PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] {
|
||||
relu_cpu_backward_kernel<data_t>(
|
||||
grad_out.data<data_t>(),
|
||||
out.data<data_t>(),
|
||||
grad_x.mutable_data<data_t>(x.place()),
|
||||
out.size());
|
||||
}));
|
||||
|
||||
return {grad_x};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> relu_cuda_forward(const paddle::Tensor& x);
|
||||
std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x,
|
||||
const paddle::Tensor& out,
|
||||
const paddle::Tensor& grad_out);
|
||||
|
||||
std::vector<paddle::Tensor> ReluForward(const paddle::Tensor& x) {
|
||||
// TODO(chenweihang): Check Input
|
||||
if (x.place() == paddle::PlaceType::kCPU) {
|
||||
return relu_cpu_forward(x);
|
||||
} else if (x.place() == paddle::PlaceType::kGPU) {
|
||||
return relu_cuda_forward(x);
|
||||
} else {
|
||||
throw std::runtime_error("Not implemented.");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> ReluBackward(const paddle::Tensor& x,
|
||||
const paddle::Tensor& out,
|
||||
const paddle::Tensor& grad_out) {
|
||||
// TODO(chenweihang): Check Input
|
||||
if (x.place() == paddle::PlaceType::kCPU) {
|
||||
return relu_cpu_backward(x, out, grad_out);
|
||||
} else if (x.place() == paddle::PlaceType::kGPU) {
|
||||
return relu_cuda_backward(x, out, grad_out);
|
||||
} else {
|
||||
throw std::runtime_error("Not implemented.");
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> ReluInferShape(std::vector<int64_t> x_shape) {
|
||||
return {x_shape};
|
||||
}
|
||||
|
||||
std::vector<paddle::DataType> ReluInferDType(paddle::DataType x_dtype) {
|
||||
return {x_dtype};
|
||||
}
|
||||
|
||||
PD_BUILD_OPERATOR("relu2")
|
||||
.Inputs({"X"})
|
||||
.Outputs({"Out"})
|
||||
.SetKernelFn(PD_KERNEL(ReluForward))
|
||||
.SetInferShapeFn(PD_INFER_SHAPE(ReluInferShape))
|
||||
.SetInferDtypeFn(PD_INFER_DTYPE(ReluInferDType))
|
||||
.SetBackwardOp("relu2_grad")
|
||||
.Inputs({"X", "Out", paddle::Grad("Out")})
|
||||
.Outputs({paddle::Grad("X")})
|
||||
.SetKernelFn(PD_KERNEL(ReluBackward));
|
@ -0,0 +1,73 @@
|
||||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/extension.h"
|
||||
|
||||
template <typename data_t>
|
||||
__global__ void relu_cuda_forward_kernel(const data_t* x,
|
||||
data_t* y,
|
||||
const int num) {
|
||||
int gid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (int i = gid; i < num; i += blockDim.x * gridDim.x) {
|
||||
y[i] = max(x[i], static_cast<data_t>(0.));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename data_t>
|
||||
__global__ void relu_cuda_backward_kernel(const data_t* dy,
|
||||
const data_t* y,
|
||||
data_t* dx,
|
||||
const int num) {
|
||||
int gid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (int i = gid; i < num; i += blockDim.x * gridDim.x) {
|
||||
dx[i] = dy[i] * (y[i] > 0 ? 1. : 0.);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> relu_cuda_forward(const paddle::Tensor& x) {
|
||||
auto out = paddle::Tensor(paddle::PlaceType::kGPU);
|
||||
out.reshape(x.shape());
|
||||
|
||||
int numel = x.size();
|
||||
int block = 512;
|
||||
int grid = (numel + block - 1) / block;
|
||||
PD_DISPATCH_FLOATING_TYPES(
|
||||
x.type(), "relu_cuda_forward_kernel", ([&] {
|
||||
relu_cuda_forward_kernel<data_t><<<grid, block>>>(
|
||||
x.data<data_t>(), out.mutable_data<data_t>(x.place()), numel);
|
||||
}));
|
||||
|
||||
return {out};
|
||||
}
|
||||
|
||||
std::vector<paddle::Tensor> relu_cuda_backward(const paddle::Tensor& x,
|
||||
const paddle::Tensor& out,
|
||||
const paddle::Tensor& grad_out) {
|
||||
auto grad_x = paddle::Tensor(paddle::PlaceType::kGPU);
|
||||
grad_x.reshape(x.shape());
|
||||
|
||||
int numel = out.size();
|
||||
int block = 512;
|
||||
int grid = (numel + block - 1) / block;
|
||||
PD_DISPATCH_FLOATING_TYPES(
|
||||
out.type(), "relu_cuda_backward_kernel", ([&] {
|
||||
relu_cuda_backward_kernel<data_t><<<grid, block>>>(
|
||||
grad_out.data<data_t>(),
|
||||
out.data<data_t>(),
|
||||
grad_x.mutable_data<data_t>(x.place()),
|
||||
numel);
|
||||
}));
|
||||
|
||||
return {grad_x};
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue