From 5b343d4c72d370e1937f72f03f038bb93f586a7f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 29 Aug 2023 13:43:42 +0800 Subject: [PATCH] Improve ECH support --- adapter/router.go | 11 +- box.go | 2 + common/tls/client.go | 14 +- common/tls/ech_client.go | 29 +-- common/tls/ech_server.go | 330 +++++++++++++++++++++++++++++++++++ common/tls/ech_stub.go | 10 +- common/tls/reality_client.go | 5 +- common/tls/reality_server.go | 7 +- common/tls/reality_stub.go | 3 +- common/tls/server.go | 11 +- common/tls/std_client.go | 7 +- common/tls/std_server.go | 7 +- common/tls/utls_client.go | 6 +- common/tls/utls_stub.go | 7 +- inbound/http.go | 2 +- inbound/hysteria.go | 2 +- inbound/naive.go | 2 +- inbound/trojan.go | 2 +- inbound/tuic.go | 3 +- inbound/vless.go | 2 +- inbound/vmess.go | 2 +- option/tls.go | 22 ++- outbound/builder.go | 2 +- outbound/http.go | 4 +- outbound/hysteria.go | 2 +- outbound/shadowsocks.go | 2 +- outbound/shadowtls.go | 2 +- outbound/trojan.go | 2 +- outbound/tuic.go | 2 +- outbound/vless.go | 2 +- outbound/vmess.go | 2 +- transport/sip003/obfs.go | 2 +- transport/sip003/plugin.go | 6 +- transport/sip003/v2ray.go | 4 +- 34 files changed, 434 insertions(+), 84 deletions(-) create mode 100644 common/tls/ech_server.go diff --git a/adapter/router.go b/adapter/router.go index ca6cc3a3..164a8dd3 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -10,6 +10,7 @@ import ( "github.com/sagernet/sing-tun" "github.com/sagernet/sing/common/control" N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/service" mdns "github.com/miekg/dns" ) @@ -56,18 +57,12 @@ type Router interface { ResetNetwork() error } -type routerContextKey struct{} - func ContextWithRouter(ctx context.Context, router Router) context.Context { - return context.WithValue(ctx, (*routerContextKey)(nil), router) + return service.ContextWith(ctx, router) } func RouterFromContext(ctx context.Context) Router { - metadata := ctx.Value((*routerContextKey)(nil)) - if metadata == nil { - return nil - } - return metadata.(Router) + return service.FromContext[Router](ctx) } type Rule interface { diff --git a/box.go b/box.go index 73a400b7..0a66c88c 100644 --- a/box.go +++ b/box.go @@ -19,6 +19,7 @@ import ( "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" F "github.com/sagernet/sing/common/format" + "github.com/sagernet/sing/service" "github.com/sagernet/sing/service/pause" ) @@ -47,6 +48,7 @@ func New(options Options) (*Box, error) { if ctx == nil { ctx = context.Background() } + ctx = service.ContextWithDefaultRegistry(ctx) ctx = pause.ContextWithDefaultManager(ctx) createdAt := time.Now() experimentalOptions := common.PtrValueOrDefault(options.Experimental) diff --git a/common/tls/client.go b/common/tls/client.go index a019a91f..d1c9475a 100644 --- a/common/tls/client.go +++ b/common/tls/client.go @@ -13,29 +13,29 @@ import ( aTLS "github.com/sagernet/sing/common/tls" ) -func NewDialerFromOptions(router adapter.Router, dialer N.Dialer, serverAddress string, options option.OutboundTLSOptions) (N.Dialer, error) { +func NewDialerFromOptions(ctx context.Context, router adapter.Router, dialer N.Dialer, serverAddress string, options option.OutboundTLSOptions) (N.Dialer, error) { if !options.Enabled { return dialer, nil } - config, err := NewClient(router, serverAddress, options) + config, err := NewClient(ctx, serverAddress, options) if err != nil { return nil, err } return NewDialer(dialer, config), nil } -func NewClient(router adapter.Router, serverAddress string, options option.OutboundTLSOptions) (Config, error) { +func NewClient(ctx context.Context, serverAddress string, options option.OutboundTLSOptions) (Config, error) { if !options.Enabled { return nil, nil } if options.ECH != nil && options.ECH.Enabled { - return NewECHClient(router, serverAddress, options) + return NewECHClient(ctx, serverAddress, options) } else if options.Reality != nil && options.Reality.Enabled { - return NewRealityClient(router, serverAddress, options) + return NewRealityClient(ctx, serverAddress, options) } else if options.UTLS != nil && options.UTLS.Enabled { - return NewUTLSClient(router, serverAddress, options) + return NewUTLSClient(ctx, serverAddress, options) } else { - return NewSTDClient(router, serverAddress, options) + return NewSTDClient(ctx, serverAddress, options) } } diff --git a/common/tls/ech_client.go b/common/tls/ech_client.go index 57b9ca9a..5b4d72da 100644 --- a/common/tls/ech_client.go +++ b/common/tls/ech_client.go @@ -7,15 +7,18 @@ import ( "crypto/tls" "crypto/x509" "encoding/base64" + "encoding/pem" "net" "net/netip" "os" + "strings" cftls "github.com/sagernet/cloudflare-tls" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-dns" E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/ntp" mDNS "github.com/miekg/dns" ) @@ -80,7 +83,7 @@ func (c *echConnWrapper) Upstream() any { return c.Conn } -func NewECHClient(router adapter.Router, serverAddress string, options option.OutboundTLSOptions) (Config, error) { +func NewECHClient(ctx context.Context, serverAddress string, options option.OutboundTLSOptions) (Config, error) { var serverName string if options.ServerName != "" { serverName = options.ServerName @@ -94,7 +97,7 @@ func NewECHClient(router adapter.Router, serverAddress string, options option.Ou } var tlsConfig cftls.Config - tlsConfig.Time = router.TimeFunc() + tlsConfig.Time = ntp.TimeFuncFromContext(ctx) if options.DisableSNI { tlsConfig.ServerName = "127.0.0.1" } else { @@ -168,24 +171,24 @@ func NewECHClient(router adapter.Router, serverAddress string, options option.Ou tlsConfig.ECHEnabled = true tlsConfig.PQSignatureSchemesEnabled = options.ECH.PQSignatureSchemesEnabled tlsConfig.DynamicRecordSizingDisabled = options.ECH.DynamicRecordSizingDisabled - if options.ECH.Config != "" { - clientConfigContent, err := base64.StdEncoding.DecodeString(options.ECH.Config) - if err != nil { - return nil, err + if len(options.ECH.Config) > 0 { + block, rest := pem.Decode([]byte(strings.Join(options.ECH.Config, "\n"))) + if block == nil || block.Type != "ECH CONFIGS" || len(rest) > 0 { + return nil, E.New("invalid ECH configs pem") } - clientConfig, err := cftls.UnmarshalECHConfigs(clientConfigContent) + echConfigs, err := cftls.UnmarshalECHConfigs(block.Bytes) if err != nil { - return nil, err + return nil, E.Cause(err, "parse ECH configs") } - tlsConfig.ClientECHConfigs = clientConfig + tlsConfig.ClientECHConfigs = echConfigs } else { - tlsConfig.GetClientECHConfigs = fetchECHClientConfig(router) + tlsConfig.GetClientECHConfigs = fetchECHClientConfig(ctx) } return &ECHClientConfig{&tlsConfig}, nil } -func fetchECHClientConfig(router adapter.Router) func(ctx context.Context, serverName string) ([]cftls.ECHConfig, error) { - return func(ctx context.Context, serverName string) ([]cftls.ECHConfig, error) { +func fetchECHClientConfig(ctx context.Context) func(_ context.Context, serverName string) ([]cftls.ECHConfig, error) { + return func(_ context.Context, serverName string) ([]cftls.ECHConfig, error) { message := &mDNS.Msg{ MsgHdr: mDNS.MsgHdr{ RecursionDesired: true, @@ -198,7 +201,7 @@ func fetchECHClientConfig(router adapter.Router) func(ctx context.Context, serve }, }, } - response, err := router.Exchange(ctx, message) + response, err := adapter.RouterFromContext(ctx).Exchange(ctx, message) if err != nil { return nil, err } diff --git a/common/tls/ech_server.go b/common/tls/ech_server.go new file mode 100644 index 00000000..412599ed --- /dev/null +++ b/common/tls/ech_server.go @@ -0,0 +1,330 @@ +//go:build with_ech + +package tls + +import ( + "context" + "crypto/tls" + "encoding/pem" + "net" + "os" + "strings" + + cftls "github.com/sagernet/cloudflare-tls" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/ntp" + + "github.com/fsnotify/fsnotify" +) + +type echServerConfig struct { + config *cftls.Config + logger log.Logger + certificate []byte + key []byte + certificatePath string + keyPath string + watcher *fsnotify.Watcher + echKeyPath string + echWatcher *fsnotify.Watcher +} + +func (c *echServerConfig) ServerName() string { + return c.config.ServerName +} + +func (c *echServerConfig) SetServerName(serverName string) { + c.config.ServerName = serverName +} + +func (c *echServerConfig) NextProtos() []string { + return c.config.NextProtos +} + +func (c *echServerConfig) SetNextProtos(nextProto []string) { + c.config.NextProtos = nextProto +} + +func (c *echServerConfig) Config() (*STDConfig, error) { + return nil, E.New("unsupported usage for ECH") +} + +func (c *echServerConfig) Client(conn net.Conn) (Conn, error) { + return &echConnWrapper{cftls.Client(conn, c.config)}, nil +} + +func (c *echServerConfig) Server(conn net.Conn) (Conn, error) { + return &echConnWrapper{cftls.Server(conn, c.config)}, nil +} + +func (c *echServerConfig) Clone() Config { + return &echServerConfig{ + config: c.config.Clone(), + } +} + +func (c *echServerConfig) Start() error { + if c.certificatePath != "" && c.keyPath != "" { + err := c.startWatcher() + if err != nil { + c.logger.Warn("create fsnotify watcher: ", err) + } + } + if c.echKeyPath != "" { + err := c.startECHWatcher() + if err != nil { + c.logger.Warn("create fsnotify watcher: ", err) + } + } + return nil +} + +func (c *echServerConfig) startWatcher() error { + watcher, err := fsnotify.NewWatcher() + if err != nil { + return err + } + if c.certificatePath != "" { + err = watcher.Add(c.certificatePath) + if err != nil { + return err + } + } + if c.keyPath != "" { + err = watcher.Add(c.keyPath) + if err != nil { + return err + } + } + c.watcher = watcher + go c.loopUpdate() + return nil +} + +func (c *echServerConfig) loopUpdate() { + for { + select { + case event, ok := <-c.watcher.Events: + if !ok { + return + } + if event.Op&fsnotify.Write != fsnotify.Write { + continue + } + err := c.reloadKeyPair() + if err != nil { + c.logger.Error(E.Cause(err, "reload TLS key pair")) + } + case err, ok := <-c.watcher.Errors: + if !ok { + return + } + c.logger.Error(E.Cause(err, "fsnotify error")) + } + } +} + +func (c *echServerConfig) reloadKeyPair() error { + if c.certificatePath != "" { + certificate, err := os.ReadFile(c.certificatePath) + if err != nil { + return E.Cause(err, "reload certificate from ", c.certificatePath) + } + c.certificate = certificate + } + if c.keyPath != "" { + key, err := os.ReadFile(c.keyPath) + if err != nil { + return E.Cause(err, "reload key from ", c.keyPath) + } + c.key = key + } + keyPair, err := cftls.X509KeyPair(c.certificate, c.key) + if err != nil { + return E.Cause(err, "reload key pair") + } + c.config.Certificates = []cftls.Certificate{keyPair} + c.logger.Info("reloaded TLS certificate") + return nil +} + +func (c *echServerConfig) startECHWatcher() error { + watcher, err := fsnotify.NewWatcher() + if err != nil { + return err + } + err = watcher.Add(c.echKeyPath) + if err != nil { + return err + } + c.watcher = watcher + go c.loopECHUpdate() + return nil +} + +func (c *echServerConfig) loopECHUpdate() { + for { + select { + case event, ok := <-c.echWatcher.Events: + if !ok { + return + } + if event.Op&fsnotify.Write != fsnotify.Write { + continue + } + err := c.reloadECHKey() + if err != nil { + c.logger.Error(E.Cause(err, "reload ECH key")) + } + case err, ok := <-c.watcher.Errors: + if !ok { + return + } + c.logger.Error(E.Cause(err, "fsnotify error")) + } + } +} + +func (c *echServerConfig) reloadECHKey() error { + echKeyContent, err := os.ReadFile(c.echKeyPath) + if err != nil { + return err + } + block, rest := pem.Decode(echKeyContent) + if block == nil || block.Type != "ECH KEYS" || len(rest) > 0 { + return E.New("invalid ECH keys pem") + } + echKeys, err := cftls.EXP_UnmarshalECHKeys(block.Bytes) + if err != nil { + return E.Cause(err, "parse ECH keys") + } + echKeySet, err := cftls.EXP_NewECHKeySet(echKeys) + if err != nil { + return E.Cause(err, "create ECH key set") + } + c.config.ServerECHProvider = echKeySet + c.logger.Info("reloaded ECH keys") + return nil +} + +func (c *echServerConfig) Close() error { + var err error + if c.watcher != nil { + err = E.Append(err, c.watcher.Close(), func(err error) error { + return E.Cause(err, "close certificate watcher") + }) + } + if c.echWatcher != nil { + err = E.Append(err, c.echWatcher.Close(), func(err error) error { + return E.Cause(err, "close ECH key watcher") + }) + } + return err +} + +func NewECHServer(ctx context.Context, logger log.Logger, options option.InboundTLSOptions) (ServerConfig, error) { + if !options.Enabled { + return nil, nil + } + var tlsConfig cftls.Config + if options.ACME != nil && len(options.ACME.Domain) > 0 { + return nil, E.New("acme is unavailable in ech") + } + tlsConfig.Time = ntp.TimeFuncFromContext(ctx) + if options.ServerName != "" { + tlsConfig.ServerName = options.ServerName + } + if len(options.ALPN) > 0 { + tlsConfig.NextProtos = append(options.ALPN, tlsConfig.NextProtos...) + } + if options.MinVersion != "" { + minVersion, err := ParseTLSVersion(options.MinVersion) + if err != nil { + return nil, E.Cause(err, "parse min_version") + } + tlsConfig.MinVersion = minVersion + } + if options.MaxVersion != "" { + maxVersion, err := ParseTLSVersion(options.MaxVersion) + if err != nil { + return nil, E.Cause(err, "parse max_version") + } + tlsConfig.MaxVersion = maxVersion + } + if options.CipherSuites != nil { + find: + for _, cipherSuite := range options.CipherSuites { + for _, tlsCipherSuite := range tls.CipherSuites() { + if cipherSuite == tlsCipherSuite.Name { + tlsConfig.CipherSuites = append(tlsConfig.CipherSuites, tlsCipherSuite.ID) + continue find + } + } + return nil, E.New("unknown cipher_suite: ", cipherSuite) + } + } + var certificate []byte + var key []byte + if len(options.Certificate) > 0 { + certificate = []byte(strings.Join(options.Certificate, "\n")) + } else if options.CertificatePath != "" { + content, err := os.ReadFile(options.CertificatePath) + if err != nil { + return nil, E.Cause(err, "read certificate") + } + certificate = content + } + if len(options.Key) > 0 { + key = []byte(strings.Join(options.Key, "")) + } else if options.KeyPath != "" { + content, err := os.ReadFile(options.KeyPath) + if err != nil { + return nil, E.Cause(err, "read key") + } + key = content + } + + if certificate == nil { + return nil, E.New("missing certificate") + } else if key == nil { + return nil, E.New("missing key") + } + + keyPair, err := cftls.X509KeyPair(certificate, key) + if err != nil { + return nil, E.Cause(err, "parse x509 key pair") + } + tlsConfig.Certificates = []cftls.Certificate{keyPair} + + block, rest := pem.Decode([]byte(strings.Join(options.ECH.Key, "\n"))) + if block == nil || block.Type != "ECH KEYS" || len(rest) > 0 { + return nil, E.New("invalid ECH keys pem") + } + + echKeys, err := cftls.EXP_UnmarshalECHKeys(block.Bytes) + if err != nil { + return nil, E.Cause(err, "parse ECH keys") + } + + echKeySet, err := cftls.EXP_NewECHKeySet(echKeys) + if err != nil { + return nil, E.Cause(err, "create ECH key set") + } + + tlsConfig.ECHEnabled = true + tlsConfig.PQSignatureSchemesEnabled = options.ECH.PQSignatureSchemesEnabled + tlsConfig.DynamicRecordSizingDisabled = options.ECH.DynamicRecordSizingDisabled + tlsConfig.ServerECHProvider = echKeySet + + return &echServerConfig{ + config: &tlsConfig, + logger: logger, + certificate: certificate, + key: key, + certificatePath: options.CertificatePath, + keyPath: options.KeyPath, + echKeyPath: options.ECH.KeyPath, + }, nil +} diff --git a/common/tls/ech_stub.go b/common/tls/ech_stub.go index a24a5ff7..0aab700e 100644 --- a/common/tls/ech_stub.go +++ b/common/tls/ech_stub.go @@ -3,11 +3,17 @@ package tls import ( - "github.com/sagernet/sing-box/adapter" + "context" + + "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" E "github.com/sagernet/sing/common/exceptions" ) -func NewECHClient(router adapter.Router, serverAddress string, options option.OutboundTLSOptions) (Config, error) { +func NewECHServer(ctx context.Context, logger log.Logger, options option.InboundTLSOptions) (ServerConfig, error) { + return nil, E.New(`ECH is not included in this build, rebuild with -tags with_ech`) +} + +func NewECHClient(ctx context.Context, serverAddress string, options option.OutboundTLSOptions) (Config, error) { return nil, E.New(`ECH is not included in this build, rebuild with -tags with_ech`) } diff --git a/common/tls/reality_client.go b/common/tls/reality_client.go index 7742b5b4..afbd3e3e 100644 --- a/common/tls/reality_client.go +++ b/common/tls/reality_client.go @@ -26,7 +26,6 @@ import ( "time" "unsafe" - "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common/debug" E "github.com/sagernet/sing/common/exceptions" @@ -45,12 +44,12 @@ type RealityClientConfig struct { shortID [8]byte } -func NewRealityClient(router adapter.Router, serverAddress string, options option.OutboundTLSOptions) (*RealityClientConfig, error) { +func NewRealityClient(ctx context.Context, serverAddress string, options option.OutboundTLSOptions) (*RealityClientConfig, error) { if options.UTLS == nil || !options.UTLS.Enabled { return nil, E.New("uTLS is required by reality client") } - uClient, err := NewUTLSClient(router, serverAddress, options) + uClient, err := NewUTLSClient(ctx, serverAddress, options) if err != nil { return nil, err } diff --git a/common/tls/reality_server.go b/common/tls/reality_server.go index fd1a6815..f6e30827 100644 --- a/common/tls/reality_server.go +++ b/common/tls/reality_server.go @@ -19,6 +19,7 @@ import ( E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/ntp" ) var _ ServerConfigCompat = (*RealityServerConfig)(nil) @@ -27,13 +28,13 @@ type RealityServerConfig struct { config *reality.Config } -func NewRealityServer(ctx context.Context, router adapter.Router, logger log.Logger, options option.InboundTLSOptions) (*RealityServerConfig, error) { +func NewRealityServer(ctx context.Context, logger log.Logger, options option.InboundTLSOptions) (*RealityServerConfig, error) { var tlsConfig reality.Config if options.ACME != nil && len(options.ACME.Domain) > 0 { return nil, E.New("acme is unavailable in reality") } - tlsConfig.Time = router.TimeFunc() + tlsConfig.Time = ntp.TimeFuncFromContext(ctx) if options.ServerName != "" { tlsConfig.ServerName = options.ServerName } @@ -101,7 +102,7 @@ func NewRealityServer(ctx context.Context, router adapter.Router, logger log.Log tlsConfig.ShortIds[shortID] = true } - handshakeDialer, err := dialer.New(router, options.Reality.Handshake.DialerOptions) + handshakeDialer, err := dialer.New(adapter.RouterFromContext(ctx), options.Reality.Handshake.DialerOptions) if err != nil { return nil, err } diff --git a/common/tls/reality_stub.go b/common/tls/reality_stub.go index 766c01df..8d394f7b 100644 --- a/common/tls/reality_stub.go +++ b/common/tls/reality_stub.go @@ -5,12 +5,11 @@ package tls import ( "context" - "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" E "github.com/sagernet/sing/common/exceptions" ) -func NewRealityServer(ctx context.Context, router adapter.Router, logger log.Logger, options option.InboundTLSOptions) (ServerConfig, error) { +func NewRealityServer(ctx context.Context, logger log.Logger, options option.InboundTLSOptions) (ServerConfig, error) { return nil, E.New(`reality server is not included in this build, rebuild with -tags with_reality_server`) } diff --git a/common/tls/server.go b/common/tls/server.go index bacb4cce..ac6d0a2e 100644 --- a/common/tls/server.go +++ b/common/tls/server.go @@ -4,21 +4,22 @@ import ( "context" "net" - "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" aTLS "github.com/sagernet/sing/common/tls" ) -func NewServer(ctx context.Context, router adapter.Router, logger log.Logger, options option.InboundTLSOptions) (ServerConfig, error) { +func NewServer(ctx context.Context, logger log.Logger, options option.InboundTLSOptions) (ServerConfig, error) { if !options.Enabled { return nil, nil } - if options.Reality != nil && options.Reality.Enabled { - return NewRealityServer(ctx, router, logger, options) + if options.ECH != nil && options.ECH.Enabled { + return NewECHServer(ctx, logger, options) + } else if options.Reality != nil && options.Reality.Enabled { + return NewRealityServer(ctx, logger, options) } else { - return NewSTDServer(ctx, router, logger, options) + return NewSTDServer(ctx, logger, options) } } diff --git a/common/tls/std_client.go b/common/tls/std_client.go index 85df7b74..4aa50b94 100644 --- a/common/tls/std_client.go +++ b/common/tls/std_client.go @@ -1,15 +1,16 @@ package tls import ( + "context" "crypto/tls" "crypto/x509" "net" "net/netip" "os" - "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/option" E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/ntp" ) type STDClientConfig struct { @@ -44,7 +45,7 @@ func (s *STDClientConfig) Clone() Config { return &STDClientConfig{s.config.Clone()} } -func NewSTDClient(router adapter.Router, serverAddress string, options option.OutboundTLSOptions) (Config, error) { +func NewSTDClient(ctx context.Context, serverAddress string, options option.OutboundTLSOptions) (Config, error) { var serverName string if options.ServerName != "" { serverName = options.ServerName @@ -58,7 +59,7 @@ func NewSTDClient(router adapter.Router, serverAddress string, options option.Ou } var tlsConfig tls.Config - tlsConfig.Time = router.TimeFunc() + tlsConfig.Time = ntp.TimeFuncFromContext(ctx) if options.DisableSNI { tlsConfig.ServerName = "127.0.0.1" } else { diff --git a/common/tls/std_server.go b/common/tls/std_server.go index f2bf56eb..33d92232 100644 --- a/common/tls/std_server.go +++ b/common/tls/std_server.go @@ -11,6 +11,7 @@ import ( "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/ntp" "github.com/fsnotify/fsnotify" ) @@ -156,7 +157,7 @@ func (c *STDServerConfig) Close() error { return nil } -func NewSTDServer(ctx context.Context, router adapter.Router, logger log.Logger, options option.InboundTLSOptions) (ServerConfig, error) { +func NewSTDServer(ctx context.Context, logger log.Logger, options option.InboundTLSOptions) (ServerConfig, error) { if !options.Enabled { return nil, nil } @@ -175,7 +176,7 @@ func NewSTDServer(ctx context.Context, router adapter.Router, logger log.Logger, } else { tlsConfig = &tls.Config{} } - tlsConfig.Time = router.TimeFunc() + tlsConfig.Time = ntp.TimeFuncFromContext(ctx) if options.ServerName != "" { tlsConfig.ServerName = options.ServerName } @@ -231,7 +232,7 @@ func NewSTDServer(ctx context.Context, router adapter.Router, logger log.Logger, } if certificate == nil && key == nil && options.Insecure { tlsConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { - return GenerateKeyPair(router.TimeFunc(), info.ServerName) + return GenerateKeyPair(ntp.TimeFuncFromContext(ctx), info.ServerName) } } else { if certificate == nil { diff --git a/common/tls/utls_client.go b/common/tls/utls_client.go index b5f38eab..d190a1f7 100644 --- a/common/tls/utls_client.go +++ b/common/tls/utls_client.go @@ -11,9 +11,9 @@ import ( "net/netip" "os" - "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/option" E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/ntp" utls "github.com/sagernet/utls" "golang.org/x/net/http2" @@ -113,7 +113,7 @@ func (c *utlsALPNWrapper) HandshakeContext(ctx context.Context) error { return c.UConn.HandshakeContext(ctx) } -func NewUTLSClient(router adapter.Router, serverAddress string, options option.OutboundTLSOptions) (*UTLSClientConfig, error) { +func NewUTLSClient(ctx context.Context, serverAddress string, options option.OutboundTLSOptions) (*UTLSClientConfig, error) { var serverName string if options.ServerName != "" { serverName = options.ServerName @@ -127,7 +127,7 @@ func NewUTLSClient(router adapter.Router, serverAddress string, options option.O } var tlsConfig utls.Config - tlsConfig.Time = router.TimeFunc() + tlsConfig.Time = ntp.TimeFuncFromContext(ctx) if options.DisableSNI { tlsConfig.ServerName = "127.0.0.1" } else { diff --git a/common/tls/utls_stub.go b/common/tls/utls_stub.go index 7d0ce80e..d015611a 100644 --- a/common/tls/utls_stub.go +++ b/common/tls/utls_stub.go @@ -3,15 +3,16 @@ package tls import ( - "github.com/sagernet/sing-box/adapter" + "context" + "github.com/sagernet/sing-box/option" E "github.com/sagernet/sing/common/exceptions" ) -func NewUTLSClient(router adapter.Router, serverAddress string, options option.OutboundTLSOptions) (Config, error) { +func NewUTLSClient(ctx context.Context, serverAddress string, options option.OutboundTLSOptions) (Config, error) { return nil, E.New(`uTLS is not included in this build, rebuild with -tags with_utls`) } -func NewRealityClient(router adapter.Router, serverAddress string, options option.OutboundTLSOptions) (Config, error) { +func NewRealityClient(ctx context.Context, serverAddress string, options option.OutboundTLSOptions) (Config, error) { return nil, E.New(`uTLS, which is required by reality client is not included in this build, rebuild with -tags with_utls`) } diff --git a/inbound/http.go b/inbound/http.go index 9fe8ac80..14a614b1 100644 --- a/inbound/http.go +++ b/inbound/http.go @@ -44,7 +44,7 @@ func NewHTTP(ctx context.Context, router adapter.Router, logger log.ContextLogge authenticator: auth.NewAuthenticator(options.Users), } if options.TLS != nil { - tlsConfig, err := tls.NewServer(ctx, router, logger, common.PtrValueOrDefault(options.TLS)) + tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS)) if err != nil { return nil, err } diff --git a/inbound/hysteria.go b/inbound/hysteria.go index 0e13bdcd..6e94e7d0 100644 --- a/inbound/hysteria.go +++ b/inbound/hysteria.go @@ -126,7 +126,7 @@ func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextL if len(options.TLS.ALPN) == 0 { options.TLS.ALPN = []string{hysteria.DefaultALPN} } - tlsConfig, err := tls.NewServer(ctx, router, logger, common.PtrValueOrDefault(options.TLS)) + tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS)) if err != nil { return nil, err } diff --git a/inbound/naive.go b/inbound/naive.go index 63342cdf..7542ae0c 100644 --- a/inbound/naive.go +++ b/inbound/naive.go @@ -60,7 +60,7 @@ func NewNaive(ctx context.Context, router adapter.Router, logger log.ContextLogg return nil, E.New("missing users") } if options.TLS != nil { - tlsConfig, err := tls.NewServer(ctx, router, logger, common.PtrValueOrDefault(options.TLS)) + tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS)) if err != nil { return nil, err } diff --git a/inbound/trojan.go b/inbound/trojan.go index 0ec65498..84ef9bcf 100644 --- a/inbound/trojan.go +++ b/inbound/trojan.go @@ -49,7 +49,7 @@ func NewTrojan(ctx context.Context, router adapter.Router, logger log.ContextLog users: options.Users, } if options.TLS != nil { - tlsConfig, err := tls.NewServer(ctx, router, logger, common.PtrValueOrDefault(options.TLS)) + tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS)) if err != nil { return nil, err } diff --git a/inbound/tuic.go b/inbound/tuic.go index 9b5c4ef3..65b6e3b3 100644 --- a/inbound/tuic.go +++ b/inbound/tuic.go @@ -34,7 +34,7 @@ func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogge if options.TLS == nil || !options.TLS.Enabled { return nil, C.ErrTLSRequired } - tlsConfig, err := tls.NewServer(ctx, router, logger, common.PtrValueOrDefault(options.TLS)) + tlsConfig, err := tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS)) if err != nil { return nil, err } @@ -115,6 +115,7 @@ func (h *TUIC) Start() error { func (h *TUIC) Close() error { return common.Close( &h.myInboundAdapter, + h.tlsConfig, common.PtrOrNil(h.server), ) } diff --git a/inbound/vless.go b/inbound/vless.go index d4c00f76..efb73551 100644 --- a/inbound/vless.go +++ b/inbound/vless.go @@ -61,7 +61,7 @@ func NewVLESS(ctx context.Context, router adapter.Router, logger log.ContextLogg inbound.service = service var err error if options.TLS != nil { - inbound.tlsConfig, err = tls.NewServer(ctx, router, logger, common.PtrValueOrDefault(options.TLS)) + inbound.tlsConfig, err = tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS)) if err != nil { return nil, err } diff --git a/inbound/vmess.go b/inbound/vmess.go index 69552162..e65c05a0 100644 --- a/inbound/vmess.go +++ b/inbound/vmess.go @@ -69,7 +69,7 @@ func NewVMess(ctx context.Context, router adapter.Router, logger log.ContextLogg return nil, err } if options.TLS != nil { - inbound.tlsConfig, err = tls.NewServer(ctx, router, logger, common.PtrValueOrDefault(options.TLS)) + inbound.tlsConfig, err = tls.NewServer(ctx, logger, common.PtrValueOrDefault(options.TLS)) if err != nil { return nil, err } diff --git a/option/tls.go b/option/tls.go index 2ff5f2e4..1f9f5746 100644 --- a/option/tls.go +++ b/option/tls.go @@ -8,11 +8,12 @@ type InboundTLSOptions struct { MinVersion string `json:"min_version,omitempty"` MaxVersion string `json:"max_version,omitempty"` CipherSuites Listable[string] `json:"cipher_suites,omitempty"` - Certificate string `json:"certificate,omitempty"` + Certificate Listable[string] `json:"certificate,omitempty"` CertificatePath string `json:"certificate_path,omitempty"` - Key string `json:"key,omitempty"` + Key Listable[string] `json:"key,omitempty"` KeyPath string `json:"key_path,omitempty"` ACME *InboundACMEOptions `json:"acme,omitempty"` + ECH *InboundECHOptions `json:"ech,omitempty"` Reality *InboundRealityOptions `json:"reality,omitempty"` } @@ -45,11 +46,20 @@ type InboundRealityHandshakeOptions struct { DialerOptions } +type InboundECHOptions struct { + Enabled bool `json:"enabled,omitempty"` + PQSignatureSchemesEnabled bool `json:"pq_signature_schemes_enabled,omitempty"` + DynamicRecordSizingDisabled bool `json:"dynamic_record_sizing_disabled,omitempty"` + Key Listable[string] `json:"ech_keys,omitempty"` + KeyPath string `json:"ech_keys_path,omitempty"` +} + type OutboundECHOptions struct { - Enabled bool `json:"enabled,omitempty"` - PQSignatureSchemesEnabled bool `json:"pq_signature_schemes_enabled,omitempty"` - DynamicRecordSizingDisabled bool `json:"dynamic_record_sizing_disabled,omitempty"` - Config string `json:"config,omitempty"` + Enabled bool `json:"enabled,omitempty"` + PQSignatureSchemesEnabled bool `json:"pq_signature_schemes_enabled,omitempty"` + DynamicRecordSizingDisabled bool `json:"dynamic_record_sizing_disabled,omitempty"` + Config Listable[string] `json:"config,omitempty"` + ConfigPath string `json:"config_path,omitempty"` } type OutboundUTLSOptions struct { diff --git a/outbound/builder.go b/outbound/builder.go index 1324fdbf..92bdef27 100644 --- a/outbound/builder.go +++ b/outbound/builder.go @@ -30,7 +30,7 @@ func New(ctx context.Context, router adapter.Router, logger log.ContextLogger, t case C.TypeSOCKS: return NewSocks(router, logger, tag, options.SocksOptions) case C.TypeHTTP: - return NewHTTP(router, logger, tag, options.HTTPOptions) + return NewHTTP(ctx, router, logger, tag, options.HTTPOptions) case C.TypeShadowsocks: return NewShadowsocks(ctx, router, logger, tag, options.ShadowsocksOptions) case C.TypeVMess: diff --git a/outbound/http.go b/outbound/http.go index 2e265da1..cd9dc959 100644 --- a/outbound/http.go +++ b/outbound/http.go @@ -25,12 +25,12 @@ type HTTP struct { client *sHTTP.Client } -func NewHTTP(router adapter.Router, logger log.ContextLogger, tag string, options option.HTTPOutboundOptions) (*HTTP, error) { +func NewHTTP(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HTTPOutboundOptions) (*HTTP, error) { outboundDialer, err := dialer.New(router, options.DialerOptions) if err != nil { return nil, err } - detour, err := tls.NewDialerFromOptions(router, outboundDialer, options.Server, common.PtrValueOrDefault(options.TLS)) + detour, err := tls.NewDialerFromOptions(ctx, router, outboundDialer, options.Server, common.PtrValueOrDefault(options.TLS)) if err != nil { return nil, err } diff --git a/outbound/hysteria.go b/outbound/hysteria.go index 9ed5b6d8..456ace50 100644 --- a/outbound/hysteria.go +++ b/outbound/hysteria.go @@ -52,7 +52,7 @@ func NewHysteria(ctx context.Context, router adapter.Router, logger log.ContextL if options.TLS == nil || !options.TLS.Enabled { return nil, C.ErrTLSRequired } - abstractTLSConfig, err := tls.NewClient(router, options.Server, common.PtrValueOrDefault(options.TLS)) + abstractTLSConfig, err := tls.NewClient(ctx, options.Server, common.PtrValueOrDefault(options.TLS)) if err != nil { return nil, err } diff --git a/outbound/shadowsocks.go b/outbound/shadowsocks.go index c8a9b0a8..e436e124 100644 --- a/outbound/shadowsocks.go +++ b/outbound/shadowsocks.go @@ -57,7 +57,7 @@ func NewShadowsocks(ctx context.Context, router adapter.Router, logger log.Conte serverAddr: options.ServerOptions.Build(), } if options.Plugin != "" { - outbound.plugin, err = sip003.CreatePlugin(options.Plugin, options.PluginOptions, router, outbound.dialer, outbound.serverAddr) + outbound.plugin, err = sip003.CreatePlugin(ctx, options.Plugin, options.PluginOptions, router, outbound.dialer, outbound.serverAddr) if err != nil { return nil, err } diff --git a/outbound/shadowtls.go b/outbound/shadowtls.go index 9f02124d..4427301c 100644 --- a/outbound/shadowtls.go +++ b/outbound/shadowtls.go @@ -47,7 +47,7 @@ func NewShadowTLS(ctx context.Context, router adapter.Router, logger log.Context options.TLS.MinVersion = "1.2" options.TLS.MaxVersion = "1.2" } - tlsConfig, err := tls.NewClient(router, options.Server, common.PtrValueOrDefault(options.TLS)) + tlsConfig, err := tls.NewClient(ctx, options.Server, common.PtrValueOrDefault(options.TLS)) if err != nil { return nil, err } diff --git a/outbound/trojan.go b/outbound/trojan.go index db11d105..e8ccb6ae 100644 --- a/outbound/trojan.go +++ b/outbound/trojan.go @@ -51,7 +51,7 @@ func NewTrojan(ctx context.Context, router adapter.Router, logger log.ContextLog key: trojan.Key(options.Password), } if options.TLS != nil { - outbound.tlsConfig, err = tls.NewClient(router, options.Server, common.PtrValueOrDefault(options.TLS)) + outbound.tlsConfig, err = tls.NewClient(ctx, options.Server, common.PtrValueOrDefault(options.TLS)) if err != nil { return nil, err } diff --git a/outbound/tuic.go b/outbound/tuic.go index 71148aca..e20cd411 100644 --- a/outbound/tuic.go +++ b/outbound/tuic.go @@ -41,7 +41,7 @@ func NewTUIC(ctx context.Context, router adapter.Router, logger log.ContextLogge if options.TLS == nil || !options.TLS.Enabled { return nil, C.ErrTLSRequired } - abstractTLSConfig, err := tls.NewClient(router, options.Server, common.PtrValueOrDefault(options.TLS)) + abstractTLSConfig, err := tls.NewClient(ctx, options.Server, common.PtrValueOrDefault(options.TLS)) if err != nil { return nil, err } diff --git a/outbound/vless.go b/outbound/vless.go index 4a130403..a78bddcd 100644 --- a/outbound/vless.go +++ b/outbound/vless.go @@ -53,7 +53,7 @@ func NewVLESS(ctx context.Context, router adapter.Router, logger log.ContextLogg serverAddr: options.ServerOptions.Build(), } if options.TLS != nil { - outbound.tlsConfig, err = tls.NewClient(router, options.Server, common.PtrValueOrDefault(options.TLS)) + outbound.tlsConfig, err = tls.NewClient(ctx, options.Server, common.PtrValueOrDefault(options.TLS)) if err != nil { return nil, err } diff --git a/outbound/vmess.go b/outbound/vmess.go index b07f13d7..48c9d818 100644 --- a/outbound/vmess.go +++ b/outbound/vmess.go @@ -52,7 +52,7 @@ func NewVMess(ctx context.Context, router adapter.Router, logger log.ContextLogg serverAddr: options.ServerOptions.Build(), } if options.TLS != nil { - outbound.tlsConfig, err = tls.NewClient(router, options.Server, common.PtrValueOrDefault(options.TLS)) + outbound.tlsConfig, err = tls.NewClient(ctx, options.Server, common.PtrValueOrDefault(options.TLS)) if err != nil { return nil, err } diff --git a/transport/sip003/obfs.go b/transport/sip003/obfs.go index bb6ca502..129e7756 100644 --- a/transport/sip003/obfs.go +++ b/transport/sip003/obfs.go @@ -18,7 +18,7 @@ func init() { RegisterPlugin("obfs-local", newObfsLocal) } -func newObfsLocal(pluginOpts Args, router adapter.Router, dialer N.Dialer, serverAddr M.Socksaddr) (Plugin, error) { +func newObfsLocal(ctx context.Context, pluginOpts Args, router adapter.Router, dialer N.Dialer, serverAddr M.Socksaddr) (Plugin, error) { plugin := &ObfsLocal{ dialer: dialer, serverAddr: serverAddr, diff --git a/transport/sip003/plugin.go b/transport/sip003/plugin.go index e6a0ba77..546a3f36 100644 --- a/transport/sip003/plugin.go +++ b/transport/sip003/plugin.go @@ -10,7 +10,7 @@ import ( N "github.com/sagernet/sing/common/network" ) -type PluginConstructor func(pluginArgs Args, router adapter.Router, dialer N.Dialer, serverAddr M.Socksaddr) (Plugin, error) +type PluginConstructor func(ctx context.Context, pluginArgs Args, router adapter.Router, dialer N.Dialer, serverAddr M.Socksaddr) (Plugin, error) type Plugin interface { DialContext(ctx context.Context) (net.Conn, error) @@ -25,7 +25,7 @@ func RegisterPlugin(name string, constructor PluginConstructor) { plugins[name] = constructor } -func CreatePlugin(name string, pluginArgs string, router adapter.Router, dialer N.Dialer, serverAddr M.Socksaddr) (Plugin, error) { +func CreatePlugin(ctx context.Context, name string, pluginArgs string, router adapter.Router, dialer N.Dialer, serverAddr M.Socksaddr) (Plugin, error) { pluginOptions, err := ParsePluginOptions(pluginArgs) if err != nil { return nil, E.Cause(err, "parse plugin_opts") @@ -34,5 +34,5 @@ func CreatePlugin(name string, pluginArgs string, router adapter.Router, dialer if !loaded { return nil, E.New("plugin not found: ", name) } - return constructor(pluginOptions, router, dialer, serverAddr) + return constructor(ctx, pluginOptions, router, dialer, serverAddr) } diff --git a/transport/sip003/v2ray.go b/transport/sip003/v2ray.go index 6ff8b695..29054c1a 100644 --- a/transport/sip003/v2ray.go +++ b/transport/sip003/v2ray.go @@ -20,7 +20,7 @@ func init() { RegisterPlugin("v2ray-plugin", newV2RayPlugin) } -func newV2RayPlugin(pluginOpts Args, router adapter.Router, dialer N.Dialer, serverAddr M.Socksaddr) (Plugin, error) { +func newV2RayPlugin(ctx context.Context, pluginOpts Args, router adapter.Router, dialer N.Dialer, serverAddr M.Socksaddr) (Plugin, error) { var tlsOptions option.OutboundTLSOptions if _, loaded := pluginOpts.Get("tls"); loaded { tlsOptions.Enabled = true @@ -54,7 +54,7 @@ func newV2RayPlugin(pluginOpts Args, router adapter.Router, dialer N.Dialer, ser var tlsClient tls.Config var err error if tlsOptions.Enabled { - tlsClient, err = tls.NewClient(router, serverAddr.AddrString(), tlsOptions) + tlsClient, err = tls.NewClient(ctx, serverAddr.AddrString(), tlsOptions) if err != nil { return nil, err }