diff --git a/rule/common/network_type.go b/rule/common/network_type.go index 93abeade..107df91e 100644 --- a/rule/common/network_type.go +++ b/rule/common/network_type.go @@ -13,22 +13,23 @@ type NetworkType struct { } func NewNetworkType(network, adapter string) (*NetworkType, error) { - var netType C.NetWork + ntType := NetworkType{ + Base: &Base{}, + } + + ntType.adapter = adapter switch strings.ToUpper(network) { case "TCP": - netType = C.TCP + ntType.network = C.TCP break case "UDP": - netType = C.UDP + ntType.network = C.UDP break default: return nil, fmt.Errorf("unsupported network type, only TCP/UDP") } - return &NetworkType{ - Base: &Base{}, - network: netType, - adapter: adapter, - }, nil + + return &ntType, nil } func (n *NetworkType) RuleType() C.RuleType { diff --git a/rule/logic/and.go b/rule/logic/and.go index 1d650e02..1d4c99f7 100644 --- a/rule/logic/and.go +++ b/rule/logic/and.go @@ -1,8 +1,6 @@ package logic import ( - "fmt" - C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/rule/common" ) @@ -21,16 +19,12 @@ func (A *AND) ShouldFindProcess() bool { func NewAND(payload string, adapter string) (*AND, error) { and := &AND{Base: &common.Base{}, payload: payload, adapter: adapter} - rules, err := parseRuleByPayload(payload, true) + rules, err := parseRuleByPayload(payload) if err != nil { return nil, err } and.rules = rules - if len(and.rules) == 0 { - return nil, fmt.Errorf("And rule is error, may be format error or not contain least one rule") - } - for _, rule := range rules { if rule.ShouldResolveIP() { and.needIP = true diff --git a/rule/logic/common.go b/rule/logic/common.go index 2966ae7b..75e3c319 100644 --- a/rule/logic/common.go +++ b/rule/logic/common.go @@ -2,20 +2,19 @@ package logic import ( "fmt" - "io" - "net/http" - "os" - "regexp" - "strings" - "github.com/Dreamacro/clash/common/collections" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/log" RC "github.com/Dreamacro/clash/rule/common" "github.com/Dreamacro/clash/rule/provider" + "io" + "net/http" + "os" + "regexp" + "strings" ) -func parseRuleByPayload(payload string, skip bool) ([]C.Rule, error) { +func parseRuleByPayload(payload string) ([]C.Rule, error) { regex, err := regexp.Compile("\\(.*\\)") if err != nil { return nil, err @@ -28,7 +27,7 @@ func parseRuleByPayload(payload string, skip bool) ([]C.Rule, error) { } rules := make([]C.Rule, 0, len(subAllRanges)) - subRanges := findSubRuleRange(payload, subAllRanges, skip) + subRanges := findSubRuleRange(payload, subAllRanges) for _, subRange := range subRanges { subPayload := payload[subRange.start+1 : subRange.end] @@ -53,7 +52,7 @@ func containRange(r Range, preStart, preEnd int) bool { func payloadToRule(subPayload string) (C.Rule, error) { splitStr := strings.SplitN(subPayload, ",", 2) if len(splitStr) < 2 { - return nil, fmt.Errorf("The logic rule contain a rule of error format") + return nil, fmt.Errorf("[%s] format is error", subPayload) } tp := splitStr[0] @@ -91,9 +90,9 @@ func parseRule(tp, payload string, params []string) (C.Rule, error) { parsed, parseErr = RC.NewGEOIP(payload, "", noResolve) case "IP-CIDR", "IP-CIDR6": noResolve := RC.HasNoResolve(params) - parsed, parseErr = RC.NewIPCIDR(payload, "", nil, RC.WithIPCIDRNoResolve(noResolve)) + parsed, parseErr = RC.NewIPCIDR(payload, "", RC.WithIPCIDRNoResolve(noResolve)) case "SRC-IP-CIDR": - parsed, parseErr = RC.NewIPCIDR(payload, "", nil, RC.WithIPCIDRSourceIP(true), RC.WithIPCIDRNoResolve(true)) + parsed, parseErr = RC.NewIPCIDR(payload, "", RC.WithIPCIDRSourceIP(true), RC.WithIPCIDRNoResolve(true)) case "SRC-PORT": parsed, parseErr = RC.NewPort(payload, "", true) case "DST-PORT": @@ -113,7 +112,7 @@ func parseRule(tp, payload string, params []string) (C.Rule, error) { case "NETWORK": parsed, parseErr = RC.NewNetworkType(payload, "") default: - parseErr = fmt.Errorf("unsupported rule type %s", tp) + parsed, parseErr = nil, fmt.Errorf("unsupported rule type %s", tp) } if parseErr != nil { @@ -151,6 +150,10 @@ func format(payload string) ([]Range, error) { num++ stack.Push(sr) } else if c == ')' { + if stack.Len() == 0 { + return nil, fmt.Errorf("missing '('") + } + sr := stack.Pop().(Range) sr.end = i subRanges = append(subRanges, sr) @@ -169,11 +172,11 @@ func format(payload string) ([]Range, error) { return sortResult, nil } -func findSubRuleRange(payload string, ruleRanges []Range, skip bool) []Range { +func findSubRuleRange(payload string, ruleRanges []Range) []Range { payloadLen := len(payload) subRuleRange := make([]Range, 0) for _, rr := range ruleRanges { - if rr.start == 0 && rr.end == payloadLen-1 && skip { + if rr.start == 0 && rr.end == payloadLen-1 { // 最大范围跳过 continue } diff --git a/rule/logic/logic_test.go b/rule/logic/logic_test.go new file mode 100644 index 00000000..c279995b --- /dev/null +++ b/rule/logic/logic_test.go @@ -0,0 +1,49 @@ +package logic + +import ( + "github.com/Dreamacro/clash/constant" + "github.com/stretchr/testify/assert" + "testing" +) + +func TestAND(t *testing.T) { + and, err := NewAND("((DOMAIN,baidu.com),(NETWORK,TCP),(DST-PORT,10001-65535))", "DIRECT") + assert.Equal(t, nil, err) + assert.Equal(t, "DIRECT", and.adapter) + assert.Equal(t, false, and.ShouldResolveIP()) + assert.Equal(t, true, and.Match(&constant.Metadata{ + Host: "baidu.com", + AddrType: constant.AtypDomainName, + NetWork: constant.TCP, + DstPort: "20000", + })) + + and, err = NewAND("(DOMAIN,baidu.com),(NETWORK,TCP),(DST-PORT,10001-65535))", "DIRECT") + assert.NotEqual(t, nil, err) + + and, err = NewAND("((AND,(DOMAIN,baidu.com),(NETWORK,TCP)),(NETWORK,TCP),(DST-PORT,10001-65535))", "DIRECT") + assert.Equal(t, nil, err) +} + +func TestNOT(t *testing.T) { + not, err := NewNOT("((DST-PORT,6000-6500))", "REJECT") + assert.Equal(t, nil, err) + assert.Equal(t, false, not.Match(&constant.Metadata{ + DstPort: "6100", + })) + + _, err = NewNOT("((DST-PORT,5600-6666),(DOMAIN,baidu.com))", "DIRECT") + assert.NotEqual(t, nil, err) + + _, err = NewNOT("(())", "DIRECT") + assert.NotEqual(t, nil, err) +} + +func TestOR(t *testing.T) { + or, err := NewOR("((DOMAIN,baidu.com),(NETWORK,TCP),(DST-PORT,10001-65535))", "DIRECT") + assert.Equal(t, nil, err) + assert.Equal(t, true, or.Match(&constant.Metadata{ + NetWork: constant.TCP, + })) + assert.Equal(t, false, or.ShouldResolveIP()) +} diff --git a/rule/logic/not.go b/rule/logic/not.go index b04d7bc7..4d56bd2a 100644 --- a/rule/logic/not.go +++ b/rule/logic/not.go @@ -19,16 +19,19 @@ func (not *NOT) ShouldFindProcess() bool { func NewNOT(payload string, adapter string) (*NOT, error) { not := &NOT{Base: &common.Base{}, payload: payload, adapter: adapter} - rule, err := parseRuleByPayload(payload, false) + rule, err := parseRuleByPayload(payload) if err != nil { return nil, err } - if len(rule) < 1 { - return nil, fmt.Errorf("NOT rule have not a rule") + if len(rule) > 1 { + return nil, fmt.Errorf("not rule can contain at most one rule") + } + + if len(rule) > 0 { + not.rule = rule[0] } - not.rule = rule[0] return not, nil } @@ -37,7 +40,7 @@ func (not *NOT) RuleType() C.RuleType { } func (not *NOT) Match(metadata *C.Metadata) bool { - return !not.rule.Match(metadata) + return not.rule == nil || !not.rule.Match(metadata) } func (not *NOT) Adapter() string { @@ -49,5 +52,5 @@ func (not *NOT) Payload() string { } func (not *NOT) ShouldResolveIP() bool { - return not.rule.ShouldResolveIP() + return not.rule != nil && not.rule.ShouldResolveIP() } diff --git a/rule/logic/or.go b/rule/logic/or.go index 9afb73d4..05ad6f91 100644 --- a/rule/logic/or.go +++ b/rule/logic/or.go @@ -1,8 +1,6 @@ package logic import ( - "fmt" - C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/rule/common" ) @@ -47,16 +45,12 @@ func (or *OR) ShouldResolveIP() bool { func NewOR(payload string, adapter string) (*OR, error) { or := &OR{Base: &common.Base{}, payload: payload, adapter: adapter} - rules, err := parseRuleByPayload(payload, true) + rules, err := parseRuleByPayload(payload) if err != nil { return nil, err } or.rules = rules - if len(or.rules) == 0 { - return nil, fmt.Errorf("Or rule is error, may be format error or not contain least one rule") - } - for _, rule := range rules { if rule.ShouldResolveIP() { or.needIP = true