当前位置:网站首页>websocket校验token:使用threadlocal存放和获取当前登录用户

websocket校验token:使用threadlocal存放和获取当前登录用户

2022-08-10 18:15:00 march of Time

都知道threadlocal可以用于线程之间的变量隔离,在登录时中它可以放入当前用户,之后再用于获取当前登录用户,下面是一个使用实例。
用户实体类:(jpa框架)

@Data
@EqualsAndHashCode(callSuper = false)
@TableName("sys_user")
public class SysUser extends SuperEntity {
    
	private static final long serialVersionUID = -5886012896705137070L;

	private String username;
	private String password;
	private String nickname;
	private String headImgUrl;
	private String mobile;
	private Integer sex;
	private Boolean enabled;
	private String type;
	private String openId;
	@TableLogic
	private boolean isDel;
}

threadlocal类:

public class LoginUserContextHolder {
    
    private static final ThreadLocal<SysUser> CONTEXT = new TransmittableThreadLocal<>();

    public static void setUser(SysUser user) {
    
        CONTEXT.set(user);
    }

    public static SysUser getUser() {
    
        return CONTEXT.get();
    }

    public static void clear() {
    
        CONTEXT.remove();
    }
}

防止用户到threadlocal中:

import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.common.exceptions.InvalidTokenException;
import org.springframework.security.oauth2.common.exceptions.UnapprovedClientAuthenticationException;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.provider.token.TokenStore;

import javax.servlet.http.HttpServletRequest;
import java.nio.charset.StandardCharsets;
import java.util.*;

public class AuthUtils {
    
   /** * 校验accessToken */
    public static SysUser checkAccessToken(HttpServletRequest request) {
    
        String accessToken = extractToken(request);
        return checkAccessToken(accessToken);
    }

    public static SysUser checkAccessToken(String accessTokenValue) {
    
        TokenStore tokenStore = SpringUtil.getBean(TokenStore.class);
        OAuth2AccessToken accessToken = tokenStore.readAccessToken(accessTokenValue);
        if (accessToken == null || accessToken.getValue() == null) {
    
            throw new InvalidTokenException("Invalid access token: " + accessTokenValue);
        } else if (accessToken.isExpired()) {
    
            tokenStore.removeAccessToken(accessToken);
            throw new InvalidTokenException("Access token expired: " + accessTokenValue);
        }
        OAuth2Authentication result = tokenStore.readAuthentication(accessToken);
        if (result == null) {
    
            throw new InvalidTokenException("Invalid access token: " + accessTokenValue);
        }
        return setContext(result);
    }
    /** * 用户信息赋值 context 对象 */
    public static SysUser setContext(Authentication authentication) {
    
        SecurityContextHolder.getContext().setAuthentication(authentication);
        SysUser user = getUser(authentication);
        LoginUserContextHolder.setUser(user);
        return user;
    }

    /** * *从header 请求中的clientId:clientSecret */
    public static String[] extractClient(HttpServletRequest request) {
    
        String header = request.getHeader("Authorization");
        if (header == null || !header.startsWith(BASIC_)) {
    
            throw new UnapprovedClientAuthenticationException("请求头中client信息为空");
        }
        return extractHeaderClient(header);
    }
   /** * 从header 请求中的clientId:clientSecret * * @param header header中的参数 */
    public static String[] extractHeaderClient(String header) {
    
        byte[] base64Client = header.substring(BASIC_.length()).getBytes(StandardCharsets.UTF_8);
        byte[] decoded = Base64.getDecoder().decode(base64Client);
        String clientStr = new String(decoded, StandardCharsets.UTF_8);
        String[] clientArr = clientStr.split(":");
        if (clientArr.length != 2) {
    
            throw new RuntimeException("Invalid basic authentication token");
        }
        return clientArr;
    }

获取当前登录人:

    /** * 通过 LoginUserContextHolder 获取当前登录人 */
    @GetMapping("/test/auth2")
    public String auth() {
    
        return "auth2:" + LoginUserContextHolder.getUser().getUsername();
    }

websocket鉴权:

import lombok.extern.slf4j.Slf4j;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.websocket.server.ServerEndpointConfig;
public class WcAuthConfigurator extends ServerEndpointConfig.Configurator {
    
//checkOrigin:校验token
    @Override
    public boolean checkOrigin(String originHeaderValue) {
    
        ServletRequestAttributes servletRequestAttributes = (ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
        try {
    
            //检查token有效性
            AuthUtils.checkAccessToken(servletRequestAttributes.getRequest());
        } catch (Exception e) {
    
            log.error("WebSocket-auth-error", e);
            return false;
        }
        return super.checkOrigin(originHeaderValue);
    }
}

在AuthUtils.checkAccessToken方法内部最终执行了threadelocal的put方法
使用WcAuthConfigurator :
//@ServerEndpoint:
主要是将目前的类定义成一个websocket服务器端, 注解的值将被用于监听用户连接的终端访问URL地址,客户端可以通过这个URL来连接到WebSocket服务器端,在这里配置configurator属性为刚刚写的配置类


@Slf4j
@Component
@ServerEndpoint(value = "/websocket/test", configurator = WcAuthConfigurator.class)
public class TestWebSocketController {
    
    @OnOpen
    public void onOpen(Session session) throws IOException {
    
        session.getBasicRemote().sendText("TestWebSocketController-ok");
    }
}
原网站

版权声明
本文为[march of Time]所创,转载请带上原文链接,感谢
https://blog.csdn.net/qq_41358574/article/details/126216978