@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License . */
limitations under the License . */
# include <set>
# include <vector>
# include <vector>
# include "paddle/math/Vector.h"
# include "paddle/math/Vector.h"
@ -72,6 +73,7 @@ class ChunkEvaluator : public Evaluator {
std : : vector < Segment > labelSegments_ ;
std : : vector < Segment > labelSegments_ ;
std : : vector < Segment > outputSegments_ ;
std : : vector < Segment > outputSegments_ ;
std : : set < int > excludedChunkTypes_ ;
public :
public :
virtual void init ( const EvaluatorConfig & config ) {
virtual void init ( const EvaluatorConfig & config ) {
@ -105,6 +107,10 @@ public:
}
}
CHECK ( config . has_num_chunk_types ( ) ) < < " Missing num_chunk_types in config " ;
CHECK ( config . has_num_chunk_types ( ) ) < < " Missing num_chunk_types in config " ;
otherChunkType_ = numChunkTypes_ = config . num_chunk_types ( ) ;
otherChunkType_ = numChunkTypes_ = config . num_chunk_types ( ) ;
// the chunks of types in excludedChunkTypes_ will not be counted
auto & tmp = config . excluded_chunk_types ( ) ;
excludedChunkTypes_ . insert ( tmp . begin ( ) , tmp . end ( ) ) ;
}
}
virtual void start ( ) {
virtual void start ( ) {
@ -157,7 +163,8 @@ public:
size_t i = 0 , j = 0 ;
size_t i = 0 , j = 0 ;
while ( i < outputSegments_ . size ( ) & & j < labelSegments_ . size ( ) ) {
while ( i < outputSegments_ . size ( ) & & j < labelSegments_ . size ( ) ) {
if ( outputSegments_ [ i ] = = labelSegments_ [ j ] ) {
if ( outputSegments_ [ i ] = = labelSegments_ [ j ] ) {
+ + numCorrect_ ;
if ( excludedChunkTypes_ . count ( outputSegments_ [ i ] . type ) ! = 1 )
+ + numCorrect_ ;
}
}
if ( outputSegments_ [ i ] . end < labelSegments_ [ j ] . end ) {
if ( outputSegments_ [ i ] . end < labelSegments_ [ j ] . end ) {
+ + i ;
+ + i ;
@ -168,8 +175,12 @@ public:
+ + j ;
+ + j ;
}
}
}
}
numLabelSegments_ + = labelSegments_ . size ( ) ;
for ( auto & segment : labelSegments_ ) {
numOutputSegments_ + = outputSegments_ . size ( ) ;
if ( excludedChunkTypes_ . count ( segment . type ) ! = 1 ) + + numLabelSegments_ ;
}
for ( auto & segment : outputSegments_ ) {
if ( excludedChunkTypes_ . count ( segment . type ) ! = 1 ) + + numOutputSegments_ ;
}
}
}
void getSegments ( int * label , int length , std : : vector < Segment > & segments ) {
void getSegments ( int * label , int length , std : : vector < Segment > & segments ) {