当前位置:网站首页>使用注解实现限流
使用注解实现限流
2022-08-10 02:40:00 【mobº】
使用注解实现限流
一、需求概述
- 使用注解添加到类上,则此类的所有方法都实现限流。
- 需要能精细化到一定时间窗口内,限流多少次,超过了次数则给予提示。
- 此限流针对相同用户且相同ip地址情况下,访问某个接口是达到限流的效果,因为不同用户,不在一个限流维度里。
- 相同用户,不同ip地址也不再一个相同限流维度里,因为可能出现一号多地登录,那么这个俩个人都对各自的请求限流,互不相干。
- 也可切换不针对用户,只针对某一个接口方法进行限流。
二、实现思路
利用Spring Aop 做前置增强,每次请求前,通过获取 Redis 中已经关于请求次数的数据进行对比,从而判断是否请求频繁。
三、实现代码
- 创建spring 项目,添加依赖:
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>2.6.7</version>
<relativePath/> <!-- lookup parent from repository -->
</parent>
<groupId>org.javaboy</groupId>
<artifactId>rate_limiter</artifactId>
<version>0.0.1-SNAPSHOT</version>
<name>rate_limiter</name>
<description>Demo project for Spring Boot</description>
<properties>
<java.version>1.8</java.version>
</properties>
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
</plugin>
</plugins>
</build>
</project>
- 添加 Redis 配置信息
spring.redis.host=localhost
spring.redis.port=6379
- 自定义限流注解
package org.javaboy.rate_limiter.annotation;
import org.javaboy.rate_limiter.enums.LimitType;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface RateLimiter {
/** * 限流的 key,主要是指前缀 * @return */
String key() default "rate_limit:";
/** * 限流时间窗 * @return */
int time() default 60;
/** * 在时间窗内的限流次数 * @return */
int count() default 100;
/** * 限流类型 * @return */
LimitType limitType() default LimitType.DEFAULT;
}
package org.javaboy.rate_limiter.enums;
/** * 限流的类型 */
public enum LimitType {
/** * 默认的限流策略,针对某一个接口进行限流 */
DEFAULT,
/** * 针对某一个 IP 进行限流 */
IP
}
- 对于redis 的相关操作,我们利用 lua脚本
local key = KEYS[1] --redis key
local time = tonumber(ARGV[1]) -- 时间窗
local count = tonumber(ARGV[2]) -- 时间窗内限定的次数
local current = redis.call('get', key)
if current and tonumber(current) > count then
return tonumber(current)
end
current = redis.call('incr', key)
if tonumber(current) == 1 then
redis.call('expire', key, time)
end
return tonumber(current)
这里为什么要用 lua 脚本进行redis 操作呢?
虽然redis服务是单线程的服务,单步的redis操作是线程安全的,但是当我们在高并发的情况下,需要一系列的redis逻辑操作,而这些操作需要保证线程安全和原子性。这时候就需要Lua登场。
Lua 为静态语言提供更多的灵活性,Lua体积小、启动速度快。 Redis Lua 脚本出现之前 Redis 是没有服务器端运算能力的,主要是用来存储,用做缓存,运算是在客户端进行。有了 Lua 的支持,客户端可以定义对键值的运算,减少编译的次数,总之。可以让 Redis 更为灵活。redis 甚至在源代码中加入了Lua脚本的解释器,eval。
该脚本里面的遇到的 Lua 语法介绍:
- call() 的参数就是发给Redis的命令:首先set key value ,然后 get key,这两个命令将依次执行,当这个脚本执行时,Redis服务不会做任何操作(单线程),它将非常快速运行。
- 我们将会访问两个Lua表:KEYS和ARGV。表单是关联性数组和结构化数据的Lua唯一机制。对于我们的意图,你可以把它们看做是一个你所熟悉的任意语言对等的数组,但是提醒两个很容易困扰到新手的两个Lua定则:
表是基于1的,也就是说索引以数值1开始。所以在表中的第一个元素就是KEYS[1],第二个就是KEY[2]等等。
表中不能有nil值。如果一个操作表中有[1, nil, 3, 4],那么结果将会是[1]——表将会在第一个nil截断。
当调用这个脚本时,我们还需要传递KEYS和ARGV表的值,为Redis编写Lua脚本时,每个KEY都是通过KEYS表指定。ARGV表用来传递参数。
- 添加 Redis 序列化配置以及 redis 读取该lua脚本配置
package org.javaboy.rate_limiter.config;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.io.ClassPathResource;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer;
import org.springframework.scripting.support.ResourceScriptSource;
@Configuration
public class RedisConfig {
@Bean
RedisTemplate<Object, Object> redisTemplate(RedisConnectionFactory redisConnectionFactory) {
RedisTemplate<Object, Object> template = new RedisTemplate<>();
template.setConnectionFactory(redisConnectionFactory);
Jackson2JsonRedisSerializer<Object> serializer = new Jackson2JsonRedisSerializer<>(Object.class);
template.setKeySerializer(serializer);
template.setHashKeySerializer(serializer);
template.setValueSerializer(serializer);
template.setHashValueSerializer(serializer);
return template;
}
@Bean
DefaultRedisScript<Long> limitScript() {
DefaultRedisScript<Long> script = new DefaultRedisScript<>();
script.setResultType(Long.class);
script.setScriptSource(new ResourceScriptSource(new ClassPathResource("lua/limit.lua")));
return script;
}
}
- 自定义限流异常
package org.javaboy.rate_limiter.exception;
public class RateLimitException extends Exception {
public RateLimitException(String message) {
super(message);
}
}
package org.javaboy.rate_limiter.exception;
import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.bind.annotation.RestControllerAdvice;
import java.util.HashMap;
import java.util.Map;
@RestControllerAdvice
public class GlobalException {
@ExceptionHandler(RateLimitException.class)
public Map<String, Object> rateLimitException(RateLimitException e) {
Map<String, Object> map = new HashMap<>();
map.put("status", 500);
map.put("message", e.getMessage());
return map;
}
}
7.添加 ip 获取工具
package org.javaboy.rate_limiter.utils;
import javax.servlet.http.HttpServletRequest;
import java.net.InetAddress;
import java.net.UnknownHostException;
/** * 获取IP方法 * * @author tienchin */
public class IpUtils {
/** * 获取客户端IP * * @param request 请求对象 * @return IP地址 */
public static String getIpAddr(HttpServletRequest request) {
if (request == null) {
return "unknown";
}
String ip = request.getHeader("x-forwarded-for");
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("X-Forwarded-For");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("WL-Proxy-Client-IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getHeader("X-Real-IP");
}
if (ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)) {
ip = request.getRemoteAddr();
}
return "0:0:0:0:0:0:0:1".equals(ip) ? "127.0.0.1" : getMultistageReverseProxyIp(ip);
}
/** * 检查是否为内部IP地址 * * @param ip IP地址 * @return 结果 */
public static boolean internalIp(String ip) {
byte[] addr = textToNumericFormatV4(ip);
return internalIp(addr) || "127.0.0.1".equals(ip);
}
/** * 检查是否为内部IP地址 * * @param addr byte地址 * @return 结果 */
private static boolean internalIp(byte[] addr) {
if (addr == null || addr.length < 2) {
return true;
}
final byte b0 = addr[0];
final byte b1 = addr[1];
// 10.x.x.x/8
final byte SECTION_1 = 0x0A;
// 172.16.x.x/12
final byte SECTION_2 = (byte) 0xAC;
final byte SECTION_3 = (byte) 0x10;
final byte SECTION_4 = (byte) 0x1F;
// 192.168.x.x/16
final byte SECTION_5 = (byte) 0xC0;
final byte SECTION_6 = (byte) 0xA8;
switch (b0) {
case SECTION_1:
return true;
case SECTION_2:
if (b1 >= SECTION_3 && b1 <= SECTION_4) {
return true;
}
case SECTION_5:
switch (b1) {
case SECTION_6:
return true;
}
default:
return false;
}
}
/** * 将IPv4地址转换成字节 * * @param text IPv4地址 * @return byte 字节 */
public static byte[] textToNumericFormatV4(String text) {
if (text.length() == 0) {
return null;
}
byte[] bytes = new byte[4];
String[] elements = text.split("\\.", -1);
try {
long l;
int i;
switch (elements.length) {
case 1:
l = Long.parseLong(elements[0]);
if ((l < 0L) || (l > 4294967295L)) {
return null;
}
bytes[0] = (byte) (int) (l >> 24 & 0xFF);
bytes[1] = (byte) (int) ((l & 0xFFFFFF) >> 16 & 0xFF);
bytes[2] = (byte) (int) ((l & 0xFFFF) >> 8 & 0xFF);
bytes[3] = (byte) (int) (l & 0xFF);
break;
case 2:
l = Integer.parseInt(elements[0]);
if ((l < 0L) || (l > 255L)) {
return null;
}
bytes[0] = (byte) (int) (l & 0xFF);
l = Integer.parseInt(elements[1]);
if ((l < 0L) || (l > 16777215L)) {
return null;
}
bytes[1] = (byte) (int) (l >> 16 & 0xFF);
bytes[2] = (byte) (int) ((l & 0xFFFF) >> 8 & 0xFF);
bytes[3] = (byte) (int) (l & 0xFF);
break;
case 3:
for (i = 0; i < 2; ++i) {
l = Integer.parseInt(elements[i]);
if ((l < 0L) || (l > 255L)) {
return null;
}
bytes[i] = (byte) (int) (l & 0xFF);
}
l = Integer.parseInt(elements[2]);
if ((l < 0L) || (l > 65535L)) {
return null;
}
bytes[2] = (byte) (int) (l >> 8 & 0xFF);
bytes[3] = (byte) (int) (l & 0xFF);
break;
case 4:
for (i = 0; i < 4; ++i) {
l = Integer.parseInt(elements[i]);
if ((l < 0L) || (l > 255L)) {
return null;
}
bytes[i] = (byte) (int) (l & 0xFF);
}
break;
default:
return null;
}
} catch (NumberFormatException e) {
return null;
}
return bytes;
}
/** * 获取IP地址 * * @return 本地IP地址 */
public static String getHostIp() {
try {
return InetAddress.getLocalHost().getHostAddress();
} catch (UnknownHostException e) {
}
return "127.0.0.1";
}
/** * 获取主机名 * * @return 本地主机名 */
public static String getHostName() {
try {
return InetAddress.getLocalHost().getHostName();
} catch (UnknownHostException e) {
}
return "未知";
}
/** * 从多级反向代理中获得第一个非unknown IP地址 * * @param ip 获得的IP地址 * @return 第一个非unknown IP地址 */
public static String getMultistageReverseProxyIp(String ip) {
// 多级反向代理检测
if (ip != null && ip.indexOf(",") > 0) {
final String[] ips = ip.trim().split(",");
for (String subIp : ips) {
if (false == isUnknown(subIp)) {
ip = subIp;
break;
}
}
}
return ip;
}
/** * 检测给定字符串是否为未知,多用于检测HTTP请求相关 * * @param checkString 被检测的字符串 * @return 是否未知 */
public static boolean isUnknown(String checkString) {
return checkString == null || checkString.length() == 0 || "unknown".equalsIgnoreCase(checkString);
}
}
8.编写限流切面
package org.javaboy.rate_limiter.aspectj;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.reflect.MethodSignature;
import org.javaboy.rate_limiter.annotation.RateLimiter;
import org.javaboy.rate_limiter.enums.LimitType;
import org.javaboy.rate_limiter.exception.RateLimitException;
import org.javaboy.rate_limiter.utils.IpUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.stereotype.Component;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import java.lang.reflect.Method;
import java.util.Collections;
@Aspect
@Component
public class RateLimiterAspect {
private static final Logger logger = LoggerFactory.getLogger(RateLimiterAspect.class);
@Autowired
RedisTemplate<Object, Object> redisTemplate;
@Autowired
RedisScript<Long> redisScript;
@Before("@annotation(rateLimiter)")
public void before(JoinPoint jp, RateLimiter rateLimiter) throws RateLimitException {
int time = rateLimiter.time();
int count = rateLimiter.count();
String combineKey = getCombineKey(rateLimiter, jp);
try {
Long number = redisTemplate.execute(redisScript, Collections.singletonList(combineKey), time, count);
if (number == null || number.intValue() > count) {
//超过限流阈值
logger.info("当前接口以达到最大限流次数");
throw new RateLimitException("访问过于频繁,请稍后访问");
}
logger.info("一个时间窗内请求次数:{},当前请求次数:{},缓存的 key 为 {}", count, number, combineKey);
} catch (Exception e) {
throw e;
}
}
/** * 这个 key 其实就是接口调用次数缓存在 redis 的 key * rate_limit:11.11.11.11-org.javaboy.ratelimit.controller.HelloController-hello * rate_limit:org.javaboy.ratelimit.controller.HelloController-hello * @param rateLimiter * @param jp * @return */
private String getCombineKey(RateLimiter rateLimiter, JoinPoint jp) {
StringBuffer key = new StringBuffer(rateLimiter.key());
if (rateLimiter.limitType() == LimitType.IP) {
key.append(IpUtils.getIpAddr(((ServletRequestAttributes) RequestContextHolder.getRequestAttributes()).getRequest()))
.append("-");
}
MethodSignature signature = (MethodSignature) jp.getSignature();
Method method = signature.getMethod();
key.append(method.getDeclaringClass().getName())
.append("-")
.append(method.getName());
return key.toString();
}
}
- 测试
package org.javaboy.rate_limiter.controller;
import org.javaboy.rate_limiter.annotation.RateLimiter;
import org.javaboy.rate_limiter.enums.LimitType;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;
@RestController
public class HelloController {
@GetMapping("/hello")
/** * 限流,10 秒之内,这个接口可以访问 3 次 */
@RateLimiter(time = 10, count = 3,limitType = LimitType.IP)
public String hello() {
return "hello";
}
}
当 10 秒内访问超过 3 次之后,抛出异常:
边栏推荐
猜你喜欢
随机推荐
Error state based Kalman filter ESKF
Research on IC enterprises
Web mining traceability?Browser browsing history viewing tool Browsinghistoryview
2022.8.8考试摄像师老马(photographer)题解
IDEA自动生成serialVersionUID
flink 12 源码编译及使用idea运行、debug
2022.8.8考试清洁工老马(sweeper)题解
PostgreSQL相关语法及指令示例
liunx PS1 settings
网页挖矿溯源?浏览器浏览历史查看工具Browsinghistoryview
单体架构应用和分布式架构应用的区别
剑指offer专项突击版第25天
MySQL: What MySQL optimizations have you done?
MySQL:你做过哪些MySQL的优化?
Example 044: Matrix Addition
What is a Cross-Site Request Forgery (CSRF) attack?How to defend?
Chapter 21 Source Code File REST API Reference (3)
Day16 charles的基本使用
The Evolutionary History of the "Double Gun" Trojan Horse Virus
湖仓一体电商项目(四):项目数据种类与采集