1、lua脚本如下:
local ratelimit_info = redis.pcall('HMGET',KEYS[1],'last_time','current_token')
local last_time = ratelimit_info[1]
local current_token = tonumber(ratelimit_info[2])
local max_token = tonumber(ARGV[1])
local token_rate = tonumber(ARGV[2])
local current_time = tonumber(ARGV[3])
if current_token == nil then
current_token = max_token
last_time = current_time
else
local past_time = current_time-last_time
if past_time>1000 then
current_token = current_token+token_rate
last_time = current_time
end
## 防止溢出
if current_token>max_token then
current_token = max_token
last_time = current_time
end
end
local result = 0
if(current_token>0) then
result = 1
current_token = current_token-1
last_time = current_time
end
redis.call('HMSET',KEYS[1],'last_time',last_time,'current_token',current_token)
return result
2、 SpringBoot代码实现
/**
* 重新注入模板
*/
@Bean(value = "redisTemplate")
@Primary
public RedisTemplate redisTemplate(RedisConnectionFactory redisConnectionFactory){
RedisTemplate<String, Object> template = new RedisTemplate<>();
template.setConnectionFactory(redisConnectionFactory);
ObjectMApper objectMapper = new ObjectMapper();
objectMapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
objectMapper.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL);
//设置序列化方式,key设置string 方式,value设置成json
StringRedisSerializer stringRedisSerializer = new StringRedisSerializer();
Jackson2JsonRedisSerializer jsonRedisSerializer = new Jackson2JsonRedisSerializer(Object.class);
jsonRedisSerializer.setObjectMapper(objectMapper);
template.setEnableDefaultSerializer(false);
template.setKeySerializer(stringRedisSerializer);
template.setHashKeySerializer(stringRedisSerializer);
template.setValueSerializer(jsonRedisSerializer);
template.setHashValueSerializer(jsonRedisSerializer);
return template;
}
/**
* @Description 限流工具类
* @Author CJB
* @Date 2020/3/19 17:21
*/
public class RedisLimiterUtils {
private static StringRedisTemplate stringRedisTemplate=ApplicationContextUtils.applicationContext.getBean(StringRedisTemplate.class);
/**
* lua脚本,限流
*/
private final static String TEXT="local ratelimit_info = redis.pcall('HMGET',KEYS[1],'last_time','current_token')n" +
"local last_time = ratelimit_info[1]n" +
"local current_token = tonumber(ratelimit_info[2])n" +
"local max_token = tonumber(ARGV[1])n" +
"local token_rate = tonumber(ARGV[2])n" +
"local current_time = tonumber(ARGV[3])n" +
"if current_token == nil thenn" +
" current_token = max_tokenn" +
" last_time = current_timen" +
"elsen" +
" local past_time = current_time-last_timen" +
" n" +
" if past_time>1000 thenn" +
"t current_token = current_token+token_raten" +
"t last_time = current_timen" +
" endn" +
"n" +
" if current_token>max_token thenn" +
" current_token = max_tokenn" +
"tlast_time = current_timen" +
" endn" +
"endn" +
"n" +
"local result = 0n" +
"if(current_token>0) thenn" +
" result = 1n" +
" current_token = current_token-1n" +
" last_time = current_timen" +
"endn" +
"redis.call('HMSET',KEYS[1],'last_time',last_time,'current_token',current_token)n" +
"return result";
/**
* 获取令牌
* @param key 请求id
* @param max 最大能同时承受多少的并发(桶容量)
* @param rate 每秒生成多少的令牌
* @return 获取令牌返回true,没有获取返回false
*/
public static boolean tryAcquire(String key, int max,int rate) {
List<String> keyList = new ArrayList<>(1);
keyList.add(key);
DefaultRedisScript<Long> script = new DefaultRedisScript<>();
script.setResultType(Long.class);
script.setScriptText(TEXT);
return Long.valueOf(1).equals(stringRedisTemplate.execute(script,keyList,Integer.toString(max), Integer.toString(rate),
Long.toString(System.currentTimeMillis())));
}
}
/**
* @Description 限流的注解,标注在类上或者方法上。在方法上的注解会覆盖类上的注解,同@Transactional
* @Author CJB
* @Date 2020/3/20 13:36
*/
@Inherited
@Target({ElementType.TYPE, ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
public @interface RateLimit {
/**
* 令牌桶的容量,默认100
* @return
*/
int capacity() default 100;
/**
* 每秒钟默认产生令牌数量,默认10个
* @return
*/
int rate() default 10;
}
/**
* @Description 限流的拦器
* @Author CJB
* @Date 2020/3/19 14:34
*/
@Component
public class RateLimiterIntercept implements HandlerInterceptor {
@Override
public boolean preHandle(HttpServletRequest request, HttpServletResponse response, Object handler) throws Exception {
if (handler instanceof HandlerMethod){
HandlerMethod handlerMethod=(HandlerMethod)handler;
Method method = handlerMethod.getMethod();
/**
* 首先获取方法上的注解
*/
RateLimit rateLimit = AnnotationUtils.findAnnotation(method, RateLimit.class);
//方法上没有标注该注解,尝试获取类上的注解
if (Objects.isNull(rateLimit)){
//获取类上的注解
rateLimit = AnnotationUtils.findAnnotation(handlerMethod.getBean().getClass(), RateLimit.class);
}
//没有标注注解,放行
if (Objects.isNull(rateLimit))
return true;
//尝试获取令牌,如果没有令牌了
if (!RedisLimiterUtils.tryAcquire(request.getRequestURI(),rateLimit.capacity(),rateLimit.rate())){
//抛出请求超时的异常
throw new TimeOutException();
}
}
return true;
}
}