Home | History | Annotate | Download | only in grpc
      1 /*
      2  * Copyright (C) 2017 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License
     15  */
     16 package com.android.voicemail.impl.transcribe.grpc;
     17 
     18 import android.content.Context;
     19 import android.content.pm.PackageInfo;
     20 import android.content.pm.PackageManager;
     21 import android.text.TextUtils;
     22 import com.android.dialer.common.Assert;
     23 import com.android.dialer.common.LogUtil;
     24 import com.android.voicemail.impl.transcribe.TranscriptionConfigProvider;
     25 import com.google.internal.communications.voicemailtranscription.v1.VoicemailTranscriptionServiceGrpc;
     26 import io.grpc.CallOptions;
     27 import io.grpc.Channel;
     28 import io.grpc.ClientCall;
     29 import io.grpc.ClientInterceptor;
     30 import io.grpc.ClientInterceptors;
     31 import io.grpc.ForwardingClientCall;
     32 import io.grpc.ManagedChannel;
     33 import io.grpc.ManagedChannelBuilder;
     34 import io.grpc.Metadata;
     35 import io.grpc.MethodDescriptor;
     36 import io.grpc.okhttp.OkHttpChannelBuilder;
     37 import java.security.MessageDigest;
     38 
     39 /**
     40  * Factory for creating grpc clients that talk to the transcription server. This allows all clients
     41  * to share the same channel, which is relatively expensive to create.
     42  */
     43 public class TranscriptionClientFactory {
     44   private static final String DIGEST_ALGORITHM_SHA1 = "SHA1";
     45   private static final char[] HEX_UPPERCASE = {
     46     '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'
     47   };
     48 
     49   private final TranscriptionConfigProvider configProvider;
     50   private final ManagedChannel originalChannel;
     51   private final String packageName;
     52   private final String cert;
     53 
     54   public TranscriptionClientFactory(Context context, TranscriptionConfigProvider configProvider) {
     55     this(context, configProvider, getManagedChannel(configProvider));
     56   }
     57 
     58   public TranscriptionClientFactory(
     59       Context context, TranscriptionConfigProvider configProvider, ManagedChannel managedChannel) {
     60     this.configProvider = configProvider;
     61     this.packageName = context.getPackageName();
     62     this.cert = getCertificateFingerprint(context);
     63     originalChannel = managedChannel;
     64   }
     65 
     66   public TranscriptionClient getClient() {
     67     LogUtil.enterBlock("TranscriptionClientFactory.getClient");
     68     Assert.checkState(!originalChannel.isShutdown());
     69     Channel channel =
     70         ClientInterceptors.intercept(
     71             originalChannel,
     72             new Interceptor(
     73                 packageName, cert, configProvider.getApiKey(), configProvider.getAuthToken()));
     74     return new TranscriptionClient(VoicemailTranscriptionServiceGrpc.newBlockingStub(channel));
     75   }
     76 
     77   public void shutdown() {
     78     LogUtil.enterBlock("TranscriptionClientFactory.shutdown");
     79     originalChannel.shutdown();
     80   }
     81 
     82   private static ManagedChannel getManagedChannel(TranscriptionConfigProvider configProvider) {
     83     ManagedChannelBuilder<OkHttpChannelBuilder> builder =
     84         OkHttpChannelBuilder.forTarget(configProvider.getServerAddress());
     85     // Only use plaintext for debugging
     86     if (configProvider.shouldUsePlaintext()) {
     87       // Just passing 'false' doesnt have the same effect as not setting this field
     88       builder.usePlaintext(true);
     89     }
     90     return builder.build();
     91   }
     92 
     93   private static String getCertificateFingerprint(Context context) {
     94     try {
     95       PackageInfo packageInfo =
     96           context
     97               .getPackageManager()
     98               .getPackageInfo(context.getPackageName(), PackageManager.GET_SIGNATURES);
     99       if (packageInfo != null
    100           && packageInfo.signatures != null
    101           && packageInfo.signatures.length > 0) {
    102         MessageDigest messageDigest = MessageDigest.getInstance(DIGEST_ALGORITHM_SHA1);
    103         if (messageDigest == null) {
    104           LogUtil.w(
    105               "TranscriptionClientFactory.getCertificateFingerprint", "error getting digest.");
    106           return null;
    107         }
    108         byte[] bytes = messageDigest.digest(packageInfo.signatures[0].toByteArray());
    109         if (bytes == null) {
    110           LogUtil.w(
    111               "TranscriptionClientFactory.getCertificateFingerprint", "empty message digest.");
    112           return null;
    113         }
    114 
    115         int length = bytes.length;
    116         StringBuilder out = new StringBuilder(length * 2);
    117         for (int i = 0; i < length; i++) {
    118           out.append(HEX_UPPERCASE[(bytes[i] & 0xf0) >>> 4]);
    119           out.append(HEX_UPPERCASE[bytes[i] & 0x0f]);
    120         }
    121         return out.toString();
    122       } else {
    123         LogUtil.w(
    124             "TranscriptionClientFactory.getCertificateFingerprint",
    125             "failed to get package signature.");
    126       }
    127     } catch (Exception e) {
    128       LogUtil.e(
    129           "TranscriptionClientFactory.getCertificateFingerprint",
    130           "error getting certificate fingerprint.",
    131           e);
    132     }
    133 
    134     return null;
    135   }
    136 
    137   private static final class Interceptor implements ClientInterceptor {
    138     private final String packageName;
    139     private final String cert;
    140     private final String apiKey;
    141     private final String authToken;
    142 
    143     private static final Metadata.Key<String> API_KEY_HEADER =
    144         Metadata.Key.of("X-Goog-Api-Key", Metadata.ASCII_STRING_MARSHALLER);
    145     private static final Metadata.Key<String> ANDROID_PACKAGE_HEADER =
    146         Metadata.Key.of("X-Android-Package", Metadata.ASCII_STRING_MARSHALLER);
    147     private static final Metadata.Key<String> ANDROID_CERT_HEADER =
    148         Metadata.Key.of("X-Android-Cert", Metadata.ASCII_STRING_MARSHALLER);
    149     private static final Metadata.Key<String> AUTHORIZATION_HEADER =
    150         Metadata.Key.of("authorization", Metadata.ASCII_STRING_MARSHALLER);
    151 
    152     public Interceptor(String packageName, String cert, String apiKey, String authToken) {
    153       this.packageName = packageName;
    154       this.cert = cert;
    155       this.apiKey = apiKey;
    156       this.authToken = authToken;
    157     }
    158 
    159     @Override
    160     public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
    161         MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
    162       LogUtil.enterBlock(
    163           "TranscriptionClientFactory.interceptCall, intercepted " + method.getFullMethodName());
    164       ClientCall<ReqT, RespT> call = next.newCall(method, callOptions);
    165 
    166       call =
    167           new ForwardingClientCall.SimpleForwardingClientCall<ReqT, RespT>(call) {
    168             @Override
    169             public void start(Listener<RespT> responseListener, Metadata headers) {
    170               if (!TextUtils.isEmpty(packageName)) {
    171                 LogUtil.i(
    172                     "TranscriptionClientFactory.interceptCall",
    173                     "attaching package name: " + packageName);
    174                 headers.put(ANDROID_PACKAGE_HEADER, packageName);
    175               }
    176               if (!TextUtils.isEmpty(cert)) {
    177                 LogUtil.i("TranscriptionClientFactory.interceptCall", "attaching android cert");
    178                 headers.put(ANDROID_CERT_HEADER, cert);
    179               }
    180               if (!TextUtils.isEmpty(apiKey)) {
    181                 LogUtil.i("TranscriptionClientFactory.interceptCall", "attaching API Key");
    182                 headers.put(API_KEY_HEADER, apiKey);
    183               }
    184               if (!TextUtils.isEmpty(authToken)) {
    185                 LogUtil.i("TranscriptionClientFactory.interceptCall", "attaching auth token");
    186                 headers.put(AUTHORIZATION_HEADER, "Bearer " + authToken);
    187               }
    188               super.start(responseListener, headers);
    189             }
    190           };
    191       return call;
    192     }
    193   }
    194 }
    195