Home | History | Annotate | Download | only in textclassifier
      1 /*
      2  * Copyright (C) 2018 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 package androidx.textclassifier;
     18 
     19 import android.os.Parcel;
     20 import android.os.Parcelable;
     21 
     22 import androidx.annotation.FloatRange;
     23 import androidx.annotation.NonNull;
     24 import androidx.annotation.RestrictTo;
     25 import androidx.collection.ArrayMap;
     26 import androidx.collection.SimpleArrayMap;
     27 import androidx.core.util.Preconditions;
     28 
     29 import java.util.ArrayList;
     30 import java.util.Collections;
     31 import java.util.Comparator;
     32 import java.util.List;
     33 import java.util.Map;
     34 
     35 /**
     36  * Helper object for setting and getting entity scores for classified text.
     37  *
     38  * @hide
     39  */
     40 @RestrictTo(RestrictTo.Scope.LIBRARY)
     41 final class EntityConfidence implements Parcelable {
     42 
     43     private final ArrayMap<String, Float> mEntityConfidence = new ArrayMap<>();
     44     private final ArrayList<String> mSortedEntities = new ArrayList<>();
     45 
     46     EntityConfidence() {}
     47 
     48     EntityConfidence(@NonNull EntityConfidence source) {
     49         Preconditions.checkNotNull(source);
     50         mEntityConfidence.putAll((SimpleArrayMap<String, Float>) source.mEntityConfidence);
     51         mSortedEntities.addAll(source.mSortedEntities);
     52     }
     53 
     54     /**
     55      * Constructs an EntityConfidence from a map of entity to confidence.
     56      *
     57      * Map entries that have 0 confidence are removed, and values greater than 1 are clamped to 1.
     58      *
     59      * @param source a map from entity to a confidence value in the range 0 (low confidence) to
     60      *               1 (high confidence).
     61      */
     62     EntityConfidence(@NonNull Map<String, Float> source) {
     63         Preconditions.checkNotNull(source);
     64 
     65         // Prune non-existent entities and clamp to 1.
     66         mEntityConfidence.ensureCapacity(source.size());
     67         for (Map.Entry<String, Float> it : source.entrySet()) {
     68             if (it.getValue() <= 0) continue;
     69             mEntityConfidence.put(it.getKey(), Math.min(1, it.getValue()));
     70         }
     71         resetSortedEntitiesFromMap();
     72     }
     73 
     74     /**
     75      * Returns an immutable list of entities found in the classified text ordered from
     76      * high confidence to low confidence.
     77      */
     78     @NonNull
     79     public List<String> getEntities() {
     80         return Collections.unmodifiableList(mSortedEntities);
     81     }
     82 
     83     /**
     84      * Returns the confidence score for the specified entity. The value ranges from
     85      * 0 (low confidence) to 1 (high confidence). 0 indicates that the entity was not found for the
     86      * classified text.
     87      */
     88     @FloatRange(from = 0.0, to = 1.0)
     89     public float getConfidenceScore(String entity) {
     90         if (mEntityConfidence.containsKey(entity)) {
     91             return mEntityConfidence.get(entity);
     92         }
     93         return 0;
     94     }
     95 
     96     @Override
     97     public String toString() {
     98         return mEntityConfidence.toString();
     99     }
    100 
    101     @Override
    102     public int describeContents() {
    103         return 0;
    104     }
    105 
    106     @Override
    107     public void writeToParcel(Parcel dest, int flags) {
    108         dest.writeInt(mEntityConfidence.size());
    109         for (Map.Entry<String, Float> entry : mEntityConfidence.entrySet()) {
    110             dest.writeString(entry.getKey());
    111             dest.writeFloat(entry.getValue());
    112         }
    113     }
    114 
    115     public static final Parcelable.Creator<EntityConfidence> CREATOR =
    116             new Parcelable.Creator<EntityConfidence>() {
    117                 @Override
    118                 public EntityConfidence createFromParcel(Parcel in) {
    119                     return new EntityConfidence(in);
    120                 }
    121 
    122                 @Override
    123                 public EntityConfidence[] newArray(int size) {
    124                     return new EntityConfidence[size];
    125                 }
    126             };
    127 
    128     private EntityConfidence(Parcel in) {
    129         final int numEntities = in.readInt();
    130         mEntityConfidence.ensureCapacity(numEntities);
    131         for (int i = 0; i < numEntities; ++i) {
    132             mEntityConfidence.put(in.readString(), in.readFloat());
    133         }
    134         resetSortedEntitiesFromMap();
    135     }
    136 
    137     private void resetSortedEntitiesFromMap() {
    138         mSortedEntities.clear();
    139         mSortedEntities.ensureCapacity(mEntityConfidence.size());
    140         mSortedEntities.addAll(mEntityConfidence.keySet());
    141         Collections.sort(mSortedEntities, new EntityConfidenceComparator());
    142     }
    143 
    144     /** Helper to sort entities according to their confidence. */
    145     private class EntityConfidenceComparator implements Comparator<String> {
    146         @Override
    147         public int compare(String e1, String e2) {
    148             float score1 = mEntityConfidence.get(e1);
    149             float score2 = mEntityConfidence.get(e2);
    150             return Float.compare(score2, score1);
    151         }
    152     }
    153 }
    154