灵活网关SpringCloud Gateway

异常处理的一头一尾

​ 在使用网关进行微服务管理的工程中,我们编写了包含鉴权、限流的解决方案,加在访问到业务代码API之前,下面进入流程:

项目访问控制

网关控制流程图

SpringCloud 的异常处理不同于SpringMVC或者SpringBoot下的全局异常处理,因为底层的处理器不同,具体代码层表现为继承的类不同,SpringCloud需要继承DefaultErrorWebExceptionHandler进行异常处理,,SpringMVC和SpringBoot通过@ControllerAdvice@ExceptionHandler处理不同的自定义异常处理逻辑。

项目结构图

全局异常处理:

Order:@Order(Ordered.HIGHEST_PRECEDENCE)

进入全局异常处理类ErrorHandlerConfiguration

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
@Configuration
@EnableConfigurationProperties({ServerProperties.class, ResourceProperties.class})
public class ErrorHandlerConfiguration {

private final ServerProperties serverProperties;

private final ApplicationContext applicationContext;

private final ResourceProperties resourceProperties;

private final List<ViewResolver> viewResolvers;

private final ServerCodecConfigurer serverCodecConfigurer;

public ErrorHandlerConfiguration(ServerProperties serverProperties,
ResourceProperties resourceProperties,
ObjectProvider<List<ViewResolver>> viewResolversProvider,
ServerCodecConfigurer serverCodecConfigurer,
ApplicationContext applicationContext) {
this.serverProperties = serverProperties;
this.applicationContext = applicationContext;
this.resourceProperties = resourceProperties;
this.viewResolvers = viewResolversProvider.getIfAvailable(Collections::emptyList);
this.serverCodecConfigurer = serverCodecConfigurer;
}

@Bean
@Order(Ordered.HIGHEST_PRECEDENCE)
public ErrorWebExceptionHandler errorWebExceptionHandler(ErrorAttributes errorAttributes) {
ExceptionHandler exceptionHandler = new ExceptionHandler(
errorAttributes,
this.resourceProperties,
this.serverProperties.getError(),
this.applicationContext);
exceptionHandler.setViewResolvers(this.viewResolvers);
exceptionHandler.setMessageWriters(this.serverCodecConfigurer.getWriters());
exceptionHandler.setMessageReaders(this.serverCodecConfigurer.getReaders());
return exceptionHandler;
}

}

该类在容器初始化的时候首先加载,后面的加载顺序会在Order上有体现,在errorWebExceptionHandler里面实现自定义的ExceptionHandler,在经过了所有的过滤器之后,如果不在白名单(表示此类访问直接通过,不需要进行鉴权,会在下面讲到用户鉴权),异常返回自定义如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
public class ExceptionHandler extends DefaultErrorWebExceptionHandler {

@Autowired
private ResultFilter resultFilter;

ExceptionHandler(ErrorAttributes errorAttributes, ResourceProperties resourceProperties, ErrorProperties errorProperties, ApplicationContext applicationContext) {
super(errorAttributes, resourceProperties, errorProperties, applicationContext);
}

/**
* 获取异常请求以及属性,构造相应的返回结果
*/
@Override
protected Map<String, Object> getErrorAttributes(ServerRequest request, boolean includeStackTrace) {
String notFound = "404 NOT_FOUND";

Result result;
Throwable error = super.getError(request);
if (error instanceof org.springframework.cloud.gateway.support.NotFoundException) {
//请求经过了所有filter,如果不再在白名单需要减去用户调用次数
result = new Result(FailureResult.valueOf("NULL_API"));
returnCount(request);

} else if (notFound.equals(error.getMessage())) {
result = new Result(FailureResult.valueOf("NULL_API"));

} else if (error instanceof java.net.UnknownHostException){
//请求经过了所有filter,如果不在白名单需要减去用户调用次数
result = new Result(FailureResult.valueOf("UNKNOW_HOST"));
returnCount(request);
} else {
result = new Result(FailureResult.valueOf("INTERNAL_ERROR"));
result.setData(result.getData() + ":" + error.toString() + ":" + error.getMessage());
error.printStackTrace();

}
return result.toMap();
}

private void returnCount(ServerRequest request){
Boolean inWhiteList = request.exchange().getAttribute("inWhiteList");
inWhiteList = inWhiteList == null ? Boolean.FALSE : inWhiteList;
resultFilter.returnApiCount(request.exchange(), inWhiteList);
}

/**
* 指定响应处理方法为JSON处理的方法
*/
@Override
protected RouterFunction<ServerResponse> getRoutingFunction(ErrorAttributes errorAttributes) {
return RouterFunctions.route(RequestPredicates.all(), this::renderErrorResponse);
}

/**
* 将返回的响应状态都设置为200
*/
@Override
protected HttpStatus getHttpStatus(Map<String, Object> errorAttributes) {
return HttpStatus.valueOf(HttpStatus.OK.value());
}
}

