Home | History | Annotate | Download | only in mtpd
      1 /*
      2  * Copyright (C) 2009 The Android Open Source Project
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  *      http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 #include <stdio.h>
     18 #include <stdlib.h>
     19 #include <stdarg.h>
     20 #include <string.h>
     21 #include <errno.h>
     22 #include <sys/types.h>
     23 #include <sys/socket.h>
     24 #include <sys/poll.h>
     25 #include <sys/wait.h>
     26 #include <netdb.h>
     27 #include <signal.h>
     28 #include <unistd.h>
     29 #include <fcntl.h>
     30 #include <time.h>
     31 
     32 #ifdef ANDROID_CHANGES
     33 #include <android/log.h>
     34 #include <cutils/sockets.h>
     35 #include "keystore_get.h"
     36 #endif
     37 
     38 #include "mtpd.h"
     39 
     40 int the_socket = -1;
     41 
     42 extern struct protocol l2tp;
     43 extern struct protocol pptp;
     44 static struct protocol *protocols[] = {&l2tp, &pptp, NULL};
     45 static struct protocol *the_protocol;
     46 
     47 static int pppd_argc;
     48 static char **pppd_argv;
     49 static pid_t pppd_pid;
     50 
     51 /* We redirect signals to a pipe in order to prevent race conditions. */
     52 static int signals[2];
     53 
     54 static void interrupt(int signal)
     55 {
     56     write(signals[1], &signal, sizeof(int));
     57 }
     58 
     59 static int initialize(int argc, char **argv)
     60 {
     61     int timeout = 0;
     62     int i;
     63 
     64     for (i = 2; i < argc; ++i) {
     65         if (!argv[i][0]) {
     66             pppd_argc = argc - i - 1;
     67             pppd_argv = &argv[i + 1];
     68             argc = i;
     69             break;
     70         }
     71     }
     72 
     73     if (argc >= 2) {
     74         for (i = 0; protocols[i]; ++i) {
     75             if (!strcmp(argv[1], protocols[i]->name)) {
     76                 log_print(INFO, "Using protocol %s", protocols[i]->name);
     77                 the_protocol = protocols[i];
     78                 timeout = the_protocol->connect(argc - 2, &argv[2]);
     79                 break;
     80             }
     81         }
     82     }
     83 
     84     if (!the_protocol || timeout == -USAGE_ERROR) {
     85         printf("Usage: %s <protocol-args> '' <pppd-args>, "
     86                "where protocol-args are one of:\n", argv[0]);
     87         for (i = 0; protocols[i]; ++i) {
     88             printf("       %s %s\n", protocols[i]->name, protocols[i]->usage);
     89         }
     90         exit(USAGE_ERROR);
     91     }
     92     return timeout;
     93 }
     94 
     95 static void stop_pppd()
     96 {
     97     if (pppd_pid) {
     98         log_print(INFO, "Sending signal to pppd (pid = %d)", pppd_pid);
     99         kill(pppd_pid, SIGTERM);
    100         sleep(5);
    101         pppd_pid = 0;
    102     }
    103 }
    104 
    105 #ifdef ANDROID_CHANGES
    106 
    107 static int get_control_and_arguments(int *argc, char ***argv)
    108 {
    109     static char *args[256];
    110     int control;
    111     int i;
    112 
    113     if ((i = android_get_control_socket("mtpd")) == -1) {
    114         return -1;
    115     }
    116     log_print(DEBUG, "Waiting for control socket");
    117     if (listen(i, 1) == -1 || (control = accept(i, NULL, 0)) == -1) {
    118         log_print(FATAL, "Cannot get control socket");
    119         exit(SYSTEM_ERROR);
    120     }
    121     close(i);
    122     fcntl(control, F_SETFD, FD_CLOEXEC);
    123 
    124     args[0] = (*argv)[0];
    125     for (i = 1; i < 256; ++i) {
    126         unsigned char length;
    127         if (recv(control, &length, 1, 0) != 1) {
    128             log_print(FATAL, "Cannot get argument length");
    129             exit(SYSTEM_ERROR);
    130         }
    131         if (length == 0xFF) {
    132             break;
    133         } else {
    134             int offset = 0;
    135             args[i] = malloc(length + 1);
    136             while (offset < length) {
    137                 int n = recv(control, &args[i][offset], length - offset, 0);
    138                 if (n > 0) {
    139                     offset += n;
    140                 } else {
    141                     log_print(FATAL, "Cannot get argument value");
    142                     exit(SYSTEM_ERROR);
    143                 }
    144             }
    145             args[i][length] = 0;
    146         }
    147     }
    148     log_print(DEBUG, "Received %d arguments", i - 1);
    149 
    150     /* L2TP secret is the only thing stored in keystore. We do the query here
    151      * so other files are clean and free from android specific code. */
    152     if (i > 4 && !strcmp("l2tp", args[1]) && args[4][0]) {
    153         char value[KEYSTORE_MESSAGE_SIZE];
    154         int length = keystore_get(args[4], strlen(args[4]), value);
    155         if (length == -1) {
    156             log_print(FATAL, "Cannot get L2TP secret from keystore");
    157             exit(SYSTEM_ERROR);
    158         }
    159         free(args[4]);
    160         args[4] = malloc(length + 1);
    161         memcpy(args[4], value, length);
    162         args[4][length] = 0;
    163     }
    164 
    165     *argc = i;
    166     *argv = args;
    167     return control;
    168 }
    169 
    170 #endif
    171 
    172 int main(int argc, char **argv)
    173 {
    174     struct pollfd pollfds[2];
    175     int timeout;
    176     int status;
    177 #ifdef ANDROID_CHANGES
    178     int control = get_control_and_arguments(&argc, &argv);
    179     unsigned char code = argc - 1;
    180     send(control, &code, 1, 0);
    181 #endif
    182 
    183     srandom(time(NULL));
    184 
    185     if (pipe(signals) == -1) {
    186         log_print(FATAL, "Pipe() %s", strerror(errno));
    187         exit(SYSTEM_ERROR);
    188     }
    189     fcntl(signals[0], F_SETFD, FD_CLOEXEC);
    190     fcntl(signals[1], F_SETFD, FD_CLOEXEC);
    191 
    192     timeout = initialize(argc, argv);
    193 
    194     signal(SIGHUP, interrupt);
    195     signal(SIGINT, interrupt);
    196     signal(SIGTERM, interrupt);
    197     signal(SIGCHLD, interrupt);
    198     signal(SIGPIPE, SIG_IGN);
    199     atexit(stop_pppd);
    200 
    201     pollfds[0].fd = signals[0];
    202     pollfds[0].events = POLLIN;
    203     pollfds[1].fd = the_socket;
    204     pollfds[1].events = POLLIN;
    205 
    206     while (timeout >= 0) {
    207         if (poll(pollfds, 2, timeout ? timeout : -1) == -1 && errno != EINTR) {
    208             log_print(FATAL, "Poll() %s", strerror(errno));
    209             exit(SYSTEM_ERROR);
    210         }
    211         if (pollfds[0].revents) {
    212             break;
    213         }
    214         timeout = pollfds[1].revents ?
    215             the_protocol->process() : the_protocol->timeout();
    216     }
    217 
    218     if (timeout < 0) {
    219         status = -timeout;
    220     } else {
    221         int signal;
    222         read(signals[0], &signal, sizeof(int));
    223         log_print(INFO, "Received signal %d", signal);
    224         if (signal == SIGCHLD && waitpid(pppd_pid, &status, WNOHANG) == pppd_pid
    225             && WIFEXITED(status)) {
    226             status = WEXITSTATUS(status);
    227             log_print(INFO, "Pppd is terminated (status = %d)", status);
    228             status += PPPD_EXITED;
    229             pppd_pid = 0;
    230         } else {
    231             status = USER_REQUESTED;
    232         }
    233     }
    234 
    235     stop_pppd();
    236     the_protocol->shutdown();
    237 
    238 #ifdef ANDROID_CHANGES
    239     code = status;
    240     send(control, &code, 1, 0);
    241 #endif
    242     log_print(INFO, "Mtpd is terminated (status = %d)", status);
    243     return status;
    244 }
    245 
    246 void log_print(int level, char *format, ...)
    247 {
    248     if (level >= 0 && level <= LOG_MAX) {
    249 #ifdef ANDROID_CHANGES
    250         static int levels[5] = {
    251             ANDROID_LOG_DEBUG, ANDROID_LOG_INFO, ANDROID_LOG_WARN,
    252             ANDROID_LOG_ERROR, ANDROID_LOG_FATAL
    253         };
    254         va_list ap;
    255         va_start(ap, format);
    256         __android_log_vprint(levels[level], "mtpd", format, ap);
    257         va_end(ap);
    258 #else
    259         static char *levels = "DIWEF";
    260         va_list ap;
    261         fprintf(stderr, "%c: ", levels[level]);
    262         va_start(ap, format);
    263         vfprintf(stderr, format, ap);
    264         va_end(ap);
    265         fputc('\n', stderr);
    266 #endif
    267     }
    268 }
    269 
    270 void create_socket(int family, int type, char *server, char *port)
    271 {
    272     struct addrinfo hints = {
    273         .ai_flags = AI_NUMERICSERV,
    274         .ai_family = family,
    275         .ai_socktype = type,
    276     };
    277     struct addrinfo *records;
    278     struct addrinfo *r;
    279     int error;
    280 
    281     log_print(INFO, "Connecting to %s port %s", server, port);
    282 
    283     error = getaddrinfo(server, port, &hints, &records);
    284     if (error) {
    285         log_print(FATAL, "Getaddrinfo() %s", (error == EAI_SYSTEM) ?
    286                   strerror(errno) : gai_strerror(error));
    287         exit(NETWORK_ERROR);
    288     }
    289 
    290     for (r = records; r; r = r->ai_next) {
    291         the_socket = socket(r->ai_family, r->ai_socktype, r->ai_protocol);
    292         if (the_socket != -1) {
    293             if (connect(the_socket, r->ai_addr, r->ai_addrlen) == 0) {
    294                 break;
    295             }
    296             close(the_socket);
    297             the_socket = -1;
    298         }
    299     }
    300 
    301     freeaddrinfo(records);
    302 
    303     if (the_socket == -1) {
    304         log_print(FATAL, "Connect() %s", strerror(errno));
    305         exit(NETWORK_ERROR);
    306     }
    307 
    308     fcntl(the_socket, F_SETFD, FD_CLOEXEC);
    309     log_print(INFO, "Connection established (socket = %d)", the_socket);
    310 }
    311 
    312 void start_pppd(int pppox)
    313 {
    314     if (pppd_pid) {
    315         log_print(WARNING, "Pppd is already started (pid = %d)", pppd_pid);
    316         close(pppox);
    317         return;
    318     }
    319 
    320     log_print(INFO, "Starting pppd (pppox = %d)", pppox);
    321 
    322     pppd_pid = fork();
    323     if (pppd_pid < 0) {
    324         log_print(FATAL, "Fork() %s", strerror(errno));
    325         exit(SYSTEM_ERROR);
    326     }
    327 
    328     if (!pppd_pid) {
    329         char *args[pppd_argc + 5];
    330         char number[12];
    331 
    332         sprintf(number, "%d", pppox);
    333         args[0] = "pppd";
    334         args[1] = "nodetach";
    335         args[2] = "pppox";
    336         args[3] = number;
    337         memcpy(&args[4], pppd_argv, sizeof(char *) * pppd_argc);
    338         args[4 + pppd_argc] = NULL;
    339 
    340 #ifdef ANDROID_CHANGES
    341         {
    342             char envargs[65536];
    343             char *tail = envargs;
    344             int i;
    345             /* Hex encode the arguments using [A-P] instead of [0-9A-F]. */
    346             for (i = 0; args[i]; ++i) {
    347                 char *p = args[i];
    348                 do {
    349                     *tail++ = 'A' + ((*p >> 4) & 0x0F);
    350                     *tail++ = 'A' + (*p & 0x0F);
    351                 } while (*p++);
    352             }
    353             *tail = 0;
    354             setenv("envargs", envargs, 1);
    355             args[1] = NULL;
    356         }
    357 #endif
    358         execvp("pppd", args);
    359         log_print(FATAL, "Exec() %s", strerror(errno));
    360         exit(1); /* Pretending a fatal error in pppd. */
    361     }
    362 
    363     log_print(INFO, "Pppd started (pid = %d)", pppd_pid);
    364     close(pppox);
    365 }
    366