Home | History | Annotate | Download | only in base
      1 // Copyright (c) 2008 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 // Written in NSPR style to also be suitable for adding to the NSS demo suite
      5 
      6 /* memio is a simple NSPR I/O layer that lets you decouple NSS from
      7  * the real network.  It's rather like openssl's memory bio,
      8  * and is useful when your app absolutely, positively doesn't
      9  * want to let NSS do its own networking.
     10  */
     11 
     12 #include <stdlib.h>
     13 #include <string.h>
     14 
     15 #include <prerror.h>
     16 #include <prinit.h>
     17 #include <prlog.h>
     18 
     19 #include "nss_memio.h"
     20 
     21 /*--------------- private memio types -----------------------*/
     22 
     23 /*----------------------------------------------------------------------
     24  Simple private circular buffer class.  Size cannot be changed once allocated.
     25 ----------------------------------------------------------------------*/
     26 
     27 struct memio_buffer {
     28     int head;     /* where to take next byte out of buf */
     29     int tail;     /* where to put next byte into buf */
     30     int bufsize;  /* number of bytes allocated to buf */
     31     /* TODO(port): error handling is pessimistic right now.
     32      * Once an error is set, the socket is considered broken
     33      * (PR_WOULD_BLOCK_ERROR not included).
     34      */
     35     PRErrorCode last_err;
     36     char *buf;
     37 };
     38 
     39 
     40 /* The 'secret' field of a PRFileDesc created by memio_CreateIOLayer points
     41  * to one of these.
     42  * In the public header, we use struct memio_Private as a typesafe alias
     43  * for this.  This causes a few ugly typecasts in the private file, but
     44  * seems safer.
     45  */
     46 struct PRFilePrivate {
     47     /* read requests are satisfied from this buffer */
     48     struct memio_buffer readbuf;
     49 
     50     /* write requests are satisfied from this buffer */
     51     struct memio_buffer writebuf;
     52 
     53     /* SSL needs to know socket peer's name */
     54     PRNetAddr peername;
     55 
     56     /* if set, empty I/O returns EOF instead of EWOULDBLOCK */
     57     int eof;
     58 
     59     /* if set, the number of bytes requested from readbuf that were not
     60      * fulfilled (due to readbuf being empty) */
     61     int read_requested;
     62 };
     63 
     64 /*--------------- private memio_buffer functions ---------------------*/
     65 
     66 /* Forward declarations.  */
     67 
     68 /* Allocate a memio_buffer of given size. */
     69 static void memio_buffer_new(struct memio_buffer *mb, int size);
     70 
     71 /* Deallocate a memio_buffer allocated by memio_buffer_new. */
     72 static void memio_buffer_destroy(struct memio_buffer *mb);
     73 
     74 /* How many bytes can be read out of the buffer without wrapping */
     75 static int memio_buffer_used_contiguous(const struct memio_buffer *mb);
     76 
     77 /* How many bytes exist after the wrap? */
     78 static int memio_buffer_wrapped_bytes(const struct memio_buffer *mb);
     79 
     80 /* How many bytes can be written into the buffer without wrapping */
     81 static int memio_buffer_unused_contiguous(const struct memio_buffer *mb);
     82 
     83 /* Write n bytes into the buffer.  Returns number of bytes written. */
     84 static int memio_buffer_put(struct memio_buffer *mb, const char *buf, int n);
     85 
     86 /* Read n bytes from the buffer.  Returns number of bytes read. */
     87 static int memio_buffer_get(struct memio_buffer *mb, char *buf, int n);
     88 
     89 /* Allocate a memio_buffer of given size. */
     90 static void memio_buffer_new(struct memio_buffer *mb, int size)
     91 {
     92     mb->head = 0;
     93     mb->tail = 0;
     94     mb->bufsize = size;
     95     mb->buf = malloc(size);
     96 }
     97 
     98 /* Deallocate a memio_buffer allocated by memio_buffer_new. */
     99 static void memio_buffer_destroy(struct memio_buffer *mb)
    100 {
    101     free(mb->buf);
    102     mb->buf = NULL;
    103     mb->head = 0;
    104     mb->tail = 0;
    105 }
    106 
    107 /* How many bytes can be read out of the buffer without wrapping */
    108 static int memio_buffer_used_contiguous(const struct memio_buffer *mb)
    109 {
    110     return (((mb->tail >= mb->head) ? mb->tail : mb->bufsize) - mb->head);
    111 }
    112 
    113 /* How many bytes exist after the wrap? */
    114 static int memio_buffer_wrapped_bytes(const struct memio_buffer *mb)
    115 {
    116     return (mb->tail >= mb->head) ? 0 : mb->tail;
    117 }
    118 
    119 /* How many bytes can be written into the buffer without wrapping */
    120 static int memio_buffer_unused_contiguous(const struct memio_buffer *mb)
    121 {
    122     if (mb->head > mb->tail) return mb->head - mb->tail - 1;
    123     return mb->bufsize - mb->tail - (mb->head == 0);
    124 }
    125 
    126 /* Write n bytes into the buffer.  Returns number of bytes written. */
    127 static int memio_buffer_put(struct memio_buffer *mb, const char *buf, int n)
    128 {
    129     int len;
    130     int transferred = 0;
    131 
    132     /* Handle part before wrap */
    133     len = PR_MIN(n, memio_buffer_unused_contiguous(mb));
    134     if (len > 0) {
    135         /* Buffer not full */
    136         memcpy(&mb->buf[mb->tail], buf, len);
    137         mb->tail += len;
    138         if (mb->tail == mb->bufsize)
    139             mb->tail = 0;
    140         n -= len;
    141         buf += len;
    142         transferred += len;
    143 
    144         /* Handle part after wrap */
    145         len = PR_MIN(n, memio_buffer_unused_contiguous(mb));
    146         if (len > 0) {
    147             /* Output buffer still not full, input buffer still not empty */
    148             memcpy(&mb->buf[mb->tail], buf, len);
    149             mb->tail += len;
    150             if (mb->tail == mb->bufsize)
    151                 mb->tail = 0;
    152             transferred += len;
    153         }
    154     }
    155 
    156     return transferred;
    157 }
    158 
    159 
    160 /* Read n bytes from the buffer.  Returns number of bytes read. */
    161 static int memio_buffer_get(struct memio_buffer *mb, char *buf, int n)
    162 {
    163     int len;
    164     int transferred = 0;
    165 
    166     /* Handle part before wrap */
    167     len = PR_MIN(n, memio_buffer_used_contiguous(mb));
    168     if (len) {
    169         memcpy(buf, &mb->buf[mb->head], len);
    170         mb->head += len;
    171         if (mb->head == mb->bufsize)
    172             mb->head = 0;
    173         n -= len;
    174         buf += len;
    175         transferred += len;
    176 
    177         /* Handle part after wrap */
    178         len = PR_MIN(n, memio_buffer_used_contiguous(mb));
    179         if (len) {
    180             memcpy(buf, &mb->buf[mb->head], len);
    181             mb->head += len;
    182             if (mb->head == mb->bufsize)
    183                 mb->head = 0;
    184             transferred += len;
    185         }
    186     }
    187 
    188     return transferred;
    189 }
    190 
    191 /*--------------- private memio functions -----------------------*/
    192 
    193 static PRStatus PR_CALLBACK memio_Close(PRFileDesc *fd)
    194 {
    195     struct PRFilePrivate *secret = fd->secret;
    196     memio_buffer_destroy(&secret->readbuf);
    197     memio_buffer_destroy(&secret->writebuf);
    198     free(secret);
    199     fd->dtor(fd);
    200     return PR_SUCCESS;
    201 }
    202 
    203 static PRStatus PR_CALLBACK memio_Shutdown(PRFileDesc *fd, PRIntn how)
    204 {
    205     /* TODO: pass shutdown status to app somehow */
    206     return PR_SUCCESS;
    207 }
    208 
    209 /* If there was a network error in the past taking bytes
    210  * out of the buffer, return it to the next call that
    211  * tries to read from an empty buffer.
    212  */
    213 static int PR_CALLBACK memio_Recv(PRFileDesc *fd, void *buf, PRInt32 len,
    214                                   PRIntn flags, PRIntervalTime timeout)
    215 {
    216     struct PRFilePrivate *secret;
    217     struct memio_buffer *mb;
    218     int rv;
    219 
    220     if (flags) {
    221         PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0);
    222         return -1;
    223     }
    224 
    225     secret = fd->secret;
    226     mb = &secret->readbuf;
    227     PR_ASSERT(mb->bufsize);
    228     rv = memio_buffer_get(mb, buf, len);
    229     if (rv == 0 && !secret->eof) {
    230         secret->read_requested = len;
    231         /* If there is no more data in the buffer, report any pending errors
    232          * that were previously observed. Note that both the readbuf and the
    233          * writebuf are checked for errors, since the application may have
    234          * encountered a socket error while writing that would otherwise not
    235          * be reported until the application attempted to write again - which
    236          * it may never do.
    237          */
    238         if (mb->last_err)
    239             PR_SetError(mb->last_err, 0);
    240         else if (secret->writebuf.last_err)
    241             PR_SetError(secret->writebuf.last_err, 0);
    242         else
    243             PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
    244         return -1;
    245     }
    246 
    247     secret->read_requested = 0;
    248     return rv;
    249 }
    250 
    251 static int PR_CALLBACK memio_Read(PRFileDesc *fd, void *buf, PRInt32 len)
    252 {
    253     /* pull bytes from buffer */
    254     return memio_Recv(fd, buf, len, 0, PR_INTERVAL_NO_TIMEOUT);
    255 }
    256 
    257 static int PR_CALLBACK memio_Send(PRFileDesc *fd, const void *buf, PRInt32 len,
    258                                   PRIntn flags, PRIntervalTime timeout)
    259 {
    260     struct PRFilePrivate *secret;
    261     struct memio_buffer *mb;
    262     int rv;
    263 
    264     secret = fd->secret;
    265     mb = &secret->writebuf;
    266     PR_ASSERT(mb->bufsize);
    267 
    268     /* Note that the read error state is not reported, because it cannot be
    269      * reported until all buffered data has been read. If there is an error
    270      * with the next layer, attempting to call Send again will report the
    271      * error appropriately.
    272      */
    273     if (mb->last_err) {
    274         PR_SetError(mb->last_err, 0);
    275         return -1;
    276     }
    277     rv = memio_buffer_put(mb, buf, len);
    278     if (rv == 0) {
    279         PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
    280         return -1;
    281     }
    282     return rv;
    283 }
    284 
    285 static int PR_CALLBACK memio_Write(PRFileDesc *fd, const void *buf, PRInt32 len)
    286 {
    287     /* append bytes to buffer */
    288     return memio_Send(fd, buf, len, 0, PR_INTERVAL_NO_TIMEOUT);
    289 }
    290 
    291 static PRStatus PR_CALLBACK memio_GetPeerName(PRFileDesc *fd, PRNetAddr *addr)
    292 {
    293     /* TODO: fail if memio_SetPeerName has not been called */
    294     struct PRFilePrivate *secret = fd->secret;
    295     *addr = secret->peername;
    296     return PR_SUCCESS;
    297 }
    298 
    299 static PRStatus memio_GetSocketOption(PRFileDesc *fd, PRSocketOptionData *data)
    300 {
    301     /*
    302      * Even in the original version for real tcp sockets,
    303      * PR_SockOpt_Nonblocking is a special case that does not
    304      * translate to a getsockopt() call
    305      */
    306     if (PR_SockOpt_Nonblocking == data->option) {
    307         data->value.non_blocking = PR_TRUE;
    308         return PR_SUCCESS;
    309     }
    310     PR_SetError(PR_OPERATION_NOT_SUPPORTED_ERROR, 0);
    311     return PR_FAILURE;
    312 }
    313 
    314 /*--------------- private memio data -----------------------*/
    315 
    316 /*
    317  * Implement just the bare minimum number of methods needed to make ssl happy.
    318  *
    319  * Oddly, PR_Recv calls ssl_Recv calls ssl_SocketIsBlocking calls
    320  * PR_GetSocketOption, so we have to provide an implementation of
    321  * PR_GetSocketOption that just says "I'm nonblocking".
    322  */
    323 
    324 static struct PRIOMethods  memio_layer_methods = {
    325     PR_DESC_LAYERED,
    326     memio_Close,
    327     memio_Read,
    328     memio_Write,
    329     NULL,
    330     NULL,
    331     NULL,
    332     NULL,
    333     NULL,
    334     NULL,
    335     NULL,
    336     NULL,
    337     NULL,
    338     NULL,
    339     NULL,
    340     NULL,
    341     memio_Shutdown,
    342     memio_Recv,
    343     memio_Send,
    344     NULL,
    345     NULL,
    346     NULL,
    347     NULL,
    348     NULL,
    349     NULL,
    350     memio_GetPeerName,
    351     NULL,
    352     NULL,
    353     memio_GetSocketOption,
    354     NULL,
    355     NULL,
    356     NULL,
    357     NULL,
    358     NULL,
    359     NULL,
    360     NULL,
    361 };
    362 
    363 static PRDescIdentity memio_identity = PR_INVALID_IO_LAYER;
    364 
    365 static PRStatus memio_InitializeLayerName(void)
    366 {
    367     memio_identity = PR_GetUniqueIdentity("memio");
    368     return PR_SUCCESS;
    369 }
    370 
    371 /*--------------- public memio functions -----------------------*/
    372 
    373 PRFileDesc *memio_CreateIOLayer(int readbufsize, int writebufsize)
    374 {
    375     PRFileDesc *fd;
    376     struct PRFilePrivate *secret;
    377     static PRCallOnceType once;
    378 
    379     PR_CallOnce(&once, memio_InitializeLayerName);
    380 
    381     fd = PR_CreateIOLayerStub(memio_identity, &memio_layer_methods);
    382     secret = malloc(sizeof(struct PRFilePrivate));
    383     memset(secret, 0, sizeof(*secret));
    384 
    385     memio_buffer_new(&secret->readbuf, readbufsize);
    386     memio_buffer_new(&secret->writebuf, writebufsize);
    387     fd->secret = secret;
    388     return fd;
    389 }
    390 
    391 void memio_SetPeerName(PRFileDesc *fd, const PRNetAddr *peername)
    392 {
    393     PRFileDesc *memiofd = PR_GetIdentitiesLayer(fd, memio_identity);
    394     struct PRFilePrivate *secret = memiofd->secret;
    395     secret->peername = *peername;
    396 }
    397 
    398 memio_Private *memio_GetSecret(PRFileDesc *fd)
    399 {
    400     PRFileDesc *memiofd = PR_GetIdentitiesLayer(fd, memio_identity);
    401     struct PRFilePrivate *secret =  memiofd->secret;
    402     return (memio_Private *)secret;
    403 }
    404 
    405 int memio_GetReadRequest(memio_Private *secret)
    406 {
    407     return ((PRFilePrivate *)secret)->read_requested;
    408 }
    409 
    410 int memio_GetReadParams(memio_Private *secret, char **buf)
    411 {
    412     struct memio_buffer* mb = &((PRFilePrivate *)secret)->readbuf;
    413     PR_ASSERT(mb->bufsize);
    414 
    415     *buf = &mb->buf[mb->tail];
    416     return memio_buffer_unused_contiguous(mb);
    417 }
    418 
    419 int memio_GetReadableBufferSize(memio_Private *secret)
    420 {
    421     struct memio_buffer* mb = &((PRFilePrivate *)secret)->readbuf;
    422     PR_ASSERT(mb->bufsize);
    423 
    424     return memio_buffer_used_contiguous(mb);
    425 }
    426 
    427 void memio_PutReadResult(memio_Private *secret, int bytes_read)
    428 {
    429     struct memio_buffer* mb = &((PRFilePrivate *)secret)->readbuf;
    430     PR_ASSERT(mb->bufsize);
    431 
    432     if (bytes_read > 0) {
    433         mb->tail += bytes_read;
    434         if (mb->tail == mb->bufsize)
    435             mb->tail = 0;
    436     } else if (bytes_read == 0) {
    437         /* Record EOF condition and report to caller when buffer runs dry */
    438         ((PRFilePrivate *)secret)->eof = PR_TRUE;
    439     } else /* if (bytes_read < 0) */ {
    440         mb->last_err = bytes_read;
    441     }
    442 }
    443 
    444 void memio_GetWriteParams(memio_Private *secret,
    445                           const char **buf1, unsigned int *len1,
    446                           const char **buf2, unsigned int *len2)
    447 {
    448     struct memio_buffer* mb = &((PRFilePrivate *)secret)->writebuf;
    449     PR_ASSERT(mb->bufsize);
    450 
    451     *buf1 = &mb->buf[mb->head];
    452     *len1 = memio_buffer_used_contiguous(mb);
    453     *buf2 = mb->buf;
    454     *len2 = memio_buffer_wrapped_bytes(mb);
    455 }
    456 
    457 void memio_PutWriteResult(memio_Private *secret, int bytes_written)
    458 {
    459     struct memio_buffer* mb = &((PRFilePrivate *)secret)->writebuf;
    460     PR_ASSERT(mb->bufsize);
    461 
    462     if (bytes_written > 0) {
    463         mb->head += bytes_written;
    464         if (mb->head >= mb->bufsize)
    465             mb->head -= mb->bufsize;
    466     } else if (bytes_written < 0) {
    467         mb->last_err = bytes_written;
    468     }
    469 }
    470 
    471 /*--------------- private memio_buffer self-test -----------------*/
    472 
    473 /* Even a trivial unit test is very helpful when doing circular buffers. */
    474 /*#define TRIVIAL_SELF_TEST*/
    475 #ifdef TRIVIAL_SELF_TEST
    476 #include <stdio.h>
    477 
    478 #define TEST_BUFLEN 7
    479 
    480 #define CHECKEQ(a, b) { \
    481     if ((a) != (b)) { \
    482         printf("%d != %d, Test failed line %d\n", a, b, __LINE__); \
    483         exit(1); \
    484     } \
    485 }
    486 
    487 int main()
    488 {
    489     struct memio_buffer mb;
    490     char buf[100];
    491     int i;
    492 
    493     memio_buffer_new(&mb, TEST_BUFLEN);
    494 
    495     CHECKEQ(memio_buffer_unused_contiguous(&mb), TEST_BUFLEN-1);
    496     CHECKEQ(memio_buffer_used_contiguous(&mb), 0);
    497 
    498     CHECKEQ(memio_buffer_put(&mb, "howdy", 5), 5);
    499 
    500     CHECKEQ(memio_buffer_unused_contiguous(&mb), TEST_BUFLEN-1-5);
    501     CHECKEQ(memio_buffer_used_contiguous(&mb), 5);
    502     CHECKEQ(memio_buffer_wrapped_bytes(&mb), 0);
    503 
    504     CHECKEQ(memio_buffer_put(&mb, "!", 1), 1);
    505 
    506     CHECKEQ(memio_buffer_unused_contiguous(&mb), 0);
    507     CHECKEQ(memio_buffer_used_contiguous(&mb), 6);
    508     CHECKEQ(memio_buffer_wrapped_bytes(&mb), 0);
    509 
    510     CHECKEQ(memio_buffer_get(&mb, buf, 6), 6);
    511     CHECKEQ(memcmp(buf, "howdy!", 6), 0);
    512 
    513     CHECKEQ(memio_buffer_unused_contiguous(&mb), 1);
    514     CHECKEQ(memio_buffer_used_contiguous(&mb), 0);
    515 
    516     CHECKEQ(memio_buffer_put(&mb, "01234", 5), 5);
    517 
    518     CHECKEQ(memio_buffer_used_contiguous(&mb), 1);
    519     CHECKEQ(memio_buffer_wrapped_bytes(&mb), 4);
    520     CHECKEQ(memio_buffer_unused_contiguous(&mb), TEST_BUFLEN-1-5);
    521 
    522     CHECKEQ(memio_buffer_put(&mb, "5", 1), 1);
    523 
    524     CHECKEQ(memio_buffer_unused_contiguous(&mb), 0);
    525     CHECKEQ(memio_buffer_used_contiguous(&mb), 1);
    526 
    527     /* TODO: add more cases */
    528 
    529     printf("Test passed\n");
    530     exit(0);
    531 }
    532 
    533 #endif
    534