返回值包装

在正常返回之前结果之前,需要对用户访问总数等进行更新,调用ResultFilter对返回结果进行封装:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
@Component
public class ResultFilter implements GlobalFilter, Ordered {

private RedisTemplate<String, String> redisTemplate;
private ModifyResponseBodyGatewayFilterFactory modifyResponseBodyGatewayFilterFactory;
private GatewayFilter filter;
private SyncCountTask syncCountTask;


@Autowired
public ResultFilter(ModifyResponseBodyGatewayFilterFactory modifyResponseBodyGatewayFilterFactory,
RedisTemplate<String, String> redisTemplate,
SyncCountTask syncCountTask) {
this.modifyResponseBodyGatewayFilterFactory = modifyResponseBodyGatewayFilterFactory;
this.redisTemplate = redisTemplate;
this.syncCountTask = syncCountTask;
filter = init();
}

@Override
public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
//需要排除流(由于基于webFlux,modifyResponseBodyGatewayFilter结果需要阻塞流,因此需要单独处理)
String streamSign = "/stream/";
if (exchange.getRequest().getPath().toString().contains(streamSign)) {
return chain.filter(exchange);
}
return filter.filter(exchange, chain);
}

@Override
public int getOrder() {
return NettyWriteResponseFilter.WRITE_RESPONSE_FILTER_ORDER - 1;
}

private GatewayFilter init() {
return modifyResponseBodyGatewayFilterFactory.apply(c -> c.setRewriteFunction(String.class, String.class, (serverWebExchange, s) -> {
serverWebExchange.getResponse().getHeaders().setContentType(MediaType.APPLICATION_JSON_UTF8);

Boolean inWhiteList = serverWebExchange.getAttribute("inWhiteList");
inWhiteList = inWhiteList == null ? Boolean.FALSE : inWhiteList;

HttpStatus statusCode = serverWebExchange.getResponse().getStatusCode();
Result result;
if (statusCode == HttpStatus.OK) {
try {
Result serverResult = JSONObject.parseObject(s, Result.class);
if (serverResult.getCode() == null && serverResult.getMessage() == null) {
throw new IllegalStateException();
}
result = new Result(serverResult.getData());
} catch (Exception e) {
returnApiCount(serverWebExchange, inWhiteList);
result = new Result(FailureResult.valueOf("ERROR_RESULT"));
return Mono.just(result.toString());
}
//接口处理成功,提交后续操作
UserLimitEntity userLimitEntity = serverWebExchange.getAttribute("userLimitEntity");
if (!inWhiteList && !userLimitEntity.notCountLimit()) {
syncCountTask.sync(StringUtil.toMD5(serverWebExchange.getAttribute("userId"),
serverWebExchange.getAttribute("md5Path")));
}

} else if (statusCode == HttpStatus.NOT_FOUND) {
returnApiCount(serverWebExchange, inWhiteList);
changeStatusCode(serverWebExchange, HttpStatus.OK);
result = new Result(FailureResult.valueOf("UNSIGN_API"));
} else {
Boolean isServerError = statusCode == null || statusCode.is5xxServerError();
returnApiCount(serverWebExchange, inWhiteList);
changeStatusCode(serverWebExchange, HttpStatus.OK);
if (isServerError) {
result = new Result(FailureResult.valueOf("ERROR_SERVER"));
} else {
result = new Result(FailureResult.valueOf("BAD_REQUEST"));
}
String message = "服务异常";
try {
JSONObject jsonObject = JSONObject.parseObject(s);
message = jsonObject.getString("message");
String data = jsonObject.getString("data");
if (!StringUtils.isEmpty(data)) {
message += "【" + data + "】";
}
} catch (Exception e) {
System.err.println("接口" + serverWebExchange.getRequest().getPath() + "异常结果不统一。异常为:" + s);
}
result.setData(result.getData() + ":" + message);
}
return Mono.just(result.toString());
}));
}

