Home | History | Annotate | Download | only in shared
      1 import cgi, datetime, re, time, urllib
      2 from django import http
      3 import django.core.exceptions
      4 from django.core import urlresolvers
      5 from django.utils import datastructures
      6 import json
      7 from autotest_lib.frontend.shared import exceptions, query_lib
      8 from autotest_lib.frontend.afe import model_logic
      9 
     10 
     11 _JSON_CONTENT_TYPE = 'application/json'
     12 
     13 
     14 def _resolve_class_path(class_path):
     15     module_path, class_name = class_path.rsplit('.', 1)
     16     module = __import__(module_path, {}, {}, [''])
     17     return getattr(module, class_name)
     18 
     19 
     20 _NO_VALUE_SPECIFIED = object()
     21 
     22 class _InputDict(dict):
     23     def get(self, key, default=_NO_VALUE_SPECIFIED):
     24         return super(_InputDict, self).get(key, default)
     25 
     26 
     27     @classmethod
     28     def remove_unspecified_fields(cls, field_dict):
     29         return dict((key, value) for key, value in field_dict.iteritems()
     30                     if value is not _NO_VALUE_SPECIFIED)
     31 
     32 
     33 class Resource(object):
     34     _permitted_methods = None # subclasses must override this
     35 
     36 
     37     def __init__(self, request):
     38         assert self._permitted_methods
     39         # this request should be used for global environment info, like
     40         # constructing absolute URIs.  it should not be used for query
     41         # parameters, because the request may not have been for this particular
     42         # resource.
     43         self._request = request
     44         # this dict will contain the applicable query parameters
     45         self._query_params = datastructures.MultiValueDict()
     46 
     47 
     48     @classmethod
     49     def dispatch_request(cls, request, *args, **kwargs):
     50         # handle a request directly
     51         try:
     52             try:
     53                 instance = cls.from_uri_args(request, **kwargs)
     54             except django.core.exceptions.ObjectDoesNotExist, exc:
     55                 raise http.Http404(exc)
     56 
     57             instance.read_query_parameters(request.GET)
     58             return instance.handle_request()
     59         except exceptions.RequestError, exc:
     60             return exc.response
     61 
     62 
     63     def handle_request(self):
     64         if self._request.method.upper() not in self._permitted_methods:
     65             return http.HttpResponseNotAllowed(self._permitted_methods)
     66 
     67         handler = getattr(self, self._request.method.lower())
     68         return handler()
     69 
     70 
     71     # the handler methods below only need to be overridden if the resource
     72     # supports the method
     73 
     74     def get(self):
     75         """Handle a GET request.
     76 
     77         @returns an HttpResponse
     78         """
     79         raise NotImplementedError
     80 
     81 
     82     def post(self):
     83         """Handle a POST request.
     84 
     85         @returns an HttpResponse
     86         """
     87         raise NotImplementedError
     88 
     89 
     90     def put(self):
     91         """Handle a PUT request.
     92 
     93         @returns an HttpResponse
     94         """
     95         raise NotImplementedError
     96 
     97 
     98     def delete(self):
     99         """Handle a DELETE request.
    100 
    101         @returns an HttpResponse
    102         """
    103         raise NotImplementedError
    104 
    105 
    106     @classmethod
    107     def from_uri_args(cls, request, **kwargs):
    108         """Construct an instance from URI args.
    109 
    110         Default implementation for resources with no URI args.
    111         """
    112         return cls(request)
    113 
    114 
    115     def _uri_args(self):
    116         """Return kwargs for a URI reference to this resource.
    117 
    118         Default implementation for resources with no URI args.
    119         """
    120         return {}
    121 
    122 
    123     def _query_parameters_accepted(self):
    124         """Return sequence of tuples (name, description) for query parameters.
    125 
    126         Documents the available query parameters for GETting this resource.
    127         Default implementation for resources with no parameters.
    128         """
    129         return ()
    130 
    131 
    132     def read_query_parameters(self, parameters):
    133         """Read relevant query parameters from a Django MultiValueDict."""
    134         params_acccepted = set(param_name for param_name, _
    135                                in self._query_parameters_accepted())
    136         for name, values in parameters.iterlists():
    137             base_name = name.split(':', 1)[0]
    138             if base_name in params_acccepted:
    139                 self._query_params.setlist(name, values)
    140 
    141 
    142     def set_query_parameters(self, **parameters):
    143         """Set query parameters programmatically."""
    144         self._query_params.update(parameters)
    145 
    146 
    147     def href(self, query_params=None):
    148         """Return URI to this resource."""
    149         kwargs = self._uri_args()
    150         path = urlresolvers.reverse(self.dispatch_request, kwargs=kwargs)
    151         full_query_params = datastructures.MultiValueDict(self._query_params)
    152         if query_params:
    153             full_query_params.update(query_params)
    154         if full_query_params:
    155             path += '?' + urllib.urlencode(full_query_params.lists(),
    156                                            doseq=True)
    157         return self._request.build_absolute_uri(path)
    158 
    159 
    160     def resolve_uri(self, uri):
    161         # check for absolute URIs
    162         match = re.match(r'(?P<root>https?://[^/]+)(?P<path>/.*)', uri)
    163         if match:
    164             # is this URI for a different host?
    165             my_root = self._request.build_absolute_uri('/')
    166             request_root = match.group('root') + '/'
    167             if my_root != request_root:
    168                 # might support this in the future, but not now
    169                 raise exceptions.BadRequest('Unable to resolve remote URI %s'
    170                                             % uri)
    171             uri = match.group('path')
    172 
    173         try:
    174             view_method, args, kwargs = urlresolvers.resolve(uri)
    175         except http.Http404:
    176             raise exceptions.BadRequest('Unable to resolve URI %s' % uri)
    177         resource_class = view_method.im_self # class owning this classmethod
    178         return resource_class.from_uri_args(self._request, **kwargs)
    179 
    180 
    181     def resolve_link(self, link):
    182         if isinstance(link, dict):
    183             uri = link['href']
    184         elif isinstance(link, basestring):
    185             uri = link
    186         else:
    187             raise exceptions.BadRequest('Unable to understand link %s' % link)
    188         return self.resolve_uri(uri)
    189 
    190 
    191     def link(self, query_params=None):
    192         return {'href': self.href(query_params=query_params)}
    193 
    194 
    195     def _query_parameters_response(self):
    196         return dict((name, description)
    197                     for name, description in self._query_parameters_accepted())
    198 
    199 
    200     def _basic_response(self, content):
    201         """Construct and return a simple 200 response."""
    202         assert isinstance(content, dict)
    203         query_parameters = self._query_parameters_response()
    204         if query_parameters:
    205             content['query_parameters'] = query_parameters
    206         encoded_content = json.dumps(content)
    207         return http.HttpResponse(encoded_content,
    208                                  content_type=_JSON_CONTENT_TYPE)
    209 
    210 
    211     def _decoded_input(self):
    212         content_type = self._request.META.get('CONTENT_TYPE',
    213                                               _JSON_CONTENT_TYPE)
    214         raw_data = self._request.raw_post_data
    215         if content_type == _JSON_CONTENT_TYPE:
    216             try:
    217                 raw_dict = json.loads(raw_data)
    218             except ValueError, exc:
    219                 raise exceptions.BadRequest('Error decoding request body: '
    220                                             '%s\n%r' % (exc, raw_data))
    221             if not isinstance(raw_dict, dict):
    222                 raise exceptions.BadRequest('Expected dict input, got %s: %r' %
    223                                             (type(raw_dict), raw_dict))
    224         elif content_type == 'application/x-www-form-urlencoded':
    225             cgi_dict = cgi.parse_qs(raw_data) # django won't do this for PUT
    226             raw_dict = {}
    227             for key, values in cgi_dict.items():
    228                 value = values[-1] # take last value if multiple were given
    229                 try:
    230                     # attempt to parse numbers, booleans and nulls
    231                     raw_dict[key] = json.loads(value)
    232                 except ValueError:
    233                     # otherwise, leave it as a string
    234                     raw_dict[key] = value
    235         else:
    236             raise exceptions.RequestError(415, 'Unsupported media type: %s'
    237                                           % content_type)
    238 
    239         return _InputDict(raw_dict)
    240 
    241 
    242     def _format_datetime(self, date_time):
    243         """Return ISO 8601 string for the given datetime"""
    244         if date_time is None:
    245             return None
    246         timezone_hrs = time.timezone / 60 / 60  # convert seconds to hours
    247         if timezone_hrs >= 0:
    248             timezone_join = '+'
    249         else:
    250             timezone_join = '' # minus sign comes from number itself
    251         timezone_spec = '%s%s:00' % (timezone_join, timezone_hrs)
    252         return date_time.strftime('%Y-%m-%dT%H:%M:%S') + timezone_spec
    253 
    254 
    255     @classmethod
    256     def _check_for_required_fields(cls, input_dict, fields):
    257         assert isinstance(fields, (list, tuple)), fields
    258         missing_fields = ', '.join(field for field in fields
    259                                    if field not in input_dict)
    260         if missing_fields:
    261             raise exceptions.BadRequest('Missing input: ' + missing_fields)
    262 
    263 
    264 class Entry(Resource):
    265     @classmethod
    266     def add_query_selectors(cls, query_processor):
    267         """Sbuclasses may override this to support querying."""
    268         pass
    269 
    270 
    271     def short_representation(self):
    272         return self.link()
    273 
    274 
    275     def full_representation(self):
    276         return self.short_representation()
    277 
    278 
    279     def get(self):
    280         return self._basic_response(self.full_representation())
    281 
    282 
    283     def put(self):
    284         try:
    285             self.update(self._decoded_input())
    286         except model_logic.ValidationError, exc:
    287             raise exceptions.BadRequest('Invalid input: %s' % exc)
    288         return self._basic_response(self.full_representation())
    289 
    290 
    291     def _delete_entry(self):
    292         raise NotImplementedError
    293 
    294 
    295     def delete(self):
    296         self._delete_entry()
    297         return http.HttpResponse(status=204) # No content
    298 
    299 
    300     def create_instance(self, input_dict, containing_collection):
    301         raise NotImplementedError
    302 
    303 
    304     def update(self, input_dict):
    305         raise NotImplementedError
    306 
    307 
    308 class InstanceEntry(Entry):
    309     class NullEntry(object):
    310         def link(self):
    311             return None
    312 
    313 
    314         def short_representation(self):
    315             return None
    316 
    317 
    318     _null_entry = NullEntry()
    319     _permitted_methods = ('GET', 'PUT', 'DELETE')
    320     model = None # subclasses must override this with a Django model class
    321 
    322 
    323     def __init__(self, request, instance):
    324         assert self.model is not None
    325         super(InstanceEntry, self).__init__(request)
    326         self.instance = instance
    327         self._is_prepared_for_full_representation = False
    328 
    329 
    330     @classmethod
    331     def from_optional_instance(cls, request, instance):
    332         if instance is None:
    333             return cls._null_entry
    334         return cls(request, instance)
    335 
    336 
    337     def _delete_entry(self):
    338         self.instance.delete()
    339 
    340 
    341     def full_representation(self):
    342         self.prepare_for_full_representation([self])
    343         return super(InstanceEntry, self).full_representation()
    344 
    345 
    346     @classmethod
    347     def prepare_for_full_representation(cls, entries):
    348         """
    349         Prepare the given list of entries to generate full representations.
    350 
    351         This method delegates to _do_prepare_for_full_representation(), which
    352         subclasses may override as necessary to do the actual processing.  This
    353         method also marks the instance as prepared, so it's safe to call this
    354         multiple times with the same instance(s) without wasting work.
    355         """
    356         not_prepared = [entry for entry in entries
    357                         if not entry._is_prepared_for_full_representation]
    358         cls._do_prepare_for_full_representation([entry.instance
    359                                                  for entry in not_prepared])
    360         for entry in not_prepared:
    361             entry._is_prepared_for_full_representation = True
    362 
    363 
    364     @classmethod
    365     def _do_prepare_for_full_representation(cls, instances):
    366         """
    367         Subclasses may override this to gather data as needed for full
    368         representations of the given model instances.  Typically, this involves
    369         querying over related objects, and this method offers a chance to query
    370         for many instances at once, which can provide a great performance
    371         benefit.
    372         """
    373         pass
    374 
    375 
    376 class Collection(Resource):
    377     _DEFAULT_ITEMS_PER_PAGE = 50
    378 
    379     _permitted_methods=('GET', 'POST')
    380 
    381     # subclasses must override these
    382     queryset = None # or override _fresh_queryset() directly
    383     entry_class = None
    384 
    385 
    386     def __init__(self, request):
    387         super(Collection, self).__init__(request)
    388         assert self.entry_class is not None
    389         if isinstance(self.entry_class, basestring):
    390             type(self).entry_class = _resolve_class_path(self.entry_class)
    391 
    392         self._query_processor = query_lib.QueryProcessor()
    393         self.entry_class.add_query_selectors(self._query_processor)
    394 
    395 
    396     def _query_parameters_accepted(self):
    397         params = [('start_index', 'Index of first member to include'),
    398                   ('items_per_page', 'Number of members to include'),
    399                   ('full_representations',
    400                    'True to include full representations of members')]
    401         for selector in self._query_processor.selectors():
    402             params.append((selector.name, selector.doc))
    403         return params
    404 
    405 
    406     def _fresh_queryset(self):
    407         assert self.queryset is not None
    408         # always copy the queryset before using it to avoid caching
    409         return self.queryset.all()
    410 
    411 
    412     def _entry_from_instance(self, instance):
    413         return self.entry_class(self._request, instance)
    414 
    415 
    416     def _representation(self, entry_instances):
    417         entries = [self._entry_from_instance(instance)
    418                    for instance in entry_instances]
    419 
    420         want_full_representation = self._read_bool_parameter(
    421                 'full_representations')
    422         if want_full_representation:
    423             self.entry_class.prepare_for_full_representation(entries)
    424 
    425         members = []
    426         for entry in entries:
    427             if want_full_representation:
    428                 rep = entry.full_representation()
    429             else:
    430                 rep = entry.short_representation()
    431             members.append(rep)
    432 
    433         rep = self.link()
    434         rep.update({'members': members})
    435         return rep
    436 
    437 
    438     def _read_bool_parameter(self, name):
    439         if name not in self._query_params:
    440             return False
    441         return (self._query_params[name].lower() == 'true')
    442 
    443 
    444     def _read_int_parameter(self, name, default):
    445         if name not in self._query_params:
    446             return default
    447         input_value = self._query_params[name]
    448         try:
    449             return int(input_value)
    450         except ValueError:
    451             raise exceptions.BadRequest('Invalid non-numeric value for %s: %r'
    452                                         % (name, input_value))
    453 
    454 
    455     def _apply_form_query(self, queryset):
    456         """Apply any query selectors passed as form variables."""
    457         for parameter, values in self._query_params.lists():
    458             if ':' in parameter:
    459                 parameter, comparison_type = parameter.split(':', 1)
    460             else:
    461                 comparison_type = None
    462 
    463             if not self._query_processor.has_selector(parameter):
    464                 continue
    465             for value in values: # forms keys can have multiple values
    466                 queryset = self._query_processor.apply_selector(
    467                         queryset, parameter, value,
    468                         comparison_type=comparison_type)
    469         return queryset
    470 
    471 
    472     def _filtered_queryset(self):
    473         return self._apply_form_query(self._fresh_queryset())
    474 
    475 
    476     def get(self):
    477         queryset = self._filtered_queryset()
    478 
    479         items_per_page = self._read_int_parameter('items_per_page',
    480                                                   self._DEFAULT_ITEMS_PER_PAGE)
    481         start_index = self._read_int_parameter('start_index', 0)
    482         page = queryset[start_index:(start_index + items_per_page)]
    483 
    484         rep = self._representation(page)
    485         rep.update({'total_results': len(queryset),
    486                     'start_index': start_index,
    487                     'items_per_page': items_per_page})
    488         return self._basic_response(rep)
    489 
    490 
    491     def full_representation(self):
    492         # careful, this rep can be huge for large collections
    493         return self._representation(self._fresh_queryset())
    494 
    495 
    496     def post(self):
    497         input_dict = self._decoded_input()
    498         try:
    499             instance = self.entry_class.create_instance(input_dict, self)
    500             entry = self._entry_from_instance(instance)
    501             entry.update(input_dict)
    502         except model_logic.ValidationError, exc:
    503             raise exceptions.BadRequest('Invalid input: %s' % exc)
    504         # RFC 2616 specifies that we provide the new URI in both the Location
    505         # header and the body
    506         response = http.HttpResponse(status=201, # Created
    507                                      content=entry.href())
    508         response['Location'] = entry.href()
    509         return response
    510 
    511 
    512 class Relationship(Entry):
    513     _permitted_methods = ('GET', 'DELETE')
    514 
    515     # subclasses must override this with a dict mapping name to entry class
    516     related_classes = None
    517 
    518 
    519     def __init__(self, **kwargs):
    520         assert len(self.related_classes) == 2
    521         self.entries = dict((name, kwargs[name])
    522                             for name in self.related_classes)
    523         for name in self.related_classes: # sanity check
    524             assert isinstance(self.entries[name], self.related_classes[name])
    525 
    526         # just grab the request from one of the entries
    527         some_entry = self.entries.itervalues().next()
    528         super(Relationship, self).__init__(some_entry._request)
    529 
    530 
    531     @classmethod
    532     def from_uri_args(cls, request, **kwargs):
    533         # kwargs contains URI args for each entry
    534         entries = {}
    535         for name, entry_class in cls.related_classes.iteritems():
    536             entries[name] = entry_class.from_uri_args(request, **kwargs)
    537         return cls(**entries)
    538 
    539 
    540     def _uri_args(self):
    541         kwargs = {}
    542         for name, entry in self.entries.iteritems():
    543             kwargs.update(entry._uri_args())
    544         return kwargs
    545 
    546 
    547     def short_representation(self):
    548         rep = self.link()
    549         for name, entry in self.entries.iteritems():
    550             rep[name] = entry.short_representation()
    551         return rep
    552 
    553 
    554     @classmethod
    555     def _get_related_manager(cls, instance):
    556         """Get the related objects manager for the given instance.
    557 
    558         The instance must be one of the related classes.  This method will
    559         return the related manager from that instance to instances of the other
    560         related class.
    561         """
    562         this_model = type(instance)
    563         models = [entry_class.model for entry_class
    564                   in cls.related_classes.values()]
    565         if isinstance(instance, models[0]):
    566             this_model, other_model = models
    567         else:
    568             other_model, this_model = models
    569 
    570         _, field = this_model.objects.determine_relationship(other_model)
    571         this_models_fields = (this_model._meta.fields
    572                               + this_model._meta.many_to_many)
    573         if field in this_models_fields:
    574             manager_name = field.attname
    575         else:
    576             # related manager is on other_model, get name of reverse related
    577             # manager on this_model
    578             manager_name = field.related.get_accessor_name()
    579 
    580         return getattr(instance, manager_name)
    581 
    582 
    583     def _delete_entry(self):
    584         # choose order arbitrarily
    585         entry, other_entry = self.entries.itervalues()
    586         related_manager = self._get_related_manager(entry.instance)
    587         related_manager.remove(other_entry.instance)
    588 
    589 
    590     @classmethod
    591     def create_instance(cls, input_dict, containing_collection):
    592         other_name = containing_collection.unfixed_name
    593         cls._check_for_required_fields(input_dict, (other_name,))
    594         entry = containing_collection.fixed_entry
    595         other_entry = containing_collection.resolve_link(input_dict[other_name])
    596         related_manager = cls._get_related_manager(entry.instance)
    597         related_manager.add(other_entry.instance)
    598         return other_entry.instance
    599 
    600 
    601     def update(self, input_dict):
    602         pass
    603 
    604 
    605 class RelationshipCollection(Collection):
    606     def __init__(self, request=None, fixed_entry=None):
    607         if request is None:
    608             request = fixed_entry._request
    609         super(RelationshipCollection, self).__init__(request)
    610 
    611         assert issubclass(self.entry_class, Relationship)
    612         self.related_classes = self.entry_class.related_classes
    613         self.fixed_name = None
    614         self.fixed_entry = None
    615         self.unfixed_name = None
    616         self.related_manager = None
    617 
    618         if fixed_entry is not None:
    619             self._set_fixed_entry(fixed_entry)
    620             entry_uri_arg = self.fixed_entry._uri_args().values()[0]
    621             self._query_params[self.fixed_name] = entry_uri_arg
    622 
    623 
    624     def _set_fixed_entry(self, entry):
    625         """Set the fixed entry for this collection.
    626 
    627         The entry must be an instance of one of the related entry classes.  This
    628         method must be called before a relationship is used.  It gets called
    629         either from the constructor (when collections are instantiated from
    630         other resource handling code) or from read_query_parameters() (when a
    631         request is made directly for the collection.
    632         """
    633         names = self.related_classes.keys()
    634         if isinstance(entry, self.related_classes[names[0]]):
    635             self.fixed_name, self.unfixed_name = names
    636         else:
    637             assert isinstance(entry, self.related_classes[names[1]])
    638             self.unfixed_name, self.fixed_name = names
    639         self.fixed_entry = entry
    640         self.unfixed_class = self.related_classes[self.unfixed_name]
    641         self.related_manager = self.entry_class._get_related_manager(
    642                 entry.instance)
    643 
    644 
    645     def _query_parameters_accepted(self):
    646         return [(name, 'Show relationships for this %s' % entry_class.__name__)
    647                 for name, entry_class
    648                 in self.related_classes.iteritems()]
    649 
    650 
    651     def _resolve_query_param(self, name, uri_arg):
    652         entry_class = self.related_classes[name]
    653         return entry_class.from_uri_args(self._request, uri_arg)
    654 
    655 
    656     def read_query_parameters(self, query_params):
    657         super(RelationshipCollection, self).read_query_parameters(query_params)
    658         if not self._query_params:
    659             raise exceptions.BadRequest(
    660                     'You must specify one of the parameters %s and %s'
    661                     % tuple(self.related_classes.keys()))
    662         query_items = self._query_params.items()
    663         fixed_entry = self._resolve_query_param(*query_items[0])
    664         self._set_fixed_entry(fixed_entry)
    665 
    666         if len(query_items) > 1:
    667             other_fixed_entry = self._resolve_query_param(*query_items[1])
    668             self.related_manager = self.related_manager.filter(
    669                     pk=other_fixed_entry.instance.id)
    670 
    671 
    672     def _entry_from_instance(self, instance):
    673         unfixed_entry = self.unfixed_class(self._request, instance)
    674         entries = {self.fixed_name: self.fixed_entry,
    675                    self.unfixed_name: unfixed_entry}
    676         return self.entry_class(**entries)
    677 
    678 
    679     def _fresh_queryset(self):
    680         return self.related_manager.all()
    681