diff --git a/transport/vmess/websocket.go b/transport/vmess/websocket.go index 60353d5a..0c2a3a16 100644 --- a/transport/vmess/websocket.go +++ b/transport/vmess/websocket.go @@ -324,19 +324,34 @@ func streamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig, return nil, fmt.Errorf("parse url %s error: %w", c.Path, err) } - scheme := "ws" + uri := url.URL{ + Scheme: "ws", + Host: net.JoinHostPort(c.Host, c.Port), + Path: u.Path, + RawQuery: u.RawQuery, + } + if c.TLS { - scheme = "wss" + uri.Scheme = "wss" + config := c.TLSConfig + if config == nil { // The config cannot be nil + config = &tls.Config{NextProtos: []string{"http/1.1"}} + } + if config.ServerName == "" && !config.InsecureSkipVerify { // users must set either ServerName or InsecureSkipVerify in the config. + config = config.Clone() + config.ServerName = uri.Host + } + if len(c.ClientFingerprint) != 0 { if fingerprint, exists := tlsC.GetFingerprint(c.ClientFingerprint); exists { - utlsConn := tlsC.UClient(conn, c.TLSConfig, fingerprint) + utlsConn := tlsC.UClient(conn, config, fingerprint) if err = utlsConn.BuildWebsocketHandshakeState(); err != nil { return nil, fmt.Errorf("parse url %s error: %w", c.Path, err) } conn = utlsConn } } else { - conn = tls.Client(conn, c.TLSConfig) + conn = tls.Client(conn, config) } if tlsConn, ok := conn.(interface { @@ -348,13 +363,6 @@ func streamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig, } } - uri := url.URL{ - Scheme: scheme, - Host: net.JoinHostPort(c.Host, c.Port), - Path: u.Path, - RawQuery: u.RawQuery, - } - request := &http.Request{ Method: http.MethodGet, URL: &uri,