Home | History | Annotate | Download | only in servlet
      1 /**
      2  * Copyright (C) 2008 Google Inc.
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  * http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 package com.google.inject.servlet;
     17 
     18 import com.google.common.base.Throwables;
     19 import com.google.common.collect.ImmutableSet;
     20 import com.google.common.collect.Lists;
     21 
     22 import java.io.IOException;
     23 import java.util.List;
     24 import java.util.Set;
     25 
     26 import javax.servlet.Filter;
     27 import javax.servlet.FilterChain;
     28 import javax.servlet.ServletException;
     29 import javax.servlet.ServletRequest;
     30 import javax.servlet.ServletResponse;
     31 import javax.servlet.http.HttpServletRequest;
     32 import javax.servlet.http.HttpServletResponse;
     33 
     34 /**
     35  * A Filter chain impl which basically passes itself to the "current" filter and iterates the chain
     36  * on {@code doFilter()}. Modeled on something similar in Apache Tomcat.
     37  *
     38  * Following this, it attempts to dispatch to guice-servlet's registered servlets using the
     39  * ManagedServletPipeline.
     40  *
     41  * And the end, it proceeds to the web.xml (default) servlet filter chain, if needed.
     42  *
     43  * @author Dhanji R. Prasanna
     44  * @since 1.0
     45  */
     46 class FilterChainInvocation implements FilterChain {
     47 
     48   private static final Set<String> SERVLET_INTERNAL_METHODS = ImmutableSet.of(
     49       FilterChainInvocation.class.getName() + ".doFilter");
     50 
     51   private final FilterDefinition[] filterDefinitions;
     52   private final FilterChain proceedingChain;
     53   private final ManagedServletPipeline servletPipeline;
     54 
     55   //state variable tracks current link in filterchain
     56   private int index = -1;
     57   // whether or not we've caught an exception & cleaned up stack traces
     58   private boolean cleanedStacks = false;
     59 
     60   public FilterChainInvocation(FilterDefinition[] filterDefinitions,
     61       ManagedServletPipeline servletPipeline, FilterChain proceedingChain) {
     62 
     63     this.filterDefinitions = filterDefinitions;
     64     this.servletPipeline = servletPipeline;
     65     this.proceedingChain = proceedingChain;
     66   }
     67 
     68   public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse)
     69       throws IOException, ServletException {
     70     GuiceFilter.Context previous = GuiceFilter.localContext.get();
     71     HttpServletRequest request = (HttpServletRequest) servletRequest;
     72     HttpServletResponse response = (HttpServletResponse) servletResponse;
     73     HttpServletRequest originalRequest
     74         = (previous != null) ? previous.getOriginalRequest() : request;
     75     GuiceFilter.localContext.set(new GuiceFilter.Context(originalRequest, request, response));
     76     try {
     77       Filter filter = findNextFilter(request);
     78       if (filter != null) {
     79         // call to the filter, which can either consume the request or
     80         // recurse back into this method. (in which case we will go to find the next filter,
     81         // or dispatch to the servlet if no more filters are left)
     82         filter.doFilter(servletRequest, servletResponse, this);
     83       } else {
     84         //we've reached the end of the filterchain, let's try to dispatch to a servlet
     85         final boolean serviced = servletPipeline.service(servletRequest, servletResponse);
     86 
     87         //dispatch to the normal filter chain only if one of our servlets did not match
     88         if (!serviced) {
     89           proceedingChain.doFilter(servletRequest, servletResponse);
     90         }
     91       }
     92     } catch (Throwable t) {
     93       // Only clean on the first pass through -- one exception deep in a filter
     94       // will propogate upward & hit this catch clause multiple times.  We don't
     95       // want to iterate through the stack elements for every filter.
     96       if (!cleanedStacks) {
     97         cleanedStacks = true;
     98         pruneStacktrace(t);
     99       }
    100       Throwables.propagateIfInstanceOf(t, ServletException.class);
    101       Throwables.propagateIfInstanceOf(t, IOException.class);
    102       throw Throwables.propagate(t);
    103     } finally {
    104       GuiceFilter.localContext.set(previous);
    105     }
    106   }
    107 
    108   /**
    109    * Iterates over the remaining filter definitions.
    110    * Returns the first applicable filter, or null if none apply.
    111    */
    112   private Filter findNextFilter(HttpServletRequest request) {
    113     while (++index < filterDefinitions.length) {
    114       Filter filter = filterDefinitions[index].getFilterIfMatching(request);
    115       if (filter != null) {
    116         return filter;
    117       }
    118     }
    119     return null;
    120   }
    121 
    122   /**
    123    * Removes stacktrace elements related to AOP internal mechanics from the
    124    * throwable's stack trace and any causes it may have.
    125    */
    126   private void pruneStacktrace(Throwable throwable) {
    127     for (Throwable t = throwable; t != null; t = t.getCause()) {
    128       StackTraceElement[] stackTrace = t.getStackTrace();
    129       List<StackTraceElement> pruned = Lists.newArrayList();
    130       for (StackTraceElement element : stackTrace) {
    131         String name = element.getClassName() + "." + element.getMethodName();
    132         if (!SERVLET_INTERNAL_METHODS.contains(name)) {
    133           pruned.add(element);
    134         }
    135       }
    136       t.setStackTrace(pruned.toArray(new StackTraceElement[pruned.size()]));
    137     }
    138   }
    139 }
    140