Home | History | Annotate | Download | only in servlet
      1 /**
      2  * Copyright (C) 2010 Google Inc.
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  * http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 package com.google.inject.servlet;
     18 
     19 import com.google.common.base.Objects;
     20 import com.google.common.collect.ImmutableSet;
     21 import com.google.common.collect.Lists;
     22 import com.google.inject.Binding;
     23 import com.google.inject.Injector;
     24 import com.google.inject.Stage;
     25 import com.google.inject.spi.DefaultBindingTargetVisitor;
     26 
     27 import junit.framework.AssertionFailedError;
     28 
     29 import java.util.List;
     30 import java.util.Map;
     31 import java.util.Set;
     32 import java.util.logging.Logger;
     33 
     34 import javax.servlet.Filter;
     35 import javax.servlet.ServletContext;
     36 import javax.servlet.ServletRequest;
     37 import javax.servlet.ServletResponse;
     38 import javax.servlet.http.HttpServlet;
     39 import javax.servlet.http.HttpServletRequest;
     40 import javax.servlet.http.HttpServletResponse;
     41 import javax.servlet.http.HttpSession;
     42 
     43 /**
     44  * A visitor for testing the servlet SPI extension.
     45  *
     46  * @author sameb (at) google.com (Sam Berlin)
     47  */
     48 class ServletSpiVisitor
     49     extends DefaultBindingTargetVisitor<Object, Integer>
     50     implements ServletModuleTargetVisitor<Object, Integer> {
     51 
     52   int otherCount = 0;
     53   int currentCount = 0;
     54   List<Params> actual = Lists.newArrayList();
     55 
     56   /* The set of classes that are allowed to be "other" bindings. */
     57   Set<Class> allowedClasses;
     58 
     59   ServletSpiVisitor(boolean forInjector) {
     60     ImmutableSet.Builder<Class> builder = ImmutableSet.builder();
     61     // always ignore these things...
     62     builder.add(ServletRequest.class,
     63         ServletResponse.class, ManagedFilterPipeline.class, ManagedServletPipeline.class,
     64         FilterPipeline.class, ServletContext.class, HttpServletRequest.class, Filter.class,
     65         HttpServletResponse.class, HttpSession.class, Map.class, HttpServlet.class,
     66         InternalServletModule.BackwardsCompatibleServletContextProvider.class,
     67         GuiceFilter.class);
     68     if(forInjector) {
     69       // only ignore these if this is for the live injector, any other time it'd be an error!
     70       builder.add(Injector.class, Stage.class, Logger.class);
     71     }
     72     this.allowedClasses = builder.build();
     73   }
     74 
     75   public Integer visit(InstanceFilterBinding binding) {
     76     actual.add(new Params(binding, binding.getFilterInstance()));
     77     return currentCount++;
     78   }
     79 
     80   public Integer visit(InstanceServletBinding binding) {
     81     actual.add(new Params(binding, binding.getServletInstance()));
     82     return currentCount++;
     83   }
     84 
     85   public Integer visit(LinkedFilterBinding binding) {
     86     actual.add(new Params(binding, binding.getLinkedKey()));
     87     return currentCount++;
     88   }
     89 
     90   public Integer visit(LinkedServletBinding binding) {
     91     actual.add(new Params(binding, binding.getLinkedKey()));
     92     return currentCount++;
     93   }
     94 
     95   @Override
     96   protected Integer visitOther(Binding<? extends Object> binding) {
     97     if(!allowedClasses.contains(binding.getKey().getTypeLiteral().getRawType())) {
     98       throw new AssertionFailedError("invalid other binding: " + binding);
     99     }
    100     otherCount++;
    101     return currentCount++;
    102   }
    103 
    104   static class Params {
    105     private final String pattern;
    106     private final Object keyOrInstance;
    107     private final Map<String, String> params;
    108     private final UriPatternType patternType;
    109 
    110     Params(ServletModuleBinding binding, Object keyOrInstance) {
    111       this.pattern = binding.getPattern();
    112       this.keyOrInstance = keyOrInstance;
    113       this.params = binding.getInitParams();
    114       this.patternType = binding.getUriPatternType();
    115     }
    116 
    117     Params(String pattern, Object keyOrInstance, Map params, UriPatternType patternType) {
    118       this.pattern = pattern;
    119       this.keyOrInstance = keyOrInstance;
    120       this.params = params;
    121       this.patternType = patternType;
    122     }
    123 
    124     @Override
    125     public boolean equals(Object obj) {
    126       if(obj instanceof Params) {
    127         Params o = (Params)obj;
    128         return Objects.equal(pattern, o.pattern)
    129             && Objects.equal(keyOrInstance, o.keyOrInstance)
    130             && Objects.equal(params, o.params)
    131             && Objects.equal(patternType, o.patternType);
    132       } else {
    133         return false;
    134       }
    135     }
    136 
    137     @Override
    138     public int hashCode() {
    139       return Objects.hashCode(pattern, keyOrInstance, params, patternType);
    140     }
    141 
    142     @Override
    143     public String toString() {
    144       return Objects.toStringHelper(Params.class)
    145         .add("pattern", pattern)
    146         .add("keyOrInstance", keyOrInstance)
    147         .add("initParams", params)
    148         .add("patternType", patternType)
    149         .toString();
    150     }
    151   }
    152 }