161 lines
5.9 KiB
Java
Raw Normal View History

package edu.whut.infrastructure.aop;
import edu.whut.types.annotations.DCCValue;
import edu.whut.types.annotations.RateLimiterAccessInterceptor;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.Signature;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.redisson.api.RAtomicLong;
import org.redisson.api.RRateLimiter;
import org.redisson.api.RateIntervalUnit;
import org.redisson.api.RateType;
import org.redisson.api.RedissonClient;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.commons.lang3.StringUtils;
import javax.annotation.Resource;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.concurrent.TimeUnit;
/**
* 分布式限流切面基于 Redisson RRateLimiter RAtomicLong 实现
*/
@Aspect
public class RateLimiterAOP {
private final Logger log = LoggerFactory.getLogger(RateLimiterAOP.class);
/**
* 全局开关open/close
*/
@DCCValue("rateLimiterSwitch:open")
private String rateLimiterSwitch;
/**
* Redisson 客户端注入使用
*/
@Resource
private RedissonClient redissonClient;
@Pointcut("@annotation(edu.whut.types.annotations.RateLimiterAccessInterceptor)")
public void aopPoint() {}
@Around("aopPoint() && @annotation(rateLimiterAccessInterceptor)")
public Object doRouter(ProceedingJoinPoint jp,
RateLimiterAccessInterceptor rateLimiterAccessInterceptor) throws Throwable {
// 0. 全局开关
if (StringUtils.isBlank(rateLimiterSwitch) || "close".equals(rateLimiterSwitch)) {
return jp.proceed();
}
// 1. 获取限流维度 key
String key = rateLimiterAccessInterceptor.key();
if (StringUtils.isBlank(key)) {
throw new RuntimeException("annotation RateLimiter key is null");
}
String keyAttr = getAttrValue(key, jp.getArgs());
log.info("[RateLimiter] attr={}, permits={}, blacklistCount={}",
keyAttr,
rateLimiterAccessInterceptor.permitsPerSecond(),
rateLimiterAccessInterceptor.blacklistCount());
// 2. 黑名单检查分布式24h rl:ratelimit bl:blacklist
// 存储的是 “用户在这一轮限流中被拒绝的次数”大于blacklistLimit则被视作进入黑名单等key释放解决黑名单
double blacklistLimit = rateLimiterAccessInterceptor.blacklistCount();
if (blacklistLimit > 0) {
RAtomicLong blCounter = redissonClient.getAtomicLong("rl:bl:" + keyAttr);
if (blCounter.isExists() && blCounter.get() > blacklistLimit) {
log.info("[RateLimiter] 黑名单拦截: {}", keyAttr);
return fallbackMethodResult(jp, rateLimiterAccessInterceptor.fallbackMethod());
}
}
// 3. 获取或创建分布式 RateLimiter
RRateLimiter limiter = redissonClient.getRateLimiter("rl:limiter:" + keyAttr);
// 尝试设置速率每秒放n个令牌 若已设置则返回 false
limiter.trySetRate(RateType.OVERALL,
(long) rateLimiterAccessInterceptor.permitsPerSecond(),
1, RateIntervalUnit.SECONDS);
// 4. 尝试获取令牌如果取不到则返回false
boolean allowed = limiter.tryAcquire();
if (!allowed) {
// 超限后计入黑名单
if (blacklistLimit > 0) {
RAtomicLong blCounter = redissonClient.getAtomicLong("rl:bl:" + keyAttr);
long count = blCounter.incrementAndGet();
if (count == 1) {
blCounter.expire(24, TimeUnit.HOURS);
}
}
log.info("[RateLimiter] 限流拦截: {}", keyAttr);
return fallbackMethodResult(jp, rateLimiterAccessInterceptor.fallbackMethod());
}
// 5. 正常执行
return jp.proceed();
}
/**
* 调用用户配置的降级方法
*/
private Object fallbackMethodResult(JoinPoint jp, String fallbackMethod)
throws NoSuchMethodException, InvocationTargetException, IllegalAccessException {
Signature sig = jp.getSignature();
MethodSignature ms = (MethodSignature) sig;
Method method = jp.getTarget().getClass()
.getMethod(fallbackMethod, ms.getParameterTypes());
return method.invoke(jp.getTarget(), jp.getArgs());
}
/**
* 从方法参数中获取 attr 字段值
*/
private String getAttrValue(String attr, Object[] args) {
if (args == null || args.length == 0) return null;
if (args[0] instanceof String) {
return args[0].toString();
}
for (Object arg : args) {
String val = extractField(arg, attr);
if (StringUtils.isNotBlank(val)) {
return val;
}
}
return null;
}
private String extractField(Object obj, String name) {
try {
Field field = getFieldByName(obj, name);
if (field == null) return null;
field.setAccessible(true);
Object v = field.get(obj);
field.setAccessible(false);
return v != null ? v.toString() : null;
} catch (Exception e) {
log.warn("[RateLimiter] 提取字段失败 {}", name, e);
return null;
}
}
private Field getFieldByName(Object obj, String name) {
Class<?> cls = obj.getClass();
while (cls != null) {
try {
return cls.getDeclaredField(name);
} catch (NoSuchFieldException e) {
cls = cls.getSuperclass();
}
}
return null;
}
}