diff --git a/test/vmess_test.go b/test/vmess_test.go index 6070354a..b696fcea 100644 --- a/test/vmess_test.go +++ b/test/vmess_test.go @@ -348,13 +348,13 @@ func TestClash_VmessGrpc(t *testing.T) { func TestClash_VmessWebsocket0RTT(t *testing.T) { cfg := &container.Config{ - Image: ImageXray, + Image: ImageVmess, ExposedPorts: defaultExposedPorts, } hostCfg := &container.HostConfig{ PortBindings: defaultPortBindings, Binds: []string{ - fmt.Sprintf("%s:/etc/xray/config.json", C.Path.Resolve("vmess-ws-0rtt.json")), + fmt.Sprintf("%s:/etc/v2ray/config.json", C.Path.Resolve("vmess-ws-0rtt.json")), }, } @@ -387,6 +387,46 @@ func TestClash_VmessWebsocket0RTT(t *testing.T) { testSuit(t, proxy) } +func TestClash_VmessWebsocketXray0RTT(t *testing.T) { + cfg := &container.Config{ + Image: ImageXray, + ExposedPorts: defaultExposedPorts, + } + hostCfg := &container.HostConfig{ + PortBindings: defaultPortBindings, + Binds: []string{ + fmt.Sprintf("%s:/etc/xray/config.json", C.Path.Resolve("vmess-ws-0rtt.json")), + }, + } + + id, err := startContainer(cfg, hostCfg, "vmess-xray-ws-0rtt") + if err != nil { + assert.FailNow(t, err.Error()) + } + defer cleanContainer(id) + + proxy, err := outbound.NewVmess(outbound.VmessOption{ + Name: "vmess", + Server: localIP.String(), + Port: 10002, + UUID: "b831381d-6324-4d53-ad4f-8cda48b30811", + Cipher: "auto", + AlterID: 32, + Network: "ws", + UDP: true, + ServerName: "example.org", + WSOpts: outbound.WSOptions{ + Path: "/?ed=2048", + }, + }) + if err != nil { + assert.FailNow(t, err.Error()) + } + + time.Sleep(waitTime) + testSuit(t, proxy) +} + func Benchmark_Vmess(b *testing.B) { configPath := C.Path.Resolve("vmess-aead.json") diff --git a/transport/vmess/websocket.go b/transport/vmess/websocket.go index 7d7a75bb..f00e4c4b 100644 --- a/transport/vmess/websocket.go +++ b/transport/vmess/websocket.go @@ -130,14 +130,12 @@ func (wsc *websocketConn) SetWriteDeadline(t time.Time) error { } func (wsedc *websocketWithEarlyDataConn) Dial(earlyData []byte) error { - earlyDataBuf := bytes.NewBuffer(nil) - base64EarlyDataEncoder := base64.NewEncoder(base64.RawURLEncoding, earlyDataBuf) + base64DataBuf := &bytes.Buffer{} + base64EarlyDataEncoder := base64.NewEncoder(base64.RawURLEncoding, base64DataBuf) - earlydata := bytes.NewReader(earlyData) - limitedEarlyDatareader := io.LimitReader(earlydata, int64(wsedc.config.MaxEarlyData)) - n, encerr := io.Copy(base64EarlyDataEncoder, limitedEarlyDatareader) - if encerr != nil { - return errors.New("failed to encode early data: " + encerr.Error()) + earlyDataBuf := bytes.NewBuffer(earlyData) + if _, err := base64EarlyDataEncoder.Write(earlyDataBuf.Next(wsedc.config.MaxEarlyData)); err != nil { + return errors.New("failed to encode early data: " + err.Error()) } if errc := base64EarlyDataEncoder.Close(); errc != nil { @@ -145,15 +143,14 @@ func (wsedc *websocketWithEarlyDataConn) Dial(earlyData []byte) error { } var err error - if wsedc.Conn, err = streamWebsocketConn(wsedc.underlay, wsedc.config, earlyDataBuf); err != nil { + if wsedc.Conn, err = streamWebsocketConn(wsedc.underlay, wsedc.config, base64DataBuf); err != nil { wsedc.Close() return errors.New("failed to dial WebSocket: " + err.Error()) } wsedc.dialed <- true - - if n != int64(len(earlyData)) { - _, err = wsedc.Conn.Write(earlyData[n:]) + if earlyDataBuf.Len() != 0 { + _, err = wsedc.Conn.Write(earlyDataBuf.Bytes()) } return err