feat: add sniffer port whitelist, when empty will add all ports

This commit is contained in:
gVisor bot 2022-04-21 07:06:08 -07:00
parent 113d84b438
commit 03a014957f
6 changed files with 112 additions and 29 deletions

44
common/utils/range.go Normal file
View File

@ -0,0 +1,44 @@
package utils
import (
"golang.org/x/exp/constraints"
)
type Range[T constraints.Ordered] struct {
start T
end T
}
func NewRange[T constraints.Ordered](start, end T) *Range[T] {
if start > end {
return &Range[T]{
start: end,
end: start,
}
}
return &Range[T]{
start: start,
end: end,
}
}
func (r *Range[T]) Contains(t T) bool {
return t >= r.start && t <= r.end
}
func (r *Range[T]) LeftContains(t T) bool {
return t >= r.start && t < r.end
}
func (r *Range[T]) RightContains(t T) bool {
return t > r.start && t <= r.end
}
func (r *Range[T]) Start() T {
return r.start
}
func (r *Range[T]) End() T {
return r.end
}

View File

@ -2,11 +2,14 @@ package sniffer
import ( import (
"errors" "errors"
"github.com/Dreamacro/clash/component/trie"
"net" "net"
"net/netip" "net/netip"
"strconv"
"github.com/Dreamacro/clash/component/trie"
CN "github.com/Dreamacro/clash/common/net" CN "github.com/Dreamacro/clash/common/net"
"github.com/Dreamacro/clash/common/utils"
"github.com/Dreamacro/clash/component/resolver" "github.com/Dreamacro/clash/component/resolver"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/log" "github.com/Dreamacro/clash/log"
@ -26,6 +29,7 @@ type SnifferDispatcher struct {
foreDomain *trie.DomainTrie[bool] foreDomain *trie.DomainTrie[bool]
skipSNI *trie.DomainTrie[bool] skipSNI *trie.DomainTrie[bool]
portRanges *[]utils.Range[uint16]
} }
func (sd *SnifferDispatcher) TCPSniff(conn net.Conn, metadata *C.Metadata) { func (sd *SnifferDispatcher) TCPSniff(conn net.Conn, metadata *C.Metadata) {
@ -35,6 +39,18 @@ func (sd *SnifferDispatcher) TCPSniff(conn net.Conn, metadata *C.Metadata) {
} }
if metadata.Host == "" || sd.foreDomain.Search(metadata.Host) != nil { if metadata.Host == "" || sd.foreDomain.Search(metadata.Host) != nil {
port, err := strconv.ParseUint(metadata.DstPort, 10, 16)
if err != nil {
log.Debugln("[Sniffer] Dst port is error")
return
}
for _, portRange := range *sd.portRanges {
if !portRange.Contains(uint16(port)) {
return
}
}
if host, err := sd.sniffDomain(bufConn, metadata); err != nil { if host, err := sd.sniffDomain(bufConn, metadata); err != nil {
log.Debugln("[Sniffer] All sniffing sniff failed with from [%s:%s] to [%s:%s]", metadata.SrcIP, metadata.SrcPort, metadata.String(), metadata.DstPort) log.Debugln("[Sniffer] All sniffing sniff failed with from [%s:%s] to [%s:%s]", metadata.SrcIP, metadata.SrcPort, metadata.String(), metadata.DstPort)
return return
@ -102,11 +118,13 @@ func NewCloseSnifferDispatcher() (*SnifferDispatcher, error) {
return &dispatcher, nil return &dispatcher, nil
} }
func NewSnifferDispatcher(needSniffer []C.SnifferType, forceDomain *trie.DomainTrie[bool], skipSNI *trie.DomainTrie[bool]) (*SnifferDispatcher, error) { func NewSnifferDispatcher(needSniffer []C.SnifferType, forceDomain *trie.DomainTrie[bool],
skipSNI *trie.DomainTrie[bool], ports *[]utils.Range[uint16]) (*SnifferDispatcher, error) {
dispatcher := SnifferDispatcher{ dispatcher := SnifferDispatcher{
enable: true, enable: true,
foreDomain: forceDomain, foreDomain: forceDomain,
skipSNI: skipSNI, skipSNI: skipSNI,
portRanges: ports,
} }
for _, snifferName := range needSniffer { for _, snifferName := range needSniffer {

View File

@ -4,16 +4,19 @@ import (
"container/list" "container/list"
"errors" "errors"
"fmt" "fmt"
R "github.com/Dreamacro/clash/rule"
RP "github.com/Dreamacro/clash/rule/provider"
"net" "net"
"net/netip" "net/netip"
"net/url" "net/url"
"os" "os"
"runtime" "runtime"
"strconv"
"strings" "strings"
"time" "time"
"github.com/Dreamacro/clash/common/utils"
R "github.com/Dreamacro/clash/rule"
RP "github.com/Dreamacro/clash/rule/provider"
"github.com/Dreamacro/clash/adapter" "github.com/Dreamacro/clash/adapter"
"github.com/Dreamacro/clash/adapter/outbound" "github.com/Dreamacro/clash/adapter/outbound"
"github.com/Dreamacro/clash/adapter/outboundgroup" "github.com/Dreamacro/clash/adapter/outboundgroup"
@ -127,6 +130,7 @@ type Sniffer struct {
Reverses *trie.DomainTrie[bool] Reverses *trie.DomainTrie[bool]
ForceDomain *trie.DomainTrie[bool] ForceDomain *trie.DomainTrie[bool]
SkipSNI *trie.DomainTrie[bool] SkipSNI *trie.DomainTrie[bool]
Ports *[]utils.Range[uint16]
} }
// Experimental config // Experimental config
@ -224,6 +228,7 @@ type SnifferRaw struct {
Reverse []string `yaml:"reverses" json:"reverses"` Reverse []string `yaml:"reverses" json:"reverses"`
ForceDomain []string `yaml:"force-domain" json:"force-domain"` ForceDomain []string `yaml:"force-domain" json:"force-domain"`
SkipSNI []string `yaml:"skip-sni" json:"skip-sni"` SkipSNI []string `yaml:"skip-sni" json:"skip-sni"`
Ports []string `yaml:"port-whitelist" json:"port-whitelist"`
} }
// Parse config // Parse config
@ -298,6 +303,7 @@ func UnmarshalRawConfig(buf []byte) (*RawConfig, error) {
Reverse: []string{}, Reverse: []string{},
ForceDomain: []string{}, ForceDomain: []string{},
SkipSNI: []string{}, SkipSNI: []string{},
Ports: []string{},
}, },
Profile: Profile{ Profile: Profile{
StoreSelected: true, StoreSelected: true,
@ -914,6 +920,33 @@ func parseSniffer(snifferRaw SnifferRaw) (*Sniffer, error) {
Force: snifferRaw.Force, Force: snifferRaw.Force,
} }
ports := []utils.Range[uint16]{}
if len(snifferRaw.Ports) == 0 {
ports = append(ports, *utils.NewRange[uint16](0, 65535))
} else {
for _, portRange := range snifferRaw.Ports {
portRaws := strings.Split(portRange, "-")
if len(portRaws) > 1 {
p, err := strconv.ParseUint(portRaws[0], 10, 16)
if err != nil {
return nil, fmt.Errorf("%s format error", portRange)
}
start := uint16(p)
p, err = strconv.ParseUint(portRaws[0], 10, 16)
if err != nil {
return nil, fmt.Errorf("%s format error", portRange)
}
end := uint16(p)
ports = append(ports, *utils.NewRange(start, end))
}
}
}
sniffer.Ports = &ports
loadSniffer := make(map[C.SnifferType]struct{}) loadSniffer := make(map[C.SnifferType]struct{})
for _, snifferName := range snifferRaw.Sniffing { for _, snifferName := range snifferRaw.Sniffing {

3
go.mod
View File

@ -21,7 +21,8 @@ require (
go.uber.org/atomic v1.9.0 go.uber.org/atomic v1.9.0
go.uber.org/automaxprocs v1.5.1 go.uber.org/automaxprocs v1.5.1
golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4 golang.org/x/crypto v0.0.0-20220411220226-7b82a4e95df4
golang.org/x/net v0.0.0-20220418201149-a630d4f3e7a2 golang.org/x/exp v0.0.0-20220414153411-bcd21879b8fd
golang.org/x/net v0.0.0-20220412020605-290c469a71a5
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c golang.org/x/sync v0.0.0-20210220032951-036812b2e83c
golang.org/x/sys v0.0.0-20220412211240-33da011f77ad golang.org/x/sys v0.0.0-20220412211240-33da011f77ad
golang.org/x/time v0.0.0-20220411224347-583f2d630306 golang.org/x/time v0.0.0-20220411224347-583f2d630306

View File

@ -222,7 +222,7 @@ func updateTun(tun *config.Tun, dns *config.DNS) {
func updateSniffer(sniffer *config.Sniffer) { func updateSniffer(sniffer *config.Sniffer) {
if sniffer.Enable { if sniffer.Enable {
dispatcher, err := SNI.NewSnifferDispatcher(sniffer.Sniffers, sniffer.ForceDomain, sniffer.SkipSNI) dispatcher, err := SNI.NewSnifferDispatcher(sniffer.Sniffers, sniffer.ForceDomain, sniffer.SkipSNI, sniffer.Ports)
if err != nil { if err != nil {
log.Warnln("initial sniffer failed, err:%v", err) log.Warnln("initial sniffer failed, err:%v", err)
} }

View File

@ -5,20 +5,16 @@ import (
"strconv" "strconv"
"strings" "strings"
"github.com/Dreamacro/clash/common/utils"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
) )
type portReal struct {
portStart int
portEnd int
}
type Port struct { type Port struct {
*Base *Base
adapter string adapter string
port string port string
isSource bool isSource bool
portList []portReal portList []utils.Range[uint16]
} }
func (p *Port) RuleType() C.RuleType { func (p *Port) RuleType() C.RuleType {
@ -45,17 +41,13 @@ func (p *Port) Payload() string {
func (p *Port) matchPortReal(portRef string) bool { func (p *Port) matchPortReal(portRef string) bool {
port, _ := strconv.Atoi(portRef) port, _ := strconv.Atoi(portRef)
var rs bool
for _, pr := range p.portList { for _, pr := range p.portList {
if pr.portEnd == -1 { if pr.Contains(uint16(port)) {
rs = port == pr.portStart
} else {
rs = port >= pr.portStart && port <= pr.portEnd
}
if rs {
return true return true
} }
} }
return false return false
} }
@ -65,7 +57,7 @@ func NewPort(port string, adapter string, isSource bool) (*Port, error) {
return nil, fmt.Errorf("%s, too many ports to use, maximum support 28 ports", errPayload.Error()) return nil, fmt.Errorf("%s, too many ports to use, maximum support 28 ports", errPayload.Error())
} }
var portList []portReal var portRange []utils.Range[uint16]
for _, p := range ports { for _, p := range ports {
if p == "" { if p == "" {
continue continue
@ -84,23 +76,18 @@ func NewPort(port string, adapter string, isSource bool) (*Port, error) {
switch subPortsLen { switch subPortsLen {
case 1: case 1:
portList = append(portList, portReal{int(portStart), -1}) portRange = append(portRange, *utils.NewRange(uint16(portStart), uint16(portStart)))
case 2: case 2:
portEnd, err := strconv.ParseUint(strings.Trim(subPorts[1], "[ ]"), 10, 16) portEnd, err := strconv.ParseUint(strings.Trim(subPorts[1], "[ ]"), 10, 16)
if err != nil { if err != nil {
return nil, errPayload return nil, errPayload
} }
shouldReverse := portStart > portEnd portRange = append(portRange, *utils.NewRange(uint16(portStart), uint16(portEnd)))
if shouldReverse {
portList = append(portList, portReal{int(portEnd), int(portStart)})
} else {
portList = append(portList, portReal{int(portStart), int(portEnd)})
}
} }
} }
if len(portList) == 0 { if len(portRange) == 0 {
return nil, errPayload return nil, errPayload
} }
@ -109,7 +96,7 @@ func NewPort(port string, adapter string, isSource bool) (*Port, error) {
adapter: adapter, adapter: adapter,
port: port, port: port,
isSource: isSource, isSource: isSource,
portList: portList, portList: portRange,
}, nil }, nil
} }