private static Number str2num(String str) throws NumberFormatException {
try {
return new BigInteger(str);
} catch (NumberFormatException e) {
return new BigDecimal(str);
}
}

public void returnApiCount(ServerWebExchange exchange, boolean inWhiteList) {
if (inWhiteList) {
return;
}
UserLimitEntity userLimitEntity = exchange.getAttribute("userLimitEntity");
if (userLimitEntity.notCountLimit()) {
return;
}
returnRedisCount(exchange);
}

private void returnRedisCount(ServerWebExchange exchange) {
String redisPath = RedisPath.USER_COUNT_LIMIT + exchange.getAttribute("userId");
String md5Path = exchange.getAttribute("md5Path");
redisTemplate.opsForZSet().incrementScore(redisPath, md5Path, -1);
}

private void changeStatusCode(ServerWebExchange exchange, HttpStatus httpStatus) {
exchange.getResponse().setStatusCode(httpStatus);
}


}

filter过滤请求

用户权限的过滤器

Order:Integer.MIN_VALUE

一头一尾简单说完,那中间经历了哪些filter的处理过程呢,根据bean的Order加载顺序,下一个我们进行的是用户的验证AuthFilter,Order:Integer.MIN_VALUE:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
@Component
public class AuthFilter implements MyFilter {

private RedisTemplate<String, String> redisTemplate;

@Autowired
public AuthFilter(RedisTemplate<String, String> redisTemplate) {
this.redisTemplate = redisTemplate;
}

@Override
public Mono<Void> noPassFilter(ServerWebExchange exchange, GatewayFilterChain chain) {

String path = exchange.getRequest().getPath().toString();
String md5Path = StringUtil.toMD5(path);

//接口是否启用
Boolean exist = redisTemplate.opsForSet().isMember(RedisPath.ALL_URL, md5Path);
if (exist == null || !exist) {
Result nullApi = new Result(FailureResult.valueOf("NULL_API"));
return ResponseUtil.newResponse(exchange, nullApi);
}

//用户是否有权限
List<String> accessTokenList = exchange.getRequest().getQueryParams().get("accessToken");
if (accessTokenList == null || accessTokenList.size() < 1) {
Result missToken = new Result(FailureResult.valueOf("MISS_TOKEN"));
return ResponseUtil.newResponse(exchange, missToken);
}
String token = accessTokenList.get(0);

String userId = redisTemplate.opsForValue().get(RedisPath.PERMISSION + token);
if (StringUtils.isEmpty(userId)) {
Result expiredToken = new Result(FailureResult.valueOf("EXPIRED_TOKEN"));
return ResponseUtil.newResponse(exchange, expiredToken);
} else {
Map<Object, Object> permission = redisTemplate.opsForHash().entries(RedisPath.USER_INFO + userId);
Object obj = permission.get(md5Path);
if (obj == null) {
Result impermissibleApi = new Result(FailureResult.valueOf("IMPERMISSIBLE_API"));
return ResponseUtil.newResponse(exchange, impermissibleApi);
} else {
UserLimitEntity userLimitEntity = new UserLimitEntity(obj.toString());
exchange.getAttributes().put("userLimitEntity", userLimitEntity);
exchange.getAttributes().put("userId", userId);
exchange.getAttributes().put("md5Path", md5Path);
}
}

return chain.filter(exchange);
}

@Override
public int getOrder() {
return Integer.MIN_VALUE;
}


}

