SpringBoot Java

SpringBoot Xss过滤

Posted on 2021-01-23,4 min read
  • 在互联网项目中常见的攻击方式有SQL注入,XSS攻击,以及CSRF。

  • XSS攻击通常指的是通过利用网页开发时留下的漏洞,通过巧妙的方法注入恶意指令代码到网页,使用户加载并执行攻击者恶意制造的网页程序。这些恶意网页程序通常是JavaScript,但实际上也可以包括Java、 VBScript、ActiveX、 Flash 或者甚至是普通的HTML。攻击成功后,攻击者可能得到包括但不限于更高的权限(如执行一些操作)、私密网页内容、会话和cookie等各种内容。

  • 防范XSS攻击的主要方式是对参数进行过滤。即:将提交的非法参数中的一些字符转义后再传送到后台。然后请求的时候再由前端转义。

  • 由于前端不一定可靠,存在破解的风险,所以后端也需要进行一次过滤。

  • 为了防止XSS攻击,主要利用两个类,一个是Servlet的Filter,用于过滤请求。还有一个是HttpServletRequestWrapper用于包装Request

代码

  • 注意:
    • 很多博文转义HTML使用的是common-langStringUtils。但是事实上StringUtils会将中文也随之转义了。其实使用SpringMvc自带的HtmlUtils就行了
    • 对于Post发送的Json数据是无法通过Request.getParameters获取的,只能通过读取InputStream来获取。所以需要单独过滤
    • 为了使Filter生效,SpringBoot启动类上添加一个ServletComponentScan的注解
    • Filter最好不要全匹配,因为不是所有的请求都需要XSS过滤

Filter

// filter
package cn.gloduck.onlinetest.security.xss;

import org.springframework.stereotype.Component;

import javax.servlet.*;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import java.io.IOException;



/**
 * @author Gloduck
 */
@WebFilter("/**")
@Component
public class XSSFilter implements Filter {
    @Override
    public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {
        HttpServletRequest httpServletRequest = (HttpServletRequest) request;
        chain.doFilter(new XssHttpServletRequestWrapper(httpServletRequest),response);
    }

}

HttpServletRequestWrapper

package cn.gloduck.onlinetest.security.xss;


import com.fasterxml.jackson.databind.ObjectMapper;

import org.springframework.web.util.HtmlUtils;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.*;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;

/**
 * @author Gloduck
 */
public class XssHttpServletRequestWrapper extends HttpServletRequestWrapper {
    private static ObjectMapper objectMapper = new ObjectMapper();

    public XssHttpServletRequestWrapper(HttpServletRequest request) {
        super(request);
    }

    @Override
    public String getHeader(String name) {
        String target = super.getHeader(name);
        return target == null ? null : HtmlUtils.htmlEscape(target);
    }

    @Override
    public String getQueryString() {
        String target = super.getQueryString();
        return target == null ? null : HtmlUtils.htmlEscape(target);
    }

    @Override
    public String getParameter(String name) {
        String target = super.getParameter(name);
        return target == null ? null : HtmlUtils.htmlEscape(target);
    }

    @Override
    public String[] getParameterValues(String name) {
        String[] values = super.getParameterValues(name);
        if(values != null) {
            int length = values.length;
            String[] escapseValues = new String[length];
            for(int i = 0; i < length; i++){
                escapseValues[i] = HtmlUtils.htmlEscape(values[i]);
            }
            return escapseValues;
        }
        return values;
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {
        String str=getRequestBody(super.getInputStream());
        Map<String,Object> map= objectMapper.readValue(str, Map.class);
        Map<String,Object> resultMap=new HashMap<>(map.size());
        for(String key:map.keySet()){
            Object val=map.get(key);
            if(map.get(key) instanceof String){
                resultMap.put(key,HtmlUtils.htmlEscape(val.toString()));
            }else{
                resultMap.put(key,val);
            }
        }

        str=objectMapper.writeValueAsString(resultMap);
        final ByteArrayInputStream bais = new ByteArrayInputStream(str.getBytes());
        return new ServletInputStream() {
            @Override
            public int read() throws IOException {
                return bais.read();
            }
            @Override
            public boolean isFinished() {
                return false;
            }
            @Override
            public boolean isReady() {
                return false;
            }
            @Override
            public void setReadListener(ReadListener listener) {
            }
        };
    }

    private String getRequestBody(InputStream stream) {
        String line = "";
        StringBuilder body = new StringBuilder();
        int counter = 0;

        // 读取POST提交的数据内容
        BufferedReader reader = new BufferedReader(new InputStreamReader(stream, StandardCharsets.UTF_8));
        try {
            while ((line = reader.readLine()) != null) {
                body.append(line);
                counter++;
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
        return body.toString();
    }
}

下一篇: JRebel配置→

loading...