Home | History | Annotate | Download | only in servlet
      1 /**
      2  * Copyright (C) 2006 Google Inc.
      3  *
      4  * Licensed under the Apache License, Version 2.0 (the "License");
      5  * you may not use this file except in compliance with the License.
      6  * You may obtain a copy of the License at
      7  *
      8  * http://www.apache.org/licenses/LICENSE-2.0
      9  *
     10  * Unless required by applicable law or agreed to in writing, software
     11  * distributed under the License is distributed on an "AS IS" BASIS,
     12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     13  * See the License for the specific language governing permissions and
     14  * limitations under the License.
     15  */
     16 
     17 package com.google.inject.servlet;
     18 
     19 import static com.google.inject.Asserts.assertContains;
     20 import static com.google.inject.Asserts.reserialize;
     21 import static com.google.inject.servlet.ServletTestUtils.newFakeHttpServletRequest;
     22 import static com.google.inject.servlet.ServletTestUtils.newFakeHttpServletResponse;
     23 import static java.lang.annotation.ElementType.FIELD;
     24 import static java.lang.annotation.ElementType.METHOD;
     25 import static java.lang.annotation.ElementType.PARAMETER;
     26 import static java.lang.annotation.RetentionPolicy.RUNTIME;
     27 
     28 import com.google.common.collect.ImmutableMap;
     29 import com.google.common.collect.Lists;
     30 import com.google.inject.AbstractModule;
     31 import com.google.inject.BindingAnnotation;
     32 import com.google.inject.CreationException;
     33 import com.google.inject.Guice;
     34 import com.google.inject.Inject;
     35 import com.google.inject.Injector;
     36 import com.google.inject.Key;
     37 import com.google.inject.Module;
     38 import com.google.inject.Provider;
     39 import com.google.inject.Provides;
     40 import com.google.inject.ProvisionException;
     41 import com.google.inject.internal.Errors;
     42 import com.google.inject.name.Named;
     43 import com.google.inject.name.Names;
     44 import com.google.inject.servlet.ServletScopes.NullObject;
     45 import com.google.inject.util.Providers;
     46 
     47 import junit.framework.TestCase;
     48 
     49 import java.io.IOException;
     50 import java.io.Serializable;
     51 import java.lang.annotation.Retention;
     52 import java.lang.annotation.Target;
     53 import java.util.Map;
     54 
     55 import javax.servlet.Filter;
     56 import javax.servlet.FilterChain;
     57 import javax.servlet.FilterConfig;
     58 import javax.servlet.ServletException;
     59 import javax.servlet.ServletRequest;
     60 import javax.servlet.ServletResponse;
     61 import javax.servlet.http.HttpServlet;
     62 import javax.servlet.http.HttpServletRequest;
     63 import javax.servlet.http.HttpServletRequestWrapper;
     64 import javax.servlet.http.HttpServletResponse;
     65 import javax.servlet.http.HttpServletResponseWrapper;
     66 import javax.servlet.http.HttpSession;
     67 
     68 /**
     69  * @author crazybob (at) google.com (Bob Lee)
     70  */
     71 public class ServletTest extends TestCase {
     72   private static final Key<HttpServletRequest> HTTP_REQ_KEY = Key.get(HttpServletRequest.class);
     73   private static final Key<HttpServletResponse> HTTP_RESP_KEY = Key.get(HttpServletResponse.class);
     74   private static final Key<Map<String, String[]>> REQ_PARAMS_KEY
     75       = new Key<Map<String, String[]>>(RequestParameters.class) {};
     76 
     77   private static final Key<InRequest> IN_REQUEST_NULL_KEY = Key.get(InRequest.class, Null.class);
     78   private static final Key<InSession> IN_SESSION_KEY = Key.get(InSession.class);
     79   private static final Key<InSession> IN_SESSION_NULL_KEY = Key.get(InSession.class, Null.class);
     80 
     81   @Override
     82   public void setUp() {
     83     //we need to clear the reference to the pipeline every test =(
     84     GuiceFilter.reset();
     85   }
     86 
     87   public void testScopeExceptions() throws Exception {
     88     Injector injector = Guice.createInjector(new AbstractModule() {
     89       @Override protected void configure() {
     90         install(new ServletModule());
     91       }
     92       @Provides @RequestScoped String provideString() { return "foo"; }
     93       @Provides @SessionScoped Integer provideInteger() { return 1; }
     94       @Provides @RequestScoped @Named("foo") String provideNamedString() { return "foo"; }
     95     });
     96 
     97     try {
     98       injector.getInstance(String.class);
     99       fail();
    100     } catch(ProvisionException oose) {
    101       assertContains(oose.getMessage(), "Cannot access scoped [java.lang.String].");
    102     }
    103 
    104     try {
    105       injector.getInstance(Integer.class);
    106       fail();
    107     } catch(ProvisionException oose) {
    108       assertContains(oose.getMessage(), "Cannot access scoped [java.lang.Integer].");
    109     }
    110 
    111     Key<?> key = Key.get(String.class, Names.named("foo"));
    112     try {
    113       injector.getInstance(key);
    114       fail();
    115     } catch(ProvisionException oose) {
    116       assertContains(oose.getMessage(), "Cannot access scoped [" + Errors.convert(key) + "]");
    117     }
    118   }
    119 
    120   public void testRequestAndResponseBindings() throws Exception {
    121     final Injector injector = createInjector();
    122     final HttpServletRequest request = newFakeHttpServletRequest();
    123     final HttpServletResponse response = newFakeHttpServletResponse();
    124 
    125     final boolean[] invoked = new boolean[1];
    126     GuiceFilter filter = new GuiceFilter();
    127     FilterChain filterChain = new FilterChain() {
    128       public void doFilter(ServletRequest servletRequest,
    129           ServletResponse servletResponse) {
    130         invoked[0] = true;
    131         assertSame(request, servletRequest);
    132         assertSame(request, injector.getInstance(ServletRequest.class));
    133         assertSame(request, injector.getInstance(HTTP_REQ_KEY));
    134 
    135         assertSame(response, servletResponse);
    136         assertSame(response, injector.getInstance(ServletResponse.class));
    137         assertSame(response, injector.getInstance(HTTP_RESP_KEY));
    138 
    139         assertSame(servletRequest.getParameterMap(), injector.getInstance(REQ_PARAMS_KEY));
    140       }
    141     };
    142     filter.doFilter(request, response, filterChain);
    143 
    144     assertTrue(invoked[0]);
    145   }
    146 
    147   public void testRequestAndResponseBindings_wrappingFilter() throws Exception {
    148     final HttpServletRequest request = newFakeHttpServletRequest();
    149     final ImmutableMap<String, String[]> wrappedParamMap
    150         = ImmutableMap.of("wrap", new String[]{"a", "b"});
    151     final HttpServletRequestWrapper requestWrapper = new HttpServletRequestWrapper(request) {
    152       @Override public Map getParameterMap() {
    153         return wrappedParamMap;
    154       }
    155 
    156       @Override public Object getAttribute(String attr) {
    157         // Ensure that attributes are stored on the original request object.
    158         throw new UnsupportedOperationException();
    159       }
    160     };
    161     final HttpServletResponse response = newFakeHttpServletResponse();
    162     final HttpServletResponseWrapper responseWrapper = new HttpServletResponseWrapper(response);
    163 
    164     final boolean[] filterInvoked = new boolean[1];
    165     final Injector injector = createInjector(new ServletModule() {
    166       @Override protected void configureServlets() {
    167         filter("/*").through(new Filter() {
    168           @Inject Provider<ServletRequest> servletReqProvider;
    169           @Inject Provider<HttpServletRequest> reqProvider;
    170           @Inject Provider<ServletResponse> servletRespProvider;
    171           @Inject Provider<HttpServletResponse> respProvider;
    172 
    173           public void init(FilterConfig filterConfig) {}
    174 
    175           public void doFilter(ServletRequest req, ServletResponse resp, FilterChain chain)
    176               throws IOException, ServletException {
    177             filterInvoked[0] = true;
    178             assertSame(req, servletReqProvider.get());
    179             assertSame(req, reqProvider.get());
    180 
    181             assertSame(resp, servletRespProvider.get());
    182             assertSame(resp, respProvider.get());
    183 
    184             chain.doFilter(requestWrapper, responseWrapper);
    185 
    186             assertSame(req, reqProvider.get());
    187             assertSame(resp, respProvider.get());
    188           }
    189 
    190           public void destroy() {}
    191         });
    192       }
    193     });
    194 
    195     GuiceFilter filter = new GuiceFilter();
    196     final boolean[] chainInvoked = new boolean[1];
    197     FilterChain filterChain = new FilterChain() {
    198       public void doFilter(ServletRequest servletRequest,
    199           ServletResponse servletResponse) {
    200         chainInvoked[0] = true;
    201         assertSame(requestWrapper, servletRequest);
    202         assertSame(requestWrapper, injector.getInstance(ServletRequest.class));
    203         assertSame(requestWrapper, injector.getInstance(HTTP_REQ_KEY));
    204 
    205         assertSame(responseWrapper, servletResponse);
    206         assertSame(responseWrapper, injector.getInstance(ServletResponse.class));
    207         assertSame(responseWrapper, injector.getInstance(HTTP_RESP_KEY));
    208 
    209         assertSame(servletRequest.getParameterMap(), injector.getInstance(REQ_PARAMS_KEY));
    210 
    211         InRequest inRequest = injector.getInstance(InRequest.class);
    212         assertSame(inRequest, injector.getInstance(InRequest.class));
    213       }
    214     };
    215     filter.doFilter(request, response, filterChain);
    216 
    217     assertTrue(chainInvoked[0]);
    218     assertTrue(filterInvoked[0]);
    219   }
    220 
    221   public void testRequestAndResponseBindings_matchesPassedParameters() throws Exception {
    222     final int[] filterInvoked = new int[1];
    223     final boolean[] servletInvoked = new boolean[1];
    224     createInjector(new ServletModule() {
    225       @Override protected void configureServlets() {
    226         final HttpServletRequest[] previousReq = new HttpServletRequest[1];
    227         final HttpServletResponse[] previousResp = new HttpServletResponse[1];
    228 
    229         final Provider<ServletRequest> servletReqProvider = getProvider(ServletRequest.class);
    230         final Provider<HttpServletRequest> reqProvider = getProvider(HttpServletRequest.class);
    231         final Provider<ServletResponse> servletRespProvider = getProvider(ServletResponse.class);
    232         final Provider<HttpServletResponse> respProvider = getProvider(HttpServletResponse.class);
    233 
    234         Filter filter = new Filter() {
    235           public void init(FilterConfig filterConfig) {}
    236 
    237           public void doFilter(ServletRequest req, ServletResponse resp, FilterChain chain)
    238               throws IOException, ServletException {
    239             filterInvoked[0]++;
    240             assertSame(req, servletReqProvider.get());
    241             assertSame(req, reqProvider.get());
    242             if (previousReq[0] != null) {
    243               assertEquals(req, previousReq[0]);
    244             }
    245 
    246             assertSame(resp, servletRespProvider.get());
    247             assertSame(resp, respProvider.get());
    248             if (previousResp[0] != null) {
    249               assertEquals(resp, previousResp[0]);
    250             }
    251 
    252             chain.doFilter(
    253                 previousReq[0] = new HttpServletRequestWrapper((HttpServletRequest) req),
    254                 previousResp[0] = new HttpServletResponseWrapper((HttpServletResponse) resp));
    255 
    256             assertSame(req, reqProvider.get());
    257             assertSame(resp, respProvider.get());
    258           }
    259 
    260           public void destroy() {}
    261         };
    262 
    263         filter("/*").through(filter);
    264         filter("/*").through(filter);  // filter twice to test wrapping in filters
    265         serve("/*").with(new HttpServlet() {
    266           @Override protected void doGet(HttpServletRequest req, HttpServletResponse resp) {
    267             servletInvoked[0] = true;
    268             assertSame(req, servletReqProvider.get());
    269             assertSame(req, reqProvider.get());
    270 
    271             assertSame(resp, servletRespProvider.get());
    272             assertSame(resp, respProvider.get());
    273           }
    274         });
    275       }
    276     });
    277 
    278     GuiceFilter filter = new GuiceFilter();
    279     filter.doFilter(newFakeHttpServletRequest(), newFakeHttpServletResponse(), new FilterChain() {
    280       public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse) {
    281         throw new IllegalStateException("Shouldn't get here");
    282       }
    283     });
    284 
    285     assertEquals(2, filterInvoked[0]);
    286     assertTrue(servletInvoked[0]);
    287   }
    288 
    289   public void testNewRequestObject()
    290       throws CreationException, IOException, ServletException {
    291     final Injector injector = createInjector();
    292     final HttpServletRequest request = newFakeHttpServletRequest();
    293 
    294     GuiceFilter filter = new GuiceFilter();
    295     final boolean[] invoked = new boolean[1];
    296     FilterChain filterChain = new FilterChain() {
    297       public void doFilter(ServletRequest servletRequest,
    298           ServletResponse servletResponse) {
    299         invoked[0] = true;
    300         assertNotNull(injector.getInstance(InRequest.class));
    301         assertNull(injector.getInstance(IN_REQUEST_NULL_KEY));
    302       }
    303     };
    304 
    305     filter.doFilter(request, null, filterChain);
    306 
    307     assertTrue(invoked[0]);
    308   }
    309 
    310   public void testExistingRequestObject()
    311       throws CreationException, IOException, ServletException {
    312     final Injector injector = createInjector();
    313     final HttpServletRequest request = newFakeHttpServletRequest();
    314 
    315     GuiceFilter filter = new GuiceFilter();
    316     final boolean[] invoked = new boolean[1];
    317     FilterChain filterChain = new FilterChain() {
    318       public void doFilter(ServletRequest servletRequest,
    319           ServletResponse servletResponse) {
    320         invoked[0] = true;
    321 
    322         InRequest inRequest = injector.getInstance(InRequest.class);
    323         assertSame(inRequest, injector.getInstance(InRequest.class));
    324 
    325         assertNull(injector.getInstance(IN_REQUEST_NULL_KEY));
    326         assertNull(injector.getInstance(IN_REQUEST_NULL_KEY));
    327       }
    328     };
    329 
    330     filter.doFilter(request, null, filterChain);
    331 
    332     assertTrue(invoked[0]);
    333   }
    334 
    335   public void testNewSessionObject()
    336       throws CreationException, IOException, ServletException {
    337     final Injector injector = createInjector();
    338     final HttpServletRequest request = newFakeHttpServletRequest();
    339 
    340     GuiceFilter filter = new GuiceFilter();
    341     final boolean[] invoked = new boolean[1];
    342     FilterChain filterChain = new FilterChain() {
    343       public void doFilter(ServletRequest servletRequest,
    344           ServletResponse servletResponse) {
    345         invoked[0] = true;
    346         assertNotNull(injector.getInstance(InSession.class));
    347         assertNull(injector.getInstance(IN_SESSION_NULL_KEY));
    348       }
    349     };
    350 
    351     filter.doFilter(request, null, filterChain);
    352 
    353     assertTrue(invoked[0]);
    354   }
    355 
    356   public void testExistingSessionObject()
    357       throws CreationException, IOException, ServletException {
    358     final Injector injector = createInjector();
    359     final HttpServletRequest request = newFakeHttpServletRequest();
    360 
    361     GuiceFilter filter = new GuiceFilter();
    362     final boolean[] invoked = new boolean[1];
    363     FilterChain filterChain = new FilterChain() {
    364       public void doFilter(ServletRequest servletRequest,
    365           ServletResponse servletResponse) {
    366         invoked[0] = true;
    367 
    368         InSession inSession = injector.getInstance(InSession.class);
    369         assertSame(inSession, injector.getInstance(InSession.class));
    370 
    371         assertNull(injector.getInstance(IN_SESSION_NULL_KEY));
    372         assertNull(injector.getInstance(IN_SESSION_NULL_KEY));
    373       }
    374     };
    375 
    376     filter.doFilter(request, null, filterChain);
    377 
    378     assertTrue(invoked[0]);
    379   }
    380 
    381   public void testHttpSessionIsSerializable() throws Exception {
    382     final Injector injector = createInjector();
    383     final HttpServletRequest request = newFakeHttpServletRequest();
    384     final HttpSession session = request.getSession();
    385 
    386     GuiceFilter filter = new GuiceFilter();
    387     final boolean[] invoked = new boolean[1];
    388     FilterChain filterChain = new FilterChain() {
    389       public void doFilter(ServletRequest servletRequest,
    390           ServletResponse servletResponse) {
    391         invoked[0] = true;
    392         assertNotNull(injector.getInstance(InSession.class));
    393         assertNull(injector.getInstance(IN_SESSION_NULL_KEY));
    394       }
    395     };
    396 
    397     filter.doFilter(request, null, filterChain);
    398 
    399     assertTrue(invoked[0]);
    400 
    401     HttpSession deserializedSession = reserialize(session);
    402 
    403     String inSessionKey = IN_SESSION_KEY.toString();
    404     String inSessionNullKey = IN_SESSION_NULL_KEY.toString();
    405     assertTrue(deserializedSession.getAttribute(inSessionKey) instanceof InSession);
    406     assertEquals(NullObject.INSTANCE, deserializedSession.getAttribute(inSessionNullKey));
    407   }
    408 
    409   public void testGuiceFilterConstructors() throws Exception {
    410     final RuntimeException servletException = new RuntimeException();
    411     final RuntimeException chainException = new RuntimeException();
    412     final Injector injector = createInjector(new ServletModule() {
    413       @Override protected void configureServlets() {
    414         serve("/*").with(new HttpServlet() {
    415           @Override protected void doGet(HttpServletRequest req, HttpServletResponse resp) {
    416             throw servletException;
    417           }
    418         });
    419       }
    420     });
    421     final HttpServletRequest request = newFakeHttpServletRequest();
    422     FilterChain filterChain = new FilterChain() {
    423       public void doFilter(ServletRequest servletRequest,
    424           ServletResponse servletResponse) {
    425         throw chainException;
    426       }
    427     };
    428 
    429     try {
    430       new GuiceFilter().doFilter(request, null, filterChain);
    431       fail();
    432     } catch (RuntimeException e) {
    433       assertSame(servletException, e);
    434     }
    435     try {
    436       injector.getInstance(GuiceFilter.class).doFilter(request, null, filterChain);
    437       fail();
    438     } catch (RuntimeException e) {
    439       assertSame(servletException, e);
    440     }
    441     try {
    442       injector.getInstance(Key.get(GuiceFilter.class, ScopingOnly.class))
    443           .doFilter(request, null, filterChain);
    444       fail();
    445     } catch (RuntimeException e) {
    446       assertSame(chainException, e);
    447     }
    448   }
    449 
    450   private Injector createInjector(Module... modules) throws CreationException {
    451     return Guice.createInjector(Lists.<Module>asList(new AbstractModule() {
    452       @Override
    453       protected void configure() {
    454         install(new ServletModule());
    455         bind(InSession.class);
    456         bind(IN_SESSION_NULL_KEY).toProvider(Providers.<InSession>of(null)).in(SessionScoped.class);
    457         bind(InRequest.class);
    458         bind(IN_REQUEST_NULL_KEY).toProvider(Providers.<InRequest>of(null)).in(RequestScoped.class);
    459       }
    460     }, modules));
    461   }
    462 
    463   @SessionScoped
    464   static class InSession implements Serializable {}
    465 
    466   @RequestScoped
    467   static class InRequest {}
    468 
    469   @BindingAnnotation @Retention(RUNTIME) @Target({PARAMETER, METHOD, FIELD})
    470   @interface Null {}
    471 }
    472