接口是否启用,通过Key值RedisPath.ALL_URL从Redis里面获取所有已经存在的API,然后对1、token是否带入参数;2、token是否为空进行判断,如果都通过,通过userId在Redis里面查找用户限制鉴权对象UserLimitEntity,UserLimitEntity和RedisPath.ALL_URL的获取都存在从MySQL缓存至Redis的过程,在下面的类InitHandler进行加载。

Order:1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
@Component
@Order(1)
public class InitHandler implements CommandLineRunner {

private static final Logger LOG = LoggerFactory.getLogger(InitHandler.class);

private RedisTemplate<String, String> redisTemplate;
private TokenBucketClient tokenBucketClient;
private UrlMapper urlMapper;

@Autowired
public InitHandler(RedisTemplate<String, String> redisTemplate, TokenBucketClient tokenBucketClient, UrlMapper urlMapper) {
this.redisTemplate = redisTemplate;
this.tokenBucketClient = tokenBucketClient;
this.urlMapper = urlMapper;
}

@Override
public void run(String... args) throws Exception {
initEnum();
List<Url> urls = urlMapper.getAll();
initUrl(urls);
initUrlLimit(urls);
}

private static void initEnum() throws Exception {
LOG.info("############ 枚举加载 开始 ############");
Resource resource = new ClassPathResource("FailureResult.properties");
InputStream is = resource.getInputStream();
BufferedReader reader = new BufferedReader(new InputStreamReader(is));
String tmp;
Map<String, Object[]> enumValues = new HashMap<>(20);
while ((tmp = reader.readLine()) != null) {
if (!StringUtils.isEmpty(tmp) && tmp.trim().indexOf("#") != 0) {
String[] nameAndCodeMsg = tmp.split("=");
String name = nameAndCodeMsg[0].trim();
String cav = nameAndCodeMsg[1].trim();
String code = cav.substring(0, cav.indexOf(","));
String msg = cav.substring(cav.indexOf(",") + 1);
Object[] value = new Object[2];
value[0] = Integer.parseInt(code.trim());
value[1] = msg.trim();
enumValues.put(name, value);
}
}
EnumUtil.addEnum(FailureResult.class, new Class[]{int.class, String.class}, enumValues);
LOG.info("############ 枚举加载 结束 ############");
}

/**
* 初始化所有的可用url
*/
private void initUrl(List<Url> urls) {
LOG.info("############ 缓存可用的URL到redis(共计:{}) 开始 ############", urls.size());
String[] urlArr = new String[urls.size()];
for (int i = 0; i < urls.size(); i++) {
urlArr[i] = urls.get(i).getId();
}
Long result = redisTemplate.opsForSet().add(RedisPath.ALL_URL, urlArr);
LOG.info("############ 缓存可用的URL到redis(初始化:{}) 结束 ############", result == null ? 0 : result);
}

/**
* 获取所有限制接口访问频率的接口,初始化redis中对应的令牌桶信息
*/
private void initUrlLimit(List<Url> urls) {
LOG.info("############ 初始化令牌桶(共计:{}) 开始 ############", urls.size());
int total = 0;
for (Url url : urls) {
if (url.getBucketMaxSize() == 0 || url.getQps() == 0) {
continue;
}
tokenBucketClient.init(RedisPath.URL_LIMIT + url.getId(), url.getBucketMaxSize(), url.getQps());
total++;
}
LOG.info("############ 初始化令牌桶(初始化:{}) 结束 ############", total);
}
}

可以看出类InitHandler中方法initUrl()初始化了所有的存在的urls(即已经存在于数据库中的API),缓存至Redis,我们看看Mysql视图:

Redis视图(只缓存url的Id):

initEnum()方法初始化返回值错误类型的枚举类加载,initUrlLimit()初始化url的令牌桶,至于令牌什么时候用,我们后面说。

用户访问速度的过滤器

