refactor: rule-set and its provider

This commit is contained in:
gVisor bot 2022-03-26 18:34:15 +08:00
parent 4e2e6879cd
commit 211638a9a9
4 changed files with 186 additions and 138 deletions

View File

@ -0,0 +1,54 @@
package provider
import (
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/log"
)
type classicalStrategy struct {
rules []C.Rule
count int
shouldResolveIP bool
}
func (c *classicalStrategy) Match(metadata *C.Metadata) bool {
for _, rule := range c.rules {
if rule.Match(metadata) {
return true
}
}
return false
}
func (c *classicalStrategy) Count() int {
return c.count
}
func (c *classicalStrategy) ShouldResolveIP() bool {
return c.shouldResolveIP
}
func (c *classicalStrategy) OnUpdate(rules []string) {
var classicalRules []C.Rule
shouldResolveIP := false
for _, rawRule := range rules {
ruleType, rule, params := ruleParse(rawRule)
r, err := parseRule(ruleType, rule, "", params)
if err != nil {
log.Warnln("parse rule error:[%s]", err.Error())
}
if !shouldResolveIP {
shouldResolveIP = shouldResolveIP || r.ShouldResolveIP()
}
classicalRules = append(classicalRules, r)
}
c.rules = classicalRules
}
func NewClassicalStrategy() *classicalStrategy {
return &classicalStrategy{rules: []C.Rule{}}
}

View File

@ -0,0 +1,57 @@
package provider
import (
"github.com/Dreamacro/clash/component/trie"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/log"
"strings"
)
type domainStrategy struct {
shouldResolveIP bool
count int
domainRules *trie.DomainTrie
}
func (d *domainStrategy) Match(metadata *C.Metadata) bool {
return d.domainRules != nil && d.domainRules.Search(metadata.Host) != nil
}
func (d *domainStrategy) Count() int {
return d.count
}
func (d *domainStrategy) ShouldResolveIP() bool {
return d.shouldResolveIP
}
func (d *domainStrategy) OnUpdate(rules []string) {
domainTrie := trie.New()
for _, rule := range rules {
err := domainTrie.Insert(rule, "")
if err != nil {
log.Warnln("invalid domain:[%s]", rule)
} else {
d.count++
}
}
d.domainRules = domainTrie
}
func ruleParse(ruleRaw string) (string, string, []string) {
item := strings.Split(ruleRaw, ",")
if len(item) == 1 {
return "", item[0], nil
} else if len(item) == 2 {
return item[0], item[1], nil
} else if len(item) > 2 {
return item[0], item[1], item[2:]
}
return "", "", nil
}
func NewDomainStrategy() *domainStrategy {
return &domainStrategy{shouldResolveIP: false}
}

View File

@ -0,0 +1,44 @@
package provider
import (
"github.com/Dreamacro/clash/component/trie"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/log"
)
type ipcidrStrategy struct {
count int
shouldResolveIP bool
trie *trie.IpCidrTrie
}
func (i *ipcidrStrategy) Match(metadata *C.Metadata) bool {
return i.trie != nil && i.trie.IsContain(metadata.DstIP)
}
func (i *ipcidrStrategy) Count() int {
return i.count
}
func (i *ipcidrStrategy) ShouldResolveIP() bool {
return i.shouldResolveIP
}
func (i *ipcidrStrategy) OnUpdate(rules []string) {
ipCidrTrie := trie.NewIpCidrTrie()
for _, rule := range rules {
err := ipCidrTrie.AddIpCidrForString(rule)
if err != nil {
log.Warnln("invalid Ipcidr:[%s]", rule)
} else {
i.count++
}
}
i.trie = ipCidrTrie
i.shouldResolveIP = i.count > 0
}
func NewIPCidrStrategy() *ipcidrStrategy {
return &ipcidrStrategy{}
}

View File

