Home | History | Annotate | Download | only in base
      1 // Copyright (c) 2008 The Chromium Authors. All rights reserved.
      2 // Use of this source code is governed by a BSD-style license that can be
      3 // found in the LICENSE file.
      4 // Written in NSPR style to also be suitable for adding to the NSS demo suite
      5 
      6 /* memio is a simple NSPR I/O layer that lets you decouple NSS from
      7  * the real network.  It's rather like openssl's memory bio,
      8  * and is useful when your app absolutely, positively doesn't
      9  * want to let NSS do its own networking.
     10  */
     11 
     12 #include <stdlib.h>
     13 #include <string.h>
     14 
     15 #include <prerror.h>
     16 #include <prinit.h>
     17 #include <prlog.h>
     18 
     19 #include "nss_memio.h"
     20 
     21 /*--------------- private memio types -----------------------*/
     22 
     23 /*----------------------------------------------------------------------
     24  Simple private circular buffer class.  Size cannot be changed once allocated.
     25 ----------------------------------------------------------------------*/
     26 
     27 struct memio_buffer {
     28     int head;     /* where to take next byte out of buf */
     29     int tail;     /* where to put next byte into buf */
     30     int bufsize;  /* number of bytes allocated to buf */
     31     /* TODO(port): error handling is pessimistic right now.
     32      * Once an error is set, the socket is considered broken
     33      * (PR_WOULD_BLOCK_ERROR not included).
     34      */
     35     PRErrorCode last_err;
     36     char *buf;
     37 };
     38 
     39 
     40 /* The 'secret' field of a PRFileDesc created by memio_CreateIOLayer points
     41  * to one of these.
     42  * In the public header, we use struct memio_Private as a typesafe alias
     43  * for this.  This causes a few ugly typecasts in the private file, but
     44  * seems safer.
     45  */
     46 struct PRFilePrivate {
     47     /* read requests are satisfied from this buffer */
     48     struct memio_buffer readbuf;
     49 
     50     /* write requests are satisfied from this buffer */
     51     struct memio_buffer writebuf;
     52 
     53     /* SSL needs to know socket peer's name */
     54     PRNetAddr peername;
     55 
     56     /* if set, empty I/O returns EOF instead of EWOULDBLOCK */
     57     int eof;
     58 };
     59 
     60 /*--------------- private memio_buffer functions ---------------------*/
     61 
     62 /* Forward declarations.  */
     63 
     64 /* Allocate a memio_buffer of given size. */
     65 static void memio_buffer_new(struct memio_buffer *mb, int size);
     66 
     67 /* Deallocate a memio_buffer allocated by memio_buffer_new. */
     68 static void memio_buffer_destroy(struct memio_buffer *mb);
     69 
     70 /* How many bytes can be read out of the buffer without wrapping */
     71 static int memio_buffer_used_contiguous(const struct memio_buffer *mb);
     72 
     73 /* How many bytes can be written into the buffer without wrapping */
     74 static int memio_buffer_unused_contiguous(const struct memio_buffer *mb);
     75 
     76 /* Write n bytes into the buffer.  Returns number of bytes written. */
     77 static int memio_buffer_put(struct memio_buffer *mb, const char *buf, int n);
     78 
     79 /* Read n bytes from the buffer.  Returns number of bytes read. */
     80 static int memio_buffer_get(struct memio_buffer *mb, char *buf, int n);
     81 
     82 /* Allocate a memio_buffer of given size. */
     83 static void memio_buffer_new(struct memio_buffer *mb, int size)
     84 {
     85     mb->head = 0;
     86     mb->tail = 0;
     87     mb->bufsize = size;
     88     mb->buf = malloc(size);
     89 }
     90 
     91 /* Deallocate a memio_buffer allocated by memio_buffer_new. */
     92 static void memio_buffer_destroy(struct memio_buffer *mb)
     93 {
     94     free(mb->buf);
     95     mb->buf = NULL;
     96     mb->head = 0;
     97     mb->tail = 0;
     98 }
     99 
    100 /* How many bytes can be read out of the buffer without wrapping */
    101 static int memio_buffer_used_contiguous(const struct memio_buffer *mb)
    102 {
    103     return (((mb->tail >= mb->head) ? mb->tail : mb->bufsize) - mb->head);
    104 }
    105 
    106 /* How many bytes can be written into the buffer without wrapping */
    107 static int memio_buffer_unused_contiguous(const struct memio_buffer *mb)
    108 {
    109     if (mb->head > mb->tail) return mb->head - mb->tail - 1;
    110     return mb->bufsize - mb->tail - (mb->head == 0);
    111 }
    112 
    113 /* Write n bytes into the buffer.  Returns number of bytes written. */
    114 static int memio_buffer_put(struct memio_buffer *mb, const char *buf, int n)
    115 {
    116     int len;
    117     int transferred = 0;
    118 
    119     /* Handle part before wrap */
    120     len = PR_MIN(n, memio_buffer_unused_contiguous(mb));
    121     if (len > 0) {
    122         /* Buffer not full */
    123         memcpy(&mb->buf[mb->tail], buf, len);
    124         mb->tail += len;
    125         if (mb->tail == mb->bufsize)
    126             mb->tail = 0;
    127         n -= len;
    128         buf += len;
    129         transferred += len;
    130 
    131         /* Handle part after wrap */
    132         len = PR_MIN(n, memio_buffer_unused_contiguous(mb));
    133         if (len > 0) {
    134             /* Output buffer still not full, input buffer still not empty */
    135             memcpy(&mb->buf[mb->tail], buf, len);
    136             mb->tail += len;
    137             if (mb->tail == mb->bufsize)
    138                 mb->tail = 0;
    139                 transferred += len;
    140         }
    141     }
    142 
    143     return transferred;
    144 }
    145 
    146 
    147 /* Read n bytes from the buffer.  Returns number of bytes read. */
    148 static int memio_buffer_get(struct memio_buffer *mb, char *buf, int n)
    149 {
    150     int len;
    151     int transferred = 0;
    152 
    153     /* Handle part before wrap */
    154     len = PR_MIN(n, memio_buffer_used_contiguous(mb));
    155     if (len) {
    156         memcpy(buf, &mb->buf[mb->head], len);
    157         mb->head += len;
    158         if (mb->head == mb->bufsize)
    159             mb->head = 0;
    160         n -= len;
    161         buf += len;
    162         transferred += len;
    163 
    164         /* Handle part after wrap */
    165         len = PR_MIN(n, memio_buffer_used_contiguous(mb));
    166         if (len) {
    167         memcpy(buf, &mb->buf[mb->head], len);
    168         mb->head += len;
    169             if (mb->head == mb->bufsize)
    170                 mb->head = 0;
    171                 transferred += len;
    172         }
    173     }
    174 
    175     return transferred;
    176 }
    177 
    178 /*--------------- private memio functions -----------------------*/
    179 
    180 static PRStatus PR_CALLBACK memio_Close(PRFileDesc *fd)
    181 {
    182     struct PRFilePrivate *secret = fd->secret;
    183     memio_buffer_destroy(&secret->readbuf);
    184     memio_buffer_destroy(&secret->writebuf);
    185     free(secret);
    186     fd->dtor(fd);
    187     return PR_SUCCESS;
    188 }
    189 
    190 static PRStatus PR_CALLBACK memio_Shutdown(PRFileDesc *fd, PRIntn how)
    191 {
    192     /* TODO: pass shutdown status to app somehow */
    193     return PR_SUCCESS;
    194 }
    195 
    196 /* If there was a network error in the past taking bytes
    197  * out of the buffer, return it to the next call that
    198  * tries to read from an empty buffer.
    199  */
    200 static int PR_CALLBACK memio_Recv(PRFileDesc *fd, void *buf, PRInt32 len,
    201                                   PRIntn flags, PRIntervalTime timeout)
    202 {
    203     struct PRFilePrivate *secret;
    204     struct memio_buffer *mb;
    205     int rv;
    206 
    207     if (flags) {
    208         PR_SetError(PR_NOT_IMPLEMENTED_ERROR, 0);
    209         return -1;
    210     }
    211 
    212     secret = fd->secret;
    213     mb = &secret->readbuf;
    214     PR_ASSERT(mb->bufsize);
    215     rv = memio_buffer_get(mb, buf, len);
    216     if (rv == 0 && !secret->eof) {
    217         if (mb->last_err)
    218             PR_SetError(mb->last_err, 0);
    219         else
    220             PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
    221         return -1;
    222     }
    223 
    224     return rv;
    225 }
    226 
    227 static int PR_CALLBACK memio_Read(PRFileDesc *fd, void *buf, PRInt32 len)
    228 {
    229     /* pull bytes from buffer */
    230     return memio_Recv(fd, buf, len, 0, PR_INTERVAL_NO_TIMEOUT);
    231 }
    232 
    233 static int PR_CALLBACK memio_Send(PRFileDesc *fd, const void *buf, PRInt32 len,
    234                                   PRIntn flags, PRIntervalTime timeout)
    235 {
    236     struct PRFilePrivate *secret;
    237     struct memio_buffer *mb;
    238     int rv;
    239 
    240     secret = fd->secret;
    241     mb = &secret->writebuf;
    242     PR_ASSERT(mb->bufsize);
    243 
    244     if (mb->last_err) {
    245         PR_SetError(mb->last_err, 0);
    246         return -1;
    247     }
    248     rv = memio_buffer_put(mb, buf, len);
    249     if (rv == 0) {
    250         PR_SetError(PR_WOULD_BLOCK_ERROR, 0);
    251         return -1;
    252     }
    253     return rv;
    254 }
    255 
    256 static int PR_CALLBACK memio_Write(PRFileDesc *fd, const void *buf, PRInt32 len)
    257 {
    258     /* append bytes to buffer */
    259     return memio_Send(fd, buf, len, 0, PR_INTERVAL_NO_TIMEOUT);
    260 }
    261 
    262 static PRStatus PR_CALLBACK memio_GetPeerName(PRFileDesc *fd, PRNetAddr *addr)
    263 {
    264     /* TODO: fail if memio_SetPeerName has not been called */
    265     struct PRFilePrivate *secret = fd->secret;
    266     *addr = secret->peername;
    267     return PR_SUCCESS;
    268 }
    269 
    270 static PRStatus memio_GetSocketOption(PRFileDesc *fd, PRSocketOptionData *data)
    271 {
    272     /*
    273      * Even in the original version for real tcp sockets,
    274      * PR_SockOpt_Nonblocking is a special case that does not
    275      * translate to a getsockopt() call
    276      */
    277     if (PR_SockOpt_Nonblocking == data->option) {
    278         data->value.non_blocking = PR_TRUE;
    279         return PR_SUCCESS;
    280     }
    281     PR_SetError(PR_OPERATION_NOT_SUPPORTED_ERROR, 0);
    282     return PR_FAILURE;
    283 }
    284 
    285 /*--------------- private memio data -----------------------*/
    286 
    287 /*
    288  * Implement just the bare minimum number of methods needed to make ssl happy.
    289  *
    290  * Oddly, PR_Recv calls ssl_Recv calls ssl_SocketIsBlocking calls
    291  * PR_GetSocketOption, so we have to provide an implementation of
    292  * PR_GetSocketOption that just says "I'm nonblocking".
    293  */
    294 
    295 static struct PRIOMethods  memio_layer_methods = {
    296     PR_DESC_LAYERED,
    297     memio_Close,
    298     memio_Read,
    299     memio_Write,
    300     NULL,
    301     NULL,
    302     NULL,
    303     NULL,
    304     NULL,
    305     NULL,
    306     NULL,
    307     NULL,
    308     NULL,
    309     NULL,
    310     NULL,
    311     NULL,
    312     memio_Shutdown,
    313     memio_Recv,
    314     memio_Send,
    315     NULL,
    316     NULL,
    317     NULL,
    318     NULL,
    319     NULL,
    320     NULL,
    321     memio_GetPeerName,
    322     NULL,
    323     NULL,
    324     memio_GetSocketOption,
    325     NULL,
    326     NULL,
    327     NULL,
    328     NULL,
    329     NULL,
    330     NULL,
    331     NULL,
    332 };
    333 
    334 static PRDescIdentity memio_identity = PR_INVALID_IO_LAYER;
    335 
    336 static PRStatus memio_InitializeLayerName(void)
    337 {
    338     memio_identity = PR_GetUniqueIdentity("memio");
    339     return PR_SUCCESS;
    340 }
    341 
    342 /*--------------- public memio functions -----------------------*/
    343 
    344 PRFileDesc *memio_CreateIOLayer(int bufsize)
    345 {
    346     PRFileDesc *fd;
    347     struct PRFilePrivate *secret;
    348     static PRCallOnceType once;
    349 
    350     PR_CallOnce(&once, memio_InitializeLayerName);
    351 
    352     fd = PR_CreateIOLayerStub(memio_identity, &memio_layer_methods);
    353     secret = malloc(sizeof(struct PRFilePrivate));
    354     memset(secret, 0, sizeof(*secret));
    355 
    356     memio_buffer_new(&secret->readbuf, bufsize);
    357     memio_buffer_new(&secret->writebuf, bufsize);
    358     fd->secret = secret;
    359     return fd;
    360 }
    361 
    362 void memio_SetPeerName(PRFileDesc *fd, const PRNetAddr *peername)
    363 {
    364     PRFileDesc *memiofd = PR_GetIdentitiesLayer(fd, memio_identity);
    365     struct PRFilePrivate *secret =  memiofd->secret;
    366     secret->peername = *peername;
    367 }
    368 
    369 memio_Private *memio_GetSecret(PRFileDesc *fd)
    370 {
    371     PRFileDesc *memiofd = PR_GetIdentitiesLayer(fd, memio_identity);
    372     struct PRFilePrivate *secret =  memiofd->secret;
    373     return (memio_Private *)secret;
    374 }
    375 
    376 int memio_GetReadParams(memio_Private *secret, char **buf)
    377 {
    378     struct memio_buffer* mb = &((PRFilePrivate *)secret)->readbuf;
    379     PR_ASSERT(mb->bufsize);
    380 
    381     *buf = &mb->buf[mb->tail];
    382     return memio_buffer_unused_contiguous(mb);
    383 }
    384 
    385 void memio_PutReadResult(memio_Private *secret, int bytes_read)
    386 {
    387     struct memio_buffer* mb = &((PRFilePrivate *)secret)->readbuf;
    388     PR_ASSERT(mb->bufsize);
    389 
    390     if (bytes_read > 0) {
    391         mb->tail += bytes_read;
    392         if (mb->tail == mb->bufsize)
    393             mb->tail = 0;
    394     } else if (bytes_read == 0) {
    395         /* Record EOF condition and report to caller when buffer runs dry */
    396         ((PRFilePrivate *)secret)->eof = PR_TRUE;
    397     } else /* if (bytes_read < 0) */ {
    398         mb->last_err = bytes_read;
    399     }
    400 }
    401 
    402 int memio_GetWriteParams(memio_Private *secret, const char **buf)
    403 {
    404     struct memio_buffer* mb = &((PRFilePrivate *)secret)->writebuf;
    405     PR_ASSERT(mb->bufsize);
    406 
    407     *buf = &mb->buf[mb->head];
    408     return memio_buffer_used_contiguous(mb);
    409 }
    410 
    411 void memio_PutWriteResult(memio_Private *secret, int bytes_written)
    412 {
    413     struct memio_buffer* mb = &((PRFilePrivate *)secret)->writebuf;
    414     PR_ASSERT(mb->bufsize);
    415 
    416     if (bytes_written > 0) {
    417         mb->head += bytes_written;
    418         if (mb->head == mb->bufsize)
    419             mb->head = 0;
    420     } else if (bytes_written < 0) {
    421         mb->last_err = bytes_written;
    422     }
    423 }
    424 
    425 /*--------------- private memio_buffer self-test -----------------*/
    426 
    427 /* Even a trivial unit test is very helpful when doing circular buffers. */
    428 /*#define TRIVIAL_SELF_TEST*/
    429 #ifdef TRIVIAL_SELF_TEST
    430 #include <stdio.h>
    431 
    432 #define TEST_BUFLEN 7
    433 
    434 #define CHECKEQ(a, b) { \
    435     if ((a) != (b)) { \
    436         printf("%d != %d, Test failed line %d\n", a, b, __LINE__); \
    437         exit(1); \
    438     } \
    439 }
    440 
    441 int main()
    442 {
    443     struct memio_buffer mb;
    444     char buf[100];
    445     int i;
    446 
    447     memio_buffer_new(&mb, TEST_BUFLEN);
    448 
    449     CHECKEQ(memio_buffer_unused_contiguous(&mb), TEST_BUFLEN-1);
    450     CHECKEQ(memio_buffer_used_contiguous(&mb), 0);
    451 
    452     CHECKEQ(memio_buffer_put(&mb, "howdy", 5), 5);
    453 
    454     CHECKEQ(memio_buffer_unused_contiguous(&mb), TEST_BUFLEN-1-5);
    455     CHECKEQ(memio_buffer_used_contiguous(&mb), 5);
    456 
    457     CHECKEQ(memio_buffer_put(&mb, "!", 1), 1);
    458 
    459     CHECKEQ(memio_buffer_unused_contiguous(&mb), 0);
    460     CHECKEQ(memio_buffer_used_contiguous(&mb), 6);
    461 
    462     CHECKEQ(memio_buffer_get(&mb, buf, 6), 6);
    463     CHECKEQ(memcmp(buf, "howdy!", 6), 0);
    464 
    465     CHECKEQ(memio_buffer_unused_contiguous(&mb), 1);
    466     CHECKEQ(memio_buffer_used_contiguous(&mb), 0);
    467 
    468     CHECKEQ(memio_buffer_put(&mb, "01234", 5), 5);
    469 
    470     CHECKEQ(memio_buffer_used_contiguous(&mb), 1);
    471     CHECKEQ(memio_buffer_unused_contiguous(&mb), TEST_BUFLEN-1-5);
    472 
    473     CHECKEQ(memio_buffer_put(&mb, "5", 1), 1);
    474 
    475     CHECKEQ(memio_buffer_unused_contiguous(&mb), 0);
    476     CHECKEQ(memio_buffer_used_contiguous(&mb), 1);
    477 
    478     /* TODO: add more cases */
    479 
    480     printf("Test passed\n");
    481     exit(0);
    482 }
    483 
    484 #endif
    485