You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							279 lines
						
					
					
						
							11 KiB
						
					
					
				
			
		
		
	
	
							279 lines
						
					
					
						
							11 KiB
						
					
					
				| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
 | |
| 
 | |
| Licensed under the Apache License, Version 2.0 (the "License");
 | |
| you may not use this file except in compliance with the License.
 | |
| You may obtain a copy of the License at
 | |
| 
 | |
|     http://www.apache.org/licenses/LICENSE-2.0
 | |
| 
 | |
| Unless required by applicable law or agreed to in writing, software
 | |
| distributed under the License is distributed on an "AS IS" BASIS,
 | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | |
| See the License for the specific language governing permissions and
 | |
| limitations under the License. */
 | |
| 
 | |
| #pragma once
 | |
| #include <set>
 | |
| #include <string>
 | |
| #include <vector>
 | |
| 
 | |
| #include "paddle/fluid/framework/eigen.h"
 | |
| #include "paddle/fluid/framework/op_registry.h"
 | |
| 
 | |
| namespace paddle {
 | |
| namespace operators {
 | |
| 
 | |
| using Tensor = framework::Tensor;
 | |
| using LoDTensor = framework::LoDTensor;
 | |
| 
 | |
| template <typename DeviceContext, typename T>
 | |
| class ChunkEvalKernel : public framework::OpKernel<T> {
 | |
|  public:
 | |
|   struct Segment {
 | |
|     int begin;
 | |
|     int end;
 | |
|     int type;
 | |
|     bool operator==(const Segment& y) const {
 | |
|       return begin == y.begin && end == y.end && type == y.type;
 | |
|     }
 | |
|   };
 | |
| 
 | |
|   void GetSegments(const int64_t* label, int length,
 | |
|                    std::vector<Segment>* segments, int num_chunk_types,
 | |
|                    int num_tag_types, int other_chunk_type, int tag_begin,
 | |
|                    int tag_inside, int tag_end, int tag_single) const {
 | |
|     segments->clear();
 | |
|     segments->reserve(length);
 | |
|     int chunk_start = 0;
 | |
|     bool in_chunk = false;
 | |
|     int tag = -1;
 | |
|     int type = other_chunk_type;
 | |
|     for (int i = 0; i < length; ++i) {
 | |
|       int prev_tag = tag;
 | |
|       int prev_type = type;
 | |
|       PADDLE_ENFORCE_LE(
 | |
|           label[i], num_chunk_types * num_tag_types,
 | |
|           platform::errors::InvalidArgument(
 | |
|               "The value of Input(Label) should be less than the number of "
 | |
|               "chunk types times the number of tag types, but received %d "
 | |
|               "(Label) vs %d (chunk types) * %d (tag types).",
 | |
|               label[i], num_chunk_types, num_tag_types));
 | |
|       tag = label[i] % num_tag_types;
 | |
|       type = label[i] / num_tag_types;
 | |
|       if (in_chunk && ChunkEnd(prev_tag, prev_type, tag, type, other_chunk_type,
 | |
|                                tag_begin, tag_inside, tag_end, tag_single)) {
 | |
|         Segment segment{
 | |
|             chunk_start,  // begin
 | |
|             i - 1,        // end
 | |
|             prev_type,
 | |
|         };
 | |
|         segments->push_back(segment);
 | |
|         in_chunk = false;
 | |
|       }
 | |
|       if (ChunkBegin(prev_tag, prev_type, tag, type, other_chunk_type,
 | |
|                      tag_begin, tag_inside, tag_end, tag_single)) {
 | |
|         chunk_start = i;
 | |
|         in_chunk = true;
 | |
|       }
 | |
|     }
 | |
|     if (in_chunk) {
 | |
|       Segment segment{
 | |
|           chunk_start,  // begin
 | |
|           length - 1,   // end
 | |
|           type,
 | |
|       };
 | |
|       segments->push_back(segment);
 | |
|     }
 | |
|   }
 | |
| 
 | |
|   bool ChunkEnd(int prev_tag, int prev_type, int tag, int type,
 | |
|                 int other_chunk_type, int tag_begin, int tag_inside,
 | |
|                 int tag_end, int tag_single) const {
 | |
|     if (prev_type == other_chunk_type) return false;
 | |
|     if (type == other_chunk_type) return true;
 | |
|     if (type != prev_type) return true;
 | |
|     if (prev_tag == tag_begin) return tag == tag_begin || tag == tag_single;
 | |
|     if (prev_tag == tag_inside) return tag == tag_begin || tag == tag_single;
 | |
|     if (prev_tag == tag_end) return true;
 | |
|     if (prev_tag == tag_single) return true;
 | |
|     return false;
 | |
|   }
 | |
| 
 | |
|   bool ChunkBegin(int prev_tag, int prev_type, int tag, int type,
 | |
|                   int other_chunk_type, int tag_begin, int tag_inside,
 | |
|                   int tag_end, int tag_single) const {
 | |
|     if (prev_type == other_chunk_type) return type != other_chunk_type;
 | |
|     if (type == other_chunk_type) return false;
 | |
|     if (type != prev_type) return true;
 | |
|     if (tag == tag_begin) return true;
 | |
|     if (tag == tag_inside) return prev_tag == tag_end || prev_tag == tag_single;
 | |
|     if (tag == tag_end) return prev_tag == tag_end || prev_tag == tag_single;
 | |
|     if (tag == tag_single) return true;
 | |
|     return false;
 | |
|   }
 | |
| 
 | |
|   void Compute(const framework::ExecutionContext& context) const override {
 | |
|     // initialize to parse configurations
 | |
|     int num_chunk_types, num_tag_types;
 | |
|     int other_chunk_type;
 | |
|     int tag_begin, tag_inside, tag_end, tag_single;
 | |
|     std::vector<Segment> label_segments;
 | |
|     std::vector<Segment> output_segments;
 | |
|     std::set<int> excluded_chunk_types;
 | |
| 
 | |
|     if (context.Attr<std::string>("chunk_scheme") == "IOB") {
 | |
|       num_tag_types = 2;
 | |
|       tag_begin = 0;
 | |
|       tag_inside = 1;
 | |
|       tag_end = -1;
 | |
|       tag_single = -1;
 | |
|     } else if (context.Attr<std::string>("chunk_scheme") == "IOE") {
 | |
|       num_tag_types = 2;
 | |
|       tag_begin = -1;
 | |
|       tag_inside = 0;
 | |
|       tag_end = 1;
 | |
|       tag_single = -1;
 | |
|     } else if (context.Attr<std::string>("chunk_scheme") == "IOBES") {
 | |
|       num_tag_types = 4;
 | |
|       tag_begin = 0;
 | |
|       tag_inside = 1;
 | |
|       tag_end = 2;
 | |
|       tag_single = 3;
 | |
|     } else if (context.Attr<std::string>("chunk_scheme") == "plain") {
 | |
|       num_tag_types = 1;
 | |
|       tag_begin = -1;
 | |
|       tag_inside = -1;
 | |
|       tag_end = -1;
 | |
|       tag_single = -1;
 | |
|     } else {
 | |
|       PADDLE_THROW(platform::errors::InvalidArgument("Unknown chunk scheme."));
 | |
|     }
 | |
|     other_chunk_type = num_chunk_types = context.Attr<int>("num_chunk_types");
 | |
|     excluded_chunk_types.insert(
 | |
|         context.Attr<std::vector<int>>("excluded_chunk_types").begin(),
 | |
|         context.Attr<std::vector<int>>("excluded_chunk_types").end());
 | |
| 
 | |
|     auto* inference = context.Input<LoDTensor>("Inference");
 | |
|     auto place = inference->place();
 | |
|     auto* label = context.Input<LoDTensor>("Label");
 | |
|     auto* precision = context.Output<Tensor>("Precision");
 | |
|     auto* recall = context.Output<Tensor>("Recall");
 | |
|     auto* f1 = context.Output<Tensor>("F1-Score");
 | |
|     auto* num_infer_chunks = context.Output<Tensor>("NumInferChunks");
 | |
|     auto* num_label_chunks = context.Output<Tensor>("NumLabelChunks");
 | |
|     auto* num_correct_chunks = context.Output<Tensor>("NumCorrectChunks");
 | |
| 
 | |
|     const int64_t* inference_data = inference->data<int64_t>();
 | |
|     const int64_t* label_data = label->data<int64_t>();
 | |
|     T* precision_data = precision->mutable_data<T>(place);
 | |
|     T* racall_data = recall->mutable_data<T>(place);
 | |
|     T* f1_data = f1->mutable_data<T>(place);
 | |
|     int64_t* num_infer_chunks_data =
 | |
|         num_infer_chunks->mutable_data<int64_t>(place);
 | |
|     int64_t* num_label_chunks_data =
 | |
|         num_label_chunks->mutable_data<int64_t>(place);
 | |
|     int64_t* num_correct_chunks_data =
 | |
|         num_correct_chunks->mutable_data<int64_t>(place);
 | |
|     *num_infer_chunks_data = 0;
 | |
|     *num_label_chunks_data = 0;
 | |
|     *num_correct_chunks_data = 0;
 | |
| 
 | |
|     auto lod = label->lod();
 | |
|     bool use_padding = lod.empty();
 | |
|     int num_sequences = 0;
 | |
| 
 | |
|     if (use_padding) {
 | |
|       auto dim1 = inference->dims()[1];
 | |
|       auto* seq_length_t = context.Input<Tensor>("SeqLength");
 | |
|       auto* seq_length_data = seq_length_t->data<int64_t>();
 | |
|       num_sequences = seq_length_t->dims()[0];
 | |
| 
 | |
|       for (int i = 0; i < num_sequences; ++i) {
 | |
|         int seq_length = seq_length_data[i];
 | |
|         EvalOneSeq(inference_data + i * dim1, label_data + i * dim1, seq_length,
 | |
|                    &output_segments, &label_segments, num_infer_chunks_data,
 | |
|                    num_label_chunks_data, num_correct_chunks_data,
 | |
|                    num_chunk_types, num_tag_types, other_chunk_type, tag_begin,
 | |
|                    tag_inside, tag_end, tag_single, excluded_chunk_types);
 | |
|       }
 | |
|     } else {
 | |
|       PADDLE_ENFORCE_EQ(
 | |
|           lod.size(), 1UL,
 | |
|           platform::errors::InvalidArgument(
 | |
|               "Only support one level LoD sequence now, but received %d.",
 | |
|               lod.size()));
 | |
|       PADDLE_ENFORCE_EQ(
 | |
|           lod, inference->lod(),
 | |
|           platform::errors::InvalidArgument(
 | |
|               "Input(Inference) and Input(Label) of Op(chunk_eval) should have "
 | |
|               "same LoD information."));
 | |
|       num_sequences = lod[0].size() - 1;
 | |
| 
 | |
|       for (int i = 0; i < num_sequences; ++i) {
 | |
|         int seq_length = lod[0][i + 1] - lod[0][i];
 | |
|         EvalOneSeq(inference_data + lod[0][i], label_data + lod[0][i],
 | |
|                    seq_length, &output_segments, &label_segments,
 | |
|                    num_infer_chunks_data, num_label_chunks_data,
 | |
|                    num_correct_chunks_data, num_chunk_types, num_tag_types,
 | |
|                    other_chunk_type, tag_begin, tag_inside, tag_end, tag_single,
 | |
|                    excluded_chunk_types);
 | |
|       }
 | |
|     }
 | |
| 
 | |
|     *precision_data = !(*num_infer_chunks_data)
 | |
|                           ? 0
 | |
|                           : static_cast<T>(*num_correct_chunks_data) /
 | |
|                                 (*num_infer_chunks_data);
 | |
|     *racall_data = !(*num_label_chunks_data)
 | |
|                        ? 0
 | |
|                        : static_cast<T>(*num_correct_chunks_data) /
 | |
|                              (*num_label_chunks_data);
 | |
|     *f1_data = !(*num_correct_chunks_data)
 | |
|                    ? 0
 | |
|                    : 2 * (*precision_data) * (*racall_data) /
 | |
|                          ((*precision_data) + (*racall_data));
 | |
|   }
 | |
| 
 | |
|   void EvalOneSeq(const int64_t* output, const int64_t* label, int length,
 | |
|                   std::vector<Segment>* output_segments,
 | |
|                   std::vector<Segment>* label_segments,
 | |
|                   int64_t* num_output_segments, int64_t* num_label_segments,
 | |
|                   int64_t* num_correct, int num_chunk_types, int num_tag_types,
 | |
|                   int other_chunk_type, int tag_begin, int tag_inside,
 | |
|                   int tag_end, int tag_single,
 | |
|                   const std::set<int>& excluded_chunk_types) const {
 | |
|     GetSegments(output, length, output_segments, num_chunk_types, num_tag_types,
 | |
|                 other_chunk_type, tag_begin, tag_inside, tag_end, tag_single);
 | |
|     GetSegments(label, length, label_segments, num_chunk_types, num_tag_types,
 | |
|                 other_chunk_type, tag_begin, tag_inside, tag_end, tag_single);
 | |
|     size_t i = 0, j = 0;
 | |
|     while (i < output_segments->size() && j < label_segments->size()) {
 | |
|       if (output_segments->at(i) == label_segments->at(j) &&
 | |
|           excluded_chunk_types.count(output_segments->at(i).type) != 1) {
 | |
|         ++(*num_correct);
 | |
|       }
 | |
|       if (output_segments->at(i).end < label_segments->at(j).end) {
 | |
|         ++i;
 | |
|       } else if (output_segments->at(i).end > label_segments->at(j).end) {
 | |
|         ++j;
 | |
|       } else {
 | |
|         ++i;
 | |
|         ++j;
 | |
|       }
 | |
|     }
 | |
|     for (auto& segment : (*label_segments)) {
 | |
|       if (excluded_chunk_types.count(segment.type) != 1) {
 | |
|         ++(*num_label_segments);
 | |
|       }
 | |
|     }
 | |
|     for (auto& segment : (*output_segments)) {
 | |
|       if (excluded_chunk_types.count(segment.type) != 1) {
 | |
|         ++(*num_output_segments);
 | |
|       }
 | |
|     }
 | |
|   }
 | |
| };
 | |
| 
 | |
| }  // namespace operators
 | |
| }  // namespace paddle
 |