Home | History | Annotate | Download | only in servlet
      1 /**
      2  * Copyright (C) 2008 Google Inc.
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  * http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 package com.google.inject.servlet;
     17 
     18 import com.google.common.base.Preconditions;
     19 import com.google.common.collect.Lists;
     20 import com.google.common.collect.Sets;
     21 import com.google.inject.Binding;
     22 import com.google.inject.Inject;
     23 import com.google.inject.Injector;
     24 import com.google.inject.Singleton;
     25 import com.google.inject.TypeLiteral;
     26 
     27 import java.io.IOException;
     28 import java.util.List;
     29 import java.util.Set;
     30 
     31 import javax.servlet.RequestDispatcher;
     32 import javax.servlet.ServletContext;
     33 import javax.servlet.ServletException;
     34 import javax.servlet.ServletRequest;
     35 import javax.servlet.ServletResponse;
     36 import javax.servlet.http.HttpServlet;
     37 import javax.servlet.http.HttpServletRequest;
     38 import javax.servlet.http.HttpServletRequestWrapper;
     39 
     40 /**
     41  * A wrapping dispatcher for servlets, in much the same way as {@link ManagedFilterPipeline} is for
     42  * filters.
     43  *
     44  * @author dhanji (at) gmail.com (Dhanji R. Prasanna)
     45  */
     46 @Singleton
     47 class ManagedServletPipeline {
     48   private final ServletDefinition[] servletDefinitions;
     49   private static final TypeLiteral<ServletDefinition> SERVLET_DEFS =
     50       TypeLiteral.get(ServletDefinition.class);
     51 
     52   @Inject
     53   public ManagedServletPipeline(Injector injector) {
     54     this.servletDefinitions = collectServletDefinitions(injector);
     55   }
     56 
     57   boolean hasServletsMapped() {
     58     return servletDefinitions.length > 0;
     59   }
     60 
     61   /**
     62    * Introspects the injector and collects all instances of bound {@code List<ServletDefinition>}
     63    * into a master list.
     64    *
     65    * We have a guarantee that {@link com.google.inject.Injector#getBindings()} returns a map
     66    * that preserves insertion order in entry-set iterators.
     67    */
     68   private ServletDefinition[] collectServletDefinitions(Injector injector) {
     69     List<ServletDefinition> servletDefinitions = Lists.newArrayList();
     70     for (Binding<ServletDefinition> entry : injector.findBindingsByType(SERVLET_DEFS)) {
     71         servletDefinitions.add(entry.getProvider().get());
     72     }
     73 
     74     // Copy to a fixed size array for speed.
     75     return servletDefinitions.toArray(new ServletDefinition[servletDefinitions.size()]);
     76   }
     77 
     78   public void init(ServletContext servletContext, Injector injector) throws ServletException {
     79     Set<HttpServlet> initializedSoFar = Sets.newIdentityHashSet();
     80 
     81     for (ServletDefinition servletDefinition : servletDefinitions) {
     82       servletDefinition.init(servletContext, injector, initializedSoFar);
     83     }
     84   }
     85 
     86   public boolean service(ServletRequest request, ServletResponse response)
     87       throws IOException, ServletException {
     88 
     89     //stop at the first matching servlet and service
     90     for (ServletDefinition servletDefinition : servletDefinitions) {
     91       if (servletDefinition.service(request, response)) {
     92         return true;
     93       }
     94     }
     95 
     96     //there was no match...
     97     return false;
     98   }
     99 
    100   public void destroy() {
    101     Set<HttpServlet> destroyedSoFar = Sets.newIdentityHashSet();
    102     for (ServletDefinition servletDefinition : servletDefinitions) {
    103       servletDefinition.destroy(destroyedSoFar);
    104     }
    105   }
    106 
    107   /**
    108    * @return Returns a request dispatcher wrapped with a servlet mapped to
    109    * the given path or null if no mapping was found.
    110    */
    111   RequestDispatcher getRequestDispatcher(String path) {
    112     final String newRequestUri = path;
    113 
    114     // TODO(dhanji): check servlet spec to see if the following is legal or not.
    115     // Need to strip query string if requested...
    116 
    117     for (final ServletDefinition servletDefinition : servletDefinitions) {
    118       if (servletDefinition.shouldServe(path)) {
    119         return new RequestDispatcher() {
    120           public void forward(ServletRequest servletRequest, ServletResponse servletResponse)
    121               throws ServletException, IOException {
    122             Preconditions.checkState(!servletResponse.isCommitted(),
    123                 "Response has been committed--you can only call forward before"
    124                 + " committing the response (hint: don't flush buffers)");
    125 
    126             // clear buffer before forwarding
    127             servletResponse.resetBuffer();
    128 
    129             ServletRequest requestToProcess;
    130             if (servletRequest instanceof HttpServletRequest) {
    131                requestToProcess = wrapRequest((HttpServletRequest)servletRequest, newRequestUri);
    132             } else {
    133               // This should never happen, but instead of throwing an exception
    134               // we will allow a happy case pass thru for maximum tolerance to
    135               // legacy (and internal) code.
    136               requestToProcess = servletRequest;
    137             }
    138 
    139             // now dispatch to the servlet
    140             doServiceImpl(servletDefinition, requestToProcess, servletResponse);
    141           }
    142 
    143           public void include(ServletRequest servletRequest, ServletResponse servletResponse)
    144               throws ServletException, IOException {
    145             // route to the target servlet
    146             doServiceImpl(servletDefinition, servletRequest, servletResponse);
    147           }
    148 
    149           private void doServiceImpl(ServletDefinition servletDefinition, ServletRequest servletRequest,
    150               ServletResponse servletResponse) throws ServletException, IOException {
    151             servletRequest.setAttribute(REQUEST_DISPATCHER_REQUEST, Boolean.TRUE);
    152 
    153             try {
    154               servletDefinition.doService(servletRequest, servletResponse);
    155             } finally {
    156               servletRequest.removeAttribute(REQUEST_DISPATCHER_REQUEST);
    157             }
    158           }
    159         };
    160       }
    161     }
    162 
    163     //otherwise, can't process
    164     return null;
    165   }
    166 
    167   // visible for testing
    168   static HttpServletRequest wrapRequest(HttpServletRequest request, String newUri) {
    169     return new RequestDispatcherRequestWrapper(request, newUri);
    170   }
    171 
    172   /**
    173    * A Marker constant attribute that when present in the request indicates to Guice servlet that
    174    * this request has been generated by a request dispatcher rather than the servlet pipeline.
    175    * In accordance with section 8.4.2 of the Servlet 2.4 specification.
    176    */
    177   public static final String REQUEST_DISPATCHER_REQUEST = "javax.servlet.forward.servlet_path";
    178 
    179   private static class RequestDispatcherRequestWrapper extends HttpServletRequestWrapper {
    180     private final String newRequestUri;
    181 
    182     public RequestDispatcherRequestWrapper(HttpServletRequest servletRequest, String newRequestUri) {
    183       super(servletRequest);
    184       this.newRequestUri = newRequestUri;
    185     }
    186 
    187     @Override
    188     public String getRequestURI() {
    189       return newRequestUri;
    190     }
    191 
    192     @Override
    193     public StringBuffer getRequestURL() {
    194       StringBuffer url = new StringBuffer();
    195       String scheme = getScheme();
    196       int port = getServerPort();
    197 
    198       url.append(scheme);
    199       url.append("://");
    200       url.append(getServerName());
    201       // port might be -1 in some cases (see java.net.URL.getPort)
    202       if (port > 0 &&
    203           (("http".equals(scheme) && (port != 80)) ||
    204            ("https".equals(scheme) && (port != 443)))) {
    205         url.append(':');
    206         url.append(port);
    207       }
    208       url.append(getRequestURI());
    209 
    210       return (url);
    211     }
    212   }
    213 }
    214