Home | History | Annotate | Download | only in base
      1 // Copyright (c) 2008 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 // Written in NSPR style to also be suitable for adding to the NSS demo suite
      5 
      6 /* memio is a simple NSPR I/O layer that lets you decouple NSS from
      7  * the real network.  It's rather like openssl's memory bio,
      8  * and is useful when your app absolutely, positively doesn't
      9  * want to let NSS do its own networking.
     10  */
     11 
     12 #include <stdlib.h>
     13 #include <string.h>
     14 
     15 #include <prerror.h>
     16 #include <prinit.h>
     17 #include <prlog.h>
     18 
     19 #include "nss_memio.h"
     20 
     21 /*--------------- private memio types -----------------------*/
     22 
     23 /*----------------------------------------------------------------------
     24  Simple private circular buffer class.  Size cannot be changed once allocated.
     25 ----------------------------------------------------------------------*/
     26 
     27 struct memio_buffer {
     28     int head;     /* where to take next byte out of buf */
     29     int tail;     /* where to put next byte into buf */
     30     int bufsize;  /* number of bytes allocated to buf */
     31     /* TODO(port): error handling is pessimistic right now.
     32      * Once an error is set, the socket is considered broken
     33      * (PR_WOULD_BLOCK_ERROR not included).
     34      */
     35     PRErrorCode last_err;
     36     char *buf;
     37 };
     38 
     39 
     40 /* The 'secret' field of a PRFileDesc created by memio_CreateIOLayer points
     41  * to one of these.
     42  * In the public header, we use struct memio_Private as a typesafe alias
     43  * for this.  This causes a few ugly typecasts in the private file, but
     44  * seems safer.
     45  */
     46 struct PRFilePrivate {
     47     /* read requests are satisfied from this buffer */
     48     struct memio_buffer readbuf;
     49 
     50     /* write requests are satisfied from this buffer */
     51     struct memio_buffer writebuf;
     52 
     53     /* SSL needs to know socket peer's name */
     54     PRNetAddr peername;
     55 
     56     /* if set, empty I/O returns EOF instead of EWOULDBLOCK */
     57     int eof;
     58 };
     59 
     60 /*--------------- private memio_buffer functions ---------------------*/
     61 
     62 /* Forward declarations.  */
     63 
     64 /* Allocate a memio_buffer of given size. */
     65 static void memio_buffer_new(struct memio_buffer *mb, int size);
     66 
     67 /* Deallocate a memio_buffer allocated by memio_buffer_new. */
     68 static void memio_buffer_destroy(struct memio_buffer *mb);
     69 
     70 /* How many bytes can be read out of the buffer without wrapping */
     71 static int memio_buffer_used_contiguous(const struct memio_buffer *mb);
     72 
     73 /* How many bytes exist after the wrap? */
     74 static int memio_buffer_wrapped_bytes(const struct memio_buffer *mb);
     75 
     76 /* How many bytes can be written into the buffer without wrapping */
     77 static int memio_buffer_unused_contiguous(const struct memio_buffer *mb);
     78 
     79 /* Write n bytes into the buffer.  Returns number of bytes written. */
     80 static int memio_buffer_put(struct memio_buffer *mb, const char *buf, int n);
     81 
     82 /* Read n bytes from the buffer.  Returns number of bytes read. */
     83 static int memio_buffer_get(struct memio_buffer *mb, char *buf, int n);
     84 
     85 /* Allocate a memio_buffer of given size. */
     86 static void memio_buffer_new(struct memio_buffer *mb, int size)
     87 {
     88     mb->head = 0;
     89     mb->tail = 0;
     90     mb->bufsize = size;
     91     mb->buf = malloc(size);
     92 }
     93 
     94 /* Deallocate a memio_buffer allocated by memio_buffer_new. */
     95 static void memio_buffer_destroy(struct memio_buffer *mb)
     96 {
     97     free(mb->buf);
     98     mb->buf = NULL;
     99     mb->head = 0;
    100     mb->tail = 0;
    101 }
    102 
    103 /* How many bytes can be read out of the buffer without wrapping */
    104 static int memio_buffer_used_contiguous(const struct memio_buffer *mb)
    105 {
    106     return (((mb->tail >= mb->head) ? mb->tail : mb->bufsize) - mb->head);
    107 }
    108 
    109 /* How many bytes exist after the wrap? */
    110 static int memio_buffer_wrapped_bytes(const struct memio_buffer *mb)
    111 {
    112     return (mb->tail >= mb->head) ? 0 : mb->tail;
    113 }
    114 
    115 /* How many bytes can be written into the buffer without wrapping */
    116 static int memio_buffer_unused_contiguous(const struct memio_buffer *mb)
    117 {
    118     if (mb->head > mb->tail) return mb->head - mb->tail - 1;
    119     return mb->bufsize - mb->tail - (mb->head == 0);
    120 }
    121 
    122 /* Write n bytes into the buffer.  Returns number of bytes written. */
    123 static int memio_buffer_put(struct memio_buffer *mb, const char *buf, int n)
    124 {
    125     int len;
    126     int transferred = 0;
    127 
    128     /* Handle part before wrap */
    129     len = PR_MIN(n, memio_buffer_unused_contiguous(mb));
    130     if (len > 0) {
    131         /* Buffer not full */
    132         memcpy(&mb->buf[mb->tail], buf, len);
    133         mb->tail += len;
    134         if (mb->tail == mb->bufsize)
    135             mb->tail = 0;
    136         n -= len;
    137         buf += len;
    138         transferred += len;
    139 
    140         /* Handle part after wrap */
    141         len = PR_MIN(n, memio_buffer_unused_contiguous(mb));
    142         if (len > 0) {
    143             /* Output buffer still not full, input buffer still not empty */
    144             memcpy(&mb->buf[mb->tail], buf, len);
    145             mb->tail += len;
    146             if (mb->tail == mb->bufsize)
    147                 mb->tail = 0;
    148                 transferred += len;
    149         }
    150     }
    151 
    152     return transferred;
    153 }
    154 
    155 
    156 /* Read n bytes from the buffer.  Returns number of bytes read. */
    157 static int memio_buffer_get(struct memio_buffer *mb, char *buf, int n)
    158 {
    159     int len;
    160     int transferred = 0;
    161 
    162     /* Handle part before wrap */
    163     len = PR_MIN(n, memio_buffer_used_contiguous(mb));
    164     if (len) {
    165         memcpy(buf, &mb->buf[mb->head], len);
    166         mb->head += len;
    167         if (mb->head == mb->bufsize)
    168             mb->head = 0;
    169         n -= len;
    170         buf += len;
    171         transferred += len;
    172 
    173         /* Handle part after wrap */
    174         len = PR_MIN(n, memio_buffer_used_contiguous(mb));
    175         if (len) {
    176         memcpy(buf, &mb->buf[mb->head], len);
    177         mb->head += len;
    178             if (mb->head == mb->bufsize)
    179                 mb->head = 0;
    180                 transferred += len;
    181         }
    182     }
    183 
    184     return transferred;
    185 }
    186 
    187 /*--------------- private memio functions -----------------------*/
    188 
    189 static PRStatus PR_CALLBACK memio_Close(PRFileDesc *fd)
    190 {
    191     struct PRFilePrivate *secret = fd->secret;
    192     memio_buffer_destroy(&secret->readbuf);
    193     memio_buffer_destroy(&secret->writebuf);
    194     free(secret);
    195     fd->dtor(fd);
    196     return PR_SUCCESS;
    197 }
    198 
    199 static PRStatus PR_CALLBACK memio_Shutdown(PRFileDesc *fd, PRIntn how)
    200 {
    201     /* TODO: pass shutdown status to app somehow */
    202     return PR_SUCCESS;
    203 }
    204 
    205 /* If there was a network error in the past taking bytes
    206  * out of the buffer, return it to the next call that
    207  * tries to read from an empty buffer.
    208  */
    209 static int PR_CALLBACK memio_Recv(PRFileDesc *fd, void *buf, PRInt32 len,
    210                                   PRIntn flags, PRIntervalTime timeout)
    211 {
    212     struct PRFilePrivate *secret;
    213     struct memio_buffer *mb;
    214     int rv;
    215 
    216     if (flags) {
    217         PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0);
    218         return -1;
    219     }
    220 
    221     secret = fd->secret;
    222     mb = &secret->readbuf;
    223     PR_ASSERT(mb->bufsize);
    224     rv = memio_buffer_get(mb, buf, len);
    225     if (rv == 0 && !secret->eof) {
    226         if (mb->last_err)
    227             PR_SetError(mb->last_err, 0);
    228         else
    229             PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
    230         return -1;
    231     }
    232 
    233     return rv;
    234 }
    235 
    236 static int PR_CALLBACK memio_Read(PRFileDesc *fd, void *buf, PRInt32 len)
    237 {
    238     /* pull bytes from buffer */
    239     return memio_Recv(fd, buf, len, 0, PR_INTERVAL_NO_TIMEOUT);
    240 }
    241 
    242 static int PR_CALLBACK memio_Send(PRFileDesc *fd, const void *buf, PRInt32 len,
    243                                   PRIntn flags, PRIntervalTime timeout)
    244 {
    245     struct PRFilePrivate *secret;
    246     struct memio_buffer *mb;
    247     int rv;
    248 
    249     secret = fd->secret;
    250     mb = &secret->writebuf;
    251     PR_ASSERT(mb->bufsize);
    252 
    253     if (mb->last_err) {
    254         PR_SetError(mb->last_err, 0);
    255         return -1;
    256     }
    257     rv = memio_buffer_put(mb, buf, len);
    258     if (rv == 0) {
    259         PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
    260         return -1;
    261     }
    262     return rv;
    263 }
    264 
    265 static int PR_CALLBACK memio_Write(PRFileDesc *fd, const void *buf, PRInt32 len)
    266 {
    267     /* append bytes to buffer */
    268     return memio_Send(fd, buf, len, 0, PR_INTERVAL_NO_TIMEOUT);
    269 }
    270 
    271 static PRStatus PR_CALLBACK memio_GetPeerName(PRFileDesc *fd, PRNetAddr *addr)
    272 {
    273     /* TODO: fail if memio_SetPeerName has not been called */
    274     struct PRFilePrivate *secret = fd->secret;
    275     *addr = secret->peername;
    276     return PR_SUCCESS;
    277 }
    278 
    279 static PRStatus memio_GetSocketOption(PRFileDesc *fd, PRSocketOptionData *data)
    280 {
    281     /*
    282      * Even in the original version for real tcp sockets,
    283      * PR_SockOpt_Nonblocking is a special case that does not
    284      * translate to a getsockopt() call
    285      */
    286     if (PR_SockOpt_Nonblocking == data->option) {
    287         data->value.non_blocking = PR_TRUE;
    288         return PR_SUCCESS;
    289     }
    290     PR_SetError(PR_OPERATION_NOT_SUPPORTED_ERROR, 0);
    291     return PR_FAILURE;
    292 }
    293 
    294 /*--------------- private memio data -----------------------*/
    295 
    296 /*
    297  * Implement just the bare minimum number of methods needed to make ssl happy.
    298  *
    299  * Oddly, PR_Recv calls ssl_Recv calls ssl_SocketIsBlocking calls
    300  * PR_GetSocketOption, so we have to provide an implementation of
    301  * PR_GetSocketOption that just says "I'm nonblocking".
    302  */
    303 
    304 static struct PRIOMethods  memio_layer_methods = {
    305     PR_DESC_LAYERED,
    306     memio_Close,
    307     memio_Read,
    308     memio_Write,
    309     NULL,
    310     NULL,
    311     NULL,
    312     NULL,
    313     NULL,
    314     NULL,
    315     NULL,
    316     NULL,
    317     NULL,
    318     NULL,
    319     NULL,
    320     NULL,
    321     memio_Shutdown,
    322     memio_Recv,
    323     memio_Send,
    324     NULL,
    325     NULL,
    326     NULL,
    327     NULL,
    328     NULL,
    329     NULL,
    330     memio_GetPeerName,
    331     NULL,
    332     NULL,
    333     memio_GetSocketOption,
    334     NULL,
    335     NULL,
    336     NULL,
    337     NULL,
    338     NULL,
    339     NULL,
    340     NULL,
    341 };
    342 
    343 static PRDescIdentity memio_identity = PR_INVALID_IO_LAYER;
    344 
    345 static PRStatus memio_InitializeLayerName(void)
    346 {
    347     memio_identity = PR_GetUniqueIdentity("memio");
    348     return PR_SUCCESS;
    349 }
    350 
    351 /*--------------- public memio functions -----------------------*/
    352 
    353 PRFileDesc *memio_CreateIOLayer(int bufsize)
    354 {
    355     PRFileDesc *fd;
    356     struct PRFilePrivate *secret;
    357     static PRCallOnceType once;
    358 
    359     PR_CallOnce(&once, memio_InitializeLayerName);
    360 
    361     fd = PR_CreateIOLayerStub(memio_identity, &memio_layer_methods);
    362     secret = malloc(sizeof(struct PRFilePrivate));
    363     memset(secret, 0, sizeof(*secret));
    364 
    365     memio_buffer_new(&secret->readbuf, bufsize);
    366     memio_buffer_new(&secret->writebuf, bufsize);
    367     fd->secret = secret;
    368     return fd;
    369 }
    370 
    371 void memio_SetPeerName(PRFileDesc *fd, const PRNetAddr *peername)
    372 {
    373     PRFileDesc *memiofd = PR_GetIdentitiesLayer(fd, memio_identity);
    374     struct PRFilePrivate *secret = memiofd->secret;
    375     secret->peername = *peername;
    376 }
    377 
    378 memio_Private *memio_GetSecret(PRFileDesc *fd)
    379 {
    380     PRFileDesc *memiofd = PR_GetIdentitiesLayer(fd, memio_identity);
    381     struct PRFilePrivate *secret =  memiofd->secret;
    382     return (memio_Private *)secret;
    383 }
    384 
    385 int memio_GetReadParams(memio_Private *secret, char **buf)
    386 {
    387     struct memio_buffer* mb = &((PRFilePrivate *)secret)->readbuf;
    388     PR_ASSERT(mb->bufsize);
    389 
    390     *buf = &mb->buf[mb->tail];
    391     return memio_buffer_unused_contiguous(mb);
    392 }
    393 
    394 void memio_PutReadResult(memio_Private *secret, int bytes_read)
    395 {
    396     struct memio_buffer* mb = &((PRFilePrivate *)secret)->readbuf;
    397     PR_ASSERT(mb->bufsize);
    398 
    399     if (bytes_read > 0) {
    400         mb->tail += bytes_read;
    401         if (mb->tail == mb->bufsize)
    402             mb->tail = 0;
    403     } else if (bytes_read == 0) {
    404         /* Record EOF condition and report to caller when buffer runs dry */
    405         ((PRFilePrivate *)secret)->eof = PR_TRUE;
    406     } else /* if (bytes_read < 0) */ {
    407         mb->last_err = bytes_read;
    408     }
    409 }
    410 
    411 void memio_GetWriteParams(memio_Private *secret,
    412                           const char **buf1, unsigned int *len1,
    413                           const char **buf2, unsigned int *len2)
    414 {
    415     struct memio_buffer* mb = &((PRFilePrivate *)secret)->writebuf;
    416     PR_ASSERT(mb->bufsize);
    417 
    418     *buf1 = &mb->buf[mb->head];
    419     *len1 = memio_buffer_used_contiguous(mb);
    420     *buf2 = mb->buf;
    421     *len2 = memio_buffer_wrapped_bytes(mb);
    422 }
    423 
    424 void memio_PutWriteResult(memio_Private *secret, int bytes_written)
    425 {
    426     struct memio_buffer* mb = &((PRFilePrivate *)secret)->writebuf;
    427     PR_ASSERT(mb->bufsize);
    428 
    429     if (bytes_written > 0) {
    430         mb->head += bytes_written;
    431         if (mb->head >= mb->bufsize)
    432             mb->head -= mb->bufsize;
    433     } else if (bytes_written < 0) {
    434         mb->last_err = bytes_written;
    435     }
    436 }
    437 
    438 /*--------------- private memio_buffer self-test -----------------*/
    439 
    440 /* Even a trivial unit test is very helpful when doing circular buffers. */
    441 /*#define TRIVIAL_SELF_TEST*/
    442 #ifdef TRIVIAL_SELF_TEST
    443 #include <stdio.h>
    444 
    445 #define TEST_BUFLEN 7
    446 
    447 #define CHECKEQ(a, b) { \
    448     if ((a) != (b)) { \
    449         printf("%d != %d, Test failed line %d\n", a, b, __LINE__); \
    450         exit(1); \
    451     } \
    452 }
    453 
    454 int main()
    455 {
    456     struct memio_buffer mb;
    457     char buf[100];
    458     int i;
    459 
    460     memio_buffer_new(&mb, TEST_BUFLEN);
    461 
    462     CHECKEQ(memio_buffer_unused_contiguous(&mb), TEST_BUFLEN-1);
    463     CHECKEQ(memio_buffer_used_contiguous(&mb), 0);
    464 
    465     CHECKEQ(memio_buffer_put(&mb, "howdy", 5), 5);
    466 
    467     CHECKEQ(memio_buffer_unused_contiguous(&mb), TEST_BUFLEN-1-5);
    468     CHECKEQ(memio_buffer_used_contiguous(&mb), 5);
    469     CHECKEQ(memio_buffer_wrapped_bytes(&mb), 0);
    470 
    471     CHECKEQ(memio_buffer_put(&mb, "!", 1), 1);
    472 
    473     CHECKEQ(memio_buffer_unused_contiguous(&mb), 0);
    474     CHECKEQ(memio_buffer_used_contiguous(&mb), 6);
    475     CHECKEQ(memio_buffer_wrapped_bytes(&mb), 0);
    476 
    477     CHECKEQ(memio_buffer_get(&mb, buf, 6), 6);
    478     CHECKEQ(memcmp(buf, "howdy!", 6), 0);
    479 
    480     CHECKEQ(memio_buffer_unused_contiguous(&mb), 1);
    481     CHECKEQ(memio_buffer_used_contiguous(&mb), 0);
    482 
    483     CHECKEQ(memio_buffer_put(&mb, "01234", 5), 5);
    484 
    485     CHECKEQ(memio_buffer_used_contiguous(&mb), 1);
    486     CHECKEQ(memio_buffer_wrapped_bytes(&mb), 4);
    487     CHECKEQ(memio_buffer_unused_contiguous(&mb), TEST_BUFLEN-1-5);
    488 
    489     CHECKEQ(memio_buffer_put(&mb, "5", 1), 1);
    490 
    491     CHECKEQ(memio_buffer_unused_contiguous(&mb), 0);
    492     CHECKEQ(memio_buffer_used_contiguous(&mb), 1);
    493 
    494     /* TODO: add more cases */
    495 
    496     printf("Test passed\n");
    497     exit(0);
    498 }
    499 
    500 #endif
    501