Order:Integer.MIN_VALUE+1

下面进入用户访问速度的过滤器,这里用到了redis+Lua的控制逻辑,直接上代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
@Component
public class UserRateFilter implements MyFilter {

private UserRateClient userRateClient;

@Autowired
public UserRateFilter(UserRateClient userRateClient) {
this.userRateClient = userRateClient;
}

@Override
public Mono<Void> noPassFilter(ServerWebExchange exchange, GatewayFilterChain chain) {

UserLimitEntity userLimitEntity = exchange.getAttribute("userLimitEntity");
if (userLimitEntity.notRateLimit()) {
return chain.filter(exchange);
}

String userId = exchange.getAttribute("userId");
String md5Path = exchange.getAttribute("md5Path");
String redisPath = RedisPath.USER_RATE_LIMIT + userId + ":" + md5Path;

return acquire(redisPath,userLimitEntity.getRate(),exchange,chain);
}

@Override
public int getOrder() {
return Integer.MIN_VALUE + 1;
}


private Mono<Void> acquire(String redisPath, Map<Integer,Integer> limitInfo,ServerWebExchange exchange, GatewayFilterChain chain){
LuaScriptResult acquire = userRateClient.acquire(redisPath);
if (acquire == LuaScriptResult.NEED_INIT) {
userRateClient.init(redisPath, limitInfo);
acquire = userRateClient.acquire(redisPath);
}

if (acquire == LuaScriptResult.SUCCESS) {
return chain.filter(exchange);
} else if (acquire == LuaScriptResult.LIMITED) {
Result frequentRequest = new Result(FailureResult.valueOf("FREQUENT_REQUEST"));
return ResponseUtil.newResponse(exchange, frequentRequest);
} else {
throw new IllegalStateException("user limit acquire error");
}
}
}

类UserRateFilte,如果不在类AuthFilter过滤器中的白名单,进入自定义的MyFilter noPassFilter()进入过滤器逻辑,其中UserLimitEntity,在上面提到的类InitHandler中,已经由用户管理系统(不用于此网关系统)Mysql缓存到Redis中,如图:

其中字段USED(已访问次数),TOTAL(总访问次数),RATE(访问速率),如:{60:10,3600:300}表示两个限制条件,此用户60秒最多可累计可访问10次,3600秒最多累计可访问300次,这个二元组个数可以根据需求随意添加。其中方法acquire(),将对用户限制的次数进行判断,并返回是否通过过滤,这里会由userRateClient加载user_rate_limit.lua对用户是否能够获取访问权限进行判断。先看看类userRateClient:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
@Component
public class UserRateClient {

/**
* redis客户端
*/
private StringRedisTemplate redisTemplate;

/**
* redis脚本调用类
*/
@Resource
@Qualifier("userRateLua")
private RedisScript<Long> userRateScript;

@Autowired
public UserRateClient(StringRedisTemplate redisTemplate) {
this.redisTemplate = redisTemplate;
}

public LuaScriptResult init(String path, Map<Integer, Integer> limit) {
try {
Integer max = 0;
StringBuilder intervalSb = new StringBuilder();
intervalSb.append("{");
StringBuilder countSb = new StringBuilder();
countSb.append("{");
for (Integer key : limit.keySet()) {
max = max > key ? max : key;
intervalSb.append(key * 1000).append(",");
countSb.append(limit.get(key)).append(",");
}
intervalSb.deleteCharAt(intervalSb.length() - 1).append("}");
countSb.deleteCharAt(countSb.length() - 1).append("}");

return exec(path, UserRateLimitMethod.init, intervalSb.toString(), countSb.toString(), String.valueOf(max), String.valueOf(limit.size()));
} catch (Exception e) {
e.printStackTrace();
return LuaScriptResult.ERROR;
}
}

public LuaScriptResult acquire(String path) {
LuaScriptResult exec = exec(path, UserRateLimitMethod.acquire);
return exec;
}

private LuaScriptResult exec(String path, UserRateLimitMethod method, Object... params) {
try {
List<String> keys = new ArrayList<String>() {{
add(path);
add(method.name());
}};
Long result = redisTemplate.execute(userRateScript, keys, params);
return LuaScriptResult.getResult(result);
} catch (Exception e) {
e.printStackTrace();
return LuaScriptResult.ERROR;
}
}
}

