SourceCode Java Spring Code

Java通过AOP实现接口限制

Posted on 2021-01-30,9 min read

待解决问题

  • 没有解决无效数据清理的问题,会出现OOM(有时间更新),推荐使用Redis版本
    • 解决方案:
      • 将缓存设置成LRU(可以直接把ConcurrentSkipListMap改为Spring提供的ConcurrentLruCache)。因为根据LRU的策略,被淘汰的是调用不频繁的接口记录。可能只调用了一次就被记录了,这些就可以直接删除,因为接口限制限制的是那些频繁次数调用的。
      • 设置一个后台任务来处理。
  • 只实现了根据IP来进行限制

测试

测试代码

package springboot;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import springboot.limit.annotation.Limit;
import springboot.limit.enums.LimitType;

@SpringBootApplication
public class SpringBootApplicationStart {
    public static void main(String[] args) {
        SpringApplication.run(SpringBootApplicationStart.class, args);
    }
    @RestController
    @RequestMapping("/test")
    public static class MyController{
        @GetMapping()
        @Limit(count = 10,timeout = 200,type = LimitType.IP)
        public String show(){
            String json = "json";
            return json;
        }
    }
}

测试结果

  • 接口限制为:时间间隔200秒,请求限制10个。即1秒50个请求
  • 压力测试为:200个请求循环20秒,一共4000个请求
  • 结果为:异常率82%,(50 * 20) / 4000 = 0.25,结果差不多
测试配置
结果

代码

注解

  • 定义一个注解,指示AOP
package springboot.limit.annotation;

import springboot.limit.enums.LimitType;

import java.lang.annotation.*;

/**
 * @author Gloduck
 */
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface Limit {
    /**
     * 限制数量
     * @return
     */
    int count();

    /**
     * 时间间隔
     * @return
     */
    long timeout();

    /**
     * 限制的类型
     * @return
     */
    LimitType type();
}

记录实体

  • 定义一个实体,用于记录访问数以及起始记录时间
package springboot.limit;

import java.util.concurrent.atomic.AtomicInteger;

/**
 * @author Gloduck
 */
public class LimitEntry {
    private  long startTime;
    private final AtomicInteger counter;



    public LimitEntry() {
        this.startTime = System.currentTimeMillis();
        this.counter = new AtomicInteger(0);
    }

    public int get(){
        return counter.get();
    }
    public boolean compareAndSet(int expectedValue, int newValue){
        return counter.compareAndSet(expectedValue, newValue);
    }
    public boolean isExpire(long expire){
        return this.startTime + expire < System.currentTimeMillis();
    }

    /**
     * 刷新
     */
    public void refresh(){
        this.startTime = System.currentTimeMillis();
        this.counter.set(0);
    }
}

切面

  • 通过切面来实现接口限制功能

Java原生实现

package springboot.limit.aspect;

import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import springboot.limit.LimitEntry;
import springboot.limit.annotation.Limit;
import springboot.limit.enums.LimitType;
import springboot.limit.exception.LimitAccessException;

import javax.servlet.http.HttpServletRequest;
import java.util.Objects;
import java.util.concurrent.ConcurrentSkipListMap;

@Aspect
@Component
public class LimitAspect {
    private Logger logger = LoggerFactory.getLogger(LimitAspect.class);
    private final ConcurrentSkipListMap<String, LimitEntry> map = new ConcurrentSkipListMap<>();

    @Pointcut("@annotation(limit)")
    public void pointCut(Limit limit) {
    }

