diff --git a/adapter/inbound.go b/adapter/inbound.go index 6e478ba3..356a3200 100644 --- a/adapter/inbound.go +++ b/adapter/inbound.go @@ -27,7 +27,7 @@ type InjectableInbound interface { type InboundContext struct { Inbound string InboundType string - IPVersion int + IPVersion uint8 Network string Source M.Socksaddr Destination M.Socksaddr diff --git a/adapter/outbound.go b/adapter/outbound.go index a45c27fd..257a525e 100644 --- a/adapter/outbound.go +++ b/adapter/outbound.go @@ -4,6 +4,7 @@ import ( "context" "net" + tun "github.com/sagernet/sing-tun" N "github.com/sagernet/sing/common/network" ) @@ -17,3 +18,8 @@ type Outbound interface { NewConnection(ctx context.Context, conn net.Conn, metadata InboundContext) error NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata InboundContext) error } + +type IPOutbound interface { + Outbound + NewIPConnection(ctx context.Context, conn tun.RouteContext, metadata InboundContext) (tun.DirectDestination, error) +} diff --git a/adapter/router.go b/adapter/router.go index e1807747..7fc6eaf9 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -23,6 +23,9 @@ type Router interface { RouteConnection(ctx context.Context, conn net.Conn, metadata InboundContext) error RoutePacketConnection(ctx context.Context, conn N.PacketConn, metadata InboundContext) error + RouteIPConnection(ctx context.Context, conn tun.RouteContext, metadata InboundContext) tun.RouteAction + + NatRequired(outbound string) bool GeoIPReader() *geoip.Reader LoadGeosite(code string) (Rule, error) @@ -39,7 +42,9 @@ type Router interface { NetworkMonitor() tun.NetworkUpdateMonitor InterfaceMonitor() tun.DefaultInterfaceMonitor PackageManager() tun.PackageManager + Rules() []Rule + IPRules() []IPRule TimeService @@ -78,6 +83,11 @@ type DNSRule interface { DisableCache() bool } +type IPRule interface { + Rule + Action() tun.ActionType +} + type InterfaceUpdateListener interface { InterfaceUpdated() error } diff --git a/box.go b/box.go index fa4393ca..424cc110 100644 --- a/box.go +++ b/box.go @@ -238,6 +238,7 @@ func (s *Box) Start() error { func (s *Box) preStart() error { for serviceName, service := range s.preServices { + s.logger.Trace("pre-starting ", serviceName) err := adapter.PreStart(service) if err != nil { return E.Cause(err, "pre-start ", serviceName) @@ -245,14 +246,15 @@ func (s *Box) preStart() error { } for i, out := range s.outbounds { if starter, isStarter := out.(common.Starter); isStarter { + var tag string + if out.Tag() == "" { + tag = F.ToString(i) + } else { + tag = out.Tag() + } + s.logger.Trace("initializing outbound ", tag) err := starter.Start() if err != nil { - var tag string - if out.Tag() == "" { - tag = F.ToString(i) - } else { - tag = out.Tag() - } return E.Cause(err, "initialize outbound/", out.Type(), "[", tag, "]") } } @@ -266,27 +268,30 @@ func (s *Box) start() error { return err } for serviceName, service := range s.preServices { + s.logger.Trace("starting ", serviceName) err = service.Start() if err != nil { return E.Cause(err, "start ", serviceName) } } for i, in := range s.inbounds { + var tag string + if in.Tag() == "" { + tag = F.ToString(i) + } else { + tag = in.Tag() + } + s.logger.Trace("initializing inbound ", tag) err = in.Start() if err != nil { - var tag string - if in.Tag() == "" { - tag = F.ToString(i) - } else { - tag = in.Tag() - } return E.Cause(err, "initialize inbound/", in.Type(), "[", tag, "]") } } for serviceName, service := range s.postServices { + s.logger.Trace("start ", serviceName) err = service.Start() if err != nil { - return E.Cause(err, "start ", serviceName) + return E.Cause(err, "starting ", serviceName) } } return nil @@ -302,29 +307,47 @@ func (s *Box) Close() error { var errors error for serviceName, service := range s.postServices { errors = E.Append(errors, service.Close(), func(err error) error { + s.logger.Trace("closing ", serviceName) return E.Cause(err, "close ", serviceName) }) } for i, in := range s.inbounds { + var tag string + if in.Tag() == "" { + tag = F.ToString(i) + } else { + tag = in.Tag() + } + s.logger.Trace("closing inbound ", tag) errors = E.Append(errors, in.Close(), func(err error) error { return E.Cause(err, "close inbound/", in.Type(), "[", i, "]") }) } for i, out := range s.outbounds { + var tag string + if out.Tag() == "" { + tag = F.ToString(i) + } else { + tag = out.Tag() + } + s.logger.Trace("closing outbound ", tag) errors = E.Append(errors, common.Close(out), func(err error) error { return E.Cause(err, "close inbound/", out.Type(), "[", i, "]") }) } + s.logger.Trace("closing router") if err := common.Close(s.router); err != nil { errors = E.Append(errors, err, func(err error) error { return E.Cause(err, "close router") }) } for serviceName, service := range s.preServices { + s.logger.Trace("closing ", serviceName) errors = E.Append(errors, service.Close(), func(err error) error { return E.Cause(err, "close ", serviceName) }) } + s.logger.Trace("closing logger") if err := common.Close(s.logFactory); err != nil { errors = E.Append(errors, err, func(err error) error { return E.Cause(err, "close log factory") diff --git a/common/dialer/tfo.go b/common/dialer/tfo.go index 70f97386..b05e01f4 100644 --- a/common/dialer/tfo.go +++ b/common/dialer/tfo.go @@ -27,7 +27,12 @@ type slowOpenConn struct { func DialSlowContext(dialer *tfo.Dialer, ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { if dialer.DisableTFO || N.NetworkName(network) != N.NetworkTCP { - return dialer.DialContext(ctx, network, destination.String(), nil) + switch N.NetworkName(network) { + case N.NetworkTCP, N.NetworkUDP: + return dialer.Dialer.DialContext(ctx, network, destination.String()) + default: + return dialer.Dialer.DialContext(ctx, network, destination.AddrString()) + } } return &slowOpenConn{ dialer: dialer, diff --git a/go.mod b/go.mod index 275154b0..6f8b075f 100644 --- a/go.mod +++ b/go.mod @@ -25,11 +25,11 @@ require ( github.com/sagernet/gomobile v0.0.0-20221130124640-349ebaa752ca github.com/sagernet/quic-go v0.0.0-20230202071646-a8c8afb18b32 github.com/sagernet/reality v0.0.0-20230312150606-35ea9af0e0b8 - github.com/sagernet/sing v0.2.1-0.20230318094614-4bbf5f2c3046 + github.com/sagernet/sing v0.2.1-0.20230321172705-3e60222a1a7d github.com/sagernet/sing-dns v0.1.4 github.com/sagernet/sing-shadowsocks v0.2.0 github.com/sagernet/sing-shadowtls v0.1.0 - github.com/sagernet/sing-tun v0.1.3-0.20230315134716-fe89bbded22d + github.com/sagernet/sing-tun v0.1.3-0.20230321172818-56bedd2f0558 github.com/sagernet/sing-vmess v0.1.3 github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37 github.com/sagernet/tfo-go v0.0.0-20230303015439-ffcfd8c41cf9 diff --git a/go.sum b/go.sum index 77b322dc..dad38bca 100644 --- a/go.sum +++ b/go.sum @@ -111,16 +111,16 @@ github.com/sagernet/reality v0.0.0-20230312150606-35ea9af0e0b8 h1:4M3+0/kqvJuTsi github.com/sagernet/reality v0.0.0-20230312150606-35ea9af0e0b8/go.mod h1:B8lp4WkQ1PwNnrVMM6KyuFR20pU8jYBD+A4EhJovEXU= github.com/sagernet/sing v0.0.0-20220817130738-ce854cda8522/go.mod h1:QVsS5L/ZA2Q5UhQwLrn0Trw+msNd/NPGEhBKR/ioWiY= github.com/sagernet/sing v0.1.8/go.mod h1:jt1w2u7lJQFFSGLiRrRIs5YWmx4kAPfWuOejuDW9qMk= -github.com/sagernet/sing v0.2.1-0.20230318094614-4bbf5f2c3046 h1:/+ZWbxRvQmco9ES2qT5Eh/x/IiQRjAcUyRG/vQ4dpxc= -github.com/sagernet/sing v0.2.1-0.20230318094614-4bbf5f2c3046/go.mod h1:9uHswk2hITw8leDbiLS/xn0t9nzBcbePxzm9PJhwdlw= +github.com/sagernet/sing v0.2.1-0.20230321172705-3e60222a1a7d h1:ktk03rtgPqTDyUd2dWg1uzyr5RnptX8grSMvIzedJlQ= +github.com/sagernet/sing v0.2.1-0.20230321172705-3e60222a1a7d/go.mod h1:9uHswk2hITw8leDbiLS/xn0t9nzBcbePxzm9PJhwdlw= github.com/sagernet/sing-dns v0.1.4 h1:7VxgeoSCiiazDSaXXQVcvrTBxFpOePPq/4XdgnUDN+0= github.com/sagernet/sing-dns v0.1.4/go.mod h1:1+6pCa48B1AI78lD+/i/dLgpw4MwfnsSpZo0Ds8wzzk= github.com/sagernet/sing-shadowsocks v0.2.0 h1:ILDWL7pwWfkPLEbviE/MyCgfjaBmJY/JVVY+5jhSb58= github.com/sagernet/sing-shadowsocks v0.2.0/go.mod h1:ysYzszRLpNzJSorvlWRMuzU6Vchsp7sd52q+JNY4axw= github.com/sagernet/sing-shadowtls v0.1.0 h1:05MYce8aR5xfKIn+y7xRFsdKhKt44QZTSEQW+lG5IWQ= github.com/sagernet/sing-shadowtls v0.1.0/go.mod h1:Kn1VUIprdkwCgkS6SXYaLmIpKzQbqBIKJBMY+RvBhYc= -github.com/sagernet/sing-tun v0.1.3-0.20230315134716-fe89bbded22d h1:1gt4Hu2fHCrmL2NZYCNJ3nCgeczuhK09oCMni9mZmZk= -github.com/sagernet/sing-tun v0.1.3-0.20230315134716-fe89bbded22d/go.mod h1:KnRkwaDHbb06zgeNPu0LQ8A+vA9myMxKEgHN1brCPHg= +github.com/sagernet/sing-tun v0.1.3-0.20230321172818-56bedd2f0558 h1:c5Rm6BTOclEeayS6G9+1rI1kTeilCsn0ALSFbOdlgRE= +github.com/sagernet/sing-tun v0.1.3-0.20230321172818-56bedd2f0558/go.mod h1:cqnZEm+2ArgP4Gq1NcQfVFm9CZaeGw21mG9AcnYOTiU= github.com/sagernet/sing-vmess v0.1.3 h1:q/+tsF46dvvapL6CpQBgPHJ6nQrDUZqEtLHCbsjO7iM= github.com/sagernet/sing-vmess v0.1.3/go.mod h1:GVXqAHwe9U21uS+Voh4YBIrADQyE4F9v0ayGSixSQAE= github.com/sagernet/smux v0.0.0-20230312102458-337ec2a5af37 h1:HuE6xSwco/Xed8ajZ+coeYLmioq0Qp1/Z2zczFaV8as= diff --git a/inbound/tun.go b/inbound/tun.go index 1791ca19..985adba6 100644 --- a/inbound/tun.go +++ b/inbound/tun.go @@ -21,7 +21,10 @@ import ( "github.com/sagernet/sing/common/ranges" ) -var _ adapter.Inbound = (*Tun)(nil) +var ( + _ adapter.Inbound = (*Tun)(nil) + _ tun.Router = (*Tun)(nil) +) type Tun struct { tag string @@ -40,10 +43,6 @@ type Tun struct { } func NewTun(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TunInboundOptions, platformInterface platform.Interface) (*Tun, error) { - tunName := options.InterfaceName - if tunName == "" { - tunName = tun.CalculateInterfaceName("") - } tunMTU := options.MTU if tunMTU == 0 { tunMTU = 9000 @@ -77,7 +76,7 @@ func NewTun(ctx context.Context, router adapter.Router, logger log.ContextLogger logger: logger, inboundOptions: options.InboundOptions, tunOptions: tun.Options{ - Name: tunName, + Name: options.InterfaceName, MTU: tunMTU, Inet4Address: common.Map(options.Inet4Address, option.ListenPrefix.Build), Inet6Address: common.Map(options.Inet6Address, option.ListenPrefix.Build), @@ -143,12 +142,17 @@ func (t *Tun) Tag() string { func (t *Tun) Start() error { if C.IsAndroid && t.platformInterface == nil { + t.logger.Trace("building android rules") t.tunOptions.BuildAndroidRules(t.router.PackageManager(), t) } + if t.tunOptions.Name == "" { + t.tunOptions.Name = tun.CalculateInterfaceName("") + } var ( tunInterface tun.Tun err error ) + t.logger.Trace("opening interface") if t.platformInterface != nil { tunInterface, err = t.platformInterface.OpenTun(t.tunOptions, t.platformOptions) } else { @@ -157,7 +161,12 @@ func (t *Tun) Start() error { if err != nil { return E.Cause(err, "configure tun interface") } + t.logger.Trace("creating stack") t.tunIf = tunInterface + var tunRouter tun.Router + if len(t.router.IPRules()) > 0 { + tunRouter = t + } t.tunStack, err = tun.NewStack(t.stack, tun.StackOptions{ Context: t.ctx, Tun: tunInterface, @@ -167,6 +176,7 @@ func (t *Tun) Start() error { Inet6Address: t.tunOptions.Inet6Address, EndpointIndependentNat: t.endpointIndependentNat, UDPTimeout: t.udpTimeout, + Router: tunRouter, Handler: t, Logger: t.logger, UnderPlatform: t.platformInterface != nil, @@ -174,6 +184,7 @@ func (t *Tun) Start() error { if err != nil { return err } + t.logger.Trace("starting stack") err = t.tunStack.Start() if err != nil { return err @@ -189,6 +200,21 @@ func (t *Tun) Close() error { ) } +func (t *Tun) RouteConnection(session tun.RouteSession, conn tun.RouteContext) tun.RouteAction { + ctx := log.ContextWithNewID(t.ctx) + var metadata adapter.InboundContext + metadata.Inbound = t.tag + metadata.InboundType = C.TypeTun + metadata.IPVersion = session.IPVersion + metadata.Network = tun.NetworkName(session.Network) + metadata.Source = M.SocksaddrFromNetIP(session.Source) + metadata.Destination = M.SocksaddrFromNetIP(session.Destination) + metadata.InboundOptions = t.inboundOptions + t.logger.DebugContext(ctx, "incoming connection from ", metadata.Source) + t.logger.DebugContext(ctx, "incoming connection to ", metadata.Destination) + return t.router.RouteIPConnection(ctx, conn, metadata) +} + func (t *Tun) NewConnection(ctx context.Context, conn net.Conn, upstreamMetadata M.Metadata) error { ctx = log.ContextWithNewID(ctx) var metadata adapter.InboundContext diff --git a/option/dns.go b/option/dns.go index 637f2c05..85a4cfcc 100644 --- a/option/dns.go +++ b/option/dns.go @@ -1,14 +1,5 @@ package option -import ( - "reflect" - - "github.com/sagernet/sing-box/common/json" - C "github.com/sagernet/sing-box/constant" - "github.com/sagernet/sing/common" - E "github.com/sagernet/sing/common/exceptions" -) - type DNSOptions struct { Servers []DNSServerOptions `json:"servers,omitempty"` Rules []DNSRule `json:"rules,omitempty"` @@ -31,97 +22,3 @@ type DNSServerOptions struct { Strategy DomainStrategy `json:"strategy,omitempty"` Detour string `json:"detour,omitempty"` } - -type _DNSRule struct { - Type string `json:"type,omitempty"` - DefaultOptions DefaultDNSRule `json:"-"` - LogicalOptions LogicalDNSRule `json:"-"` -} - -type DNSRule _DNSRule - -func (r DNSRule) MarshalJSON() ([]byte, error) { - var v any - switch r.Type { - case C.RuleTypeDefault: - r.Type = "" - v = r.DefaultOptions - case C.RuleTypeLogical: - v = r.LogicalOptions - default: - return nil, E.New("unknown rule type: " + r.Type) - } - return MarshallObjects((_DNSRule)(r), v) -} - -func (r *DNSRule) UnmarshalJSON(bytes []byte) error { - err := json.Unmarshal(bytes, (*_DNSRule)(r)) - if err != nil { - return err - } - var v any - switch r.Type { - case "", C.RuleTypeDefault: - r.Type = C.RuleTypeDefault - v = &r.DefaultOptions - case C.RuleTypeLogical: - v = &r.LogicalOptions - default: - return E.New("unknown rule type: " + r.Type) - } - err = UnmarshallExcluded(bytes, (*_DNSRule)(r), v) - if err != nil { - return E.Cause(err, "dns route rule") - } - return nil -} - -type DefaultDNSRule struct { - Inbound Listable[string] `json:"inbound,omitempty"` - IPVersion int `json:"ip_version,omitempty"` - QueryType Listable[DNSQueryType] `json:"query_type,omitempty"` - Network string `json:"network,omitempty"` - AuthUser Listable[string] `json:"auth_user,omitempty"` - Protocol Listable[string] `json:"protocol,omitempty"` - Domain Listable[string] `json:"domain,omitempty"` - DomainSuffix Listable[string] `json:"domain_suffix,omitempty"` - DomainKeyword Listable[string] `json:"domain_keyword,omitempty"` - DomainRegex Listable[string] `json:"domain_regex,omitempty"` - Geosite Listable[string] `json:"geosite,omitempty"` - SourceGeoIP Listable[string] `json:"source_geoip,omitempty"` - SourceIPCIDR Listable[string] `json:"source_ip_cidr,omitempty"` - SourcePort Listable[uint16] `json:"source_port,omitempty"` - SourcePortRange Listable[string] `json:"source_port_range,omitempty"` - Port Listable[uint16] `json:"port,omitempty"` - PortRange Listable[string] `json:"port_range,omitempty"` - ProcessName Listable[string] `json:"process_name,omitempty"` - ProcessPath Listable[string] `json:"process_path,omitempty"` - PackageName Listable[string] `json:"package_name,omitempty"` - User Listable[string] `json:"user,omitempty"` - UserID Listable[int32] `json:"user_id,omitempty"` - Outbound Listable[string] `json:"outbound,omitempty"` - ClashMode string `json:"clash_mode,omitempty"` - Invert bool `json:"invert,omitempty"` - Server string `json:"server,omitempty"` - DisableCache bool `json:"disable_cache,omitempty"` -} - -func (r DefaultDNSRule) IsValid() bool { - var defaultValue DefaultDNSRule - defaultValue.Invert = r.Invert - defaultValue.Server = r.Server - defaultValue.DisableCache = r.DisableCache - return !reflect.DeepEqual(r, defaultValue) -} - -type LogicalDNSRule struct { - Mode string `json:"mode"` - Rules []DefaultDNSRule `json:"rules,omitempty"` - Invert bool `json:"invert,omitempty"` - Server string `json:"server,omitempty"` - DisableCache bool `json:"disable_cache,omitempty"` -} - -func (r LogicalDNSRule) IsValid() bool { - return len(r.Rules) > 0 && common.All(r.Rules, DefaultDNSRule.IsValid) -} diff --git a/option/route.go b/option/route.go index 308c4802..b32d4b3f 100644 --- a/option/route.go +++ b/option/route.go @@ -1,17 +1,9 @@ package option -import ( - "reflect" - - "github.com/sagernet/sing-box/common/json" - C "github.com/sagernet/sing-box/constant" - "github.com/sagernet/sing/common" - E "github.com/sagernet/sing/common/exceptions" -) - type RouteOptions struct { GeoIP *GeoIPOptions `json:"geoip,omitempty"` Geosite *GeositeOptions `json:"geosite,omitempty"` + IPRules []IPRule `json:"ip_rules,omitempty"` Rules []Rule `json:"rules,omitempty"` Final string `json:"final,omitempty"` FindProcess bool `json:"find_process,omitempty"` @@ -32,94 +24,3 @@ type GeositeOptions struct { DownloadURL string `json:"download_url,omitempty"` DownloadDetour string `json:"download_detour,omitempty"` } - -type _Rule struct { - Type string `json:"type,omitempty"` - DefaultOptions DefaultRule `json:"-"` - LogicalOptions LogicalRule `json:"-"` -} - -type Rule _Rule - -func (r Rule) MarshalJSON() ([]byte, error) { - var v any - switch r.Type { - case C.RuleTypeDefault: - r.Type = "" - v = r.DefaultOptions - case C.RuleTypeLogical: - v = r.LogicalOptions - default: - return nil, E.New("unknown rule type: " + r.Type) - } - return MarshallObjects((_Rule)(r), v) -} - -func (r *Rule) UnmarshalJSON(bytes []byte) error { - err := json.Unmarshal(bytes, (*_Rule)(r)) - if err != nil { - return err - } - var v any - switch r.Type { - case "", C.RuleTypeDefault: - r.Type = C.RuleTypeDefault - v = &r.DefaultOptions - case C.RuleTypeLogical: - v = &r.LogicalOptions - default: - return E.New("unknown rule type: " + r.Type) - } - err = UnmarshallExcluded(bytes, (*_Rule)(r), v) - if err != nil { - return E.Cause(err, "route rule") - } - return nil -} - -type DefaultRule struct { - Inbound Listable[string] `json:"inbound,omitempty"` - IPVersion int `json:"ip_version,omitempty"` - Network string `json:"network,omitempty"` - AuthUser Listable[string] `json:"auth_user,omitempty"` - Protocol Listable[string] `json:"protocol,omitempty"` - Domain Listable[string] `json:"domain,omitempty"` - DomainSuffix Listable[string] `json:"domain_suffix,omitempty"` - DomainKeyword Listable[string] `json:"domain_keyword,omitempty"` - DomainRegex Listable[string] `json:"domain_regex,omitempty"` - Geosite Listable[string] `json:"geosite,omitempty"` - SourceGeoIP Listable[string] `json:"source_geoip,omitempty"` - GeoIP Listable[string] `json:"geoip,omitempty"` - SourceIPCIDR Listable[string] `json:"source_ip_cidr,omitempty"` - IPCIDR Listable[string] `json:"ip_cidr,omitempty"` - SourcePort Listable[uint16] `json:"source_port,omitempty"` - SourcePortRange Listable[string] `json:"source_port_range,omitempty"` - Port Listable[uint16] `json:"port,omitempty"` - PortRange Listable[string] `json:"port_range,omitempty"` - ProcessName Listable[string] `json:"process_name,omitempty"` - ProcessPath Listable[string] `json:"process_path,omitempty"` - PackageName Listable[string] `json:"package_name,omitempty"` - User Listable[string] `json:"user,omitempty"` - UserID Listable[int32] `json:"user_id,omitempty"` - ClashMode string `json:"clash_mode,omitempty"` - Invert bool `json:"invert,omitempty"` - Outbound string `json:"outbound,omitempty"` -} - -func (r DefaultRule) IsValid() bool { - var defaultValue DefaultRule - defaultValue.Invert = r.Invert - defaultValue.Outbound = r.Outbound - return !reflect.DeepEqual(r, defaultValue) -} - -type LogicalRule struct { - Mode string `json:"mode"` - Rules []DefaultRule `json:"rules,omitempty"` - Invert bool `json:"invert,omitempty"` - Outbound string `json:"outbound,omitempty"` -} - -func (r LogicalRule) IsValid() bool { - return len(r.Rules) > 0 && common.All(r.Rules, DefaultRule.IsValid) -} diff --git a/option/rule.go b/option/rule.go new file mode 100644 index 00000000..f78a752d --- /dev/null +++ b/option/rule.go @@ -0,0 +1,101 @@ +package option + +import ( + "reflect" + + "github.com/sagernet/sing-box/common/json" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" +) + +type _Rule struct { + Type string `json:"type,omitempty"` + DefaultOptions DefaultRule `json:"-"` + LogicalOptions LogicalRule `json:"-"` +} + +type Rule _Rule + +func (r Rule) MarshalJSON() ([]byte, error) { + var v any + switch r.Type { + case C.RuleTypeDefault: + r.Type = "" + v = r.DefaultOptions + case C.RuleTypeLogical: + v = r.LogicalOptions + default: + return nil, E.New("unknown rule type: " + r.Type) + } + return MarshallObjects((_Rule)(r), v) +} + +func (r *Rule) UnmarshalJSON(bytes []byte) error { + err := json.Unmarshal(bytes, (*_Rule)(r)) + if err != nil { + return err + } + var v any + switch r.Type { + case "", C.RuleTypeDefault: + r.Type = C.RuleTypeDefault + v = &r.DefaultOptions + case C.RuleTypeLogical: + v = &r.LogicalOptions + default: + return E.New("unknown rule type: " + r.Type) + } + err = UnmarshallExcluded(bytes, (*_Rule)(r), v) + if err != nil { + return E.Cause(err, "route rule") + } + return nil +} + +type DefaultRule struct { + Inbound Listable[string] `json:"inbound,omitempty"` + IPVersion int `json:"ip_version,omitempty"` + Network Listable[string] `json:"network,omitempty"` + AuthUser Listable[string] `json:"auth_user,omitempty"` + Protocol Listable[string] `json:"protocol,omitempty"` + Domain Listable[string] `json:"domain,omitempty"` + DomainSuffix Listable[string] `json:"domain_suffix,omitempty"` + DomainKeyword Listable[string] `json:"domain_keyword,omitempty"` + DomainRegex Listable[string] `json:"domain_regex,omitempty"` + Geosite Listable[string] `json:"geosite,omitempty"` + SourceGeoIP Listable[string] `json:"source_geoip,omitempty"` + GeoIP Listable[string] `json:"geoip,omitempty"` + SourceIPCIDR Listable[string] `json:"source_ip_cidr,omitempty"` + IPCIDR Listable[string] `json:"ip_cidr,omitempty"` + SourcePort Listable[uint16] `json:"source_port,omitempty"` + SourcePortRange Listable[string] `json:"source_port_range,omitempty"` + Port Listable[uint16] `json:"port,omitempty"` + PortRange Listable[string] `json:"port_range,omitempty"` + ProcessName Listable[string] `json:"process_name,omitempty"` + ProcessPath Listable[string] `json:"process_path,omitempty"` + PackageName Listable[string] `json:"package_name,omitempty"` + User Listable[string] `json:"user,omitempty"` + UserID Listable[int32] `json:"user_id,omitempty"` + ClashMode string `json:"clash_mode,omitempty"` + Invert bool `json:"invert,omitempty"` + Outbound string `json:"outbound,omitempty"` +} + +func (r DefaultRule) IsValid() bool { + var defaultValue DefaultRule + defaultValue.Invert = r.Invert + defaultValue.Outbound = r.Outbound + return !reflect.DeepEqual(r, defaultValue) +} + +type LogicalRule struct { + Mode string `json:"mode"` + Rules []DefaultRule `json:"rules,omitempty"` + Invert bool `json:"invert,omitempty"` + Outbound string `json:"outbound,omitempty"` +} + +func (r LogicalRule) IsValid() bool { + return len(r.Rules) > 0 && common.All(r.Rules, DefaultRule.IsValid) +} diff --git a/option/rule_dns.go b/option/rule_dns.go new file mode 100644 index 00000000..98caf0a8 --- /dev/null +++ b/option/rule_dns.go @@ -0,0 +1,104 @@ +package option + +import ( + "reflect" + + "github.com/sagernet/sing-box/common/json" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" +) + +type _DNSRule struct { + Type string `json:"type,omitempty"` + DefaultOptions DefaultDNSRule `json:"-"` + LogicalOptions LogicalDNSRule `json:"-"` +} + +type DNSRule _DNSRule + +func (r DNSRule) MarshalJSON() ([]byte, error) { + var v any + switch r.Type { + case C.RuleTypeDefault: + r.Type = "" + v = r.DefaultOptions + case C.RuleTypeLogical: + v = r.LogicalOptions + default: + return nil, E.New("unknown rule type: " + r.Type) + } + return MarshallObjects((_DNSRule)(r), v) +} + +func (r *DNSRule) UnmarshalJSON(bytes []byte) error { + err := json.Unmarshal(bytes, (*_DNSRule)(r)) + if err != nil { + return err + } + var v any + switch r.Type { + case "", C.RuleTypeDefault: + r.Type = C.RuleTypeDefault + v = &r.DefaultOptions + case C.RuleTypeLogical: + v = &r.LogicalOptions + default: + return E.New("unknown rule type: " + r.Type) + } + err = UnmarshallExcluded(bytes, (*_DNSRule)(r), v) + if err != nil { + return E.Cause(err, "dns route rule") + } + return nil +} + +type DefaultDNSRule struct { + Inbound Listable[string] `json:"inbound,omitempty"` + IPVersion int `json:"ip_version,omitempty"` + QueryType Listable[DNSQueryType] `json:"query_type,omitempty"` + Network Listable[string] `json:"network,omitempty"` + AuthUser Listable[string] `json:"auth_user,omitempty"` + Protocol Listable[string] `json:"protocol,omitempty"` + Domain Listable[string] `json:"domain,omitempty"` + DomainSuffix Listable[string] `json:"domain_suffix,omitempty"` + DomainKeyword Listable[string] `json:"domain_keyword,omitempty"` + DomainRegex Listable[string] `json:"domain_regex,omitempty"` + Geosite Listable[string] `json:"geosite,omitempty"` + SourceGeoIP Listable[string] `json:"source_geoip,omitempty"` + SourceIPCIDR Listable[string] `json:"source_ip_cidr,omitempty"` + SourcePort Listable[uint16] `json:"source_port,omitempty"` + SourcePortRange Listable[string] `json:"source_port_range,omitempty"` + Port Listable[uint16] `json:"port,omitempty"` + PortRange Listable[string] `json:"port_range,omitempty"` + ProcessName Listable[string] `json:"process_name,omitempty"` + ProcessPath Listable[string] `json:"process_path,omitempty"` + PackageName Listable[string] `json:"package_name,omitempty"` + User Listable[string] `json:"user,omitempty"` + UserID Listable[int32] `json:"user_id,omitempty"` + Outbound Listable[string] `json:"outbound,omitempty"` + ClashMode string `json:"clash_mode,omitempty"` + Invert bool `json:"invert,omitempty"` + Server string `json:"server,omitempty"` + DisableCache bool `json:"disable_cache,omitempty"` +} + +func (r DefaultDNSRule) IsValid() bool { + var defaultValue DefaultDNSRule + defaultValue.Invert = r.Invert + defaultValue.Server = r.Server + defaultValue.DisableCache = r.DisableCache + return !reflect.DeepEqual(r, defaultValue) +} + +type LogicalDNSRule struct { + Mode string `json:"mode"` + Rules []DefaultDNSRule `json:"rules,omitempty"` + Invert bool `json:"invert,omitempty"` + Server string `json:"server,omitempty"` + DisableCache bool `json:"disable_cache,omitempty"` +} + +func (r LogicalDNSRule) IsValid() bool { + return len(r.Rules) > 0 && common.All(r.Rules, DefaultDNSRule.IsValid) +} diff --git a/option/rule_ip.go b/option/rule_ip.go new file mode 100644 index 00000000..c0aa5cfc --- /dev/null +++ b/option/rule_ip.go @@ -0,0 +1,125 @@ +package option + +import ( + "reflect" + + "github.com/sagernet/sing-box/common/json" + C "github.com/sagernet/sing-box/constant" + tun "github.com/sagernet/sing-tun" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" +) + +type _IPRule struct { + Type string `json:"type,omitempty"` + DefaultOptions DefaultIPRule `json:"-"` + LogicalOptions LogicalIPRule `json:"-"` +} + +type IPRule _IPRule + +func (r IPRule) MarshalJSON() ([]byte, error) { + var v any + switch r.Type { + case C.RuleTypeDefault: + r.Type = "" + v = r.DefaultOptions + case C.RuleTypeLogical: + v = r.LogicalOptions + default: + return nil, E.New("unknown rule type: " + r.Type) + } + return MarshallObjects((_IPRule)(r), v) +} + +func (r *IPRule) UnmarshalJSON(bytes []byte) error { + err := json.Unmarshal(bytes, (*_IPRule)(r)) + if err != nil { + return err + } + var v any + switch r.Type { + case "", C.RuleTypeDefault: + r.Type = C.RuleTypeDefault + v = &r.DefaultOptions + case C.RuleTypeLogical: + v = &r.LogicalOptions + default: + return E.New("unknown rule type: " + r.Type) + } + err = UnmarshallExcluded(bytes, (*_IPRule)(r), v) + if err != nil { + return E.Cause(err, "ip route rule") + } + return nil +} + +type DefaultIPRule struct { + Inbound Listable[string] `json:"inbound,omitempty"` + IPVersion int `json:"ip_version,omitempty"` + Network Listable[string] `json:"network,omitempty"` + Domain Listable[string] `json:"domain,omitempty"` + DomainSuffix Listable[string] `json:"domain_suffix,omitempty"` + DomainKeyword Listable[string] `json:"domain_keyword,omitempty"` + DomainRegex Listable[string] `json:"domain_regex,omitempty"` + Geosite Listable[string] `json:"geosite,omitempty"` + SourceGeoIP Listable[string] `json:"source_geoip,omitempty"` + SourceIPCIDR Listable[string] `json:"source_ip_cidr,omitempty"` + SourcePort Listable[uint16] `json:"source_port,omitempty"` + SourcePortRange Listable[string] `json:"source_port_range,omitempty"` + Port Listable[uint16] `json:"port,omitempty"` + PortRange Listable[string] `json:"port_range,omitempty"` + Invert bool `json:"invert,omitempty"` + Action RouteAction `json:"action,omitempty"` + Outbound string `json:"outbound,omitempty"` +} + +type RouteAction tun.ActionType + +func (a RouteAction) MarshalJSON() ([]byte, error) { + switch tun.ActionType(a) { + case tun.ActionTypeReject, tun.ActionTypeDirect: + default: + return nil, E.New("unknown action: ", a) + } + return json.Marshal(tun.ActionTypeName(tun.ActionType(a))) +} + +func (a *RouteAction) UnmarshalJSON(bytes []byte) error { + var value string + err := json.Unmarshal(bytes, &value) + if err != nil { + return err + } + actionType, err := tun.ParseActionType(value) + if err != nil { + return err + } + switch actionType { + case tun.ActionTypeReject, tun.ActionTypeDirect: + default: + return E.New("unknown action: ", a) + } + *a = RouteAction(actionType) + return nil +} + +func (r DefaultIPRule) IsValid() bool { + var defaultValue DefaultIPRule + defaultValue.Invert = r.Invert + defaultValue.Action = r.Action + defaultValue.Outbound = r.Outbound + return !reflect.DeepEqual(r, defaultValue) +} + +type LogicalIPRule struct { + Mode string `json:"mode"` + Rules []DefaultIPRule `json:"rules,omitempty"` + Invert bool `json:"invert,omitempty"` + Action RouteAction `json:"action,omitempty"` + Outbound string `json:"outbound,omitempty"` +} + +func (r LogicalIPRule) IsValid() bool { + return len(r.Rules) > 0 && common.All(r.Rules, DefaultIPRule.IsValid) +} diff --git a/option/wireguard.go b/option/wireguard.go index ee6e1053..15639474 100644 --- a/option/wireguard.go +++ b/option/wireguard.go @@ -13,4 +13,5 @@ type WireGuardOutboundOptions struct { Workers int `json:"workers,omitempty"` MTU uint32 `json:"mtu,omitempty"` Network NetworkList `json:"network,omitempty"` + IPRewrite bool `json:"ip_rewrite,omitempty"` } diff --git a/outbound/wireguard.go b/outbound/wireguard.go index dd194729..1674eb72 100644 --- a/outbound/wireguard.go +++ b/outbound/wireguard.go @@ -8,7 +8,9 @@ import ( "encoding/hex" "fmt" "net" + "os" "strings" + "syscall" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/dialer" @@ -26,7 +28,7 @@ import ( ) var ( - _ adapter.Outbound = (*WireGuard)(nil) + _ adapter.IPOutbound = (*WireGuard)(nil) _ adapter.InterfaceUpdateListener = (*WireGuard)(nil) ) @@ -34,6 +36,7 @@ type WireGuard struct { myOutboundAdapter bind *wireguard.ClientBind device *device.Device + natDevice wireguard.NatDevice tunDevice wireguard.Device } @@ -106,17 +109,25 @@ func NewWireGuard(ctx context.Context, router adapter.Router, logger log.Context if mtu == 0 { mtu = 1408 } - var wireTunDevice wireguard.Device + var tunDevice wireguard.Device var err error if !options.SystemInterface && tun.WithGVisor { - wireTunDevice, err = wireguard.NewStackDevice(localPrefixes, mtu) + tunDevice, err = wireguard.NewStackDevice(localPrefixes, mtu, options.IPRewrite) } else { - wireTunDevice, err = wireguard.NewSystemDevice(router, options.InterfaceName, localPrefixes, mtu) + tunDevice, err = wireguard.NewSystemDevice(router, options.InterfaceName, localPrefixes, mtu) } if err != nil { return nil, E.Cause(err, "create WireGuard device") } - wgDevice := device.NewDevice(wireTunDevice, outbound.bind, &device.Logger{ + natDevice, isNatDevice := tunDevice.(wireguard.NatDevice) + if !isNatDevice && router.NatRequired(tag) { + natDevice = wireguard.NewNATDevice(tunDevice, options.IPRewrite) + } + deviceInput := tunDevice + if natDevice != nil { + deviceInput = natDevice + } + wgDevice := device.NewDevice(deviceInput, outbound.bind, &device.Logger{ Verbosef: func(format string, args ...interface{}) { logger.Debug(fmt.Sprintf(strings.ToLower(format), args...)) }, @@ -132,7 +143,8 @@ func NewWireGuard(ctx context.Context, router adapter.Router, logger log.Context return nil, E.Cause(err, "setup wireguard") } outbound.device = wgDevice - outbound.tunDevice = wireTunDevice + outbound.natDevice = natDevice + outbound.tunDevice = tunDevice return outbound, nil } @@ -171,6 +183,27 @@ func (w *WireGuard) NewPacketConnection(ctx context.Context, conn N.PacketConn, return NewPacketConnection(ctx, w, conn, metadata) } +func (w *WireGuard) NewIPConnection(ctx context.Context, conn tun.RouteContext, metadata adapter.InboundContext) (tun.DirectDestination, error) { + if w.natDevice == nil { + return nil, os.ErrInvalid + } + session := tun.RouteSession{ + IPVersion: metadata.IPVersion, + Network: tun.NetworkFromName(metadata.Network), + Source: metadata.Source.AddrPort(), + Destination: metadata.Destination.AddrPort(), + } + switch session.Network { + case syscall.IPPROTO_TCP: + w.logger.InfoContext(ctx, "linked connection to ", metadata.Destination) + case syscall.IPPROTO_UDP: + w.logger.InfoContext(ctx, "linked packet connection to ", metadata.Destination) + default: + w.logger.InfoContext(ctx, "linked ", metadata.Network, " connection to ", metadata.Destination.AddrString()) + } + return w.natDevice.CreateDestination(session, conn), nil +} + func (w *WireGuard) Start() error { return w.tunDevice.Start() } diff --git a/route/router.go b/route/router.go index 6dbb1922..5063f1c2 100644 --- a/route/router.go +++ b/route/router.go @@ -2,14 +2,11 @@ package route import ( "context" - "io" "net" - "net/http" "net/netip" "net/url" "os" "os/user" - "path/filepath" "strings" "time" @@ -37,7 +34,6 @@ import ( F "github.com/sagernet/sing/common/format" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/common/rw" "github.com/sagernet/sing/common/uot" ) @@ -72,6 +68,7 @@ type Router struct { outbounds []adapter.Outbound outboundByTag map[string]adapter.Outbound rules []adapter.Rule + ipRules []adapter.IPRule defaultDetour string defaultOutboundForConnection adapter.Outbound defaultOutboundForPacketConnection adapter.Outbound @@ -128,6 +125,7 @@ func NewRouter( dnsLogger: logFactory.NewLogger("dns"), outboundByTag: make(map[string]adapter.Outbound), rules: make([]adapter.Rule, 0, len(options.Rules)), + ipRules: make([]adapter.IPRule, 0, len(options.IPRules)), dnsRules: make([]adapter.DNSRule, 0, len(dnsOptions.Rules)), needGeoIPDatabase: hasRule(options.Rules, isGeoIPRule) || hasDNSRule(dnsOptions.Rules, isGeoIPDNSRule), needGeositeDatabase: hasRule(options.Rules, isGeositeRule) || hasDNSRule(dnsOptions.Rules, isGeositeDNSRule), @@ -149,6 +147,13 @@ func NewRouter( } router.rules = append(router.rules, routeRule) } + for i, ipRuleOptions := range options.IPRules { + ipRule, err := NewIPRule(router, router.logger, ipRuleOptions) + if err != nil { + return nil, E.Cause(err, "parse ip rule[", i, "]") + } + router.ipRules = append(router.ipRules, ipRule) + } for i, dnsRuleOptions := range dnsOptions.Rules { dnsRule, err := NewDNSRule(router, router.logger, dnsRuleOptions) if err != nil { @@ -156,6 +161,7 @@ func NewRouter( } router.dnsRules = append(router.dnsRules, dnsRule) } + transports := make([]dns.Transport, len(dnsOptions.Servers)) dummyTransportMap := make(map[string]dns.Transport) transportMap := make(map[string]dns.Transport) @@ -516,27 +522,6 @@ func (r *Router) Close() error { ) } -func (r *Router) GeoIPReader() *geoip.Reader { - return r.geoIPReader -} - -func (r *Router) LoadGeosite(code string) (adapter.Rule, error) { - rule, cached := r.geositeCache[code] - if cached { - return rule, nil - } - items, err := r.geositeReader.Read(code) - if err != nil { - return nil, err - } - rule, err = NewDefaultRule(r, nil, geosite.Compile(items)) - if err != nil { - return nil, err - } - r.geositeCache[code] = rule - return rule, nil -} - func (r *Router) Outbound(tag string) (adapter.Outbound, bool) { outbound, loaded := r.outboundByTag[tag] return outbound, loaded @@ -811,6 +796,10 @@ func (r *Router) Rules() []adapter.Rule { return r.rules } +func (r *Router) IPRules() []adapter.IPRule { + return r.ipRules +} + func (r *Router) NetworkMonitor() tun.NetworkUpdateMonitor { return r.networkMonitor } @@ -846,239 +835,6 @@ func (r *Router) SetV2RayServer(server adapter.V2RayServer) { r.v2rayServer = server } -func hasRule(rules []option.Rule, cond func(rule option.DefaultRule) bool) bool { - for _, rule := range rules { - switch rule.Type { - case C.RuleTypeDefault: - if cond(rule.DefaultOptions) { - return true - } - case C.RuleTypeLogical: - for _, subRule := range rule.LogicalOptions.Rules { - if cond(subRule) { - return true - } - } - } - } - return false -} - -func hasDNSRule(rules []option.DNSRule, cond func(rule option.DefaultDNSRule) bool) bool { - for _, rule := range rules { - switch rule.Type { - case C.RuleTypeDefault: - if cond(rule.DefaultOptions) { - return true - } - case C.RuleTypeLogical: - for _, subRule := range rule.LogicalOptions.Rules { - if cond(subRule) { - return true - } - } - } - } - return false -} - -func isGeoIPRule(rule option.DefaultRule) bool { - return len(rule.SourceGeoIP) > 0 && common.Any(rule.SourceGeoIP, notPrivateNode) || len(rule.GeoIP) > 0 && common.Any(rule.GeoIP, notPrivateNode) -} - -func isGeoIPDNSRule(rule option.DefaultDNSRule) bool { - return len(rule.SourceGeoIP) > 0 && common.Any(rule.SourceGeoIP, notPrivateNode) -} - -func isGeositeRule(rule option.DefaultRule) bool { - return len(rule.Geosite) > 0 -} - -func isGeositeDNSRule(rule option.DefaultDNSRule) bool { - return len(rule.Geosite) > 0 -} - -func isProcessRule(rule option.DefaultRule) bool { - return len(rule.ProcessName) > 0 || len(rule.ProcessPath) > 0 || len(rule.PackageName) > 0 || len(rule.User) > 0 || len(rule.UserID) > 0 -} - -func isProcessDNSRule(rule option.DefaultDNSRule) bool { - return len(rule.ProcessName) > 0 || len(rule.ProcessPath) > 0 || len(rule.PackageName) > 0 || len(rule.User) > 0 || len(rule.UserID) > 0 -} - -func notPrivateNode(code string) bool { - return code != "private" -} - -func (r *Router) prepareGeoIPDatabase() error { - var geoPath string - if r.geoIPOptions.Path != "" { - geoPath = r.geoIPOptions.Path - } else { - geoPath = "geoip.db" - if foundPath, loaded := C.FindPath(geoPath); loaded { - geoPath = foundPath - } - } - geoPath = C.BasePath(geoPath) - if !rw.FileExists(geoPath) { - r.logger.Warn("geoip database not exists: ", geoPath) - var err error - for attempts := 0; attempts < 3; attempts++ { - err = r.downloadGeoIPDatabase(geoPath) - if err == nil { - break - } - r.logger.Error("download geoip database: ", err) - os.Remove(geoPath) - // time.Sleep(10 * time.Second) - } - if err != nil { - return err - } - } - geoReader, codes, err := geoip.Open(geoPath) - if err != nil { - return E.Cause(err, "open geoip database") - } - r.logger.Info("loaded geoip database: ", len(codes), " codes") - r.geoIPReader = geoReader - return nil -} - -func (r *Router) prepareGeositeDatabase() error { - var geoPath string - if r.geositeOptions.Path != "" { - geoPath = r.geositeOptions.Path - } else { - geoPath = "geosite.db" - if foundPath, loaded := C.FindPath(geoPath); loaded { - geoPath = foundPath - } - } - geoPath = C.BasePath(geoPath) - if !rw.FileExists(geoPath) { - r.logger.Warn("geosite database not exists: ", geoPath) - var err error - for attempts := 0; attempts < 3; attempts++ { - err = r.downloadGeositeDatabase(geoPath) - if err == nil { - break - } - r.logger.Error("download geosite database: ", err) - os.Remove(geoPath) - // time.Sleep(10 * time.Second) - } - if err != nil { - return err - } - } - geoReader, codes, err := geosite.Open(geoPath) - if err == nil { - r.logger.Info("loaded geosite database: ", len(codes), " codes") - r.geositeReader = geoReader - } else { - return E.Cause(err, "open geosite database") - } - return nil -} - -func (r *Router) downloadGeoIPDatabase(savePath string) error { - var downloadURL string - if r.geoIPOptions.DownloadURL != "" { - downloadURL = r.geoIPOptions.DownloadURL - } else { - downloadURL = "https://github.com/SagerNet/sing-geoip/releases/latest/download/geoip.db" - } - r.logger.Info("downloading geoip database") - var detour adapter.Outbound - if r.geoIPOptions.DownloadDetour != "" { - outbound, loaded := r.Outbound(r.geoIPOptions.DownloadDetour) - if !loaded { - return E.New("detour outbound not found: ", r.geoIPOptions.DownloadDetour) - } - detour = outbound - } else { - detour = r.defaultOutboundForConnection - } - - if parentDir := filepath.Dir(savePath); parentDir != "" { - os.MkdirAll(parentDir, 0o755) - } - - saveFile, err := os.OpenFile(savePath, os.O_CREATE|os.O_WRONLY, 0o644) - if err != nil { - return E.Cause(err, "open output file: ", downloadURL) - } - defer saveFile.Close() - - httpClient := &http.Client{ - Transport: &http.Transport{ - ForceAttemptHTTP2: true, - TLSHandshakeTimeout: 5 * time.Second, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return detour.DialContext(ctx, network, M.ParseSocksaddr(addr)) - }, - }, - } - defer httpClient.CloseIdleConnections() - response, err := httpClient.Get(downloadURL) - if err != nil { - return err - } - defer response.Body.Close() - _, err = io.Copy(saveFile, response.Body) - return err -} - -func (r *Router) downloadGeositeDatabase(savePath string) error { - var downloadURL string - if r.geositeOptions.DownloadURL != "" { - downloadURL = r.geositeOptions.DownloadURL - } else { - downloadURL = "https://github.com/SagerNet/sing-geosite/releases/latest/download/geosite.db" - } - r.logger.Info("downloading geosite database") - var detour adapter.Outbound - if r.geositeOptions.DownloadDetour != "" { - outbound, loaded := r.Outbound(r.geositeOptions.DownloadDetour) - if !loaded { - return E.New("detour outbound not found: ", r.geositeOptions.DownloadDetour) - } - detour = outbound - } else { - detour = r.defaultOutboundForConnection - } - - if parentDir := filepath.Dir(savePath); parentDir != "" { - os.MkdirAll(parentDir, 0o755) - } - - saveFile, err := os.OpenFile(savePath, os.O_CREATE|os.O_WRONLY, 0o644) - if err != nil { - return E.Cause(err, "open output file: ", downloadURL) - } - defer saveFile.Close() - - httpClient := &http.Client{ - Transport: &http.Transport{ - ForceAttemptHTTP2: true, - TLSHandshakeTimeout: 5 * time.Second, - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return detour.DialContext(ctx, network, M.ParseSocksaddr(addr)) - }, - }, - } - defer httpClient.CloseIdleConnections() - response, err := httpClient.Get(downloadURL) - if err != nil { - return err - } - defer response.Body.Close() - _, err = io.Copy(saveFile, response.Body) - return err -} - func (r *Router) OnPackagesUpdated(packages int, sharedUsers int) { r.logger.Info("updated packages list: ", packages, " packages, ", sharedUsers, " shared users") } diff --git a/route/router_geo_resources.go b/route/router_geo_resources.go new file mode 100644 index 00000000..a72b4bad --- /dev/null +++ b/route/router_geo_resources.go @@ -0,0 +1,283 @@ +package route + +import ( + "context" + "io" + "net" + "net/http" + "os" + "path/filepath" + "time" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/geoip" + "github.com/sagernet/sing-box/common/geosite" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/rw" +) + +func (r *Router) GeoIPReader() *geoip.Reader { + return r.geoIPReader +} + +func (r *Router) LoadGeosite(code string) (adapter.Rule, error) { + rule, cached := r.geositeCache[code] + if cached { + return rule, nil + } + items, err := r.geositeReader.Read(code) + if err != nil { + return nil, err + } + rule, err = NewDefaultRule(r, nil, geosite.Compile(items)) + if err != nil { + return nil, err + } + r.geositeCache[code] = rule + return rule, nil +} + +func (r *Router) prepareGeoIPDatabase() error { + var geoPath string + if r.geoIPOptions.Path != "" { + geoPath = r.geoIPOptions.Path + } else { + geoPath = "geoip.db" + if foundPath, loaded := C.FindPath(geoPath); loaded { + geoPath = foundPath + } + } + geoPath = C.BasePath(geoPath) + if rw.FileExists(geoPath) { + geoReader, codes, err := geoip.Open(geoPath) + if err == nil { + r.logger.Info("loaded geoip database: ", len(codes), " codes") + r.geoIPReader = geoReader + return nil + } + } + if !rw.FileExists(geoPath) { + r.logger.Warn("geoip database not exists: ", geoPath) + var err error + for attempts := 0; attempts < 3; attempts++ { + err = r.downloadGeoIPDatabase(geoPath) + if err == nil { + break + } + r.logger.Error("download geoip database: ", err) + os.Remove(geoPath) + // time.Sleep(10 * time.Second) + } + if err != nil { + return err + } + } + geoReader, codes, err := geoip.Open(geoPath) + if err != nil { + return E.Cause(err, "open geoip database") + } + r.logger.Info("loaded geoip database: ", len(codes), " codes") + r.geoIPReader = geoReader + return nil +} + +func (r *Router) prepareGeositeDatabase() error { + var geoPath string + if r.geositeOptions.Path != "" { + geoPath = r.geositeOptions.Path + } else { + geoPath = "geosite.db" + if foundPath, loaded := C.FindPath(geoPath); loaded { + geoPath = foundPath + } + } + geoPath = C.BasePath(geoPath) + if !rw.FileExists(geoPath) { + r.logger.Warn("geosite database not exists: ", geoPath) + var err error + for attempts := 0; attempts < 3; attempts++ { + err = r.downloadGeositeDatabase(geoPath) + if err == nil { + break + } + r.logger.Error("download geosite database: ", err) + os.Remove(geoPath) + // time.Sleep(10 * time.Second) + } + if err != nil { + return err + } + } + geoReader, codes, err := geosite.Open(geoPath) + if err == nil { + r.logger.Info("loaded geosite database: ", len(codes), " codes") + r.geositeReader = geoReader + } else { + return E.Cause(err, "open geosite database") + } + return nil +} + +func (r *Router) downloadGeoIPDatabase(savePath string) error { + var downloadURL string + if r.geoIPOptions.DownloadURL != "" { + downloadURL = r.geoIPOptions.DownloadURL + } else { + downloadURL = "https://github.com/SagerNet/sing-geoip/releases/latest/download/geoip.db" + } + r.logger.Info("downloading geoip database") + var detour adapter.Outbound + if r.geoIPOptions.DownloadDetour != "" { + outbound, loaded := r.Outbound(r.geoIPOptions.DownloadDetour) + if !loaded { + return E.New("detour outbound not found: ", r.geoIPOptions.DownloadDetour) + } + detour = outbound + } else { + detour = r.defaultOutboundForConnection + } + + if parentDir := filepath.Dir(savePath); parentDir != "" { + os.MkdirAll(parentDir, 0o755) + } + + saveFile, err := os.OpenFile(savePath, os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + return E.Cause(err, "open output file: ", downloadURL) + } + defer saveFile.Close() + + httpClient := &http.Client{ + Transport: &http.Transport{ + ForceAttemptHTTP2: true, + TLSHandshakeTimeout: 5 * time.Second, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return detour.DialContext(ctx, network, M.ParseSocksaddr(addr)) + }, + }, + } + defer httpClient.CloseIdleConnections() + response, err := httpClient.Get(downloadURL) + if err != nil { + return err + } + defer response.Body.Close() + _, err = io.Copy(saveFile, response.Body) + return err +} + +func (r *Router) downloadGeositeDatabase(savePath string) error { + var downloadURL string + if r.geositeOptions.DownloadURL != "" { + downloadURL = r.geositeOptions.DownloadURL + } else { + downloadURL = "https://github.com/SagerNet/sing-geosite/releases/latest/download/geosite.db" + } + r.logger.Info("downloading geosite database") + var detour adapter.Outbound + if r.geositeOptions.DownloadDetour != "" { + outbound, loaded := r.Outbound(r.geositeOptions.DownloadDetour) + if !loaded { + return E.New("detour outbound not found: ", r.geositeOptions.DownloadDetour) + } + detour = outbound + } else { + detour = r.defaultOutboundForConnection + } + + if parentDir := filepath.Dir(savePath); parentDir != "" { + os.MkdirAll(parentDir, 0o755) + } + + saveFile, err := os.OpenFile(savePath, os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + return E.Cause(err, "open output file: ", downloadURL) + } + defer saveFile.Close() + + httpClient := &http.Client{ + Transport: &http.Transport{ + ForceAttemptHTTP2: true, + TLSHandshakeTimeout: 5 * time.Second, + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return detour.DialContext(ctx, network, M.ParseSocksaddr(addr)) + }, + }, + } + defer httpClient.CloseIdleConnections() + response, err := httpClient.Get(downloadURL) + if err != nil { + return err + } + defer response.Body.Close() + _, err = io.Copy(saveFile, response.Body) + return err +} + +func hasRule(rules []option.Rule, cond func(rule option.DefaultRule) bool) bool { + for _, rule := range rules { + switch rule.Type { + case C.RuleTypeDefault: + if cond(rule.DefaultOptions) { + return true + } + case C.RuleTypeLogical: + for _, subRule := range rule.LogicalOptions.Rules { + if cond(subRule) { + return true + } + } + } + } + return false +} + +func hasDNSRule(rules []option.DNSRule, cond func(rule option.DefaultDNSRule) bool) bool { + for _, rule := range rules { + switch rule.Type { + case C.RuleTypeDefault: + if cond(rule.DefaultOptions) { + return true + } + case C.RuleTypeLogical: + for _, subRule := range rule.LogicalOptions.Rules { + if cond(subRule) { + return true + } + } + } + } + return false +} + +func isGeoIPRule(rule option.DefaultRule) bool { + return len(rule.SourceGeoIP) > 0 && common.Any(rule.SourceGeoIP, notPrivateNode) || len(rule.GeoIP) > 0 && common.Any(rule.GeoIP, notPrivateNode) +} + +func isGeoIPDNSRule(rule option.DefaultDNSRule) bool { + return len(rule.SourceGeoIP) > 0 && common.Any(rule.SourceGeoIP, notPrivateNode) +} + +func isGeositeRule(rule option.DefaultRule) bool { + return len(rule.Geosite) > 0 +} + +func isGeositeDNSRule(rule option.DefaultDNSRule) bool { + return len(rule.Geosite) > 0 +} + +func isProcessRule(rule option.DefaultRule) bool { + return len(rule.ProcessName) > 0 || len(rule.ProcessPath) > 0 || len(rule.PackageName) > 0 || len(rule.User) > 0 || len(rule.UserID) > 0 +} + +func isProcessDNSRule(rule option.DefaultDNSRule) bool { + return len(rule.ProcessName) > 0 || len(rule.ProcessPath) > 0 || len(rule.PackageName) > 0 || len(rule.User) > 0 || len(rule.UserID) > 0 +} + +func notPrivateNode(code string) bool { + return code != "private" +} diff --git a/route/router_ip.go b/route/router_ip.go new file mode 100644 index 00000000..660aaaac --- /dev/null +++ b/route/router_ip.go @@ -0,0 +1,47 @@ +package route + +import ( + "context" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-tun" +) + +func (r *Router) RouteIPConnection(ctx context.Context, conn tun.RouteContext, metadata adapter.InboundContext) tun.RouteAction { + for i, rule := range r.ipRules { + if rule.Match(&metadata) { + if rule.Action() == tun.ActionTypeReject { + r.logger.InfoContext(ctx, "match[", i, "] ", rule.String(), " => reject") + return (*tun.ActionReject)(nil) + } + detour := rule.Outbound() + r.logger.InfoContext(ctx, "match[", i, "] ", rule.String(), " => ", detour) + outbound, loaded := r.Outbound(detour) + if !loaded { + r.logger.ErrorContext(ctx, "outbound not found: ", detour) + break + } + ipOutbound, loaded := outbound.(adapter.IPOutbound) + if !loaded { + r.logger.ErrorContext(ctx, "outbound have no ip connection support: ", detour) + break + } + destination, err := ipOutbound.NewIPConnection(ctx, conn, metadata) + if err != nil { + r.logger.ErrorContext(ctx, err) + break + } + return &tun.ActionDirect{DirectDestination: destination} + } + } + return (*tun.ActionReturn)(nil) +} + +func (r *Router) NatRequired(outbound string) bool { + for _, ipRule := range r.ipRules { + if ipRule.Outbound() == outbound { + return true + } + } + return false +} diff --git a/route/rule_abstract.go b/route/rule_abstract.go new file mode 100644 index 00000000..832a3198 --- /dev/null +++ b/route/rule_abstract.go @@ -0,0 +1,199 @@ +package route + +import ( + "strings" + + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing/common" + F "github.com/sagernet/sing/common/format" +) + +type abstractDefaultRule struct { + items []RuleItem + sourceAddressItems []RuleItem + sourcePortItems []RuleItem + destinationAddressItems []RuleItem + destinationPortItems []RuleItem + allItems []RuleItem + invert bool + outbound string +} + +func (r *abstractDefaultRule) Type() string { + return C.RuleTypeDefault +} + +func (r *abstractDefaultRule) Start() error { + for _, item := range r.allItems { + err := common.Start(item) + if err != nil { + return err + } + } + return nil +} + +func (r *abstractDefaultRule) Close() error { + for _, item := range r.allItems { + err := common.Close(item) + if err != nil { + return err + } + } + return nil +} + +func (r *abstractDefaultRule) UpdateGeosite() error { + for _, item := range r.allItems { + if geositeItem, isSite := item.(*GeositeItem); isSite { + err := geositeItem.Update() + if err != nil { + return err + } + } + } + return nil +} + +func (r *abstractDefaultRule) Match(metadata *adapter.InboundContext) bool { + for _, item := range r.items { + if !item.Match(metadata) { + return r.invert + } + } + + if len(r.sourceAddressItems) > 0 { + var sourceAddressMatch bool + for _, item := range r.sourceAddressItems { + if item.Match(metadata) { + sourceAddressMatch = true + break + } + } + if !sourceAddressMatch { + return r.invert + } + } + + if len(r.sourcePortItems) > 0 { + var sourcePortMatch bool + for _, item := range r.sourcePortItems { + if item.Match(metadata) { + sourcePortMatch = true + break + } + } + if !sourcePortMatch { + return r.invert + } + } + + if len(r.destinationAddressItems) > 0 { + var destinationAddressMatch bool + for _, item := range r.destinationAddressItems { + if item.Match(metadata) { + destinationAddressMatch = true + break + } + } + if !destinationAddressMatch { + return r.invert + } + } + + if len(r.destinationPortItems) > 0 { + var destinationPortMatch bool + for _, item := range r.destinationPortItems { + if item.Match(metadata) { + destinationPortMatch = true + break + } + } + if !destinationPortMatch { + return r.invert + } + } + + return !r.invert +} + +func (r *abstractDefaultRule) Outbound() string { + return r.outbound +} + +func (r *abstractDefaultRule) String() string { + return strings.Join(F.MapToString(r.allItems), " ") +} + +type abstractLogicalRule struct { + rules []adapter.Rule + mode string + invert bool + outbound string +} + +func (r *abstractLogicalRule) Type() string { + return C.RuleTypeLogical +} + +func (r *abstractLogicalRule) UpdateGeosite() error { + for _, rule := range r.rules { + err := rule.UpdateGeosite() + if err != nil { + return err + } + } + return nil +} + +func (r *abstractLogicalRule) Start() error { + for _, rule := range r.rules { + err := rule.Start() + if err != nil { + return err + } + } + return nil +} + +func (r *abstractLogicalRule) Close() error { + for _, rule := range r.rules { + err := rule.Close() + if err != nil { + return err + } + } + return nil +} + +func (r *abstractLogicalRule) Match(metadata *adapter.InboundContext) bool { + if r.mode == C.LogicalTypeAnd { + return common.All(r.rules, func(it adapter.Rule) bool { + return it.Match(metadata) + }) != r.invert + } else { + return common.Any(r.rules, func(it adapter.Rule) bool { + return it.Match(metadata) + }) != r.invert + } +} + +func (r *abstractLogicalRule) Outbound() string { + return r.outbound +} + +func (r *abstractLogicalRule) String() string { + var op string + switch r.mode { + case C.LogicalTypeAnd: + op = "&&" + case C.LogicalTypeOr: + op = "||" + } + if !r.invert { + return strings.Join(F.MapToString(r.rules), " "+op+" ") + } else { + return "!(" + strings.Join(F.MapToString(r.rules), " "+op+" ") + ")" + } +} diff --git a/route/rule.go b/route/rule_default.go similarity index 61% rename from route/rule.go rename to route/rule_default.go index 6f2f3baa..01322c13 100644 --- a/route/rule.go +++ b/route/rule_default.go @@ -1,16 +1,11 @@ package route import ( - "strings" - "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" - F "github.com/sagernet/sing/common/format" - N "github.com/sagernet/sing/common/network" ) func NewRule(router adapter.Router, logger log.ContextLogger, options option.Rule) (adapter.Rule, error) { @@ -39,14 +34,7 @@ func NewRule(router adapter.Router, logger log.ContextLogger, options option.Rul var _ adapter.Rule = (*DefaultRule)(nil) type DefaultRule struct { - items []RuleItem - sourceAddressItems []RuleItem - sourcePortItems []RuleItem - destinationAddressItems []RuleItem - destinationPortItems []RuleItem - allItems []RuleItem - invert bool - outbound string + abstractDefaultRule } type RuleItem interface { @@ -56,8 +44,10 @@ type RuleItem interface { func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options option.DefaultRule) (*DefaultRule, error) { rule := &DefaultRule{ - invert: options.Invert, - outbound: options.Outbound, + abstractDefaultRule{ + invert: options.Invert, + outbound: options.Outbound, + }, } if len(options.Inbound) > 0 { item := NewInboundRule(options.Inbound) @@ -74,15 +64,10 @@ func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options opt return nil, E.New("invalid ip version: ", options.IPVersion) } } - if options.Network != "" { - switch options.Network { - case N.NetworkTCP, N.NetworkUDP: - item := NewNetworkItem(options.Network) - rule.items = append(rule.items, item) - rule.allItems = append(rule.allItems, item) - default: - return nil, E.New("invalid network: ", options.Network) - } + if len(options.Network) > 0 { + item := NewNetworkItem(options.Network) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) } if len(options.AuthUser) > 0 { item := NewAuthUserItem(options.AuthUser) @@ -202,130 +187,19 @@ func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options opt return rule, nil } -func (r *DefaultRule) Type() string { - return C.RuleTypeDefault -} - -func (r *DefaultRule) Start() error { - for _, item := range r.allItems { - err := common.Start(item) - if err != nil { - return err - } - } - return nil -} - -func (r *DefaultRule) Close() error { - for _, item := range r.allItems { - err := common.Close(item) - if err != nil { - return err - } - } - return nil -} - -func (r *DefaultRule) UpdateGeosite() error { - for _, item := range r.allItems { - if geositeItem, isSite := item.(*GeositeItem); isSite { - err := geositeItem.Update() - if err != nil { - return err - } - } - } - return nil -} - -func (r *DefaultRule) Match(metadata *adapter.InboundContext) bool { - for _, item := range r.items { - if !item.Match(metadata) { - return r.invert - } - } - - if len(r.sourceAddressItems) > 0 { - var sourceAddressMatch bool - for _, item := range r.sourceAddressItems { - if item.Match(metadata) { - sourceAddressMatch = true - break - } - } - if !sourceAddressMatch { - return r.invert - } - } - - if len(r.sourcePortItems) > 0 { - var sourcePortMatch bool - for _, item := range r.sourcePortItems { - if item.Match(metadata) { - sourcePortMatch = true - break - } - } - if !sourcePortMatch { - return r.invert - } - } - - if len(r.destinationAddressItems) > 0 { - var destinationAddressMatch bool - for _, item := range r.destinationAddressItems { - if item.Match(metadata) { - destinationAddressMatch = true - break - } - } - if !destinationAddressMatch { - return r.invert - } - } - - if len(r.destinationPortItems) > 0 { - var destinationPortMatch bool - for _, item := range r.destinationPortItems { - if item.Match(metadata) { - destinationPortMatch = true - break - } - } - if !destinationPortMatch { - return r.invert - } - } - - return !r.invert -} - -func (r *DefaultRule) Outbound() string { - return r.outbound -} - -func (r *DefaultRule) String() string { - if !r.invert { - return strings.Join(F.MapToString(r.allItems), " ") - } else { - return "!(" + strings.Join(F.MapToString(r.allItems), " ") + ")" - } -} - var _ adapter.Rule = (*LogicalRule)(nil) type LogicalRule struct { - mode string - rules []*DefaultRule - invert bool - outbound string + abstractLogicalRule } func NewLogicalRule(router adapter.Router, logger log.ContextLogger, options option.LogicalRule) (*LogicalRule, error) { r := &LogicalRule{ - rules: make([]*DefaultRule, len(options.Rules)), - invert: options.Invert, - outbound: options.Outbound, + abstractLogicalRule{ + rules: make([]adapter.Rule, len(options.Rules)), + invert: options.Invert, + outbound: options.Outbound, + }, } switch options.Mode { case C.LogicalTypeAnd: @@ -344,68 +218,3 @@ func NewLogicalRule(router adapter.Router, logger log.ContextLogger, options opt } return r, nil } - -func (r *LogicalRule) Type() string { - return C.RuleTypeLogical -} - -func (r *LogicalRule) UpdateGeosite() error { - for _, rule := range r.rules { - err := rule.UpdateGeosite() - if err != nil { - return err - } - } - return nil -} - -func (r *LogicalRule) Start() error { - for _, rule := range r.rules { - err := rule.Start() - if err != nil { - return err - } - } - return nil -} - -func (r *LogicalRule) Close() error { - for _, rule := range r.rules { - err := rule.Close() - if err != nil { - return err - } - } - return nil -} - -func (r *LogicalRule) Match(metadata *adapter.InboundContext) bool { - if r.mode == C.LogicalTypeAnd { - return common.All(r.rules, func(it *DefaultRule) bool { - return it.Match(metadata) - }) != r.invert - } else { - return common.Any(r.rules, func(it *DefaultRule) bool { - return it.Match(metadata) - }) != r.invert - } -} - -func (r *LogicalRule) Outbound() string { - return r.outbound -} - -func (r *LogicalRule) String() string { - var op string - switch r.mode { - case C.LogicalTypeAnd: - op = "&&" - case C.LogicalTypeOr: - op = "||" - } - if !r.invert { - return strings.Join(F.MapToString(r.rules), " "+op+" ") - } else { - return "!(" + strings.Join(F.MapToString(r.rules), " "+op+" ") + ")" - } -} diff --git a/route/rule_dns.go b/route/rule_dns.go index 3bfdb729..027a2f91 100644 --- a/route/rule_dns.go +++ b/route/rule_dns.go @@ -1,16 +1,11 @@ package route import ( - "strings" - "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" - F "github.com/sagernet/sing/common/format" - N "github.com/sagernet/sing/common/network" ) func NewDNSRule(router adapter.Router, logger log.ContextLogger, options option.DNSRule) (adapter.DNSRule, error) { @@ -39,21 +34,16 @@ func NewDNSRule(router adapter.Router, logger log.ContextLogger, options option. var _ adapter.DNSRule = (*DefaultDNSRule)(nil) type DefaultDNSRule struct { - items []RuleItem - sourceAddressItems []RuleItem - sourcePortItems []RuleItem - destinationAddressItems []RuleItem - destinationPortItems []RuleItem - allItems []RuleItem - invert bool - outbound string - disableCache bool + abstractDefaultRule + disableCache bool } func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options option.DefaultDNSRule) (*DefaultDNSRule, error) { rule := &DefaultDNSRule{ - invert: options.Invert, - outbound: options.Server, + abstractDefaultRule: abstractDefaultRule{ + invert: options.Invert, + outbound: options.Server, + }, disableCache: options.DisableCache, } if len(options.Inbound) > 0 { @@ -76,15 +66,10 @@ func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options rule.items = append(rule.items, item) rule.allItems = append(rule.allItems, item) } - if options.Network != "" { - switch options.Network { - case N.NetworkTCP, N.NetworkUDP: - item := NewNetworkItem(options.Network) - rule.items = append(rule.items, item) - rule.allItems = append(rule.allItems, item) - default: - return nil, E.New("invalid network: ", options.Network) - } + if len(options.Network) > 0 { + item := NewNetworkItem(options.Network) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) } if len(options.AuthUser) > 0 { item := NewAuthUserItem(options.AuthUser) @@ -196,131 +181,24 @@ func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options return rule, nil } -func (r *DefaultDNSRule) Type() string { - return C.RuleTypeDefault -} - -func (r *DefaultDNSRule) Start() error { - for _, item := range r.allItems { - err := common.Start(item) - if err != nil { - return err - } - } - return nil -} - -func (r *DefaultDNSRule) Close() error { - for _, item := range r.allItems { - err := common.Close(item) - if err != nil { - return err - } - } - return nil -} - -func (r *DefaultDNSRule) UpdateGeosite() error { - for _, item := range r.allItems { - if geositeItem, isSite := item.(*GeositeItem); isSite { - err := geositeItem.Update() - if err != nil { - return err - } - } - } - return nil -} - -func (r *DefaultDNSRule) Match(metadata *adapter.InboundContext) bool { - for _, item := range r.items { - if !item.Match(metadata) { - return r.invert - } - } - - if len(r.sourceAddressItems) > 0 { - var sourceAddressMatch bool - for _, item := range r.sourceAddressItems { - if item.Match(metadata) { - sourceAddressMatch = true - break - } - } - if !sourceAddressMatch { - return r.invert - } - } - - if len(r.sourcePortItems) > 0 { - var sourcePortMatch bool - for _, item := range r.sourcePortItems { - if item.Match(metadata) { - sourcePortMatch = true - break - } - } - if !sourcePortMatch { - return r.invert - } - } - - if len(r.destinationAddressItems) > 0 { - var destinationAddressMatch bool - for _, item := range r.destinationAddressItems { - if item.Match(metadata) { - destinationAddressMatch = true - break - } - } - if !destinationAddressMatch { - return r.invert - } - } - - if len(r.destinationPortItems) > 0 { - var destinationPortMatch bool - for _, item := range r.destinationPortItems { - if item.Match(metadata) { - destinationPortMatch = true - break - } - } - if !destinationPortMatch { - return r.invert - } - } - - return !r.invert -} - -func (r *DefaultDNSRule) Outbound() string { - return r.outbound -} - func (r *DefaultDNSRule) DisableCache() bool { return r.disableCache } -func (r *DefaultDNSRule) String() string { - return strings.Join(F.MapToString(r.allItems), " ") -} - var _ adapter.DNSRule = (*LogicalDNSRule)(nil) type LogicalDNSRule struct { - mode string - rules []*DefaultDNSRule - invert bool - outbound string + abstractLogicalRule disableCache bool } func NewLogicalDNSRule(router adapter.Router, logger log.ContextLogger, options option.LogicalDNSRule) (*LogicalDNSRule, error) { r := &LogicalDNSRule{ - rules: make([]*DefaultDNSRule, len(options.Rules)), - invert: options.Invert, - outbound: options.Server, + abstractLogicalRule: abstractLogicalRule{ + rules: make([]adapter.Rule, len(options.Rules)), + invert: options.Invert, + outbound: options.Server, + }, disableCache: options.DisableCache, } switch options.Mode { @@ -341,71 +219,6 @@ func NewLogicalDNSRule(router adapter.Router, logger log.ContextLogger, options return r, nil } -func (r *LogicalDNSRule) Type() string { - return C.RuleTypeLogical -} - -func (r *LogicalDNSRule) UpdateGeosite() error { - for _, rule := range r.rules { - err := rule.UpdateGeosite() - if err != nil { - return err - } - } - return nil -} - -func (r *LogicalDNSRule) Start() error { - for _, rule := range r.rules { - err := rule.Start() - if err != nil { - return err - } - } - return nil -} - -func (r *LogicalDNSRule) Close() error { - for _, rule := range r.rules { - err := rule.Close() - if err != nil { - return err - } - } - return nil -} - -func (r *LogicalDNSRule) Match(metadata *adapter.InboundContext) bool { - if r.mode == C.LogicalTypeAnd { - return common.All(r.rules, func(it *DefaultDNSRule) bool { - return it.Match(metadata) - }) != r.invert - } else { - return common.Any(r.rules, func(it *DefaultDNSRule) bool { - return it.Match(metadata) - }) != r.invert - } -} - -func (r *LogicalDNSRule) Outbound() string { - return r.outbound -} - func (r *LogicalDNSRule) DisableCache() bool { return r.disableCache } - -func (r *LogicalDNSRule) String() string { - var op string - switch r.mode { - case C.LogicalTypeAnd: - op = "&&" - case C.LogicalTypeOr: - op = "||" - } - if !r.invert { - return strings.Join(F.MapToString(r.rules), " "+op+" ") - } else { - return "!(" + strings.Join(F.MapToString(r.rules), " "+op+" ") + ")" - } -} diff --git a/route/rule_ip.go b/route/rule_ip.go new file mode 100644 index 00000000..274a4344 --- /dev/null +++ b/route/rule_ip.go @@ -0,0 +1,176 @@ +package route + +import ( + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing-box/option" + tun "github.com/sagernet/sing-tun" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" +) + +func NewIPRule(router adapter.Router, logger log.ContextLogger, options option.IPRule) (adapter.IPRule, error) { + switch options.Type { + case "", C.RuleTypeDefault: + if !options.DefaultOptions.IsValid() { + return nil, E.New("missing conditions") + } + if common.IsEmpty(options.DefaultOptions.Action) { + return nil, E.New("missing action") + } + return NewDefaultIPRule(router, logger, options.DefaultOptions) + case C.RuleTypeLogical: + if !options.LogicalOptions.IsValid() { + return nil, E.New("missing conditions") + } + if common.IsEmpty(options.DefaultOptions.Action) { + return nil, E.New("missing action") + } + return NewLogicalIPRule(router, logger, options.LogicalOptions) + default: + return nil, E.New("unknown rule type: ", options.Type) + } +} + +var _ adapter.IPRule = (*DefaultIPRule)(nil) + +type DefaultIPRule struct { + abstractDefaultRule + action tun.ActionType +} + +func NewDefaultIPRule(router adapter.Router, logger log.ContextLogger, options option.DefaultIPRule) (*DefaultIPRule, error) { + rule := &DefaultIPRule{ + abstractDefaultRule: abstractDefaultRule{ + invert: options.Invert, + outbound: options.Outbound, + }, + action: tun.ActionType(options.Action), + } + if len(options.Inbound) > 0 { + item := NewInboundRule(options.Inbound) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } + if options.IPVersion > 0 { + switch options.IPVersion { + case 4, 6: + item := NewIPVersionItem(options.IPVersion == 6) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + default: + return nil, E.New("invalid ip version: ", options.IPVersion) + } + } + if len(options.Network) > 0 { + item := NewNetworkItem(options.Network) + rule.items = append(rule.items, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.Domain) > 0 || len(options.DomainSuffix) > 0 { + item := NewDomainItem(options.Domain, options.DomainSuffix) + rule.destinationAddressItems = append(rule.destinationAddressItems, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.DomainKeyword) > 0 { + item := NewDomainKeywordItem(options.DomainKeyword) + rule.destinationAddressItems = append(rule.destinationAddressItems, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.DomainRegex) > 0 { + item, err := NewDomainRegexItem(options.DomainRegex) + if err != nil { + return nil, E.Cause(err, "domain_regex") + } + rule.destinationAddressItems = append(rule.destinationAddressItems, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.Geosite) > 0 { + item := NewGeositeItem(router, logger, options.Geosite) + rule.destinationAddressItems = append(rule.destinationAddressItems, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.SourceGeoIP) > 0 { + item := NewGeoIPItem(router, logger, true, options.SourceGeoIP) + rule.sourceAddressItems = append(rule.sourceAddressItems, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.SourceIPCIDR) > 0 { + item, err := NewIPCIDRItem(true, options.SourceIPCIDR) + if err != nil { + return nil, E.Cause(err, "source_ipcidr") + } + rule.sourceAddressItems = append(rule.sourceAddressItems, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.SourcePort) > 0 { + item := NewPortItem(true, options.SourcePort) + rule.sourcePortItems = append(rule.sourcePortItems, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.SourcePortRange) > 0 { + item, err := NewPortRangeItem(true, options.SourcePortRange) + if err != nil { + return nil, E.Cause(err, "source_port_range") + } + rule.sourcePortItems = append(rule.sourcePortItems, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.Port) > 0 { + item := NewPortItem(false, options.Port) + rule.destinationPortItems = append(rule.destinationPortItems, item) + rule.allItems = append(rule.allItems, item) + } + if len(options.PortRange) > 0 { + item, err := NewPortRangeItem(false, options.PortRange) + if err != nil { + return nil, E.Cause(err, "port_range") + } + rule.destinationPortItems = append(rule.destinationPortItems, item) + rule.allItems = append(rule.allItems, item) + } + return rule, nil +} + +func (r *DefaultIPRule) Action() tun.ActionType { + return r.action +} + +var _ adapter.IPRule = (*LogicalIPRule)(nil) + +type LogicalIPRule struct { + abstractLogicalRule + action tun.ActionType +} + +func NewLogicalIPRule(router adapter.Router, logger log.ContextLogger, options option.LogicalIPRule) (*LogicalIPRule, error) { + r := &LogicalIPRule{ + abstractLogicalRule: abstractLogicalRule{ + rules: make([]adapter.Rule, len(options.Rules)), + invert: options.Invert, + outbound: options.Outbound, + }, + action: tun.ActionType(options.Action), + } + switch options.Mode { + case C.LogicalTypeAnd: + r.mode = C.LogicalTypeAnd + case C.LogicalTypeOr: + r.mode = C.LogicalTypeOr + default: + return nil, E.New("unknown logical mode: ", options.Mode) + } + for i, subRule := range options.Rules { + rule, err := NewDefaultIPRule(router, logger, subRule) + if err != nil { + return nil, E.Cause(err, "sub rule[", i, "]") + } + r.rules[i] = rule + } + return r, nil +} + +func (r *LogicalIPRule) Action() tun.ActionType { + return r.action +} diff --git a/route/rule_auth_user.go b/route/rule_item_auth_user.go similarity index 100% rename from route/rule_auth_user.go rename to route/rule_item_auth_user.go diff --git a/route/rule_cidr.go b/route/rule_item_cidr.go similarity index 100% rename from route/rule_cidr.go rename to route/rule_item_cidr.go diff --git a/route/rule_clash_mode.go b/route/rule_item_clash_mode.go similarity index 100% rename from route/rule_clash_mode.go rename to route/rule_item_clash_mode.go diff --git a/route/rule_domain.go b/route/rule_item_domain.go similarity index 100% rename from route/rule_domain.go rename to route/rule_item_domain.go diff --git a/route/rule_domain_keyword.go b/route/rule_item_domain_keyword.go similarity index 100% rename from route/rule_domain_keyword.go rename to route/rule_item_domain_keyword.go diff --git a/route/rule_domain_regex.go b/route/rule_item_domain_regex.go similarity index 100% rename from route/rule_domain_regex.go rename to route/rule_item_domain_regex.go diff --git a/route/rule_geoip.go b/route/rule_item_geoip.go similarity index 100% rename from route/rule_geoip.go rename to route/rule_item_geoip.go diff --git a/route/rule_geosite.go b/route/rule_item_geosite.go similarity index 100% rename from route/rule_geosite.go rename to route/rule_item_geosite.go diff --git a/route/rule_inbound.go b/route/rule_item_inbound.go similarity index 100% rename from route/rule_inbound.go rename to route/rule_item_inbound.go diff --git a/route/rule_ipversion.go b/route/rule_item_ipversion.go similarity index 100% rename from route/rule_ipversion.go rename to route/rule_item_ipversion.go diff --git a/route/rule_item_network.go b/route/rule_item_network.go new file mode 100644 index 00000000..fc54f425 --- /dev/null +++ b/route/rule_item_network.go @@ -0,0 +1,42 @@ +package route + +import ( + "strings" + + "github.com/sagernet/sing-box/adapter" + F "github.com/sagernet/sing/common/format" +) + +var _ RuleItem = (*NetworkItem)(nil) + +type NetworkItem struct { + networks []string + networkMap map[string]bool +} + +func NewNetworkItem(networks []string) *NetworkItem { + networkMap := make(map[string]bool) + for _, network := range networks { + networkMap[network] = true + } + return &NetworkItem{ + networks: networks, + networkMap: networkMap, + } +} + +func (r *NetworkItem) Match(metadata *adapter.InboundContext) bool { + return r.networkMap[metadata.Network] +} + +func (r *NetworkItem) String() string { + description := "network=" + + pLen := len(r.networks) + if pLen == 1 { + description += F.ToString(r.networks[0]) + } else { + description += "[" + strings.Join(F.MapToString(r.networks), " ") + "]" + } + return description +} diff --git a/route/rule_outbound.go b/route/rule_item_outbound.go similarity index 100% rename from route/rule_outbound.go rename to route/rule_item_outbound.go diff --git a/route/rule_package_name.go b/route/rule_item_package_name.go similarity index 100% rename from route/rule_package_name.go rename to route/rule_item_package_name.go diff --git a/route/rule_port.go b/route/rule_item_port.go similarity index 100% rename from route/rule_port.go rename to route/rule_item_port.go diff --git a/route/rule_port_range.go b/route/rule_item_port_range.go similarity index 100% rename from route/rule_port_range.go rename to route/rule_item_port_range.go diff --git a/route/rule_process_name.go b/route/rule_item_process_name.go similarity index 100% rename from route/rule_process_name.go rename to route/rule_item_process_name.go diff --git a/route/rule_process_path.go b/route/rule_item_process_path.go similarity index 100% rename from route/rule_process_path.go rename to route/rule_item_process_path.go diff --git a/route/rule_protocol.go b/route/rule_item_protocol.go similarity index 100% rename from route/rule_protocol.go rename to route/rule_item_protocol.go diff --git a/route/rule_query_type.go b/route/rule_item_query_type.go similarity index 100% rename from route/rule_query_type.go rename to route/rule_item_query_type.go diff --git a/route/rule_user.go b/route/rule_item_user.go similarity index 100% rename from route/rule_user.go rename to route/rule_item_user.go diff --git a/route/rule_user_id.go b/route/rule_item_user_id.go similarity index 100% rename from route/rule_user_id.go rename to route/rule_item_user_id.go diff --git a/route/rule_network.go b/route/rule_network.go deleted file mode 100644 index 0346cb13..00000000 --- a/route/rule_network.go +++ /dev/null @@ -1,23 +0,0 @@ -package route - -import ( - "github.com/sagernet/sing-box/adapter" -) - -var _ RuleItem = (*NetworkItem)(nil) - -type NetworkItem struct { - network string -} - -func NewNetworkItem(network string) *NetworkItem { - return &NetworkItem{network} -} - -func (r *NetworkItem) Match(metadata *adapter.InboundContext) bool { - return r.network == metadata.Network -} - -func (r *NetworkItem) String() string { - return "network=" + r.network -} diff --git a/transport/wireguard/client_bind.go b/transport/wireguard/client_bind.go index 91a30f24..570b2831 100644 --- a/transport/wireguard/client_bind.go +++ b/transport/wireguard/client_bind.go @@ -100,14 +100,10 @@ func (c *ClientBind) receive(b []byte) (n int, ep conn.Endpoint, err error) { } func (c *ClientBind) Reset() { - c.connAccess.Lock() - defer c.connAccess.Unlock() common.Close(common.PtrOrNil(c.conn)) } func (c *ClientBind) Close() error { - c.connAccess.Lock() - defer c.connAccess.Unlock() common.Close(common.PtrOrNil(c.conn)) if c.done == nil { c.done = make(chan struct{}) diff --git a/transport/wireguard/device.go b/transport/wireguard/device.go index 14e04bf5..9fb750b0 100644 --- a/transport/wireguard/device.go +++ b/transport/wireguard/device.go @@ -1,13 +1,23 @@ package wireguard import ( + "net/netip" + + "github.com/sagernet/sing-tun" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/wireguard-go/tun" + wgTun "github.com/sagernet/wireguard-go/tun" ) type Device interface { - tun.Device + wgTun.Device N.Dialer Start() error + Inet4Address() netip.Addr + Inet6Address() netip.Addr // NewEndpoint() (stack.LinkEndpoint, error) } + +type NatDevice interface { + Device + CreateDestination(session tun.RouteSession, conn tun.RouteContext) tun.DirectDestination +} diff --git a/transport/wireguard/device_nat.go b/transport/wireguard/device_nat.go new file mode 100644 index 00000000..72201bb5 --- /dev/null +++ b/transport/wireguard/device_nat.go @@ -0,0 +1,75 @@ +package wireguard + +import ( + "github.com/sagernet/sing-tun" + "github.com/sagernet/sing/common/buf" +) + +var _ Device = (*natDeviceWrapper)(nil) + +type natDeviceWrapper struct { + Device + outbound chan *buf.Buffer + mapping *tun.NatMapping + writer *tun.NatWriter +} + +func NewNATDevice(upstream Device, ipRewrite bool) NatDevice { + wrapper := &natDeviceWrapper{ + Device: upstream, + outbound: make(chan *buf.Buffer, 256), + mapping: tun.NewNatMapping(ipRewrite), + } + if ipRewrite { + wrapper.writer = tun.NewNatWriter(upstream.Inet4Address(), upstream.Inet6Address()) + } + return wrapper +} + +func (d *natDeviceWrapper) Read(p []byte, offset int) (int, error) { + select { + case packet := <-d.outbound: + defer packet.Release() + return copy(p[offset:], packet.Bytes()), nil + default: + } + return d.Device.Read(p, offset) +} + +func (d *natDeviceWrapper) Write(p []byte, offset int) (int, error) { + packet := p[offset:] + handled, err := d.mapping.WritePacket(packet) + if handled { + return len(packet), err + } + return d.Device.Write(p, offset) +} + +func (d *natDeviceWrapper) CreateDestination(session tun.RouteSession, conn tun.RouteContext) tun.DirectDestination { + d.mapping.CreateSession(session, conn) + return &natDestinationWrapper{d, session} +} + +var _ tun.DirectDestination = (*natDestinationWrapper)(nil) + +type natDestinationWrapper struct { + device *natDeviceWrapper + session tun.RouteSession +} + +func (d *natDestinationWrapper) WritePacket(buffer *buf.Buffer) error { + if d.device.writer != nil { + d.device.writer.RewritePacket(buffer.Bytes()) + } + d.device.outbound <- buffer + return nil +} + +func (d *natDestinationWrapper) Close() error { + d.device.mapping.DeleteSession(d.session) + return nil +} + +func (d *natDestinationWrapper) Timeout() bool { + return false +} diff --git a/transport/wireguard/device_nat_gvisor.go b/transport/wireguard/device_nat_gvisor.go new file mode 100644 index 00000000..6c55ec96 --- /dev/null +++ b/transport/wireguard/device_nat_gvisor.go @@ -0,0 +1,27 @@ +//go:build with_gvisor + +package wireguard + +import ( + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +func (d *natDestinationWrapper) WritePacketBuffer(buffer *stack.PacketBuffer) error { + defer buffer.DecRef() + if d.device.writer != nil { + d.device.writer.RewritePacketBuffer(buffer) + } + var packetLen int + for _, slice := range buffer.AsSlices() { + packetLen += len(slice) + } + packet := buf.NewSize(packetLen) + for _, slice := range buffer.AsSlices() { + common.Must1(packet.Write(slice)) + } + d.device.outbound <- packet + return nil +} diff --git a/transport/wireguard/device_stack.go b/transport/wireguard/device_stack.go index b2981e36..56f8f4a5 100644 --- a/transport/wireguard/device_stack.go +++ b/transport/wireguard/device_stack.go @@ -8,10 +8,12 @@ import ( "net/netip" "os" + "github.com/sagernet/sing-tun" + "github.com/sagernet/sing/common/buf" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/wireguard-go/tun" + wgTun "github.com/sagernet/wireguard-go/tun" "gvisor.dev/gvisor/pkg/bufferv2" "gvisor.dev/gvisor/pkg/tcpip" @@ -25,33 +27,38 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/transport/udp" ) -var _ Device = (*StackDevice)(nil) +var _ NatDevice = (*StackDevice)(nil) const defaultNIC tcpip.NICID = 1 type StackDevice struct { - stack *stack.Stack - mtu uint32 - events chan tun.Event - outbound chan *stack.PacketBuffer - done chan struct{} - dispatcher stack.NetworkDispatcher - addr4 tcpip.Address - addr6 tcpip.Address + stack *stack.Stack + mtu uint32 + events chan wgTun.Event + outbound chan *stack.PacketBuffer + packetOutbound chan *buf.Buffer + done chan struct{} + dispatcher stack.NetworkDispatcher + addr4 tcpip.Address + addr6 tcpip.Address + mapping *tun.NatMapping + writer *tun.NatWriter } -func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, error) { +func NewStackDevice(localAddresses []netip.Prefix, mtu uint32, ipRewrite bool) (*StackDevice, error) { ipStack := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol, icmp.NewProtocol4, icmp.NewProtocol6}, HandleLocal: true, }) tunDevice := &StackDevice{ - stack: ipStack, - mtu: mtu, - events: make(chan tun.Event, 1), - outbound: make(chan *stack.PacketBuffer, 256), - done: make(chan struct{}), + stack: ipStack, + mtu: mtu, + events: make(chan wgTun.Event, 1), + outbound: make(chan *stack.PacketBuffer, 256), + packetOutbound: make(chan *buf.Buffer, 256), + done: make(chan struct{}), + mapping: tun.NewNatMapping(ipRewrite), } err := ipStack.CreateNIC(defaultNIC, (*wireEndpoint)(tunDevice)) if err != nil { @@ -77,6 +84,9 @@ func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (*StackDevice, er return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", err.String()) } } + if ipRewrite { + tunDevice.writer = tun.NewNatWriter(tunDevice.Inet4Address(), tunDevice.Inet6Address()) + } sOpt := tcpip.TCPSACKEnabled(true) ipStack.SetTransportProtocolOption(tcp.ProtocolNumber, &sOpt) cOpt := tcpip.CongestionControlOption("cubic") @@ -144,8 +154,16 @@ func (w *StackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) return udpConn, nil } +func (w *StackDevice) Inet4Address() netip.Addr { + return M.AddrFromIP(net.IP(w.addr4)) +} + +func (w *StackDevice) Inet6Address() netip.Addr { + return M.AddrFromIP(net.IP(w.addr6)) +} + func (w *StackDevice) Start() error { - w.events <- tun.EventUp + w.events <- wgTun.EventUp return nil } @@ -165,6 +183,10 @@ func (w *StackDevice) Read(p []byte, offset int) (n int, err error) { n += copy(p[n:], slice) } return + case packet := <-w.packetOutbound: + defer packet.Release() + n = copy(p[offset:], packet.Bytes()) + return case <-w.done: return 0, os.ErrClosed } @@ -175,6 +197,10 @@ func (w *StackDevice) Write(p []byte, offset int) (n int, err error) { if len(p) == 0 { return } + handled, err := w.mapping.WritePacket(p) + if handled { + return len(p), err + } var networkProtocol tcpip.NetworkProtocolNumber switch header.IPVersion(p) { case header.IPv4Version: @@ -203,7 +229,7 @@ func (w *StackDevice) Name() (string, error) { return "sing-box", nil } -func (w *StackDevice) Events() chan tun.Event { +func (w *StackDevice) Events() chan wgTun.Event { return w.events } @@ -222,6 +248,44 @@ func (w *StackDevice) Close() error { return nil } +func (w *StackDevice) CreateDestination(session tun.RouteSession, conn tun.RouteContext) tun.DirectDestination { + w.mapping.CreateSession(session, conn) + return &stackNatDestination{ + device: w, + session: session, + } +} + +type stackNatDestination struct { + device *StackDevice + session tun.RouteSession +} + +func (d *stackNatDestination) WritePacket(buffer *buf.Buffer) error { + if d.device.writer != nil { + d.device.writer.RewritePacket(buffer.Bytes()) + } + d.device.packetOutbound <- buffer + return nil +} + +func (d *stackNatDestination) WritePacketBuffer(buffer *stack.PacketBuffer) error { + if d.device.writer != nil { + d.device.writer.RewritePacketBuffer(buffer) + } + d.device.outbound <- buffer + return nil +} + +func (d *stackNatDestination) Close() error { + d.device.mapping.DeleteSession(d.session) + return nil +} + +func (d *stackNatDestination) Timeout() bool { + return false +} + var _ stack.LinkEndpoint = (*wireEndpoint)(nil) type wireEndpoint StackDevice diff --git a/transport/wireguard/device_stack_stub.go b/transport/wireguard/device_stack_stub.go index b383ab38..5d2fc1dc 100644 --- a/transport/wireguard/device_stack_stub.go +++ b/transport/wireguard/device_stack_stub.go @@ -8,6 +8,6 @@ import ( "github.com/sagernet/sing-tun" ) -func NewStackDevice(localAddresses []netip.Prefix, mtu uint32) (Device, error) { +func NewStackDevice(localAddresses []netip.Prefix, mtu uint32, ipRewrite bool) (Device, error) { return nil, tun.ErrGVisorNotIncluded } diff --git a/transport/wireguard/device_system.go b/transport/wireguard/device_system.go index d4316422..faca3023 100644 --- a/transport/wireguard/device_system.go +++ b/transport/wireguard/device_system.go @@ -23,6 +23,8 @@ type SystemDevice struct { name string mtu int events chan wgTun.Event + addr4 netip.Addr + addr6 netip.Addr } /*func (w *SystemDevice) NewEndpoint() (stack.LinkEndpoint, error) { @@ -55,11 +57,24 @@ func NewSystemDevice(router adapter.Router, interfaceName string, localPrefixes if err != nil { return nil, err } + var inet4Address netip.Addr + var inet6Address netip.Addr + if len(inet4Addresses) > 0 { + inet4Address = inet4Addresses[0].Addr() + } + if len(inet6Addresses) > 0 { + inet6Address = inet6Addresses[0].Addr() + } return &SystemDevice{ - dialer.NewDefault(router, option.DialerOptions{ + dialer: dialer.NewDefault(router, option.DialerOptions{ BindInterface: interfaceName, }), - tunInterface, interfaceName, int(mtu), make(chan wgTun.Event), + device: tunInterface, + name: interfaceName, + mtu: int(mtu), + events: make(chan wgTun.Event), + addr4: inet4Address, + addr6: inet6Address, }, nil } @@ -71,6 +86,14 @@ func (w *SystemDevice) ListenPacket(ctx context.Context, destination M.Socksaddr return w.dialer.ListenPacket(ctx, destination) } +func (w *SystemDevice) Inet4Address() netip.Addr { + return w.addr4 +} + +func (w *SystemDevice) Inet6Address() netip.Addr { + return w.addr6 +} + func (w *SystemDevice) Start() error { w.events <- wgTun.EventUp return nil @@ -80,12 +103,12 @@ func (w *SystemDevice) File() *os.File { return nil } -func (w *SystemDevice) Read(bytes []byte, index int) (int, error) { - return w.device.Read(bytes[index-tun.PacketOffset:]) +func (w *SystemDevice) Read(p []byte, offset int) (int, error) { + return w.device.Read(p[offset-tun.PacketOffset:]) } -func (w *SystemDevice) Write(bytes []byte, index int) (int, error) { - return w.device.Write(bytes[index:]) +func (w *SystemDevice) Write(p []byte, offset int) (int, error) { + return w.device.Write(p[offset:]) } func (w *SystemDevice) Flush() error {