Home | History | Annotate | Download | only in base
      1 // Copyright (c) 2008 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 // Written in NSPR style to also be suitable for adding to the NSS demo suite
      5 
      6 /* memio is a simple NSPR I/O layer that lets you decouple NSS from
      7  * the real network.  It's rather like openssl's memory bio,
      8  * and is useful when your app absolutely, positively doesn't
      9  * want to let NSS do its own networking.
     10  */
     11 
     12 #include <stdlib.h>
     13 #include <string.h>
     14 
     15 #include <prerror.h>
     16 #include <prinit.h>
     17 #include <prlog.h>
     18 
     19 #include "nss_memio.h"
     20 
     21 /*--------------- private memio types -----------------------*/
     22 
     23 /*----------------------------------------------------------------------
     24  Simple private circular buffer class.  Size cannot be changed once allocated.
     25 ----------------------------------------------------------------------*/
     26 
     27 struct memio_buffer {
     28     int head;     /* where to take next byte out of buf */
     29     int tail;     /* where to put next byte into buf */
     30     int bufsize;  /* number of bytes allocated to buf */
     31     /* TODO(port): error handling is pessimistic right now.
     32      * Once an error is set, the socket is considered broken
     33      * (PR_WOULD_BLOCK_ERROR not included).
     34      */
     35     PRErrorCode last_err;
     36     char *buf;
     37 };
     38 
     39 
     40 /* The 'secret' field of a PRFileDesc created by memio_CreateIOLayer points
     41  * to one of these.
     42  * In the public header, we use struct memio_Private as a typesafe alias
     43  * for this.  This causes a few ugly typecasts in the private file, but
     44  * seems safer.
     45  */
     46 struct PRFilePrivate {
     47     /* read requests are satisfied from this buffer */
     48     struct memio_buffer readbuf;
     49 
     50     /* write requests are satisfied from this buffer */
     51     struct memio_buffer writebuf;
     52 
     53     /* SSL needs to know socket peer's name */
     54     PRNetAddr peername;
     55 
     56     /* if set, empty I/O returns EOF instead of EWOULDBLOCK */
     57     int eof;
     58 
     59     /* if set, the number of bytes requested from readbuf that were not
     60      * fulfilled (due to readbuf being empty) */
     61     int read_requested;
     62 };
     63 
     64 /*--------------- private memio_buffer functions ---------------------*/
     65 
     66 /* Forward declarations.  */
     67 
     68 /* Allocate a memio_buffer of given size. */
     69 static void memio_buffer_new(struct memio_buffer *mb, int size);
     70 
     71 /* Deallocate a memio_buffer allocated by memio_buffer_new. */
     72 static void memio_buffer_destroy(struct memio_buffer *mb);
     73 
     74 /* How many bytes can be read out of the buffer without wrapping */
     75 static int memio_buffer_used_contiguous(const struct memio_buffer *mb);
     76 
     77 /* How many bytes exist after the wrap? */
     78 static int memio_buffer_wrapped_bytes(const struct memio_buffer *mb);
     79 
     80 /* How many bytes can be written into the buffer without wrapping */
     81 static int memio_buffer_unused_contiguous(const struct memio_buffer *mb);
     82 
     83 /* Write n bytes into the buffer.  Returns number of bytes written. */
     84 static int memio_buffer_put(struct memio_buffer *mb, const char *buf, int n);
     85 
     86 /* Read n bytes from the buffer.  Returns number of bytes read. */
     87 static int memio_buffer_get(struct memio_buffer *mb, char *buf, int n);
     88 
     89 /* Allocate a memio_buffer of given size. */
     90 static void memio_buffer_new(struct memio_buffer *mb, int size)
     91 {
     92     mb->head = 0;
     93     mb->tail = 0;
     94     mb->bufsize = size;
     95     mb->buf = malloc(size);
     96 }
     97 
     98 /* Deallocate a memio_buffer allocated by memio_buffer_new. */
     99 static void memio_buffer_destroy(struct memio_buffer *mb)
    100 {
    101     free(mb->buf);
    102     mb->buf = NULL;
    103     mb->bufsize = 0;
    104     mb->head = 0;
    105     mb->tail = 0;
    106 }
    107 
    108 /* How many bytes can be read out of the buffer without wrapping */
    109 static int memio_buffer_used_contiguous(const struct memio_buffer *mb)
    110 {
    111     return (((mb->tail >= mb->head) ? mb->tail : mb->bufsize) - mb->head);
    112 }
    113 
    114 /* How many bytes exist after the wrap? */
    115 static int memio_buffer_wrapped_bytes(const struct memio_buffer *mb)
    116 {
    117     return (mb->tail >= mb->head) ? 0 : mb->tail;
    118 }
    119 
    120 /* How many bytes can be written into the buffer without wrapping */
    121 static int memio_buffer_unused_contiguous(const struct memio_buffer *mb)
    122 {
    123     if (mb->head > mb->tail) return mb->head - mb->tail - 1;
    124     return mb->bufsize - mb->tail - (mb->head == 0);
    125 }
    126 
    127 /* Write n bytes into the buffer.  Returns number of bytes written. */
    128 static int memio_buffer_put(struct memio_buffer *mb, const char *buf, int n)
    129 {
    130     int len;
    131     int transferred = 0;
    132 
    133     /* Handle part before wrap */
    134     len = PR_MIN(n, memio_buffer_unused_contiguous(mb));
    135     if (len > 0) {
    136         /* Buffer not full */
    137         memcpy(&mb->buf[mb->tail], buf, len);
    138         mb->tail += len;
    139         if (mb->tail == mb->bufsize)
    140             mb->tail = 0;
    141         n -= len;
    142         buf += len;
    143         transferred += len;
    144 
    145         /* Handle part after wrap */
    146         len = PR_MIN(n, memio_buffer_unused_contiguous(mb));
    147         if (len > 0) {
    148             /* Output buffer still not full, input buffer still not empty */
    149             memcpy(&mb->buf[mb->tail], buf, len);
    150             mb->tail += len;
    151             if (mb->tail == mb->bufsize)
    152                 mb->tail = 0;
    153             transferred += len;
    154         }
    155     }
    156 
    157     return transferred;
    158 }
    159 
    160 
    161 /* Read n bytes from the buffer.  Returns number of bytes read. */
    162 static int memio_buffer_get(struct memio_buffer *mb, char *buf, int n)
    163 {
    164     int len;
    165     int transferred = 0;
    166 
    167     /* Handle part before wrap */
    168     len = PR_MIN(n, memio_buffer_used_contiguous(mb));
    169     if (len) {
    170         memcpy(buf, &mb->buf[mb->head], len);
    171         mb->head += len;
    172         if (mb->head == mb->bufsize)
    173             mb->head = 0;
    174         n -= len;
    175         buf += len;
    176         transferred += len;
    177 
    178         /* Handle part after wrap */
    179         len = PR_MIN(n, memio_buffer_used_contiguous(mb));
    180         if (len) {
    181             memcpy(buf, &mb->buf[mb->head], len);
    182             mb->head += len;
    183             if (mb->head == mb->bufsize)
    184                 mb->head = 0;
    185             transferred += len;
    186         }
    187     }
    188 
    189     return transferred;
    190 }
    191 
    192 /*--------------- private memio functions -----------------------*/
    193 
    194 static PRStatus PR_CALLBACK memio_Close(PRFileDesc *fd)
    195 {
    196     struct PRFilePrivate *secret = fd->secret;
    197     memio_buffer_destroy(&secret->readbuf);
    198     memio_buffer_destroy(&secret->writebuf);
    199     free(secret);
    200     fd->dtor(fd);
    201     return PR_SUCCESS;
    202 }
    203 
    204 static PRStatus PR_CALLBACK memio_Shutdown(PRFileDesc *fd, PRIntn how)
    205 {
    206     /* TODO: pass shutdown status to app somehow */
    207     return PR_SUCCESS;
    208 }
    209 
    210 /* If there was a network error in the past taking bytes
    211  * out of the buffer, return it to the next call that
    212  * tries to read from an empty buffer.
    213  */
    214 static int PR_CALLBACK memio_Recv(PRFileDesc *fd, void *buf, PRInt32 len,
    215                                   PRIntn flags, PRIntervalTime timeout)
    216 {
    217     struct PRFilePrivate *secret;
    218     struct memio_buffer *mb;
    219     int rv;
    220 
    221     if (flags) {
    222         PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0);
    223         return -1;
    224     }
    225 
    226     secret = fd->secret;
    227     mb = &secret->readbuf;
    228     PR_ASSERT(mb->bufsize);
    229     rv = memio_buffer_get(mb, buf, len);
    230     if (rv == 0 && !secret->eof) {
    231         secret->read_requested = len;
    232         /* If there is no more data in the buffer, report any pending errors
    233          * that were previously observed. Note that both the readbuf and the
    234          * writebuf are checked for errors, since the application may have
    235          * encountered a socket error while writing that would otherwise not
    236          * be reported until the application attempted to write again - which
    237          * it may never do.
    238          */
    239         if (mb->last_err)
    240             PR_SetError(mb->last_err, 0);
    241         else if (secret->writebuf.last_err)
    242             PR_SetError(secret->writebuf.last_err, 0);
    243         else
    244             PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
    245         return -1;
    246     }
    247 
    248     secret->read_requested = 0;
    249     return rv;
    250 }
    251 
    252 static int PR_CALLBACK memio_Read(PRFileDesc *fd, void *buf, PRInt32 len)
    253 {
    254     /* pull bytes from buffer */
    255     return memio_Recv(fd, buf, len, 0, PR_INTERVAL_NO_TIMEOUT);
    256 }
    257 
    258 static int PR_CALLBACK memio_Send(PRFileDesc *fd, const void *buf, PRInt32 len,
    259                                   PRIntn flags, PRIntervalTime timeout)
    260 {
    261     struct PRFilePrivate *secret;
    262     struct memio_buffer *mb;
    263     int rv;
    264 
    265     secret = fd->secret;
    266     mb = &secret->writebuf;
    267     PR_ASSERT(mb->bufsize);
    268 
    269     /* Note that the read error state is not reported, because it cannot be
    270      * reported until all buffered data has been read. If there is an error
    271      * with the next layer, attempting to call Send again will report the
    272      * error appropriately.
    273      */
    274     if (mb->last_err) {
    275         PR_SetError(mb->last_err, 0);
    276         return -1;
    277     }
    278     rv = memio_buffer_put(mb, buf, len);
    279     if (rv == 0) {
    280         PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
    281         return -1;
    282     }
    283     return rv;
    284 }
    285 
    286 static int PR_CALLBACK memio_Write(PRFileDesc *fd, const void *buf, PRInt32 len)
    287 {
    288     /* append bytes to buffer */
    289     return memio_Send(fd, buf, len, 0, PR_INTERVAL_NO_TIMEOUT);
    290 }
    291 
    292 static PRStatus PR_CALLBACK memio_GetPeerName(PRFileDesc *fd, PRNetAddr *addr)
    293 {
    294     /* TODO: fail if memio_SetPeerName has not been called */
    295     struct PRFilePrivate *secret = fd->secret;
    296     *addr = secret->peername;
    297     return PR_SUCCESS;
    298 }
    299 
    300 static PRStatus memio_GetSocketOption(PRFileDesc *fd, PRSocketOptionData *data)
    301 {
    302     /*
    303      * Even in the original version for real tcp sockets,
    304      * PR_SockOpt_Nonblocking is a special case that does not
    305      * translate to a getsockopt() call
    306      */
    307     if (PR_SockOpt_Nonblocking == data->option) {
    308         data->value.non_blocking = PR_TRUE;
    309         return PR_SUCCESS;
    310     }
    311     PR_SetError(PR_OPERATION_NOT_SUPPORTED_ERROR, 0);
    312     return PR_FAILURE;
    313 }
    314 
    315 /*--------------- private memio data -----------------------*/
    316 
    317 /*
    318  * Implement just the bare minimum number of methods needed to make ssl happy.
    319  *
    320  * Oddly, PR_Recv calls ssl_Recv calls ssl_SocketIsBlocking calls
    321  * PR_GetSocketOption, so we have to provide an implementation of
    322  * PR_GetSocketOption that just says "I'm nonblocking".
    323  */
    324 
    325 static struct PRIOMethods  memio_layer_methods = {
    326     PR_DESC_LAYERED,
    327     memio_Close,
    328     memio_Read,
    329     memio_Write,
    330     NULL,
    331     NULL,
    332     NULL,
    333     NULL,
    334     NULL,
    335     NULL,
    336     NULL,
    337     NULL,
    338     NULL,
    339     NULL,
    340     NULL,
    341     NULL,
    342     memio_Shutdown,
    343     memio_Recv,
    344     memio_Send,
    345     NULL,
    346     NULL,
    347     NULL,
    348     NULL,
    349     NULL,
    350     NULL,
    351     memio_GetPeerName,
    352     NULL,
    353     NULL,
    354     memio_GetSocketOption,
    355     NULL,
    356     NULL,
    357     NULL,
    358     NULL,
    359     NULL,
    360     NULL,
    361     NULL,
    362 };
    363 
    364 static PRDescIdentity memio_identity = PR_INVALID_IO_LAYER;
    365 
    366 static PRStatus memio_InitializeLayerName(void)
    367 {
    368     memio_identity = PR_GetUniqueIdentity("memio");
    369     return PR_SUCCESS;
    370 }
    371 
    372 /*--------------- public memio functions -----------------------*/
    373 
    374 PRFileDesc *memio_CreateIOLayer(int readbufsize, int writebufsize)
    375 {
    376     PRFileDesc *fd;
    377     struct PRFilePrivate *secret;
    378     static PRCallOnceType once;
    379 
    380     PR_CallOnce(&once, memio_InitializeLayerName);
    381 
    382     fd = PR_CreateIOLayerStub(memio_identity, &memio_layer_methods);
    383     secret = malloc(sizeof(struct PRFilePrivate));
    384     memset(secret, 0, sizeof(*secret));
    385 
    386     memio_buffer_new(&secret->readbuf, readbufsize);
    387     memio_buffer_new(&secret->writebuf, writebufsize);
    388     fd->secret = secret;
    389     return fd;
    390 }
    391 
    392 void memio_SetPeerName(PRFileDesc *fd, const PRNetAddr *peername)
    393 {
    394     PRFileDesc *memiofd = PR_GetIdentitiesLayer(fd, memio_identity);
    395     struct PRFilePrivate *secret = memiofd->secret;
    396     secret->peername = *peername;
    397 }
    398 
    399 memio_Private *memio_GetSecret(PRFileDesc *fd)
    400 {
    401     PRFileDesc *memiofd = PR_GetIdentitiesLayer(fd, memio_identity);
    402     struct PRFilePrivate *secret =  memiofd->secret;
    403     return (memio_Private *)secret;
    404 }
    405 
    406 int memio_GetReadRequest(memio_Private *secret)
    407 {
    408     return ((PRFilePrivate *)secret)->read_requested;
    409 }
    410 
    411 int memio_GetReadParams(memio_Private *secret, char **buf)
    412 {
    413     struct memio_buffer* mb = &((PRFilePrivate *)secret)->readbuf;
    414     PR_ASSERT(mb->bufsize);
    415 
    416     *buf = &mb->buf[mb->tail];
    417     return memio_buffer_unused_contiguous(mb);
    418 }
    419 
    420 int memio_GetReadableBufferSize(memio_Private *secret)
    421 {
    422     struct memio_buffer* mb = &((PRFilePrivate *)secret)->readbuf;
    423     PR_ASSERT(mb->bufsize);
    424 
    425     return memio_buffer_used_contiguous(mb);
    426 }
    427 
    428 void memio_PutReadResult(memio_Private *secret, int bytes_read)
    429 {
    430     struct memio_buffer* mb = &((PRFilePrivate *)secret)->readbuf;
    431     PR_ASSERT(mb->bufsize);
    432 
    433     if (bytes_read > 0) {
    434         mb->tail += bytes_read;
    435         if (mb->tail == mb->bufsize)
    436             mb->tail = 0;
    437     } else if (bytes_read == 0) {
    438         /* Record EOF condition and report to caller when buffer runs dry */
    439         ((PRFilePrivate *)secret)->eof = PR_TRUE;
    440     } else /* if (bytes_read < 0) */ {
    441         mb->last_err = bytes_read;
    442     }
    443 }
    444 
    445 int memio_GetWriteParams(memio_Private *secret,
    446                          const char **buf1, unsigned int *len1,
    447                          const char **buf2, unsigned int *len2)
    448 {
    449     struct memio_buffer* mb = &((PRFilePrivate *)secret)->writebuf;
    450     PR_ASSERT(mb->bufsize);
    451 
    452     if (mb->last_err)
    453         return mb->last_err;
    454 
    455     *buf1 = &mb->buf[mb->head];
    456     *len1 = memio_buffer_used_contiguous(mb);
    457     *buf2 = mb->buf;
    458     *len2 = memio_buffer_wrapped_bytes(mb);
    459     return 0;
    460 }
    461 
    462 void memio_PutWriteResult(memio_Private *secret, int bytes_written)
    463 {
    464     struct memio_buffer* mb = &((PRFilePrivate *)secret)->writebuf;
    465     PR_ASSERT(mb->bufsize);
    466 
    467     if (bytes_written > 0) {
    468         mb->head += bytes_written;
    469         if (mb->head >= mb->bufsize)
    470             mb->head -= mb->bufsize;
    471     } else if (bytes_written < 0) {
    472         mb->last_err = bytes_written;
    473     }
    474 }
    475 
    476 /*--------------- private memio_buffer self-test -----------------*/
    477 
    478 /* Even a trivial unit test is very helpful when doing circular buffers. */
    479 /*#define TRIVIAL_SELF_TEST*/
    480 #ifdef TRIVIAL_SELF_TEST
    481 #include <stdio.h>
    482 
    483 #define TEST_BUFLEN 7
    484 
    485 #define CHECKEQ(a, b) { \
    486     if ((a) != (b)) { \
    487         printf("%d != %d, Test failed line %d\n", a, b, __LINE__); \
    488         exit(1); \
    489     } \
    490 }
    491 
    492 int main()
    493 {
    494     struct memio_buffer mb;
    495     char buf[100];
    496     int i;
    497 
    498     memio_buffer_new(&mb, TEST_BUFLEN);
    499 
    500     CHECKEQ(memio_buffer_unused_contiguous(&mb), TEST_BUFLEN-1);
    501     CHECKEQ(memio_buffer_used_contiguous(&mb), 0);
    502 
    503     CHECKEQ(memio_buffer_put(&mb, "howdy", 5), 5);
    504 
    505     CHECKEQ(memio_buffer_unused_contiguous(&mb), TEST_BUFLEN-1-5);
    506     CHECKEQ(memio_buffer_used_contiguous(&mb), 5);
    507     CHECKEQ(memio_buffer_wrapped_bytes(&mb), 0);
    508 
    509     CHECKEQ(memio_buffer_put(&mb, "!", 1), 1);
    510 
    511     CHECKEQ(memio_buffer_unused_contiguous(&mb), 0);
    512     CHECKEQ(memio_buffer_used_contiguous(&mb), 6);
    513     CHECKEQ(memio_buffer_wrapped_bytes(&mb), 0);
    514 
    515     CHECKEQ(memio_buffer_get(&mb, buf, 6), 6);
    516     CHECKEQ(memcmp(buf, "howdy!", 6), 0);
    517 
    518     CHECKEQ(memio_buffer_unused_contiguous(&mb), 1);
    519     CHECKEQ(memio_buffer_used_contiguous(&mb), 0);
    520 
    521     CHECKEQ(memio_buffer_put(&mb, "01234", 5), 5);
    522 
    523     CHECKEQ(memio_buffer_used_contiguous(&mb), 1);
    524     CHECKEQ(memio_buffer_wrapped_bytes(&mb), 4);
    525     CHECKEQ(memio_buffer_unused_contiguous(&mb), TEST_BUFLEN-1-5);
    526 
    527     CHECKEQ(memio_buffer_put(&mb, "5", 1), 1);
    528 
    529     CHECKEQ(memio_buffer_unused_contiguous(&mb), 0);
    530     CHECKEQ(memio_buffer_used_contiguous(&mb), 1);
    531 
    532     /* TODO: add more cases */
    533 
    534     printf("Test passed\n");
    535     exit(0);
    536 }
    537 
    538 #endif
    539