161 lines
5.9 KiB
Java
161 lines
5.9 KiB
Java
package edu.whut.infrastructure.aop;
|
||
|
||
import edu.whut.types.annotations.DCCValue;
|
||
import edu.whut.types.annotations.RateLimiterAccessInterceptor;
|
||
import org.aspectj.lang.JoinPoint;
|
||
import org.aspectj.lang.ProceedingJoinPoint;
|
||
import org.aspectj.lang.Signature;
|
||
import org.aspectj.lang.annotation.Around;
|
||
import org.aspectj.lang.annotation.Aspect;
|
||
import org.aspectj.lang.annotation.Pointcut;
|
||
import org.aspectj.lang.reflect.MethodSignature;
|
||
import org.redisson.api.RAtomicLong;
|
||
import org.redisson.api.RRateLimiter;
|
||
import org.redisson.api.RateIntervalUnit;
|
||
import org.redisson.api.RateType;
|
||
import org.redisson.api.RedissonClient;
|
||
import org.slf4j.Logger;
|
||
import org.slf4j.LoggerFactory;
|
||
import org.apache.commons.lang3.StringUtils;
|
||
|
||
import javax.annotation.Resource;
|
||
import java.lang.reflect.Field;
|
||
import java.lang.reflect.InvocationTargetException;
|
||
import java.lang.reflect.Method;
|
||
import java.util.concurrent.TimeUnit;
|
||
|
||
/**
|
||
* 分布式限流切面,基于 Redisson 的 RRateLimiter 和 RAtomicLong 实现
|
||
*/
|
||
@Aspect
|
||
public class RateLimiterAOP {
|
||
|
||
private final Logger log = LoggerFactory.getLogger(RateLimiterAOP.class);
|
||
|
||
/**
|
||
* 全局开关:open/close
|
||
*/
|
||
@DCCValue("rateLimiterSwitch:open")
|
||
private String rateLimiterSwitch;
|
||
|
||
/**
|
||
* Redisson 客户端,注入使用
|
||
*/
|
||
@Resource
|
||
private RedissonClient redissonClient;
|
||
|
||
@Pointcut("@annotation(edu.whut.types.annotations.RateLimiterAccessInterceptor)")
|
||
public void aopPoint() {}
|
||
|
||
@Around("aopPoint() && @annotation(rateLimiterAccessInterceptor)")
|
||
public Object doRouter(ProceedingJoinPoint jp,
|
||
RateLimiterAccessInterceptor rateLimiterAccessInterceptor) throws Throwable {
|
||
// 0. 全局开关
|
||
if (StringUtils.isBlank(rateLimiterSwitch) || "close".equals(rateLimiterSwitch)) {
|
||
return jp.proceed();
|
||
}
|
||
|
||
// 1. 获取限流维度 key
|
||
String key = rateLimiterAccessInterceptor.key();
|
||
if (StringUtils.isBlank(key)) {
|
||
throw new RuntimeException("annotation RateLimiter key is null!");
|
||
}
|
||
String keyAttr = getAttrValue(key, jp.getArgs());
|
||
log.info("[RateLimiter] attr={}, permits={}, blacklistCount={}",
|
||
keyAttr,
|
||
rateLimiterAccessInterceptor.permitsPerSecond(),
|
||
rateLimiterAccessInterceptor.blacklistCount());
|
||
|
||
// 2. 黑名单检查(分布式,24h) rl:ratelimit bl:blacklist
|
||
// 存储的是 “用户在这一轮限流中被拒绝的次数”,大于blacklistLimit则被视作进入黑名单,等key释放解决黑名单
|
||
double blacklistLimit = rateLimiterAccessInterceptor.blacklistCount();
|
||
if (blacklistLimit > 0) {
|
||
RAtomicLong blCounter = redissonClient.getAtomicLong("rl:bl:" + keyAttr);
|
||
if (blCounter.isExists() && blCounter.get() > blacklistLimit) {
|
||
log.info("[RateLimiter] 黑名单拦截: {}", keyAttr);
|
||
return fallbackMethodResult(jp, rateLimiterAccessInterceptor.fallbackMethod());
|
||
}
|
||
}
|
||
|
||
// 3. 获取或创建分布式 RateLimiter
|
||
RRateLimiter limiter = redissonClient.getRateLimiter("rl:limiter:" + keyAttr);
|
||
// 尝试设置速率,每秒放n个令牌 若已设置则返回 false
|
||
limiter.trySetRate(RateType.OVERALL,
|
||
(long) rateLimiterAccessInterceptor.permitsPerSecond(),
|
||
1, RateIntervalUnit.SECONDS);
|
||
|
||
// 4. 尝试获取令牌,如果取不到,则返回false
|
||
boolean allowed = limiter.tryAcquire();
|
||
if (!allowed) {
|
||
// 超限后计入黑名单
|
||
if (blacklistLimit > 0) {
|
||
RAtomicLong blCounter = redissonClient.getAtomicLong("rl:bl:" + keyAttr);
|
||
long count = blCounter.incrementAndGet();
|
||
if (count == 1) {
|
||
blCounter.expire(24, TimeUnit.HOURS);
|
||
}
|
||
}
|
||
log.info("[RateLimiter] 限流拦截: {}", keyAttr);
|
||
return fallbackMethodResult(jp, rateLimiterAccessInterceptor.fallbackMethod());
|
||
}
|
||
|
||
// 5. 正常执行
|
||
return jp.proceed();
|
||
}
|
||
|
||
/**
|
||
* 调用用户配置的降级方法
|
||
*/
|
||
private Object fallbackMethodResult(JoinPoint jp, String fallbackMethod)
|
||
throws NoSuchMethodException, InvocationTargetException, IllegalAccessException {
|
||
Signature sig = jp.getSignature();
|
||
MethodSignature ms = (MethodSignature) sig;
|
||
Method method = jp.getTarget().getClass()
|
||
.getMethod(fallbackMethod, ms.getParameterTypes());
|
||
return method.invoke(jp.getTarget(), jp.getArgs());
|
||
}
|
||
|
||
/**
|
||
* 从方法参数中获取 attr 字段值
|
||
*/
|
||
private String getAttrValue(String attr, Object[] args) {
|
||
if (args == null || args.length == 0) return null;
|
||
if (args[0] instanceof String) {
|
||
return args[0].toString();
|
||
}
|
||
for (Object arg : args) {
|
||
String val = extractField(arg, attr);
|
||
if (StringUtils.isNotBlank(val)) {
|
||
return val;
|
||
}
|
||
}
|
||
return null;
|
||
}
|
||
|
||
private String extractField(Object obj, String name) {
|
||
try {
|
||
Field field = getFieldByName(obj, name);
|
||
if (field == null) return null;
|
||
field.setAccessible(true);
|
||
Object v = field.get(obj);
|
||
field.setAccessible(false);
|
||
return v != null ? v.toString() : null;
|
||
} catch (Exception e) {
|
||
log.warn("[RateLimiter] 提取字段失败 {}", name, e);
|
||
return null;
|
||
}
|
||
}
|
||
|
||
private Field getFieldByName(Object obj, String name) {
|
||
Class<?> cls = obj.getClass();
|
||
while (cls != null) {
|
||
try {
|
||
return cls.getDeclaredField(name);
|
||
} catch (NoSuchFieldException e) {
|
||
cls = cls.getSuperclass();
|
||
}
|
||
}
|
||
return null;
|
||
}
|
||
}
|