Home | History | Annotate | Download | only in markupsafe
      1 /**
      2  * markupsafe._speedups
      3  * ~~~~~~~~~~~~~~~~~~~~
      4  *
      5  * This module implements functions for automatic escaping in C for better
      6  * performance.
      7  *
      8  * :copyright: (c) 2010 by Armin Ronacher.
      9  * :license: BSD.
     10  */
     11 
     12 #include <Python.h>
     13 
     14 #define ESCAPED_CHARS_TABLE_SIZE 63
     15 #define UNICHR(x) (PyUnicode_AS_UNICODE((PyUnicodeObject*)PyUnicode_DecodeASCII(x, strlen(x), NULL)));
     16 
     17 #if PY_VERSION_HEX < 0x02050000 && !defined(PY_SSIZE_T_MIN)
     18 typedef int Py_ssize_t;
     19 #define PY_SSIZE_T_MAX INT_MAX
     20 #define PY_SSIZE_T_MIN INT_MIN
     21 #endif
     22 
     23 
     24 static PyObject* markup;
     25 static Py_ssize_t escaped_chars_delta_len[ESCAPED_CHARS_TABLE_SIZE];
     26 static Py_UNICODE *escaped_chars_repl[ESCAPED_CHARS_TABLE_SIZE];
     27 
     28 static int
     29 init_constants(void)
     30 {
     31 	PyObject *module;
     32 	/* happing of characters to replace */
     33 	escaped_chars_repl['"'] = UNICHR("&#34;");
     34 	escaped_chars_repl['\''] = UNICHR("&#39;");
     35 	escaped_chars_repl['&'] = UNICHR("&amp;");
     36 	escaped_chars_repl['<'] = UNICHR("&lt;");
     37 	escaped_chars_repl['>'] = UNICHR("&gt;");
     38 
     39 	/* lengths of those characters when replaced - 1 */
     40 	memset(escaped_chars_delta_len, 0, sizeof (escaped_chars_delta_len));
     41 	escaped_chars_delta_len['"'] = escaped_chars_delta_len['\''] = \
     42 		escaped_chars_delta_len['&'] = 4;
     43 	escaped_chars_delta_len['<'] = escaped_chars_delta_len['>'] = 3;
     44 
     45 	/* import markup type so that we can mark the return value */
     46 	module = PyImport_ImportModule("markupsafe");
     47 	if (!module)
     48 		return 0;
     49 	markup = PyObject_GetAttrString(module, "Markup");
     50 	Py_DECREF(module);
     51 
     52 	return 1;
     53 }
     54 
     55 static PyObject*
     56 escape_unicode(PyUnicodeObject *in)
     57 {
     58 	PyUnicodeObject *out;
     59 	Py_UNICODE *inp = PyUnicode_AS_UNICODE(in);
     60 	const Py_UNICODE *inp_end = PyUnicode_AS_UNICODE(in) + PyUnicode_GET_SIZE(in);
     61 	Py_UNICODE *next_escp;
     62 	Py_UNICODE *outp;
     63 	Py_ssize_t delta=0, erepl=0, delta_len=0;
     64 
     65 	/* First we need to figure out how long the escaped string will be */
     66 	while (*(inp) || inp < inp_end) {
     67 		if (*inp < ESCAPED_CHARS_TABLE_SIZE) {
     68 			delta += escaped_chars_delta_len[*inp];
     69 			erepl += !!escaped_chars_delta_len[*inp];
     70 		}
     71 		++inp;
     72 	}
     73 
     74 	/* Do we need to escape anything at all? */
     75 	if (!erepl) {
     76 		Py_INCREF(in);
     77 		return (PyObject*)in;
     78 	}
     79 
     80 	out = (PyUnicodeObject*)PyUnicode_FromUnicode(NULL, PyUnicode_GET_SIZE(in) + delta);
     81 	if (!out)
     82 		return NULL;
     83 
     84 	outp = PyUnicode_AS_UNICODE(out);
     85 	inp = PyUnicode_AS_UNICODE(in);
     86 	while (erepl-- > 0) {
     87 		/* look for the next substitution */
     88 		next_escp = inp;
     89 		while (next_escp < inp_end) {
     90 			if (*next_escp < ESCAPED_CHARS_TABLE_SIZE &&
     91 			    (delta_len = escaped_chars_delta_len[*next_escp])) {
     92 				++delta_len;
     93 				break;
     94 			}
     95 			++next_escp;
     96 		}
     97 
     98 		if (next_escp > inp) {
     99 			/* copy unescaped chars between inp and next_escp */
    100 			Py_UNICODE_COPY(outp, inp, next_escp-inp);
    101 			outp += next_escp - inp;
    102 		}
    103 
    104 		/* escape 'next_escp' */
    105 		Py_UNICODE_COPY(outp, escaped_chars_repl[*next_escp], delta_len);
    106 		outp += delta_len;
    107 
    108 		inp = next_escp + 1;
    109 	}
    110 	if (inp < inp_end)
    111 		Py_UNICODE_COPY(outp, inp, PyUnicode_GET_SIZE(in) - (inp - PyUnicode_AS_UNICODE(in)));
    112 
    113 	return (PyObject*)out;
    114 }
    115 
    116 
    117 static PyObject*
    118 escape(PyObject *self, PyObject *text)
    119 {
    120 	PyObject *s = NULL, *rv = NULL, *html;
    121 
    122 	/* we don't have to escape integers, bools or floats */
    123 	if (PyLong_CheckExact(text) ||
    124 #if PY_MAJOR_VERSION < 3
    125 	    PyInt_CheckExact(text) ||
    126 #endif
    127 	    PyFloat_CheckExact(text) || PyBool_Check(text) ||
    128 	    text == Py_None)
    129 		return PyObject_CallFunctionObjArgs(markup, text, NULL);
    130 
    131 	/* if the object has an __html__ method that performs the escaping */
    132 	html = PyObject_GetAttrString(text, "__html__");
    133 	if (html) {
    134 		rv = PyObject_CallObject(html, NULL);
    135 		Py_DECREF(html);
    136 		return rv;
    137 	}
    138 
    139 	/* otherwise make the object unicode if it isn't, then escape */
    140 	PyErr_Clear();
    141 	if (!PyUnicode_Check(text)) {
    142 #if PY_MAJOR_VERSION < 3
    143 		PyObject *unicode = PyObject_Unicode(text);
    144 #else
    145 		PyObject *unicode = PyObject_Str(text);
    146 #endif
    147 		if (!unicode)
    148 			return NULL;
    149 		s = escape_unicode((PyUnicodeObject*)unicode);
    150 		Py_DECREF(unicode);
    151 	}
    152 	else
    153 		s = escape_unicode((PyUnicodeObject*)text);
    154 
    155 	/* convert the unicode string into a markup object. */
    156 	rv = PyObject_CallFunctionObjArgs(markup, (PyObject*)s, NULL);
    157 	Py_DECREF(s);
    158 	return rv;
    159 }
    160 
    161 
    162 static PyObject*
    163 escape_silent(PyObject *self, PyObject *text)
    164 {
    165 	if (text != Py_None)
    166 		return escape(self, text);
    167 	return PyObject_CallFunctionObjArgs(markup, NULL);
    168 }
    169 
    170 
    171 static PyObject*
    172 soft_unicode(PyObject *self, PyObject *s)
    173 {
    174 	if (!PyUnicode_Check(s))
    175 #if PY_MAJOR_VERSION < 3
    176 		return PyObject_Unicode(s);
    177 #else
    178 		return PyObject_Str(s);
    179 #endif
    180 	Py_INCREF(s);
    181 	return s;
    182 }
    183 
    184 
    185 static PyMethodDef module_methods[] = {
    186 	{"escape", (PyCFunction)escape, METH_O,
    187 	 "escape(s) -> markup\n\n"
    188 	 "Convert the characters &, <, >, ', and \" in string s to HTML-safe\n"
    189 	 "sequences.  Use this if you need to display text that might contain\n"
    190 	 "such characters in HTML.  Marks return value as markup string."},
    191 	{"escape_silent", (PyCFunction)escape_silent, METH_O,
    192 	 "escape_silent(s) -> markup\n\n"
    193 	 "Like escape but converts None to an empty string."},
    194 	{"soft_unicode", (PyCFunction)soft_unicode, METH_O,
    195 	 "soft_unicode(object) -> string\n\n"
    196          "Make a string unicode if it isn't already.  That way a markup\n"
    197          "string is not converted back to unicode."},
    198 	{NULL, NULL, 0, NULL}		/* Sentinel */
    199 };
    200 
    201 
    202 #if PY_MAJOR_VERSION < 3
    203 
    204 #ifndef PyMODINIT_FUNC	/* declarations for DLL import/export */
    205 #define PyMODINIT_FUNC void
    206 #endif
    207 PyMODINIT_FUNC
    208 init_speedups(void)
    209 {
    210 	if (!init_constants())
    211 		return;
    212 
    213 	Py_InitModule3("markupsafe._speedups", module_methods, "");
    214 }
    215 
    216 #else /* Python 3.x module initialization */
    217 
    218 static struct PyModuleDef module_definition = {
    219         PyModuleDef_HEAD_INIT,
    220 	"markupsafe._speedups",
    221 	NULL,
    222 	-1,
    223 	module_methods,
    224 	NULL,
    225 	NULL,
    226 	NULL,
    227 	NULL
    228 };
    229 
    230 PyMODINIT_FUNC
    231 PyInit__speedups(void)
    232 {
    233 	if (!init_constants())
    234 		return NULL;
    235 
    236 	return PyModule_Create(&module_definition);
    237 }
    238 
    239 #endif
    240