@ -2,14 +2,10 @@ package provider
import (
"encoding/json"
"errors"
"github.com/Dreamacro/clash/component/trie"
C "github.com/Dreamacro/clash/constant"
P "github.com/Dreamacro/clash/constant/provider"
"github.com/Dreamacro/clash/log"
"gopkg.in/yaml.v2"
"runtime"
"strings"
"time"
)
@ -19,12 +15,8 @@ var (
type ruleSetProvider struct {
*fetcher
behavior P.RuleType
shouldResolveIP bool
count int
DomainRules *trie.DomainTrie
IPCIDRRules *trie.IpCidrTrie
ClassicalRules []C.Rule
behavior P.RuleType
strategy ruleStrategy
}
type RuleSetProvider struct {
@ -39,6 +31,13 @@ type RulePayload struct {
Rules []string `yaml:"payload"`
}
type ruleStrategy interface {
Match(metadata *C.Metadata) bool
Count() int
ShouldResolveIP() bool
OnUpdate(rules []string)
}
func RuleProviders() map[string]P.RuleProvider {
return ruleProviders
}
@ -76,30 +75,11 @@ func (rp *ruleSetProvider) Behavior() P.RuleType {
}
func (rp *ruleSetProvider) Match(metadata *C.Metadata) bool {
if rp.count == 0 {
return false
}
switch rp.behavior {
case P.Domain:
return rp.DomainRules != nil && rp.DomainRules.Search(metadata.Host) != nil
case P.IPCIDR:
return rp.IPCIDRRules != nil && rp.IPCIDRRules.IsContain(metadata.DstIP)
case P.Classical:
for _, rule := range rp.ClassicalRules {
if rule.Match(metadata) {
return true
}
}
return false
default:
return false
}
return rp.strategy != nil && rp.strategy.Match(metadata)
}
func (rp *ruleSetProvider) ShouldResolveIP() bool {
return rp.shouldResolveIP
return rp.strategy.ShouldResolveIP()
}
func (rp *ruleSetProvider) AsRule(adaptor string) C.Rule {
@ -111,7 +91,7 @@ func (rp *ruleSetProvider) MarshalJSON() ([]byte, error) {
map[string]interface{}{
"behavior": rp.behavior.String(),
"name": rp.Name(),
"ruleCount": rp.count,
"ruleCount": rp.strategy.Count(),
"type": rp.Type().String(),
"updatedAt": rp.updatedAt,
"vehicleType": rp.VehicleType().String(),
@ -125,23 +105,14 @@ func NewRuleSetProvider(name string, behavior P.RuleType, interval time.Duration
onUpdate := func(elm interface{}) error {
rulesRaw := elm.([]string)
rules, err := constructRules(rp.behavior, rulesRaw)
if err != nil {
return err
}
if rp.behavior == P.Classical {
rp.count = len(rules.([]C.Rule))
} else {
rp.count = len(rulesRaw)
}
rp.setRules(rules)
rp.strategy.OnUpdate(rulesRaw)
return nil
}
fetcher := newFetcher(name, interval, vehicle, rulesParse, onUpdate)
rp.fetcher = fetcher
rp.strategy = newStrategy(behavior)
wrapper := &RuleSetProvider{
rp,
}
@ -150,6 +121,22 @@ func NewRuleSetProvider(name string, behavior P.RuleType, interval time.Duration
return wrapper
}
func newStrategy(behavior P.RuleType) ruleStrategy {
switch behavior {
case P.Domain:
strategy := NewDomainStrategy()
return strategy
case P.IPCIDR:
strategy := NewIPCidrStrategy()
return strategy
case P.Classical:
strategy := NewClassicalStrategy()
return strategy
default:
return nil
}
}
func rulesParse(buf []byte) (interface{}, error) {
rulePayload := RulePayload{}
err := yaml.Unmarshal(buf, &rulePayload)
@ -159,97 +146,3 @@ func rulesParse(buf []byte) (interface{}, error) {
return rulePayload.Rules, nil
}
func constructRules(behavior P.RuleType, rules []string) (interface{}, error) {
switch behavior {
case P.Domain:
return handleDomainRules(rules)
case P.IPCIDR:
return handleIpCidrRules(rules)
case P.Classical:
return handleClassicalRules(rules)
default:
return nil, errors.New("unknown behavior type")
}
}
func handleDomainRules(rules []string) (interface{}, error) {
domainRules := trie.New()
for _, rawRule := range rules {
ruleType, rule, _ := ruleParse(rawRule)
if ruleType != "" {
return nil, errors.New("error format of domain")
}
if err := domainRules.Insert(rule, ""); err != nil {
return nil, err
}
}
return domainRules, nil
}
func handleIpCidrRules(rules []string) (interface{}, error) {
ipCidrRules := trie.NewIpCidrTrie()
for _, rawRule := range rules {
ruleType, rule, _ := ruleParse(rawRule)
if ruleType != "" {
return nil, errors.New("error format of ip-cidr")
}
if err := ipCidrRules.AddIpCidrForString(rule); err != nil {
return nil, err
}
}
return ipCidrRules, nil
}
func handleClassicalRules(rules []string) (interface{}, error) {
var classicalRules []C.Rule
for _, rawRule := range rules {
ruleType, rule, params := ruleParse(rawRule)
r, err := parseRule(ruleType, rule, "", params)
if err != nil {
//return nil, err
log.Warnln("%s", err)
continue
}
classicalRules = append(classicalRules, r)
}
return classicalRules, nil
}
func ruleParse(ruleRaw string) (string, string, []string) {
item := strings.Split(ruleRaw, ",")
if len(item) == 1 {
return "", item[0], nil
} else if len(item) == 2 {
return item[0], item[1], nil
} else if len(item) > 2 {
return item[0], item[1], item[2:]
}
return "", "", nil
}
func (rp *ruleSetProvider) setRules(rules interface{}) {
switch rp.behavior {
case P.Domain:
rp.DomainRules = rules.(*trie.DomainTrie)
rp.shouldResolveIP = false
case P.Classical:
rp.ClassicalRules = rules.([]C.Rule)
for i := range rp.ClassicalRules {
if rp.ClassicalRules[i].ShouldResolveIP() {
rp.shouldResolveIP = true
break
}
}
case P.IPCIDR:
rp.IPCIDRRules = rules.(*trie.IpCidrTrie)
rp.shouldResolveIP = true
default:
rp.shouldResolveIP = false
}
}