|
|
|
|
@ -70,7 +70,7 @@ NpuOpRunner::NpuOpRunner(std::string op_type) : op_type_(op_type) {
|
|
|
|
|
|
|
|
|
|
NpuOpRunner::NpuOpRunner(std::string op_type, const std::vector<Tensor> &inputs,
|
|
|
|
|
const std::vector<Tensor> &outputs,
|
|
|
|
|
const AttributeMap &attrs)
|
|
|
|
|
const NPUAttributeMap &attrs)
|
|
|
|
|
: op_type_(op_type) {
|
|
|
|
|
attr_ = aclopCreateAttr();
|
|
|
|
|
AddInputs(inputs);
|
|
|
|
|
@ -85,7 +85,7 @@ NpuOpRunner::~NpuOpRunner() {
|
|
|
|
|
const std::string &NpuOpRunner::Type() { return op_type_; }
|
|
|
|
|
|
|
|
|
|
NpuOpRunner &NpuOpRunner::AddAttr(const std::string &name,
|
|
|
|
|
const Attribute &attr) {
|
|
|
|
|
const NPUAttribute &attr) {
|
|
|
|
|
if (attr.type() == typeid(bool)) {
|
|
|
|
|
PADDLE_ENFORCE_NPU_SUCCESS(
|
|
|
|
|
aclopSetAttrBool(attr_, name.c_str(), BOOST_GET_CONST(bool, attr)));
|
|
|
|
|
@ -135,6 +135,16 @@ NpuOpRunner &NpuOpRunner::AddAttr(const std::string &name,
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_NPU_SUCCESS(
|
|
|
|
|
aclopSetAttrListString(attr_, name.c_str(), s.size(), s.data()));
|
|
|
|
|
} else if (attr.type() == typeid(std::vector<std::vector<int64_t>>)) {
|
|
|
|
|
auto a = BOOST_GET_CONST(std::vector<std::vector<int64_t>>, attr);
|
|
|
|
|
std::vector<int64_t *> data;
|
|
|
|
|
std::vector<int> num;
|
|
|
|
|
for (auto &&v : a) {
|
|
|
|
|
data.push_back(v.data());
|
|
|
|
|
num.push_back(v.size());
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_NPU_SUCCESS(
|
|
|
|
|
aclopSetAttrListListInt(attr_, name.c_str(), data.size(), num.data(), data.data()));
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW(platform::errors::Unimplemented(
|
|
|
|
|
"Can not convert attribubte '%s' to convert to aclopAttr", name));
|
|
|
|
|
@ -142,7 +152,7 @@ NpuOpRunner &NpuOpRunner::AddAttr(const std::string &name,
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
NpuOpRunner &NpuOpRunner::AddAttrs(const AttributeMap &attrs) {
|
|
|
|
|
NpuOpRunner &NpuOpRunner::AddAttrs(const NPUAttributeMap &attrs) {
|
|
|
|
|
for (const auto &pair : attrs) {
|
|
|
|
|
AddAttr(pair.first, pair.second);
|
|
|
|
|
}
|
|
|
|
|
|