You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/fluid/imperative/prepared_operator.h

66 lines
2.4 KiB

// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_transform.h"
#include "paddle/fluid/framework/op_kernel_type.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/imperative/layer.h"
#include "paddle/fluid/imperative/type_defs.h"
namespace paddle {
namespace imperative {
const framework::Tensor* GetTensorFromVar(const framework::Variable& var);
class PreparedOp {
public:
Add dygraph execution context (#20157) * add_dygraph_execution_context * add dygraph infershape context and execution context; test=develop * fix imperative bug; test=develop * remove inputs outputs interface from execution context, because it have same function with inputNames; test=develop * remove tracer_test ctest; test=develop * fix split op bug; test=develop * fix unitests bug; test=develop * fix distribute test bug; test=develop * fix ngraph compile bug; test=develop * fix grad maker bug; test=develop * fix load op bugs; test=develop * fix operator.cc construct bug; test=develop * remove useless name find in operator; test=develop * add tracer_test; test=develop * fix concat, split bug; test=develop * remove tracer_test unitest; test=develop * fix attribute check bug; test=develop * add test code to fix converage; test=develop * remove useless code, change check backward input in engin; test=develop * unlock var type infer shape;test=develop * add ShareAllLoD api; test=develop * add dygraph infershape context unitest; test=develop * remove increase and decrease lod in dygraph; test=develop * addd override; test=develop * fix increase descrease lod; test=develop * fix paddle_enforce; test=develop * disable lod op dygraph check; test=develop * fix paddle enforce error; test=develop * add comment for op_registry and OperatorBase; test=develop * optimize the comment of op_registry; test=develop * fix format of comment; test=develop * fix format of comment; test=develop * optimize the format of comment; test=develop * optimize the format of the comment; test=develop * optimize comment of op_registry; test=develop
5 years ago
static PreparedOp Prepare(const NameVarBaseMap& ins,
const NameVarBaseMap& outs,
const framework::OperatorWithKernel& op,
Add dygraph execution context (#20157) * add_dygraph_execution_context * add dygraph infershape context and execution context; test=develop * fix imperative bug; test=develop * remove inputs outputs interface from execution context, because it have same function with inputNames; test=develop * remove tracer_test ctest; test=develop * fix split op bug; test=develop * fix unitests bug; test=develop * fix distribute test bug; test=develop * fix ngraph compile bug; test=develop * fix grad maker bug; test=develop * fix load op bugs; test=develop * fix operator.cc construct bug; test=develop * remove useless name find in operator; test=develop * add tracer_test; test=develop * fix concat, split bug; test=develop * remove tracer_test unitest; test=develop * fix attribute check bug; test=develop * add test code to fix converage; test=develop * remove useless code, change check backward input in engin; test=develop * unlock var type infer shape;test=develop * add ShareAllLoD api; test=develop * add dygraph infershape context unitest; test=develop * remove increase and decrease lod in dygraph; test=develop * addd override; test=develop * fix increase descrease lod; test=develop * fix paddle_enforce; test=develop * disable lod op dygraph check; test=develop * fix paddle enforce error; test=develop * add comment for op_registry and OperatorBase; test=develop * optimize the comment of op_registry; test=develop * fix format of comment; test=develop * fix format of comment; test=develop * optimize the format of comment; test=develop * optimize the format of the comment; test=develop * optimize comment of op_registry; test=develop
5 years ago
platform::Place place,
const framework::AttributeMap* attrs);
inline platform::DeviceContext* GetDeviceContext() const { return dev_ctx_; }
Add dygraph execution context (#20157) * add_dygraph_execution_context * add dygraph infershape context and execution context; test=develop * fix imperative bug; test=develop * remove inputs outputs interface from execution context, because it have same function with inputNames; test=develop * remove tracer_test ctest; test=develop * fix split op bug; test=develop * fix unitests bug; test=develop * fix distribute test bug; test=develop * fix ngraph compile bug; test=develop * fix grad maker bug; test=develop * fix load op bugs; test=develop * fix operator.cc construct bug; test=develop * remove useless name find in operator; test=develop * add tracer_test; test=develop * fix concat, split bug; test=develop * remove tracer_test unitest; test=develop * fix attribute check bug; test=develop * add test code to fix converage; test=develop * remove useless code, change check backward input in engin; test=develop * unlock var type infer shape;test=develop * add ShareAllLoD api; test=develop * add dygraph infershape context unitest; test=develop * remove increase and decrease lod in dygraph; test=develop * addd override; test=develop * fix increase descrease lod; test=develop * fix paddle_enforce; test=develop * disable lod op dygraph check; test=develop * fix paddle enforce error; test=develop * add comment for op_registry and OperatorBase; test=develop * optimize the comment of op_registry; test=develop * fix format of comment; test=develop * fix format of comment; test=develop * optimize the format of comment; test=develop * optimize the format of the comment; test=develop * optimize comment of op_registry; test=develop
5 years ago
void Run(const NameVarBaseMap* in, const NameVarBaseMap* out,
const framework::AttributeMap* attrs);
static void PrepareData(const platform::Place& place,
const NameVarBaseMap& ins,
const framework::OperatorWithKernel& op,
const framework::OpKernelType& expected_kernel_key);
private:
PreparedOp(const framework::OperatorBase& op,
const framework::RuntimeContext& ctx,
framework::OperatorWithKernel::OpKernelFunc func,
platform::DeviceContext* dev_ctx,
std::vector<framework::KernelConfig>* kernel_configs);
private:
const framework::OperatorBase& op_;
const framework::RuntimeContext& ctx_;
framework::OperatorWithKernel::OpKernelFunc func_;
platform::DeviceContext* dev_ctx_;
std::vector<framework::KernelConfig>* kernel_configs_;
};
} // namespace imperative
} // namespace paddle