SpringBoot2+Spring Security集成WebSocket

项目要求前端要做到自动刷新功能,low 的办法就是由前端做轮询。但是随着后面产品要求的实时通知越来越多,轮询已经渐渐不满足需求了,只能选择使用 socket 长链接了。

但是百度,Google 之后,发现大多数都是使用的 SocketJS 来实现的 socket,而前端使用的 React Native 只能使用 websocket 实现。没有办法,只能自己查阅资料实现了。

引入 Maven 依赖
1
2
3
4
5
6
7
8
9
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-websocket</artifactId>
</dependency>
<!-- 兼容spring security的socket -->
<dependency>
<groupId>org.springframework.security</groupId>
<artifactId>spring-security-messaging</artifactId>
</dependency>
websocket 配置
  1. 项目中有使用 Spring Security 作为登录权限验证,所以要先将 security 对 socket 的路径忽略掉,自己去做权限
1
2
3
4
5
6
7
8
@Override
protected void configure(HttpSecurity http) throws Exception {
http.csrf().disable()
...
.antMatchers("/socket/**").permitAll()
.antMatchers("/socket").permitAll()
...
}
  1. socket作为访问路径,允许所有远端访问
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
@Configuration
@EnableWebSocket
public class WebSocketConfig implements WebSocketConfigurer {

@Autowired
private WebSocketMessageHandler webSocketMessageHandler;
@Autowired
private GuguHandshakeInterceptor handshakeInterceptor;

@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
registry.addHandler(webSocketMessageHandler, "/socket")
.addInterceptors(handshakeInterceptor)
.setAllowedOrigins("*");
}
}
开始实现 websocket 的接入
websocket 的握手

项目中使用了 security 做了登录用户的管理。所以限制只有登录的用户才能和服务端进行 websocket 链接,因此之前放在 http 请求头中的 token 信息要带到 websocket 中来,但是 websocket 中没有请求头参数的设置,因此只能将 token 信息放在链接地址中,类似与 get 请求。如ws://localhost:8080/socket?token={token}这样。所以,要在 websocket 和服务器建立连接的握手中去校验 token 的合法性。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
package com.gugu.boy.spring.websocket.interceptor;

import com.gugu.boy.spring.security.GuguUserDetails;
import com.gugu.boy.spring.security.authentication.storage.TokenStorage;
import com.gugu.boy.spring.security.authentication.storage.UserCache;
import com.gugu.boy.spring.security.authentication.token.TokenAuthenticationHandler;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor;

import java.util.ArrayList;
import java.util.Map;

/**
* websocket握手.
*
* @author zmq
* @date 2019/10/14
*/
@Slf4j
@Component
public class GuguHandshakeInterceptor extends HttpSessionHandshakeInterceptor {

@Autowired
private TokenAuthenticationHandler tokenUtils;
@Autowired
private TokenStorage tokenStorage;

@Override
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws Exception {
ServletServerHttpRequest serverRequest = (ServletServerHttpRequest) request;
// websocket的token信息不能想
String token = serverRequest.getServletRequest().getParameter("token");
boolean authentication = this.authentication(token, attributes);
if (authentication) {
log.info("握手成功");
return super.beforeHandshake(request, response, wsHandler, attributes);
} else {
log.warn("握手失败");
}
return false;
}

@Override
public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Exception exception) {
}

private boolean authentication(String token, Map<String, Object> attributes) {
if (StringUtils.isBlank(token)) {
log.warn("缺失参数: token");
return false;
}
String id = tokenUtils.getFromToken(TokenAuthenticationHandler.CLAIM_KEY_ID, token);
UserCache userCache = tokenStorage.getUserById(id);
log.info("用户ID【{}】尝试进行握手", id);
if (id != null && userCache != null) {
GuguUserDetails userDetails = toUserDetails(userCache);
if (token.equals(userCache.getToken())) {
//token中的用户信息和数据库中的用户信息对比成功后将用户信息加入SecurityContextHolder相当于登录
UsernamePasswordAuthenticationToken authentication = new UsernamePasswordAuthenticationToken(userDetails.getPhone(), null, new ArrayList<>(0));
authentication.setDetails(userDetails);
attributes.put("id", userDetails.getId());
attributes.put("username", userDetails.getPhone());
attributes.put("userType", userDetails.getUserType());
SecurityContextHolder.getContext().setAuthentication(authentication);
return true;
} else {
log.warn("用户{}可能在其他设备上被登录了", id);
}

}
log.warn("token信息:{}", token);
return false;
}