UserRateClient里面的exec(path, UserRateLimitMethod.init, intervalSb.toString(), countSb.toString(), String.valueOf(max), String.valueOf(limit.size())) 方法,intervalSb和countSb都封装成{毫秒数:可访问数}这种多个二元组的形式,@Qualifier(“userRateLua”)通过Spring Configuration加载lua脚本user_rate_limit.lua:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
redis.replicate_commands()

local method = KEYS[2]

local curr_time_arr = redis.call('TIME')
local curr_time = curr_time_arr[1] * 1000 + math.floor(curr_time_arr[2]/1000)

local Result = {SUCCESS=1,DEFEAT=0,LIMITED=-1,NEEDINIT=99}

local limit_info = redis.pcall('HMGET',KEYS[1],'interval','max_count','first_time','count')

local function arr2str(arr)
local str = '{'
for i=1,#arr,1 do
str = str .. arr[i]
if i<#arr then
str = str .. ','
end
end
return str .. '}'
end

if method == 'init' then
if(type(limit_info[1]) ~='boolean' and limit_info[1] ~=nil) then
return Result.DEFEAT
end

local curr_time_str = '{'
local count_str = '{'
for i=1,ARGV[4],1 do
curr_time_str = curr_time_str .. curr_time
count_str = count_str .. 0
if i < tonumber(ARGV[4]) then
curr_time_str = curr_time_str .. ','
count_str = count_str .. ','
end
end
curr_time_str = curr_time_str .. '}'
count_str = count_str .. '}'

redis.pcall('HMSET', KEYS[1],
'interval', ARGV[1],
'max_count', ARGV[2],
'first_time', curr_time_str,
'count', count_str)
redis.pcall('EXPIRE',KEYS[1],ARGV[3])
return Result.SUCCESS
end

if method == 'acquire' then
if(type(limit_info[1]) =='boolean' or limit_info[1] ==nil) then
return Result.NEEDINIT
end
local interval = loadstring('return ' .. limit_info[1])()
local max_count = loadstring('return ' .. limit_info[2])()
local first_time = loadstring('return ' .. limit_info[3])()
local count = loadstring('return ' .. limit_info[4])()

for i=1,#interval,1 do
if curr_time-first_time[i] > interval[i] then
first_time[i] = curr_time
count[i] = 0
end
if count[i] >= max_count[i] then
return Result.LIMITED
end
end

for i=1,#interval,1 do
count[i] = count[i] + 1
end

redis.pcall('HSET', KEYS[1], 'first_time', arr2str(first_time),'count',arr2str(count))

return Result.SUCCESS
end

如果不懂lua脚本语法,可以去网上先了解一下,作为redis分布式锁原子性实现的利器(redis官方推出的解决方案),在进行redis分布式锁设计的时候也可以大展拳脚,实际顺手程度非常推荐。说远了,回来到上面的lua脚本。init方法初始化用户限流的各项参数,interval-时间间隔,max_count-最大访问次数,first_time-第一次访问时间,count-置0,当然是实际是{0}样子,上面做了字符串的拼接,这是lua的数组包装形式,在acquire方法里面如果redis里面没有进行初始化的话就进行初始化,在Redis里面的视图是:

acquire发现,如果当前时间减去第一次时间大于时间间隔,用户已使用次数count归0,并把curr_time(当前时间)赋值给first_time(第一次时间),如果在这个时间间隔内,使用次数count+1,如果访问次数大于最大访问次数,返回Result.LIMITED,如果返回Result.SUCCESS,通过类UserRateFilter的过滤条件,进入下一个过滤器。

接口访问速度的过滤器

Order:Integer.MIN_VALUE+2