    @Around(value = "pointCut(limit)", argNames = "joinPoint,limit")
    public Object interfaceLimit(ProceedingJoinPoint joinPoint, Limit limit) throws Throwable {
        LimitType type = limit.type();
        String key = type.generateKey(getRequest());
        // 获取limitEntry
        LimitEntry limitEntry = map.get(key);
        if(limitEntry == null){
            limitEntry = new LimitEntry();
            map.put(key, limitEntry);
        }
        int currentValue;
        boolean flag = false;
        do {
            if(limitEntry.isExpire(limit.timeout())){
                // 如果过期了,加锁刷新
                synchronized (limitEntry){
                    if(limitEntry.isExpire(limit.timeout())){
                        limitEntry.refresh();
                    }
                }
            }
            // 获取当前的值
            currentValue = limitEntry.get();
            if (currentValue >= limit.count()) {
                // 如果访问次数超限了就退出。
                flag = true;
                break;
            }
        } while (!limitEntry.compareAndSet(currentValue, currentValue + 1));

        logger.info("获取到的key为:{},当前的是否超出访问限制:{}", key, (flag ? "是" : "否"));
        checkIfNeedClear();
        if (flag) {
            throw new LimitAccessException("访问超出限制");
        } else {
            return joinPoint.proceed();
        }
    }

    public HttpServletRequest getRequest() {
        return ((ServletRequestAttributes) Objects.requireNonNull(RequestContextHolder.getRequestAttributes())).getRequest();
    }

    /**
     * 检查是否需要清理
     */
    public void checkIfNeedClear(){

    }
}

结合Redis版

package cn.gloduck.gmall.limit.aspect;

import cn.gloduck.gmall.limit.annotation.Limit;
import cn.gloduck.gmall.limit.enums.LimitType;
import cn.gloduck.gmall.limit.exception.LimitAccessException;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.http.HttpServletRequest;
import java.util.Collections;
import java.util.List;
import java.util.Objects;

/**
 * @author Gloduck
 */
@Component
@Aspect
public class RedisLimitAspect {
    private Logger logger = LoggerFactory.getLogger(RedisLimitAspect.class);
    @Autowired
    private RedisTemplate<String, Object> redisTemplate;
    private final static String LIMIT_SCRIPT = "local c" +
            "\nc = redis.call('get',KEYS[1])" +
            "\nif c and tonumber(c) > tonumber(ARGV[1]) then" +
            "\nreturn c;" +
            "\nend" +
            "\nc = redis.call('incr',KEYS[1])" +
            "\nif tonumber(c) == 1 then" +
            "\nredis.call('expire',KEYS[1],ARGV[2])" +
            "\nend" +
            "\nreturn c;";

    @Pointcut("@annotation(limit)")
    public void pointcut(Limit limit) {
    }

    @Around(value = "pointcut(limit)", argNames = "joinPoint,limit")
    public Object around(ProceedingJoinPoint joinPoint,Limit limit) throws Throwable {
        LimitType type = limit.type();
        RedisScript<Long> script =new DefaultRedisScript<>(LIMIT_SCRIPT, Long.class);
        List<String> key = Collections.singletonList(type.generateKey(getRequest()));
        Number number = redisTemplate.execute(script, key, limit.count(), limit.timeout());
        if(number != null && number.intValue() <= limit.count()){
            logger.info("第{}次访问接口,key为{}", number, key);
            return joinPoint.proceed();
        } else {
            throw new LimitAccessException("访问超出限制");
        }
    }

    public HttpServletRequest getRequest() {
        return ((ServletRequestAttributes) Objects.requireNonNull(RequestContextHolder.getRequestAttributes())).getRequest();
    }
}

枚举类

  • 定义一个枚举类,用于指定限制的规则,目前只实现了根据IP限制
package springboot.limit.enums;

import springboot.limit.utils.LimitUtils;

import javax.servlet.http.HttpServletRequest;

/**
 * @author Gloduck
 */
public enum LimitType {
    /**
     * 根据IP限制
     */
    IP {
        @Override
        public String generateKey(HttpServletRequest request) {
            String key = String.format("ip:%s", LimitUtils.getIpAddr(request));
            return key;
        }
    };

    public String generateKey(HttpServletRequest request) {
        throw new AbstractMethodError();
    }


}