/**
* 类型转换.
*
* @param source
* @return UserCache
* @author zhoumeiqin
* @date 2019-07-30 10:21
**/
private GuguUserDetails toUserDetails(UserCache source) {
GuguUserDetails target = new GuguUserDetails();
BeanUtils.copyProperties(source, target);
return target;
}
}
服务端针对前端通过 websocket 长链接发送的消息处理
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
package com.gugu.boy.spring.websocket.handler;

import com.alibaba.fastjson.JSON;
import com.gugu.boy.spring.security.GuguUserDetails;
import com.gugu.boy.spring.websocket.MessagePayload;
import com.gugu.boy.spring.websocket.PayloadConst;
import com.gugu.boy.spring.websocket.WebSocketContainer;
import com.gugu.boy.spring.websocket.WebSocketTemplate;
import com.gugu.core.exception.LogicException;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;

/**
* @author zhoumeiqin
* @date 2019/10/14
*/
@Slf4j
@Component
public class WebSocketMessageHandler extends TextWebSocketHandler {
@Autowired
private WebSocketContainer container;
@Autowired
private WebSocketTemplate webSocketTemplate;

@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
container.connect(session);
super.afterConnectionEstablished(session);
}

@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
container.close(session, status);
super.afterConnectionClosed(session, status);
}

@Override
public void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
GuguUserDetails user = container.getCurrentUser(session);
MessagePayload payload = JSON.parseObject(message.getPayload(), MessagePayload.class);
if (payload.getEvent() == PayloadConst.Event.HEARTBEAT) {
log.info("{} heartbeat", user.getPhone());
session.sendMessage(message);
} else {
try {
webSocketTemplate.sendMessage(payload);
} catch (LogicException e) {
log.warn("触发业务异常: {} - {}", e.getErrorCode(), e.getMessage());
} catch (Exception e) {
log.error(",errMsg", e);
}
}
}
}

因为项目中需要有业务需要对固定的用户发送消息的需求,所以我针对 WebSocketSession 进行了集中管理。
首先创建WebSocketContainer组件,将所有的 WebSocketSession 集中存储到ConcurrentHashMap中,以 userId 作为 Map 的 key,WebSocketSession 作为 value,这样就可以通过 Map 的特性可以快速找到对应 userId 的 WebSocketSession,从而将消息发送出去。
当有客户端连接上来或者断开链接,都要增加 session 或者移除对应的 session。
同时,启动一个只有 1 个线程的线程池,定时对 map 进行监控,可以定期展示当前链接人员。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
package com.gugu.boy.spring.websocket;

import com.alibaba.fastjson.JSON;
import com.gugu.boy.spring.security.GuguUserDetails;
import com.gugu.boy.spring.websocket.handler.PayloadProcessorContext;
import com.gugu.boy.spring.websocket.handler.processor.PayloadProcessor;
import com.gugu.core.constants.SystemConstant;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.concurrent.BasicThreadFactory;
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.WebSocketSession;

import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
import java.io.IOException;
import java.security.Principal;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeUnit;

