Home | History | Annotate | Download | only in lang_id
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #ifndef LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_BRAIN_INTERFACE_H_
     18 #define LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_BRAIN_INTERFACE_H_
     19 
     20 #include <string>
     21 #include <vector>
     22 
     23 #include "common/embedding-feature-extractor.h"
     24 #include "common/feature-extractor.h"
     25 #include "common/task-context.h"
     26 #include "common/workspace.h"
     27 #include "lang_id/light-sentence-features.h"
     28 #include "lang_id/light-sentence.h"
     29 #include "util/base/macros.h"
     30 
     31 namespace libtextclassifier {
     32 namespace nlp_core {
     33 namespace lang_id {
     34 
     35 // Specialization of EmbeddingFeatureExtractor that extracts from LightSentence.
     36 class LangIdEmbeddingFeatureExtractor
     37     : public EmbeddingFeatureExtractor<LightSentenceExtractor, LightSentence> {
     38  public:
     39   LangIdEmbeddingFeatureExtractor() {}
     40   const std::string ArgPrefix() const override { return "language_identifier"; }
     41 
     42   TC_DISALLOW_COPY_AND_ASSIGN(LangIdEmbeddingFeatureExtractor);
     43 };
     44 
     45 // Handles sentence -> numeric_features and numeric_prediction -> language
     46 // conversions.
     47 class LangIdBrainInterface {
     48  public:
     49   LangIdBrainInterface() {}
     50 
     51   // Initializes resources and parameters.
     52   bool Init(TaskContext *context) {
     53     if (!feature_extractor_.Init(context)) {
     54       return false;
     55     }
     56     feature_extractor_.RequestWorkspaces(&workspace_registry_);
     57     return true;
     58   }
     59 
     60   // Extract features from sentence.  On return, FeatureVector features[i]
     61   // contains the features for the embedding space #i.
     62   void GetFeatures(LightSentence *sentence,
     63                    std::vector<FeatureVector> *features) const {
     64     WorkspaceSet workspace;
     65     workspace.Reset(workspace_registry_);
     66     feature_extractor_.Preprocess(&workspace, sentence);
     67     return feature_extractor_.ExtractFeatures(workspace, *sentence, features);
     68   }
     69 
     70   int NumEmbeddings() const {
     71     return feature_extractor_.NumEmbeddings();
     72   }
     73 
     74  private:
     75   // Typed feature extractor for embeddings.
     76   LangIdEmbeddingFeatureExtractor feature_extractor_;
     77 
     78   // The registry of shared workspaces in the feature extractor.
     79   WorkspaceRegistry workspace_registry_;
     80 
     81   TC_DISALLOW_COPY_AND_ASSIGN(LangIdBrainInterface);
     82 };
     83 
     84 }  // namespace lang_id
     85 }  // namespace nlp_core
     86 }  // namespace libtextclassifier
     87 
     88 #endif  // LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_BRAIN_INTERFACE_H_
     89