工具类

  • 使用到的工具类。
package springboot.limit.utils;

import cn.hutool.core.net.NetUtil;

import javax.servlet.http.HttpServletRequest;

/**
 * @author Gloduck
 */
public class LimitUtils {
    private static final String UNKNOWN = "unknown";
    private static final String X_FORWARDED_FOR = "x-forwarded-for";
    private static final String Proxy_Client_IP = "Proxy-Client-IP";
    private static final String WL_Proxy_Client_IP = "WL-Proxy-Client-IP";
    private static final String LOCAL_IP = "127.0.0.1";

    private LimitUtils() {
    }

    /**
     * 获取IP地址
     *
     * @param request
     * @return
     */
    public static String getIpAddr(HttpServletRequest request) {
        // 获取客户端ip地址
        String clientIp = request.getHeader(X_FORWARDED_FOR);
        if (clientIp == null || clientIp.length() == 0 || UNKNOWN.equalsIgnoreCase(clientIp)) {
            clientIp = request.getHeader(Proxy_Client_IP);
        }
        if (clientIp == null || clientIp.length() == 0 || UNKNOWN.equalsIgnoreCase(clientIp)) {
            clientIp = request.getHeader(WL_Proxy_Client_IP);
        }
        if (clientIp == null || clientIp.length() == 0 || UNKNOWN.equalsIgnoreCase(clientIp)) {
            clientIp = request.getRemoteAddr();
        }
        /*
         * 对于获取到多ip的情况下,找到公网ip.
         */
        String sIP = null;
        if (clientIp != null && !clientIp.contains(UNKNOWN) && clientIp.indexOf(",") > 0) {
            String[] ipsz = clientIp.split(",");
            for (String anIpsz : ipsz) {
                if (!isInnerIp(anIpsz.trim())) {
                    sIP = anIpsz.trim();
                    break;
                }
            }
            /*
             * 如果多ip都是内网ip,则取第一个ip.
             */
            if (null == sIP) {
                sIP = ipsz[0].trim();
            }
            clientIp = sIP;
        }
        if (clientIp != null && clientIp.contains(UNKNOWN)) {
            clientIp = clientIp.replaceAll("unknown,", "");
            clientIp = clientIp.trim();
        }
        if ("".equals(clientIp) || null == clientIp) {
            clientIp = "127.0.0.1";
        }
        return clientIp;
    }

    /**
     * 判断是否为内网IP
     *
     * @param ipAddress
     * @return
     */
    public static boolean isInnerIp(String ipAddress) {
        boolean isInnerIp;
        long ipNum = NetUtil.ipv4ToLong(ipAddress);

        long aBegin = NetUtil.ipv4ToLong("10.0.0.0");
        long aEnd = NetUtil.ipv4ToLong("10.255.255.255");

        long bBegin = NetUtil.ipv4ToLong("172.16.0.0");
        long bEnd = NetUtil.ipv4ToLong("172.31.255.255");

        long cBegin = NetUtil.ipv4ToLong("192.168.0.0");
        long cEnd = NetUtil.ipv4ToLong("192.168.255.255");

        isInnerIp = isInner(ipNum, aBegin, aEnd) || isInner(ipNum, bBegin, bEnd) || isInner(ipNum, cBegin, cEnd) || ipAddress.equals(LOCAL_IP);
        return isInnerIp;
    }

    /**
     * 指定IP的long是否在指定范围内
     *
     * @param userIp 用户IP
     * @param begin  开始IP
     * @param end    结束IP
     * @return 是否在范围内
     */
    private static boolean isInner(long userIp, long begin, long end) {
        return (userIp >= begin) && (userIp <= end);
    }
}

异常

  • 接口访问超限出现的异常。
package springboot.limit.exception;

public class LimitAccessException extends RuntimeException {
    public LimitAccessException(String message) {
        super(message);
    }
}

下一篇: Java HashMap源码分析→

loading...