|
|
|
@ -1,9 +1,23 @@
|
|
|
|
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
|
|
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/framework/ddim.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
|
|
|
|
|
|
///@cond HIDDEN
|
|
|
|
|
/// @cond HIDDEN
|
|
|
|
|
|
|
|
|
|
template <int i>
|
|
|
|
|
Dim<i> make_dim(const int* d) {
|
|
|
|
@ -50,7 +64,7 @@ void make_ddim(DDim& ddim, const int* dims, int n) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
///@endcond
|
|
|
|
|
/// @endcond
|
|
|
|
|
|
|
|
|
|
DDim make_ddim(std::initializer_list<int> dims) {
|
|
|
|
|
DDim result(make_dim(0));
|
|
|
|
@ -64,11 +78,11 @@ DDim make_ddim(const std::vector<int>& dims) {
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
///@cond HIDDEN
|
|
|
|
|
/// @cond HIDDEN
|
|
|
|
|
// XXX For some reason, putting this in an anonymous namespace causes errors
|
|
|
|
|
class DynamicMutableIndexer : public boost::static_visitor<int&> {
|
|
|
|
|
public:
|
|
|
|
|
DynamicMutableIndexer(int idx) : idx_(idx) {}
|
|
|
|
|
explicit DynamicMutableIndexer(int idx) : idx_(idx) {}
|
|
|
|
|
|
|
|
|
|
template <int D>
|
|
|
|
|
int& operator()(Dim<D>& dim) const {
|
|
|
|
@ -81,7 +95,7 @@ class DynamicMutableIndexer : public boost::static_visitor<int&> {
|
|
|
|
|
|
|
|
|
|
class DynamicConstIndexer : public boost::static_visitor<int> {
|
|
|
|
|
public:
|
|
|
|
|
DynamicConstIndexer(int idx) : idx_(idx) {}
|
|
|
|
|
explicit DynamicConstIndexer(int idx) : idx_(idx) {}
|
|
|
|
|
|
|
|
|
|
template <int D>
|
|
|
|
|
int operator()(const Dim<D>& dim) const {
|
|
|
|
@ -92,7 +106,7 @@ class DynamicConstIndexer : public boost::static_visitor<int> {
|
|
|
|
|
int idx_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
///@endcond
|
|
|
|
|
/// @endcond
|
|
|
|
|
|
|
|
|
|
int& DDim::operator[](int idx) {
|
|
|
|
|
return boost::apply_visitor(DynamicMutableIndexer(idx), var);
|
|
|
|
@ -155,11 +169,11 @@ int get(const DDim& ddim, int idx) { return ddim[idx]; }
|
|
|
|
|
|
|
|
|
|
void set(DDim& ddim, int idx, int value) { ddim[idx] = value; }
|
|
|
|
|
|
|
|
|
|
///@cond HIDDEN
|
|
|
|
|
/// @cond HIDDEN
|
|
|
|
|
struct VectorizeVisitor : public boost::static_visitor<> {
|
|
|
|
|
std::vector<int>& vector;
|
|
|
|
|
|
|
|
|
|
VectorizeVisitor(std::vector<int>& v) : vector(v) {}
|
|
|
|
|
explicit VectorizeVisitor(std::vector<int>& v) : vector(v) {}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void operator()(const T& t) {
|
|
|
|
@ -169,7 +183,7 @@ struct VectorizeVisitor : public boost::static_visitor<> {
|
|
|
|
|
|
|
|
|
|
void operator()(const Dim<1>& t) { vector.push_back(t.head); }
|
|
|
|
|
};
|
|
|
|
|
///@endcond
|
|
|
|
|
/// @endcond
|
|
|
|
|
|
|
|
|
|
std::vector<int> vectorize(const DDim& ddim) {
|
|
|
|
|
std::vector<int> result;
|
|
|
|
@ -187,7 +201,7 @@ ssize_t product(const DDim& ddim) {
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
///\cond HIDDEN
|
|
|
|
|
/// \cond HIDDEN
|
|
|
|
|
|
|
|
|
|
struct ArityVisitor : boost::static_visitor<int> {
|
|
|
|
|
template <int D>
|
|
|
|
@ -196,15 +210,15 @@ struct ArityVisitor : boost::static_visitor<int> {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
///\endcond
|
|
|
|
|
/// \endcond
|
|
|
|
|
|
|
|
|
|
int arity(const DDim& d) { return boost::apply_visitor(ArityVisitor(), d); }
|
|
|
|
|
|
|
|
|
|
///\cond HIDDEN
|
|
|
|
|
/// \cond HIDDEN
|
|
|
|
|
|
|
|
|
|
struct DDimPrinter : boost::static_visitor<void> {
|
|
|
|
|
std::ostream& os;
|
|
|
|
|
DDimPrinter(std::ostream& os_) : os(os_) {}
|
|
|
|
|
explicit DDimPrinter(std::ostream& os_) : os(os_) {}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
void operator()(const T& t) {
|
|
|
|
@ -212,7 +226,7 @@ struct DDimPrinter : boost::static_visitor<void> {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
///\endcond
|
|
|
|
|
/// \endcond
|
|
|
|
|
|
|
|
|
|
std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
|
|
|
|
|
DDimPrinter printer(os);
|
|
|
|
|