diff --git a/common/sniff/sniff.go b/common/sniff/sniff.go index 3e21b69f..2d2999a8 100644 --- a/common/sniff/sniff.go +++ b/common/sniff/sniff.go @@ -40,22 +40,28 @@ func PeekStream(ctx context.Context, conn net.Conn, buffer *buf.Buffer, timeout } deadline := time.Now().Add(timeout) var errors []error - err := conn.SetReadDeadline(deadline) - if err != nil { - return nil, E.Cause(err, "set read deadline") - } - defer conn.SetReadDeadline(time.Time{}) - var metadata *adapter.InboundContext - for _, sniffer := range sniffers { - if buffer.IsEmpty() { - metadata, err = sniffer(ctx, io.TeeReader(conn, buffer)) - } else { - metadata, err = sniffer(ctx, io.MultiReader(bytes.NewReader(buffer.Bytes()), io.TeeReader(conn, buffer))) + for i := 0; ; i++ { + err := conn.SetReadDeadline(deadline) + if err != nil { + return nil, E.Cause(err, "set read deadline") } - if metadata != nil { - return metadata, nil + _, err = buffer.ReadOnceFrom(conn) + _ = conn.SetReadDeadline(time.Time{}) + if err != nil { + if i > 0 { + break + } + return nil, E.Cause(err, "read payload") + } + errors = nil + var metadata *adapter.InboundContext + for _, sniffer := range sniffers { + metadata, err = sniffer(ctx, bytes.NewReader(buffer.Bytes())) + if metadata != nil { + return metadata, nil + } + errors = append(errors, err) } - errors = append(errors, err) } return nil, E.Errors(errors...) }