MAVEN依赖
<!-- aspectj -->
<dependency>
<groupId>org.aspectj</groupId>
<artifactId>aspectjrt</artifactId>
<version>1.9.2</version>
</dependency>
<dependency>
<groupId>org.aspectj</groupId>
<artifactId>aspectjweaver</artifactId>
<version>1.9.2</version>
</dependency>
注解类
@Target({ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimiter {
int timeout() default -1;
int count() default -1;
}
AOP
处理类
@Component
@Aspect
public class RatelimiterAop {
private static ConcurrentHashMap<String, Semaphore> LIMITER = new ConcurrentHashMap<>();
@Pointcut("@annotation(RateLimiter)")
public void point() {
}
@Around("point()")
public Object limit(ProceedingJoinPoint proceedingJoinPoint) {
MethodSignature methodSignature = (MethodSignature) proceedingJoinPoint.getSignature();
RateLimiter limit = methodSignature.getMethod().getDeclaredAnnotation(RateLimiter.class);
if (limit.timeout() > 0) {
ExecutorService es = Executors.newFixedThreadPool(2);
Future future = es.submit(() -> {
try {
return proceedingJoinPoint.proceed();
} catch (Throwable throwable) {
return null;
}
});
final Object obj;
try {
obj = future.get(limit.timeout(), TimeUnit.MILLISECONDS);
return obj;
} catch (Exception e) {
future.cancel(true);
throw new RuntimeException("处理超时");
}
}else if (limit.count() > 0) {
// key unique.
String cacheKey = proceedingJoinPoint.getTarget().getClass().getName() + "::" + methodSignature.getName()
+ "::" + Arrays.toString(methodSignature.getParameterNames());
LIMITER.putIfAbsent(cacheKey, new Semaphore(limit.count()));
System.out.println(cacheKey);
Semaphore semaphore = LIMITER.get(cacheKey);
try {
semaphore.acquire();
proceedingJoinPoint.proceed();
} catch (Throwable throwable) {
throw new RuntimeException("请求异常");
} finally {
// 释放
if (null != semaphore) {
semaphore.release();
}
}
}
try {
return proceedingJoinPoint.proceed();
} catch (Throwable throwable) {
return null;
}
}
测试
@RestController
public class TestController {
@RateLimiter(timeout = 100)
@PostMapping(value = "/test")
public void test(){
try {
Random random = new Random();
int time = random.nextInt(200);
System.out.println(time + "ms");
Thread.sleep(time);
System.out.println("the end");
} catch (InterruptedException e) {
e.printStackTrace();
}
}
}