Resolving multipart/form-data request in spring filter

15,498

Solution 1

You can not cast HttpServletRequest to MultipartHttpServletRequest, because you first have to resolve your request.

I used CommonsMultipartResolver Class and got MultipartHttpServletRequest using commonsMultipartResolver.resolveMultipart(request) method (where request is type of HttpServletRequest)

So, here is my CSRF class, checkPostedCsrfToken() method:

private boolean checkPostedCsrfToken() {
        if (request.getParameterMap().containsKey("csrf")) {
            String csrf = request.getParameter("csrf");
            if (csrf.equals(request.getSession().getAttribute("csrf"))) {
                return true;
            }
        } else if (request.getContentType() != null && request.getContentType().toLowerCase().contains("multipart/form-data")) {
            CommonsMultipartResolver commonsMultipartResolver = new CommonsMultipartResolver();
            MultipartHttpServletRequest multipartRequest = commonsMultipartResolver.resolveMultipart(request);
            if (multipartRequest.getParameterMap().containsKey("csrf")) {
                String csrf = multipartRequest.getParameter("csrf");
                if (csrf.equals(request.getSession().getAttribute("csrf"))) {
                    return true;
                }
            }
        }

        log();
        return false;
    }

But, Note that you will end up loosing all request parameters and data with this approach. So you have to extend HttpServletRequestWrapper class to read request bytes and use them to get parameters if it matters to you that parameters don't get lost throw filter chain. In other words, you need a clone of your request.

Here is a good helper class I found in StackOverflow, (I cant find the question again, I will edit this if I find it).

MultiReadHttpServletRequest

public class MultiReadHttpServletRequest extends HttpServletRequestWrapper {
    private ByteArrayOutputStream cachedBytes;

    public MultiReadHttpServletRequest(HttpServletRequest request) {
        super(request);
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {
        if (cachedBytes == null)
            cacheInputStream();

        return new CachedServletInputStream();
    }

    @Override
    public BufferedReader getReader() throws IOException{
        return new BufferedReader(new InputStreamReader(getInputStream()));
    }

    private void cacheInputStream() throws IOException {
    /* Cache the inputstream in order to read it multiple times. For
     * convenience, I use apache.commons IOUtils
     */
        cachedBytes = new ByteArrayOutputStream();
        IOUtils.copy(super.getInputStream(), cachedBytes);
    }

    /* An inputstream which reads the cached request body */
    public class CachedServletInputStream extends ServletInputStream {
        private ByteArrayInputStream input;

        public CachedServletInputStream() {
      /* create a new input stream from the cached request body */
            input = new ByteArrayInputStream(cachedBytes.toByteArray());
        }

        @Override
        public int read() throws IOException {
            return input.read();
        }
    }
}

now all you need to do is to use MultiReadHttpServletRequest instead of normal HttpServletRequest in filter :

public class CSRFilter extends GenericFilterBean {
    @Override
    public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) throws IOException, ServletException {
        HttpServletRequest request = (HttpServletRequest) req;
        HttpServletResponse response = (HttpServletResponse) res;
        // The important part!! wrap the request:
        MultiReadHttpServletRequest multiReadHttpServletRequest = new MultiReadHttpServletRequest(request);
        CSRF csrf = new CSRF(multiReadHttpServletRequest);
        if(csrf.isOk()){
            chain.doFilter(multiReadHttpServletRequest, res);
        }else {
            //todo : Show Error Page
            String redirect = request.getScheme() + "://" + request.getServerName() + ":" + request.getServerPort() + request.getContextPath() + "/access-forbidden";
            response.sendRedirect(redirect);
        }
    }
}

I wish this helps someone :)

Solution 2

I needed to be able to inspect the Request's body without damaging it for the Servlet or subsequent Filters, so I created a mini-project that does just that.

The jar is < 10kb, and if you're using Tomcat then you don't need anything beyond that. Also, it's MIT licensed so you can use it in whatever project you may need.

You can find the project at https://github.com/isapir/servlet-filter-utils

All you have to do is wrap the incoming Request with RereadableServletRequest, e.g.

HttpServletRequest requestWrapper = new RereadableServletRequest(servletRequest);
Share:
15,498
Sep GH
Author by

Sep GH

Updated on June 28, 2022

