From 97ab9bb194a513aac5c03efc97f1949586167084 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 24 Sep 2023 12:00:00 +0800 Subject: [PATCH] Fix shadow-tls user context --- inbound/hysteria2.go | 8 ++++++-- inbound/shadowtls.go | 13 ++++++++++++- inbound/tuic.go | 8 ++++++-- 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/inbound/hysteria2.go b/inbound/hysteria2.go index 82e2f0b7..fd650ae9 100644 --- a/inbound/hysteria2.go +++ b/inbound/hysteria2.go @@ -116,11 +116,13 @@ func NewHysteria2(ctx context.Context, router adapter.Router, logger log.Context func (h *Hysteria2) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { ctx = log.ContextWithNewID(ctx) - h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination) metadata = h.createMetadata(conn, metadata) userID, _ := auth.UserFromContext[int](ctx) if userName := h.userNameList[userID]; userName != "" { metadata.User = userName + h.logger.InfoContext(ctx, "[", userName, "] inbound connection to ", metadata.Destination) + } else { + h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination) } return h.router.RouteConnection(ctx, conn, metadata) } @@ -131,8 +133,10 @@ func (h *Hysteria2) newPacketConnection(ctx context.Context, conn N.PacketConn, userID, _ := auth.UserFromContext[int](ctx) if userName := h.userNameList[userID]; userName != "" { metadata.User = userName + h.logger.InfoContext(ctx, "[", userName, "] inbound packet connection to ", metadata.Destination) + } else { + h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination) } - h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination) return h.router.RoutePacketConnection(ctx, conn, metadata) } diff --git a/inbound/shadowtls.go b/inbound/shadowtls.go index 59b7c107..358564f2 100644 --- a/inbound/shadowtls.go +++ b/inbound/shadowtls.go @@ -11,6 +11,7 @@ import ( "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-shadowtls" "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/auth" N "github.com/sagernet/sing/common/network" ) @@ -66,7 +67,7 @@ func NewShadowTLS(ctx context.Context, router adapter.Router, logger log.Context }, HandshakeForServerName: handshakeForServerName, StrictMode: options.StrictMode, - Handler: inbound.upstreamContextHandler(), + Handler: adapter.NewUpstreamContextHandler(inbound.newConnection, nil, inbound), Logger: logger, }) if err != nil { @@ -80,3 +81,13 @@ func NewShadowTLS(ctx context.Context, router adapter.Router, logger log.Context func (h *ShadowTLS) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { return h.service.NewConnection(adapter.WithContext(log.ContextWithNewID(ctx), &metadata), conn, adapter.UpstreamMetadata(metadata)) } + +func (h *ShadowTLS) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { + if userName, _ := auth.UserFromContext[string](ctx); userName != "" { + metadata.User = userName + h.logger.InfoContext(ctx, "[", userName, "] inbound connection to ", metadata.Destination) + } else { + h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination) + } + return h.router.RouteConnection(ctx, conn, metadata) +} diff --git a/inbound/tuic.go b/inbound/tuic.go index e6714f0d..19b4b3d7 100644 --- a/inbound/tuic.go +++ b/inbound/tuic.go @@ -88,11 +88,13 @@ func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogge func (h *TUIC) newConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { ctx = log.ContextWithNewID(ctx) - h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination) metadata = h.createMetadata(conn, metadata) userID, _ := auth.UserFromContext[int](ctx) if userName := h.userNameList[userID]; userName != "" { metadata.User = userName + h.logger.InfoContext(ctx, "[", userName, "] inbound connection to ", metadata.Destination) + } else { + h.logger.InfoContext(ctx, "inbound connection to ", metadata.Destination) } return h.router.RouteConnection(ctx, conn, metadata) } @@ -103,8 +105,10 @@ func (h *TUIC) newPacketConnection(ctx context.Context, conn N.PacketConn, metad userID, _ := auth.UserFromContext[int](ctx) if userName := h.userNameList[userID]; userName != "" { metadata.User = userName + h.logger.InfoContext(ctx, "[", userName, "] inbound packet connection to ", metadata.Destination) + } else { + h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination) } - h.logger.InfoContext(ctx, "inbound packet connection to ", metadata.Destination) return h.router.RoutePacketConnection(ctx, conn, metadata) }