类UrlRateFilter是对接口访问速度的控制,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
@Component
public class UrlRateFilter implements MyFilter {

private TokenBucketClient tokenBucketClient;

@Autowired
public UrlRateFilter(TokenBucketClient tokenBucketClient) {
this.tokenBucketClient = tokenBucketClient;
}

@Override
public Mono<Void> noPassFilter(ServerWebExchange exchange, GatewayFilterChain chain) {

String md5Path = exchange.getAttribute("md5Path");
LuaScriptResult acquire = tokenBucketClient.acquire(RedisPath.URL_LIMIT + md5Path);

if (acquire == LuaScriptResult.SUCCESS) {
return chain.filter(exchange);
} else if (acquire == LuaScriptResult.LIMITED) {
Result limitedApi = new Result(FailureResult.valueOf("LIMITED_API"));
return ResponseUtil.newResponse(exchange, limitedApi);
} else {
throw new IllegalStateException("api limit acquire error");
}

}

@Override
public int getOrder() {
return Integer.MIN_VALUE + 2;
}
}

调用TokenBucketClient进行控制逻辑,此处的逻辑和上面的用户访问控制有点区别,这里面的初始化,即令牌的初始化在类initHandler里面的initUrlLimit()方法已经初始化了,缓存至Redis如图:

里面参数表示:bucket_max_size(令牌桶里面最大令牌数),interval(产生令牌的时间间隔),token(token数),更新时间(timestamp)。关于令牌桶的概念,不清楚的话,可先查看令牌桶概念实现例子

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
@Component
public class TokenBucketClient {

/**
* redis客户端
*/
private StringRedisTemplate redisTemplate;

/**
* redis脚本调用类
*/
@Resource
@Qualifier("tokenBucketLua")
private RedisScript<Long> rateLimitScript;

@Autowired
public TokenBucketClient(StringRedisTemplate redisTemplate) {
this.redisTemplate = redisTemplate;
}

public LuaScriptResult init(String path, int bucketMaxSize, int qps) {
return init(path,bucketMaxSize,getInterval(qps));
}

public LuaScriptResult init(String path, int bucketMaxSize, float interval) {
return exec(path, TbRateLimitMethod.init, bucketMaxSize, interval, bucketMaxSize);
}

public LuaScriptResult modify(String path, int bucketMaxSize, int qps) {
return modify(path, bucketMaxSize, getInterval(qps));
}

public LuaScriptResult modify(String path, int bucketMaxSize, float interval) {
return exec(path, TbRateLimitMethod.modify, bucketMaxSize, interval);
}

public LuaScriptResult delete(String path) {
return exec(path, TbRateLimitMethod.delete);
}

public LuaScriptResult acquire(String path) {
return acquire(path, 1);
}

public LuaScriptResult acquire(String path, int acquireToken) {
return exec(path, TbRateLimitMethod.acquire, acquireToken);
}

private LuaScriptResult exec(String path, TbRateLimitMethod method, Object... params) {
try {
String[] allParams = new String[params.length + 1];
allParams[0] = method.name();
for (int index = 0; index < params.length; index++) {
allParams[1 + index] = params[index].toString();
}
Long result = redisTemplate.execute(rateLimitScript,
Collections.singletonList(path),
allParams);
return LuaScriptResult.getResult(result);
} catch (Exception e) {
e.printStackTrace();
return LuaScriptResult.ERROR;
}
}

private static float getInterval(int qps) {
return 1000.0f / qps;
}

控制令牌产生的lua脚本如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
redis.replicate_commands()

local Result = {SUCCESS=1,DEFEAT=0,LIMITED=-1}

local token_bucket_info = redis.pcall('HMGET',KEYS[1],'bucket_max_size','interval','token','timestamp')

local bucket_max_size = tonumber(token_bucket_info[1])
local interval = tonumber(token_bucket_info[2])
local token = tonumber(token_bucket_info[3])
local timestamp = token_bucket_info[4]

local method = ARGV[1]

local curr_time_arr = redis.call('TIME')
local curr_timestamp = curr_time_arr[1] * 1000 + math.floor(curr_time_arr[2]/1000)

if method == 'init' then
if(type(timestamp) ~='boolean' and timestamp ~=nil) then
return Result.SUCCESS
end
redis.pcall('HMSET', KEYS[1],
'bucket_max_size', ARGV[2],
'interval', ARGV[3],
'token', ARGV[4],
'timestamp', curr_timestamp)
return Result.SUCCESS
end


if method == 'modify' then
if(type(timestamp)=='boolean' and timestamp==nil) then
return Result.DEFEAT
end
redis.pcall('HMSET', KEYS[1],
'bucket_max_size', ARGV[2],
'interval', ARGV[3])
return Result.SUCCESS
end

if method == 'delete' then
if(type(timestamp) =='boolean' or timestamp ==nil) then
return Result.SUCCESS
end
redis.pcall('DEL', KEYS[1])
return Result.SUCCESS
end

if method == 'acquire' then
if(type(timestamp) =='boolean' or timestamp ==nil) then
return Result.SUCCESS
end
--获取认证消耗的令牌数
local acquire_token = tonumber(ARGV[2])
--计算当前时间与上一次认证的时间差内改产生的令牌数
local reserve_token = math.max(0, math.floor((curr_timestamp - timestamp) / interval))
--如果超出桶的最大令牌数,则抛弃
local curr_token = math.min(bucket_max_size, token + reserve_token)
local result = Result.LIMITED
--如果桶中令牌数量够则放行
if curr_token >= acquire_token then
result = Result.SUCCESS
curr_token = curr_token - acquire_token
end
--更新当前桶中的令牌数量
redis.pcall('HSET', KEYS[1], 'token', curr_token)
--如果这次有放入令牌,则更新时间
if reserve_token > 0 then
redis.pcall('HSET', KEYS[1], 'timestamp', curr_timestamp)
end
return result
end

acquire方法里面 acquire_token传值是1(即每次访问消耗令牌数1),更新时间差与interval比值产生对应令牌数,如果桶里面令牌大于acquire_token,则放行,返回Result.SUCCESS,完成接口访问速率的控制。

用户总访问次数控制过滤器

Order:Integer.MIN_VALUE+3

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
@Component
public class UserCountFilter implements MyFilter {


private RedisTemplate<String, String> redisTemplate;

@Autowired
public UserCountFilter(RedisTemplate<String, String> redisTemplate) {
this.redisTemplate = redisTemplate;
}

@Override
public Mono<Void> noPassFilter(ServerWebExchange exchange, GatewayFilterChain chain) {

ZSetOperations<String, String> zSet = redisTemplate.opsForZSet();

UserLimitEntity userLimitEntity = exchange.getAttribute("userLimitEntity");

if (userLimitEntity.notCountLimit()) {
return chain.filter(exchange);
}

String redisPath = RedisPath.USER_COUNT_LIMIT + exchange.getAttribute("userId");
String md5Path = exchange.getAttribute("md5Path");

Double count = zSet.incrementScore(redisPath, md5Path, 1);
if (count == 1 && userLimitEntity.getUsed() != 0) {
zSet.incrementScore(redisPath, md5Path, userLimitEntity.getUsed());
} else if (count > userLimitEntity.getTotal()) {
zSet.incrementScore(redisPath, md5Path, -1);

Result limitedCount = new Result(FailureResult.valueOf("LIMITED_COUNT"));
return ResponseUtil.newResponse(exchange, limitedCount);
}

return chain.filter(exchange);
}

@Override
public int getOrder() {
return Integer.MIN_VALUE + 3;
}
}

UserLimitEntity用户控制里面有个total,表示用户可以访问最大次数,上面过滤器进行访问值的判断,如果大于,返回FailureResult.valueOf(“LIMITED_COUNT”),错误结果通过枚举形式把配置文件的自定义的错误码(下图),在前文提到的InitHandler类中进行加载。

FailureResult.properties:

通过网关加载访问业务层

经过上面过滤器,我们完成了对用户的鉴权和限流的设计和控制,之后正式进行业务层逻辑访问,本文完。

Gitlab:项目地址