Comments

  • Sep GH
    Sep GH almost 2 years

    I'm trying to develop my own CSRF filter in Spring MVC 3 (There are some extra trainings that made me do that, thats why Im not considering spring security.)

    My filter works fine with all forms except those that have enctype="multipart/form-data". So I can not get request parameters from normal HttpServletRequest.

    I've tried casting HttpServletRequest to MultipartHttpServletRequest but I found out I can not do that either.

    My objective is not getting files from the request, but to only get simple form input named csrf. (Ive already uploaded files with my forms)

    Here is my code till now:

    CSRFilter

    public class CSRFilter extends GenericFilterBean {
        @Override
        public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) throws IOException, ServletException {
            HttpServletRequest request = (HttpServletRequest) req;
            HttpServletResponse response = (HttpServletResponse) res;
    
            CSRF csrf = new CSRF(req);
            if(csrf.isOk()){
                chain.doFilter(req, res);
            }else {
                //todo : Show Error Page
                String redirect = request.getScheme() + "://" + request.getServerName() + ":" + request.getServerPort() + request.getContextPath() + "/access-forbidden";
                response.sendRedirect(redirect);
            }
    
        }
    }
    

    CSRF

    public class CSRF {
        HttpServletRequest request;
        ServletRequest req;
        String token;
        boolean ok;
        private static final Logger logger = Logger.getLogger(CSRF.class);
    
    
        public CSRF(ServletRequest request) {
            this.request = (HttpServletRequest) request;
            this.req = request;
            init();
        }
    
        public CSRF() {
        }
    
    
        public void setRequest(HttpServletRequest request) {
            this.request = (HttpServletRequest) request;
            this.req = request;
            init();
        }
    
        private void init() {
            if (request.getMethod().equals("GET")) {
                generateToken();
                addCSRFTokenToSession();
                addCSRFTokenToModelAttribute();
                ok = true;
            } else if (request.getMethod().equals("POST")) {
                if (checkPostedCsrfToken()) {
                    ok = true;
                }
            }
        }
    
        private void generateToken() {
            String token;
            java.util.Date date = new java.util.Date();
            UUID uuid = UUID.randomUUID();
            token = uuid.toString() + String.valueOf(new Timestamp(date.getTime()));
            try {
                this.token = sha1(token);
            } catch (NoSuchAlgorithmException e) {
                e.printStackTrace();
                this.token = token;
            }
        }
    
        private void addCSRFTokenToSession() {
            request.getSession().setAttribute("csrf", token);
        }
    
        private void addCSRFTokenToModelAttribute() {
            request.setAttribute("csrf", token);
        }
    
        private boolean checkPostedCsrfToken() {
            System.out.println("____ CSRF CHECK POST _____");
            if (request.getParameterMap().containsKey("csrf")) {
                String csrf = request.getParameter("csrf");
                if (csrf.equals(request.getSession().getAttribute("csrf"))) {
                    return true;
                }
            }else {
                //Check for multipart requests
    
                MultipartHttpServletRequest multiPartRequest = new DefaultMultipartHttpServletRequest((HttpServletRequest) req);
                if (multiPartRequest.getParameterMap().containsKey("csrf")) {
                    String csrf = multiPartRequest.getParameter("csrf");
                    if (csrf.equals(request.getSession().getAttribute("csrf"))) {
                        return true;
                    }
                }
            }
    
            log();
            return false;
        }
    
        private void log() {
            HttpSession session = request.getSession();
            String username = (String) session.getAttribute("username");
            if(username==null){
                username = "unknown (not logged in)";
            }
            String ipAddress = request.getHeader("X-FORWARDED-FOR");
            if (ipAddress == null) {
                ipAddress = request.getRemoteAddr();
            }
            String userAgent = request.getHeader("User-Agent");
            String address = request.getRequestURI();
            System.out.println("a CSRF attack detected from IP: " + ipAddress + " in address \"" + address + "\" - Client User Agent : " + userAgent + " Username: " + username);
    
            logger.error("a CSRF attack detected from IP: " + ipAddress + " in address \"" + address + "\" - Client User Agent : " + userAgent + " Username: " + username);
        }
    
        public boolean isOk() {
            return ok;
        }
    
        static String sha1(String input) throws NoSuchAlgorithmException {
            MessageDigest mDigest = MessageDigest.getInstance("SHA1");
            byte[] result = mDigest.digest(input.getBytes());
            StringBuffer sb = new StringBuffer();
            for (int i = 0; i < result.length; i++) {
                sb.append(Integer.toString((result[i] & 0xff) + 0x100, 16).substring(1));
            }
            return sb.toString();
        }
    }
    

    I have this line in my dispatcher too :

    <bean id="multipartResolver" class="org.springframework.web.multipart.commons.CommonsMultipartResolver">
            <!-- one of the properties available; the maximum file size in bytes -->
            <property name="maxUploadSize" value="40000000"/>
        </bean>
    

    and also I use springMultipartResolver filter ...

    <filter>
            <display-name>springMultipartFilter</display-name>
            <filter-name>springMultipartFilter</filter-name>
            <filter-class>org.springframework.web.multipart.support.MultipartFilter</filter-class>
        </filter>
        <filter-mapping>
            <filter-name>springMultipartFilter</filter-name>
            <url-pattern>/*</url-pattern>
        </filter-mapping>
    </filter>
    

    I get java.lang.IllegalStateException: Multipart request not initialized Exception when I try it on multipart/form-data forms.

    I looked at many Examples in internet. Most of them was for file uploading purpose and could not help me, I also tried different ways to cast HttpServletRequest to any other object that gives me resolved multipart request, But I could not succeed.

    How can I do it ?

    Thanks.