161 lines
5.9 KiB
Java
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package edu.whut.infrastructure.aop;
import edu.whut.types.annotations.DCCValue;
import edu.whut.types.annotations.RateLimiterAccessInterceptor;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.Signature;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.redisson.api.RAtomicLong;
import org.redisson.api.RRateLimiter;
import org.redisson.api.RateIntervalUnit;
import org.redisson.api.RateType;
import org.redisson.api.RedissonClient;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.commons.lang3.StringUtils;
import javax.annotation.Resource;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.concurrent.TimeUnit;
/**
* 分布式限流切面,基于 Redisson 的 RRateLimiter 和 RAtomicLong 实现
*/
@Aspect
public class RateLimiterAOP {
private final Logger log = LoggerFactory.getLogger(RateLimiterAOP.class);
/**
* 全局开关open/close
*/
@DCCValue("rateLimiterSwitch:open")
private String rateLimiterSwitch;
/**
* Redisson 客户端,注入使用
*/
@Resource
private RedissonClient redissonClient;
@Pointcut("@annotation(edu.whut.types.annotations.RateLimiterAccessInterceptor)")
public void aopPoint() {}
@Around("aopPoint() && @annotation(rateLimiterAccessInterceptor)")
public Object doRouter(ProceedingJoinPoint jp,
RateLimiterAccessInterceptor rateLimiterAccessInterceptor) throws Throwable {
// 0. 全局开关
if (StringUtils.isBlank(rateLimiterSwitch) || "close".equals(rateLimiterSwitch)) {
return jp.proceed();
}
// 1. 获取限流维度 key
String key = rateLimiterAccessInterceptor.key();
if (StringUtils.isBlank(key)) {
throw new RuntimeException("annotation RateLimiter key is null");
}
String keyAttr = getAttrValue(key, jp.getArgs());
log.info("[RateLimiter] attr={}, permits={}, blacklistCount={}",
keyAttr,
rateLimiterAccessInterceptor.permitsPerSecond(),
rateLimiterAccessInterceptor.blacklistCount());
// 2. 黑名单检查分布式24h rl:ratelimit bl:blacklist
// 存储的是 “用户在这一轮限流中被拒绝的次数”大于blacklistLimit则被视作进入黑名单等key释放解决黑名单
double blacklistLimit = rateLimiterAccessInterceptor.blacklistCount();
if (blacklistLimit > 0) {
RAtomicLong blCounter = redissonClient.getAtomicLong("rl:bl:" + keyAttr);
if (blCounter.isExists() && blCounter.get() > blacklistLimit) {
log.info("[RateLimiter] 黑名单拦截: {}", keyAttr);
return fallbackMethodResult(jp, rateLimiterAccessInterceptor.fallbackMethod());
}
}
// 3. 获取或创建分布式 RateLimiter
RRateLimiter limiter = redissonClient.getRateLimiter("rl:limiter:" + keyAttr);
// 尝试设置速率每秒放n个令牌 若已设置则返回 false
limiter.trySetRate(RateType.OVERALL,
(long) rateLimiterAccessInterceptor.permitsPerSecond(),
1, RateIntervalUnit.SECONDS);
// 4. 尝试获取令牌如果取不到则返回false
boolean allowed = limiter.tryAcquire();
if (!allowed) {
// 超限后计入黑名单
if (blacklistLimit > 0) {
RAtomicLong blCounter = redissonClient.getAtomicLong("rl:bl:" + keyAttr);
long count = blCounter.incrementAndGet();
if (count == 1) {
blCounter.expire(24, TimeUnit.HOURS);
}
}
log.info("[RateLimiter] 限流拦截: {}", keyAttr);
return fallbackMethodResult(jp, rateLimiterAccessInterceptor.fallbackMethod());
}
// 5. 正常执行
return jp.proceed();
}
/**
* 调用用户配置的降级方法
*/
private Object fallbackMethodResult(JoinPoint jp, String fallbackMethod)
throws NoSuchMethodException, InvocationTargetException, IllegalAccessException {
Signature sig = jp.getSignature();
MethodSignature ms = (MethodSignature) sig;
Method method = jp.getTarget().getClass()
.getMethod(fallbackMethod, ms.getParameterTypes());
return method.invoke(jp.getTarget(), jp.getArgs());
}
/**
* 从方法参数中获取 attr 字段值
*/
private String getAttrValue(String attr, Object[] args) {
if (args == null || args.length == 0) return null;
if (args[0] instanceof String) {
return args[0].toString();
}
for (Object arg : args) {
String val = extractField(arg, attr);
if (StringUtils.isNotBlank(val)) {
return val;
}
}
return null;
}
private String extractField(Object obj, String name) {
try {
Field field = getFieldByName(obj, name);
if (field == null) return null;
field.setAccessible(true);
Object v = field.get(obj);
field.setAccessible(false);
return v != null ? v.toString() : null;
} catch (Exception e) {
log.warn("[RateLimiter] 提取字段失败 {}", name, e);
return null;
}
}
private Field getFieldByName(Object obj, String name) {
Class<?> cls = obj.getClass();
while (cls != null) {
try {
return cls.getDeclaredField(name);
} catch (NoSuchFieldException e) {
cls = cls.getSuperclass();
}
}
return null;
}
}