You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/fluid/framework/custom_operator.cc

878 lines
36 KiB

/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/custom_operator.h"
#include <algorithm>
#include <functional>
#include <iostream>
#include <map>
#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/fluid/extension/include/ext_tensor.h"
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/custom_tensor_utils.h"
#include "paddle/fluid/framework/op_meta_info_helper.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle {
namespace framework {
namespace detail {
// dynamic lib load func
template <typename T>
static T* DynLoad(void* handle, std::string name) {
T* func = reinterpret_cast<T*>(dlsym(handle, name.c_str()));
#if !defined(_WIN32)
auto errorno = dlerror();
#else
auto errorno = GetLastError();
#endif // !_WIN32
PADDLE_ENFORCE_NOT_NULL(
func, platform::errors::NotFound(
"Failed to load dynamic operator library, error message(%s).",
errorno));
return func;
}
inline bool IsGradVar(const std::string& var_name) {
std::string suffix = kGradVarSuffix;
return var_name.rfind(suffix) != std::string::npos;
}
inline bool IsDuplicableVar(const std::string& var_name) {
std::string suffix = kTensorVectorSuffix;
return var_name.rfind(suffix) != std::string::npos;
}
inline std::string NoGrad(const std::string& var_name) {
std::string suffix = kGradVarSuffix;
return var_name.substr(0, var_name.size() - kGradVarSuffixSize);
}
inline bool IsMemberOf(const std::vector<std::string>& vec,
const std::string& name) {
return std::find(vec.cbegin(), vec.cend(), name) != vec.cend();
}
std::vector<std::string> ParseAttrStr(const std::string& attr) {
auto split_pos = attr.find_first_of(":");
PADDLE_ENFORCE_NE(split_pos, std::string::npos,
platform::errors::InvalidArgument(
"Invalid attribute string format. Attribute string "
"format is `<name>:<type>`."));
std::vector<std::string> rlt;
// 1. name
rlt.emplace_back(string::trim_spaces(attr.substr(0, split_pos)));
// 2. type
rlt.emplace_back(string::trim_spaces(attr.substr(split_pos + 1)));
VLOG(1) << "attr name: " << rlt[0] << ", attr type str: " << rlt[1];
return rlt;
}
} // namespace detail
////////////////// Kernel Define ////////////////////
// custom op kernel call function define
static void RunKernelFunc(const framework::ExecutionContext& ctx,
const paddle::KernelFunc& func,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const std::vector<std::string>& attrs) {
VLOG(1) << "Custom Operator: Start run KernelFunc.";
std::vector<paddle::Tensor> custom_ins;
std::vector<std::vector<paddle::Tensor>> custom_vec_ins;
for (auto& in_name : inputs) {
VLOG(1) << "Custom Operator: input name - " << in_name;
if (detail::IsDuplicableVar(in_name)) {
// return const std::vector<const Tensor*>
auto vec_x = ctx.MultiInput<Tensor>(in_name);
PADDLE_ENFORCE_NE(vec_x.empty(), true,
platform::errors::NotFound(
"Input vector<tensor> (%s) is empty.", in_name));
std::vector<paddle::Tensor> custom_vec_in;
for (size_t i = 0; i < vec_x.size(); ++i) {
auto* x = vec_x[i];
PADDLE_ENFORCE_NOT_NULL(
x, platform::errors::NotFound(
"The %d-th tensor in input vector<tensor> (%s) is nullptr.",
i, in_name));
PADDLE_ENFORCE_EQ(x->IsInitialized(), true,
platform::errors::InvalidArgument(
"The %d-th tensor in input vector<tensor> (%s) "
"is not initialized.",
i, in_name));
auto custom_t = paddle::Tensor(
CustomTensorUtils::ConvertInnerPlaceToEnumPlace(x->place()));
CustomTensorUtils::ShareDataFrom(static_cast<const void*>(x), custom_t);
CustomTensorUtils::SetTensorCurrentStream(&custom_t, ctx.GetPlace());
custom_vec_in.emplace_back(custom_t);
}
custom_vec_ins.emplace_back(custom_vec_in);
} else {
auto* x = ctx.Input<Tensor>(in_name);
PADDLE_ENFORCE_NOT_NULL(x, platform::errors::NotFound(
"Input tensor (%s) is nullptr.", in_name));
PADDLE_ENFORCE_EQ(x->IsInitialized(), true,
platform::errors::InvalidArgument(
"Input tensor (%s) is not initialized.", in_name));
auto custom_in = paddle::Tensor(
CustomTensorUtils::ConvertInnerPlaceToEnumPlace(x->place()));
CustomTensorUtils::ShareDataFrom(static_cast<const void*>(x), custom_in);
CustomTensorUtils::SetTensorCurrentStream(&custom_in, ctx.GetPlace());
custom_ins.emplace_back(custom_in);
}
}
std::vector<boost::any> custom_attrs;
for (auto& attr_str : attrs) {
auto attr_name_and_type = detail::ParseAttrStr(attr_str);
auto attr_name = attr_name_and_type[0];
auto attr_type_str = attr_name_and_type[1];
if (attr_type_str == "bool") {
custom_attrs.emplace_back(ctx.Attr<bool>(attr_name));
} else if (attr_type_str == "int") {
custom_attrs.emplace_back(ctx.Attr<int>(attr_name));
} else if (attr_type_str == "float") {
custom_attrs.emplace_back(ctx.Attr<float>(attr_name));
} else if (attr_type_str == "int64_t") {
custom_attrs.emplace_back(ctx.Attr<int64_t>(attr_name));
} else if (attr_type_str == "std::string") {
custom_attrs.emplace_back(ctx.Attr<std::string>(attr_name));
} else if (attr_type_str == "std::vector<int>") {
custom_attrs.emplace_back(ctx.Attr<std::vector<int>>(attr_name));
} else if (attr_type_str == "std::vector<float>") {
custom_attrs.emplace_back(ctx.Attr<std::vector<float>>(attr_name));
} else if (attr_type_str == "std::vector<int64_t>") {
custom_attrs.emplace_back(ctx.Attr<std::vector<int64_t>>(attr_name));
} else if (attr_type_str == "std::vector<std::string>") {
custom_attrs.emplace_back(ctx.Attr<std::vector<std::string>>(attr_name));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported `%s` type value as custom attribute now. "
"Supported data types include `bool`, `int`, `float`, "
"`int64_t`, `std::string`, `std::vector<int>`, "
"`std::vector<float>`, `std::vector<int64_t>`, "
"`std::vector<std::string>`, Please check whether "
"the attribute data type and data type string are matched.",
attr_type_str));
}
}
VLOG(1) << "Custom Operator: Run ComputeFunc.";
try {
auto outs = func(custom_ins, custom_vec_ins, custom_attrs);
VLOG(1) << "Custom Operator: Share outputs into ExecutionContext.";
for (size_t i = 0; i < outputs.size(); ++i) {
auto out_name = outputs[i];
if (detail::IsDuplicableVar(out_name)) {
PADDLE_ENFORCE(i == 0UL && outputs.size() == 1UL,
platform::errors::PreconditionNotMet(
"If custom operator's outputs contains `paddle::Vec("
")` type, "
"it only can hold one output."));
auto vec_true_outs = ctx.MultiOutput<Tensor>(out_name);
PADDLE_ENFORCE_EQ(
vec_true_outs.size(), outs.size(),
platform::errors::InvalidArgument(
"The number of element in custom operator outputs is wrong, "
"expected contains %d Tensors, but actually contains %d "
"Tensors.",
vec_true_outs.size(), outs.size()));
for (size_t j = 0; j < vec_true_outs.size(); ++j) {
CustomTensorUtils::ShareDataTo(outs.at(j), vec_true_outs.at(j));
}
} else {
auto* true_out = ctx.Output<Tensor>(out_name);
CustomTensorUtils::ShareDataTo(outs.at(i), true_out);
}
}
} catch (platform::EnforceNotMet& exception) {
throw std::move(exception);
} catch (std::exception& ex) {
PADDLE_THROW(platform::errors::External("%s", ex.what()));
} catch (...) {
PADDLE_THROW(platform::errors::Fatal(
"Custom operator raises an unknown exception in rumtime."));
}
}
//////////////////// Operator Define /////////////////
class CustomOperator : public OperatorWithKernel {
public:
using OperatorWithKernel::OperatorWithKernel;
// Dummy infershape
// Because it is a pure virtual function, it must be implemented
void InferShape(framework::InferShapeContext* ctx) const override {
VLOG(1) << "Custom Operator: Dummy infer shape of custom operator.";
}
/**
* NOTE: [Skip the Kernel Selection]
* Custom Op only registers one Op kernel on each device, so that the
* data type selection and promotion that depends on GetExpectedKernelType,
* as well as the adaptation of various other special situations,
* need users to implement, to avoid users needs to implement
* GetExpectedKernelType function when expanding other cases.
* The RAW type is used here as the data type, indicating that
* it can only be determined at runtime.
*/
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(proto::VarType::RAW, ctx.GetPlace());
}
/**
* NOTE: [Skip Input Variable Cast for DataType]
* Because the kernel data type is RAW, we should skip the cast for
* data type difference when PrepareData.
*/
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const OpKernelType& expected_kernel_type) {
return OpKernelType(expected_kernel_type.data_type_,
expected_kernel_type.place_, tensor.layout());
}
};
class CustomOpMaker : public OpProtoAndCheckerMaker {
public:
explicit CustomOpMaker(const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const std::vector<std::string>& attrs)
: inputs_(inputs), outputs_(outputs), attrs_(attrs) {}
void Make() override {
for (auto& in_name : inputs_) {
if (detail::IsDuplicableVar(in_name)) {
AddInput(in_name, "The input " + in_name + "of Custom operator.")
.AsDuplicable();
} else {
AddInput(in_name, "The input " + in_name + "of Custom operator.");
}
}
for (auto& out_name : outputs_) {
if (detail::IsDuplicableVar(out_name)) {
AddOutput(out_name, "The output " + out_name + "of Custom Operator.")
.AsDuplicable();
} else {
AddOutput(out_name, "The output " + out_name + "of Custom Operator.");
}
}
for (auto& attr : attrs_) {
auto attr_name_and_type = detail::ParseAttrStr(attr);
auto attr_name = attr_name_and_type[0];
auto attr_type_str = attr_name_and_type[1];
if (attr_type_str == "bool") {
AddAttr<bool>(attr_name, "custom operator bool attribute.")
.SetDefault(false);
} else if (attr_type_str == "int") {
AddAttr<int>(attr_name, "custom operator int attribute.").SetDefault(1);
} else if (attr_type_str == "float") {
AddAttr<float>(attr_name, "custom operator float attribute.")
.SetDefault(1.0f);
} else if (attr_type_str == "int64_t") {
AddAttr<int64_t>(attr_name, "custom operator int64_t attribute.")
.SetDefault(1);
} else if (attr_type_str == "std::string") {
AddAttr<std::string>(attr_name, "custom operator int attribute.")
.SetDefault("");
} else if (attr_type_str == "std::vector<int>") {
AddAttr<std::vector<int>>(attr_name,
"custom operator std::vector<int> attribute.")
.SetDefault({});
} else if (attr_type_str == "std::vector<float>") {
AddAttr<std::vector<float>>(
attr_name, "custom operator std::vector<float> attribute.")
.SetDefault({});
} else if (attr_type_str == "std::vector<int64_t>") {
AddAttr<std::vector<int64_t>>(
attr_name, "custom operator std::vector<int64_t> attribute.")
.SetDefault({});
} else if (attr_type_str == "std::vector<std::string>") {
AddAttr<std::vector<std::string>>(
attr_name, "custom operator std::vector<std::string> attribute.")
.SetDefault({});
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported `%s` type value as custom attribute now. "
"Supported data types include `bool`, `int`, `float`, "
"`int64_t`, `std::string`, `std::vector<int>`, "
"`std::vector<float>`, `std::vector<int64_t>`, "
"`std::vector<std::string>`, Please check whether "
"the attribute data type and data type string are matched.",
attr_type_str));
}
}
AddComment(R"DOC(
Custom Operator.
According to the Tensor operation function implemented by the user
independently of the framework, it is encapsulated into a framework
operator to adapt to various execution scenarios such as dynamic graph,
mode static graph mode, and inference mode.
)DOC");
}
private:
std::vector<std::string> inputs_;
std::vector<std::string> outputs_;
std::vector<std::string> attrs_;
};
template <typename T>
class CustomGradOpMaker;
template <>
class CustomGradOpMaker<OpDesc> : public SingleGradOpMaker<OpDesc> {
public:
explicit CustomGradOpMaker(
const OpDesc& fwd_op, const std::unordered_set<std::string>& no_grad_set,
std::unordered_map<std::string, std::string>* grad_to_var,
const std::vector<BlockDesc*>& grad_block, const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs)
: SingleGradOpMaker<OpDesc>(fwd_op, no_grad_set, grad_to_var, grad_block),
name_(name),
inputs_(inputs),
outputs_(outputs) {}
protected:
void Apply(GradOpPtr<OpDesc> grad_op) const override {
grad_op->SetType(name_);
auto fwd_op_inputs = this->InputNames();
auto fwd_op_outputs = this->OutputNames();
for (auto& in_name : inputs_) {
VLOG(1) << "Custom Operator: GradOpDescMaker - input: " << in_name;
if (!detail::IsGradVar(in_name)) {
if (detail::IsMemberOf(fwd_op_inputs, in_name)) {
grad_op->SetInput(in_name, this->Input(in_name));
} else if (detail::IsMemberOf(fwd_op_outputs, in_name)) {
grad_op->SetInput(in_name, this->Output(in_name));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The input tensor name `%s` is invalid, expected it is the input "
"or output of forward operator.",
in_name));
}
} else {
grad_op->SetInput(in_name, this->OutputGrad(detail::NoGrad(in_name)));
}
}
for (auto& out_name : outputs_) {
VLOG(1) << "Custom Operator: GradOpDescMaker - output: " << out_name;
if (detail::IsDuplicableVar(out_name)) {
grad_op->SetOutput(out_name,
this->InputGrad(detail::NoGrad(out_name),
/*drop_empty_grad=*/false));
} else {
grad_op->SetOutput(out_name, this->InputGrad(detail::NoGrad(out_name)));
}
}
grad_op->SetAttrMap(this->Attrs());
}
private:
std::string name_;
std::vector<std::string> inputs_;
std::vector<std::string> outputs_;
};
template <>
class CustomGradOpMaker<imperative::OpBase>
: public SingleGradOpMaker<imperative::OpBase> {
public:
explicit CustomGradOpMaker(
const std::string& type,
const imperative::NameVarBaseMap& var_base_map_in,
const imperative::NameVarBaseMap& var_base_map_out,
const AttributeMap& attrs,
const std::map<std::string, std::string>& inplace_map,
const std::string& name, const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs)
: SingleGradOpMaker<imperative::OpBase>(
type, var_base_map_in, var_base_map_out, attrs, inplace_map),
name_(name),
inputs_(inputs),
outputs_(outputs) {}
protected:
// TODO(chenweihang): The code is duplicated with the previous one, because
// ere OpMaker's Input, Output and other methods are protected. Putting the
// function implementation outside the class will cause the method to be
// uncallable,
// so it is still implemented in the class for the time being.
void Apply(GradOpPtr<imperative::OpBase> grad_op) const override {
grad_op->SetType(name_);
auto fwd_op_inputs = this->InputNames();
auto fwd_op_outputs = this->OutputNames();
for (auto& in_name : inputs_) {
VLOG(1) << "Custom Operator: GradOpBaseMaker - input: " << in_name;
if (!detail::IsGradVar(in_name)) {
if (detail::IsMemberOf(fwd_op_inputs, in_name)) {
grad_op->SetInput(in_name, this->Input(in_name));
} else if (detail::IsMemberOf(fwd_op_outputs, in_name)) {
grad_op->SetInput(in_name, this->Output(in_name));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The input tensor name `%s` is invalid, expected it is the input "
"or output of forward operator.",
in_name));
}
} else {
grad_op->SetInput(in_name, this->OutputGrad(detail::NoGrad(in_name)));
}
}
for (auto& out_name : outputs_) {
VLOG(1) << "Custom Operator: GradOpBaseMaker - output: " << out_name;
grad_op->SetOutput(out_name, this->InputGrad(detail::NoGrad(out_name)));
}
grad_op->SetAttrMap(this->Attrs());
}
private:
std::string name_;
std::vector<std::string> inputs_;
std::vector<std::string> outputs_;
};
//////////// Operator and Kernel Register //////////////
void RegisterOperatorKernelWithPlace(const std::string& name,
const paddle::KernelFunc& kernel_func,
const proto::VarType::Type type,
const PlaceType& place,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const std::vector<std::string>& attrs) {
OpKernelType key(type,
CustomTensorUtils::ConvertEnumPlaceToInnerPlace(place));
VLOG(1) << "Custom Operator: op kernel key: " << key;
OperatorWithKernel::AllOpKernels()[name][key] =
[kernel_func, inputs, outputs,
attrs](const framework::ExecutionContext& ctx) {
VLOG(1) << "Custom Operator: run custom kernel func in lambda.";
RunKernelFunc(ctx, kernel_func, inputs, outputs, attrs);
};
}
void RegisterOperatorKernel(const std::string& name,
const paddle::KernelFunc& kernel_func,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const std::vector<std::string>& attrs) {
VLOG(1) << "Custom Operator: op name in kernel: " << name;
// NOTE [ Dummy Op Kernel Key ]
// TODO(chenweihang): Because execute engine need get device context based
// op_kernel_key.place_, so we should register kernel for each
// device. But this is not entirely correct, if user only give a cpu kernel,
// but call api in gpu device, it will cause error.
RegisterOperatorKernelWithPlace(name, kernel_func, proto::VarType::RAW,
PlaceType::kCPU, inputs, outputs, attrs);
#ifdef PADDLE_WITH_CUDA
RegisterOperatorKernelWithPlace(name, kernel_func, proto::VarType::RAW,
PlaceType::kGPU, inputs, outputs, attrs);
#endif
}
void RegisterOperatorWithMetaInfo(
const std::vector<OpMetaInfo>& op_meta_infos) {
/* Op register */
OpInfo info;
auto& base_op_meta = op_meta_infos.front();
auto op_name = OpMetaInfoHelper::GetOpName(base_op_meta);
auto& op_inputs = OpMetaInfoHelper::GetInputs(base_op_meta);
auto& op_outputs = OpMetaInfoHelper::GetOutputs(base_op_meta);
auto& op_attrs = OpMetaInfoHelper::GetAttrs(base_op_meta);
auto& kernel_fn = OpMetaInfoHelper::GetKernelFn(base_op_meta);
auto& infer_shape_func = OpMetaInfoHelper::GetInferShapeFn(base_op_meta);
auto& infer_dtype_func = OpMetaInfoHelper::GetInferDtypeFn(base_op_meta);
VLOG(1) << "Custom Operator: forward, op name: " << op_name;
VLOG(1) << "Custom Operator: forward, op inputs: "
<< string::join_strings(op_inputs, ',');
VLOG(1) << "Custom Operator: forward, op outputs: "
<< string::join_strings(op_outputs, ',');
VLOG(1) << "Custom Operator: forward, op attrs: "
<< string::join_strings(op_attrs, ',');
// Op
info.creator_ = [](const std::string& op_name, const VariableNameMap& inputs,
const VariableNameMap& outputs,
const AttributeMap& attrs) {
return new CustomOperator(op_name, inputs, outputs, attrs);
};
// OpMaker
info.proto_ = new proto::OpProto;
info.proto_->set_type(op_name);
info.checker_ = new OpAttrChecker();
CustomOpMaker custom_maker(op_inputs, op_outputs, op_attrs);
custom_maker(info.proto_, info.checker_);
PADDLE_ENFORCE_EQ(
info.proto_->IsInitialized(), true,
platform::errors::PreconditionNotMet(
"Fail to initialize %s's OpProto, because %s is not initialized.",
op_name, info.proto_->InitializationErrorString()));
// InferShape
if (infer_shape_func == nullptr) {
// use default InferShape
info.infer_shape_ = [op_inputs, op_outputs](InferShapeContext* ctx) {
PADDLE_ENFORCE_EQ(
op_inputs.size(), 1UL,
platform::errors::Unavailable(
"Your custom operator contains multiple inputs. "
"We only allow a custom operator that contains only one input "
"and only one output without setting the InferShapeFn. "
"At this time, the input shape will be directly set to "
"the output shape.\n"
"Please set the InferShapeFn of custom "
"operator by .SetInferShapeFn(PD_INFER_SHAPE(...))"));
PADDLE_ENFORCE_EQ(
op_outputs.size(), 1UL,
platform::errors::Unavailable(
"Your custom operator contains multiple outputs. "
"We only allow a custom operator that contains only one input "
"and only one output without setting the InferShapeFn. "
"At this time, the input shape will be directly set to "
"the output shape.\n"
"Please set the InferShapeFn of custom "
"operator by .SetInferShapeFn(PD_INFER_SHAPE(...))"));
VLOG(1) << "Custom Operator: Default InferShape - share ddim.";
ctx->ShareDim(op_inputs[0], op_outputs[0]);
};
} else {
info.infer_shape_ = [op_inputs, op_outputs, op_attrs,
infer_shape_func](InferShapeContext* ctx) {
std::vector<std::vector<int64_t>> input_shapes;
std::vector<std::vector<std::vector<int64_t>>> vec_input_shapes;
VLOG(1) << "Custom Operator: InferShape - get input ddim.";
for (auto& in_name : op_inputs) {
if (detail::IsDuplicableVar(in_name)) {
OP_INOUT_CHECK(ctx->HasInputs(in_name), "Input", in_name, "Custom");
auto vec_ddim = ctx->GetInputsDim(in_name);
std::vector<std::vector<int64_t>> vec_shape;
vec_shape.reserve(vec_ddim.size());
std::transform(vec_ddim.begin(), vec_ddim.end(),
std::back_inserter(vec_shape),
[&](const DDim& ddim) -> std::vector<int64_t> {
return framework::vectorize(ddim);
});
vec_input_shapes.emplace_back(vec_shape);
} else {
OP_INOUT_CHECK(ctx->HasInput(in_name), "Input", in_name, "Custom");
auto ddim = ctx->GetInputDim(in_name);
input_shapes.emplace_back(framework::vectorize(ddim));
}
}
std::vector<boost::any> custom_attrs;
for (auto& attr_str : op_attrs) {
auto attr_name_and_type = detail::ParseAttrStr(attr_str);
auto attr_name = attr_name_and_type[0];
auto attr_type_str = attr_name_and_type[1];
if (attr_type_str == "bool") {
custom_attrs.emplace_back(ctx->Attrs().Get<bool>(attr_name));
} else if (attr_type_str == "int") {
custom_attrs.emplace_back(ctx->Attrs().Get<int>(attr_name));
} else if (attr_type_str == "float") {
custom_attrs.emplace_back(ctx->Attrs().Get<float>(attr_name));
} else if (attr_type_str == "int64_t") {
custom_attrs.emplace_back(ctx->Attrs().Get<int64_t>(attr_name));
} else if (attr_type_str == "std::string") {
custom_attrs.emplace_back(ctx->Attrs().Get<std::string>(attr_name));
} else if (attr_type_str == "std::vector<int>") {
custom_attrs.emplace_back(
ctx->Attrs().Get<std::vector<int>>(attr_name));
} else if (attr_type_str == "std::vector<float>") {
custom_attrs.emplace_back(
ctx->Attrs().Get<std::vector<float>>(attr_name));
} else if (attr_type_str == "std::vector<int64_t>") {
// NOTE(chenweihang): InferShape can't support std::vector<int64_t>
// attr type, because the input type is std::vector<int64_t>, only
// can use one rule to parse std::vector<int64_t> parameter
continue;
} else if (attr_type_str == "std::vector<std::string>") {
custom_attrs.emplace_back(
ctx->Attrs().Get<std::vector<std::string>>(attr_name));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported `%s` type value as custom attribute now. "
"Supported data types include `bool`, `int`, `float`, "
"`int64_t`, `std::string`, `std::vector<int>`, "
"`std::vector<float>`, `std::vector<std::string>`, "
"Please check whether the attribute data type and "
"data type string are matched.",
attr_type_str));
}
}
VLOG(1) << "Custom Operator: InferShape - calc output ddim.";
auto output_shapes =
infer_shape_func(input_shapes, vec_input_shapes, custom_attrs);
VLOG(1) << "Custom Operator: InferShape - set output ddim.";
for (size_t i = 0; i < op_outputs.size(); ++i) {
auto out_name = op_outputs[i];
if (detail::IsDuplicableVar(out_name)) {
std::vector<DDim> vec_ddim;
vec_ddim.reserve(output_shapes.size());
std::transform(output_shapes.begin(), output_shapes.end(),
std::back_inserter(vec_ddim),
[&](const std::vector<int64_t>& shape) -> DDim {
return framework::make_ddim(shape);
});
ctx->SetOutputsDim(out_name, vec_ddim);
} else {
ctx->SetOutputDim(out_name, framework::make_ddim(output_shapes[i]));
}
}
};
}
// Infer Dtype
if (infer_dtype_func == nullptr) {
// use defalut InferDtype
info.infer_var_type_ = [op_inputs, op_outputs](InferVarTypeContext* ctx) {
PADDLE_ENFORCE_EQ(
op_inputs.size(), 1UL,
platform::errors::Unavailable(
"Your custom operator contains multiple inputs. "
"We only allow a custom operator that contains only one input "
"and only one output without setting the InferDtypeFn. "
"At this time, the input dtype will be directly set to "
"the output dtype.\n"
"Please set the InferDtypeFn of custom "
"operator by .SetInferDtypeFn(PD_INFER_DTYPE(...))"));
PADDLE_ENFORCE_EQ(
op_outputs.size(), 1UL,
platform::errors::Unavailable(
"Your custom operator contains multiple outputs. "
"We only allow a custom operator that contains only one input "
"and only one output without setting the InferDtypeFn. "
"At this time, the input dtype will be directly set to "
"the output dtype.\n"
"Please set the InferDtypeFn of custom "
"operator by .SetInferDtypeFn(PD_INFER_DTYPE(...))"));
VLOG(1) << "Custom Operator: InferDtype - share dtype.";
auto dtype = ctx->GetInputDataType(op_inputs[0]);
ctx->SetOutputDataType(op_outputs[0], dtype);
};
} else {
info.infer_var_type_ = [op_inputs, op_outputs,
infer_dtype_func](InferVarTypeContext* ctx) {
std::vector<DataType> input_dtypes;
std::vector<std::vector<DataType>> vec_input_dtypes;
VLOG(1) << "Custom Operator: InferDtype - get input dtype.";
for (auto& in_name : op_inputs) {
if (detail::IsDuplicableVar(in_name)) {
std::vector<DataType> vec_custom_dtype;
for (size_t i = 0; i < ctx->InputSize(in_name); ++i) {
auto dtype = ctx->GetInputDataType(in_name, i);
vec_custom_dtype.emplace_back(
CustomTensorUtils::ConvertInnerDTypeToEnumDType(dtype));
}
vec_input_dtypes.emplace_back(vec_custom_dtype);
} else {
auto dtype = ctx->GetInputDataType(in_name);
input_dtypes.emplace_back(
CustomTensorUtils::ConvertInnerDTypeToEnumDType(dtype));
}
}
VLOG(1) << "Custom Operator: InferDtype - infer output dtype.";
auto output_dtypes = infer_dtype_func(input_dtypes, vec_input_dtypes);
VLOG(1) << "Custom Operator: InferDtype - set output dtype.";
for (size_t i = 0; i < op_outputs.size(); ++i) {
auto out_name = op_outputs[i];
if (detail::IsDuplicableVar(out_name)) {
for (size_t j = 0; j < output_dtypes.size(); ++j) {
auto dtype = CustomTensorUtils::ConvertEnumDTypeToInnerDType(
output_dtypes[i]);
ctx->SetOutputDataType(out_name, dtype, j);
}
} else {
ctx->SetOutputDataType(
out_name, CustomTensorUtils::ConvertEnumDTypeToInnerDType(
output_dtypes[i]));
}
}
};
}
// Kernel func
RegisterOperatorKernel(op_name, kernel_fn, op_inputs, op_outputs, op_attrs);
// If grad op or double grad op exists
std::string cur_op_name = op_name;
for (size_t i = 1; i < op_meta_infos.size(); ++i) {
auto& cur_grad_op = op_meta_infos[i];
auto& grad_op_name = OpMetaInfoHelper::GetOpName(cur_grad_op);
auto& grad_op_inputs = OpMetaInfoHelper::GetInputs(cur_grad_op);
auto& grad_op_outputs = OpMetaInfoHelper::GetOutputs(cur_grad_op);
auto& grad_op_attrs = OpMetaInfoHelper::GetAttrs(cur_grad_op);
auto& grad_kernel_fn = OpMetaInfoHelper::GetKernelFn(cur_grad_op);
VLOG(1) << "Custom Operator: backward, op name: " << grad_op_name;
VLOG(1) << "Custom Operator: backward, op inputs: "
<< string::join_strings(grad_op_inputs, ',');
VLOG(1) << "Custom Operator: backward, op outputs: "
<< string::join_strings(grad_op_outputs, ',');
// GradOpDescMaker
info.grad_op_maker_ = [grad_op_name, grad_op_inputs, grad_op_outputs](
const OpDesc& fwd_op,
const std::unordered_set<std::string>& no_grad_set,
std::unordered_map<std::string, std::string>* grad_to_var,
const std::vector<BlockDesc*>& grad_block) {
CustomGradOpMaker<paddle::framework::OpDesc> maker(
fwd_op, no_grad_set, grad_to_var, grad_block, grad_op_name,
grad_op_inputs, grad_op_outputs);
return maker();
};
// GradOpBaseMaker
info.dygraph_grad_op_maker_ = [grad_op_name, grad_op_inputs,
grad_op_outputs](
const std::string& type,
const imperative::NameVarBaseMap& var_base_map_in,
const imperative::NameVarBaseMap& var_base_map_out,
const framework::AttributeMap& attrs,
const std::map<std::string, std::string>& inplace_map) {
CustomGradOpMaker<paddle::imperative::OpBase> maker(
type, var_base_map_in, var_base_map_out, attrs, inplace_map,
grad_op_name, grad_op_inputs, grad_op_outputs);
return maker();
};
/* Grad op register */
OpInfo grad_info;
// Grad Op
grad_info.creator_ = [](
const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs) {
return new CustomOperator(type, inputs, outputs, attrs);
};
// Grad InferShape
grad_info.infer_shape_ = [grad_op_inputs,
grad_op_outputs](InferShapeContext* ctx) {
// 1. if forward input exists, gradient's shape is same with forward input
// default
// [Suitable for most situations]
// 2. if forward input not exists, and only contains one grad input and
// output,
// use grad input shape as grad output shape
// [Suitable for the situation that forward input is not used as
// backward input]
// TODO(chenweihang): support set grad op infershape func if needed
for (auto& out_name : grad_op_outputs) {
auto fwd_name = detail::NoGrad(out_name);
if (detail::IsDuplicableVar(fwd_name)) {
// Duplicable forward var must as backward input
ctx->ShareDim(fwd_name, out_name);
} else {
if (ctx->HasInput(fwd_name)) {
ctx->ShareDim(fwd_name, out_name);
} else {
PADDLE_ENFORCE_EQ(
grad_op_inputs.size() == 1UL && grad_op_outputs.size() == 1UL,
true,
platform::errors::Unavailable(
"Custom grad operator infershape error. "
"If a custom grad operator contains only one input and "
"only one output, the input shape will be directly set to "
"the output shape. Otherwise, Please set the forward input "
"as the grad operator's input."));
ctx->ShareDim(grad_op_inputs[0], out_name);
}
}
}
};
// Kernel func
RegisterOperatorKernel(grad_op_name, grad_kernel_fn, grad_op_inputs,
grad_op_outputs, grad_op_attrs);
// update current info
OpInfoMap::Instance().Insert(cur_op_name, info);
cur_op_name = grad_op_name;
info = grad_info;
}
// insert last info
OpInfoMap::Instance().Insert(cur_op_name, info);
}
void RegisterOperatorWithMetaInfoMap(
const paddle::OpMetaInfoMap& op_meta_info_map) {
auto& meta_info_map = op_meta_info_map.GetMap();
VLOG(1) << "Custom Operator: size of op meta info map - "
<< meta_info_map.size();
// pair: {op_type, OpMetaInfo}
for (auto& pair : meta_info_map) {
VLOG(1) << "Custom Operator: pair first -> op name: " << pair.first;
RegisterOperatorWithMetaInfo(pair.second);
}
}
////////////////////// User APIs ///////////////////////
// load op api
void LoadOpMetaInfoAndRegisterOp(const std::string& dso_name) {
void* handle = paddle::platform::dynload::GetOpDsoHandle(dso_name);
typedef OpMetaInfoMap& get_op_meta_info_map_t();
auto* get_op_meta_info_map =
detail::DynLoad<get_op_meta_info_map_t>(handle, "PD_GetOpMetaInfoMap");
auto& op_meta_info_map = get_op_meta_info_map();
RegisterOperatorWithMetaInfoMap(op_meta_info_map);
}
} // namespace framework
} // namespace paddle