/**
* @author zhoumeiqin
* @date 2019/10/21
*/
@Slf4j
@Component
public class WebSocketContainer {
private ScheduledExecutorService executor;
/**
* 查勘员在线session.
* <p>key = user id </p>
*/
private static Map<Long, WebSocketSession> surveySessions = new ConcurrentHashMap<>();

@PostConstruct
public void init() {
this.monitor();
}

@PreDestroy
public void destroy() {
if (this.executor != null) {
log.info("销毁定时器");
this.executor.shutdown();
}
}

/**
* 监控.
*/
private void monitor() {
log.info("定时检测在线人数");
Runnable command = () -> {
log.info("在线人数: {}", surveySessions.size());
log.info("在线人员:{}", surveySessions.keySet());
};
executor = new ScheduledThreadPoolExecutor(1,
new BasicThreadFactory.Builder().namingPattern("socket-mon-%d").daemon(true).build());
executor.scheduleAtFixedRate(command, 1, 5, TimeUnit.MINUTES);
}

/**
* 查勘员连接socket.
*
* @param user
* @param session
* @return
* @throws IOException
*/
public void connect(GuguUserDetails user, WebSocketSession session) throws IOException {
WebSocketSession history = surveySessions.get(user.getId());
if (history != null) {
log.info("用户[{}]重复连接,关闭已有连接", user.getPhone());
history.close(WebSocketConst.Close.NORMAL);
}
surveySessions.put(user.getId(), session);
log.info("[{}]登录了, sessionId: {},当前在线人数:{}", user.getPhone(), session.getId(), surveySessions.size());
}

/**
* 连接.
*
* @param session
* @throws IOException
*/
public void connect(WebSocketSession session) throws Exception {
GuguUserDetails user = this.getCurrentUser(session);
if (user != null) {
this.connect(user, session);
this.afterSocketConnect(user.getId());
} else {
log.warn("当前用户未知");
}
}

/**
* socket连接错误.
*
* @param userId
*/
public void afterSocketConnect(Long userId) throws Exception {
MessagePayload payload = new MessagePayload();
payload.setEvent(PayloadConst.Event.TO_SINGLE);
payload.setType(PayloadConst.Type.SOCKET_CONNECT);
payload.setToUser(userId);
payload.setToUserType(SystemConstant.SURVEYOR);
this.sendMessage(userId, payload);
}

/**
* close.
*
* @param userId
*/
public void close(Long userId) throws IOException {
WebSocketSession session = surveySessions.get(userId);
if (session == null) {
log.warn("用户[{}]不在线", userId);
return;
}
session.close(WebSocketConst.Close.NORMAL);
surveySessions.remove(userId);
}

/**
* 关闭. 默认4000,主动关闭,app不会进行连接重试.
*
* @param user
*/
public void close(GuguUserDetails user) throws IOException {
this.close(user, WebSocketConst.Close.NORMAL);
}

/**
* 关闭.
*
* @param user
* @param status
*/
public void close(GuguUserDetails user, CloseStatus status) throws IOException {
if (user != null) {
log.info("连接关闭:{}, 关闭状态:{}", user.getPhone(), status);
this.close(user.getId());
} else {
log.warn("当前用户未知");
}
}

/**
* 关闭连接.
*
* @param session
* @param status
*/
public void close(WebSocketSession session, CloseStatus status) throws IOException {
GuguUserDetails user = this.getCurrentUser(session);
this.close(user, status);
}

/**
* 获取session.
*
* @param userId
* @return
*/
public WebSocketSession getSession(Long userId) {
return surveySessions.get(userId);
}

/**
* 发送消息给单独用户.
*
* @param userId
* @param payload
* @throws Exception
*/
public synchronized void sendMessage(Long userId, MessagePayload payload) throws Exception {
log.info("发送消息给用户:{}, 数据:{}", userId, JSON.toJSONString(payload));
WebSocketSession session = this.getSession(userId);
if (session == null) {
log.warn("当前用户:{}不存在或者不在线", userId);
return;
}
PayloadProcessor processor = PayloadProcessorContext.getInstance(payload.getType());
processor.process(session, payload);
}

/**
* 发送消息给所有用户.
*
* @param payload
*/
public synchronized void sendMessage(MessagePayload payload) {
log.info("发送消息给所有人.: {}", JSON.toJSONString(payload));
PayloadProcessor processor = PayloadProcessorContext.getInstance(payload.getType());
surveySessions.forEach((key, value) -> {
try {
processor.process(value, payload);
} catch (Exception e) {
log.error("发送给用户消息失败: {}", JSON.toJSONString(value.getPrincipal()), e);
}
});
}

/**
* 获取当前用户.
*
* @param session
* @return
*/
public GuguUserDetails getCurrentUser(WebSocketSession session) {
Principal principal = session.getPrincipal();
if (principal instanceof UsernamePasswordAuthenticationToken) {
UsernamePasswordAuthenticationToken authenticationToken = (UsernamePasswordAuthenticationToken) principal;
return (GuguUserDetails) authenticationToken.getDetails();
} else {
log.warn("未知的连接:{}", JSON.toJSONString(session.getPrincipal()));
}
return null;
}

/**
* 是否在线.
*
* @param userId
* @return
*/
public boolean isOnline(Long userId) {
return surveySessions.containsKey(userId);
}
}

最后再增加一个工具类WebSocketTemplate进行统一发送封装,目前功能比较简单,都是同步发送,后面可以追加线程池管理进行异步发送。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
package com.gugu.boy.spring.websocket;

import com.alibaba.fastjson.JSON;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.io.IOException;

/**
* @author zhoumeiqin
* @date 2019/10/21
*/
@Slf4j
@Component
public class WebSocketTemplate {
@Autowired
private WebSocketContainer container;

/**
* 发送消息.
*
* @param payload
* @throws IOException
*/
public void sendMessage(MessagePayload payload) throws Exception {
if (payload.getEvent() == PayloadConst.Event.TO_SINGLE) {
container.sendMessage(payload.getToUser(), payload);
} else if (payload.getEvent() == PayloadConst.Event.TO_ALL) {
container.sendMessage(payload);
}
}

/**
* 关闭close.
*
* @param userId
* @throws IOException
*/
public void close(Long userId) throws IOException {
log.info("断开用户{}的socket连接", userId);
container.close(userId);
}
}