commit
5229ccbdc7
@ -1,119 +0,0 @@
|
|||||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License. */
|
|
||||||
|
|
||||||
#include "paddle/fluid/framework/tensor_util.h"
|
|
||||||
|
|
||||||
namespace paddle {
|
|
||||||
namespace framework {
|
|
||||||
template <typename Predicate, typename DevCtx>
|
|
||||||
struct AnyDTypeVisitor {
|
|
||||||
Predicate predicate_;
|
|
||||||
const Tensor& tensor_;
|
|
||||||
const DevCtx& ctx_;
|
|
||||||
Tensor* out_;
|
|
||||||
|
|
||||||
AnyDTypeVisitor(Predicate predicate, const Tensor& tensor, const DevCtx& ctx,
|
|
||||||
Tensor* out)
|
|
||||||
: predicate_(predicate), tensor_(tensor), ctx_(ctx), out_(out) {}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void operator()() const {
|
|
||||||
auto t = EigenVector<T>::Flatten(tensor_);
|
|
||||||
auto o = EigenScalar<bool>::From(*out_);
|
|
||||||
// return any of predicate_(t) is true.
|
|
||||||
o.device(*ctx_.eigen_device()) = predicate_(t).any();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename Predicate, typename DevCtx>
|
|
||||||
inline void AnyImpl(Predicate predicate, const framework::Tensor& tensor,
|
|
||||||
const DevCtx& ctx, framework::Tensor* out) {
|
|
||||||
VisitDataType(ToDataType(tensor.type()), AnyDTypeVisitor<Predicate, DevCtx>(
|
|
||||||
predicate, tensor, ctx, out));
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename Predicate>
|
|
||||||
struct AnyVisitor : public boost::static_visitor<bool> {
|
|
||||||
const framework::Tensor& tensor_;
|
|
||||||
Predicate predicate_;
|
|
||||||
|
|
||||||
AnyVisitor(const framework::Tensor& tensor, Predicate predicate)
|
|
||||||
: tensor_(tensor), predicate_(std::move(predicate)) {}
|
|
||||||
|
|
||||||
template <typename Place>
|
|
||||||
bool operator()(const Place& place) const {
|
|
||||||
framework::Tensor out;
|
|
||||||
out.Resize({1});
|
|
||||||
out.mutable_data<bool>(place);
|
|
||||||
auto* ctx = platform::DeviceContextPool::Instance().GetByPlace(place);
|
|
||||||
AnyImpl(predicate_, tensor_, *ctx, &out);
|
|
||||||
return this->GetResult(out, place);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool GetResult(const framework::Tensor& out,
|
|
||||||
const platform::CUDAPlace& gpu) const {
|
|
||||||
platform::CPUPlace cpu;
|
|
||||||
framework::Tensor tmp;
|
|
||||||
tmp.Resize({1});
|
|
||||||
tmp.mutable_data<bool>(cpu);
|
|
||||||
auto gpuctx = platform::DeviceContextPool::Instance().Get(gpu);
|
|
||||||
gpuctx->Wait();
|
|
||||||
Copy(out, cpu, *gpuctx, &tmp);
|
|
||||||
gpuctx->Wait();
|
|
||||||
return GetResult(tmp, cpu);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool GetResult(const framework::Tensor& out,
|
|
||||||
const platform::CPUPlace& cpu) const {
|
|
||||||
return *out.data<bool>();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <typename Predicate>
|
|
||||||
inline bool Any(const framework::Tensor& tensor, Predicate predicate) {
|
|
||||||
AnyVisitor<Predicate> visitor(tensor, predicate);
|
|
||||||
auto place = tensor.place();
|
|
||||||
return platform::VisitPlace(place, visitor);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct HasNANPredicate {
|
|
||||||
template <typename T>
|
|
||||||
auto operator()(const T& eigen_vec) const
|
|
||||||
-> decltype(std::declval<T>().isnan()) {
|
|
||||||
// Cast eigen_vector to vector of bool. true if is inf.
|
|
||||||
return eigen_vec.isnan();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
bool HasNAN(const framework::Tensor& tensor) {
|
|
||||||
HasNANPredicate predicate;
|
|
||||||
return Any(tensor, predicate);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct HasInfPredicate {
|
|
||||||
template <typename T>
|
|
||||||
auto operator()(const T& eigen_vec) const
|
|
||||||
-> decltype(std::declval<T>().isinf()) {
|
|
||||||
// Cast eigen_vector to vector of bool. true if is inf.
|
|
||||||
return eigen_vec.isinf();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
bool HasInf(const framework::Tensor& tensor) {
|
|
||||||
HasInfPredicate predicate;
|
|
||||||
return Any(tensor, predicate);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace framework
|
|
||||||
} // namespace paddle
|
|
||||||
@ -0,0 +1 @@
|
|||||||
|
tensor_util.cc
|
||||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,64 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#include "paddle/fluid/operators/batch_size_like.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
void BatchSizeLikeOp::InferShape(framework::InferShapeContext *ctx) const {
|
||||||
|
PADDLE_ENFORCE(ctx->HasInput("Input"),
|
||||||
|
"Input(Input) of %s should not be null.", Type());
|
||||||
|
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of %s should not be null.",
|
||||||
|
Type());
|
||||||
|
|
||||||
|
auto &shape = ctx->Attrs().Get<std::vector<int>>("shape");
|
||||||
|
PADDLE_ENFORCE_GT(shape.size(), 0);
|
||||||
|
std::vector<int64_t> shape_int64(shape.size(), 0);
|
||||||
|
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
|
||||||
|
[](int a) { return static_cast<int64_t>(a); });
|
||||||
|
auto output_dim = framework::make_ddim(shape_int64);
|
||||||
|
|
||||||
|
int input_dim_idx = ctx->Attrs().Get<int>("input_dim_idx");
|
||||||
|
PADDLE_ENFORCE_GE(input_dim_idx, 0);
|
||||||
|
PADDLE_ENFORCE_GT(ctx->GetInputDim("Input").size(), input_dim_idx);
|
||||||
|
|
||||||
|
int output_dim_idx = ctx->Attrs().Get<int>("output_dim_idx");
|
||||||
|
PADDLE_ENFORCE_GE(output_dim_idx, 0);
|
||||||
|
PADDLE_ENFORCE_GT(static_cast<int>(shape.size()), output_dim_idx);
|
||||||
|
|
||||||
|
output_dim[output_dim_idx] = ctx->GetInputDim("Input")[input_dim_idx];
|
||||||
|
ctx->SetOutputDim("Out", output_dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
BatchSizeLikeOpMaker::BatchSizeLikeOpMaker(OpProto *proto,
|
||||||
|
OpAttrChecker *op_checker)
|
||||||
|
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
|
||||||
|
AddInput("Input",
|
||||||
|
"(Tensor) Tensor "
|
||||||
|
"whose input_dim_idx'th dimension specifies the batch_size");
|
||||||
|
AddOutput("Out",
|
||||||
|
"(Tensor) Tensor of specified shape will be filled "
|
||||||
|
"with the specified value");
|
||||||
|
AddAttr<std::vector<int>>("shape", "(vector<int>) The shape of the output");
|
||||||
|
AddAttr<int>("input_dim_idx",
|
||||||
|
"(int, default 0) The index of input's batch size dimension")
|
||||||
|
.SetDefault(0);
|
||||||
|
AddAttr<int>("output_dim_idx",
|
||||||
|
"(int, default 0) The index of output's batch size dimension")
|
||||||
|
.SetDefault(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
@ -0,0 +1,36 @@
|
|||||||
|
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License. */
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "paddle/fluid/framework/op_registry.h"
|
||||||
|
#include "paddle/fluid/operators/math/math_function.h"
|
||||||
|
|
||||||
|
namespace paddle {
|
||||||
|
namespace operators {
|
||||||
|
|
||||||
|
class BatchSizeLikeOp : public framework::OperatorWithKernel {
|
||||||
|
public:
|
||||||
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||||
|
|
||||||
|
void InferShape(framework::InferShapeContext *ctx) const override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class BatchSizeLikeOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||||
|
public:
|
||||||
|
BatchSizeLikeOpMaker(OpProto *proto, OpAttrChecker *op_checker);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace operators
|
||||||
|
} // namespace paddle
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in new issue