diff --git a/v2/pkg/protocols/common/compare/compare.go b/v2/pkg/protocols/common/compare/compare.go deleted file mode 100644 index 6e4970451..000000000 --- a/v2/pkg/protocols/common/compare/compare.go +++ /dev/null @@ -1,39 +0,0 @@ -package compare - -import ( - "strings" -) - -// StringSlice compares two string slices for equality -func StringSlice(a, b []string) bool { - // If one is nil, the other must also be nil. - if (a == nil) != (b == nil) { - return false - } - if len(a) != len(b) { - return false - } - for i := range a { - if !strings.EqualFold(a[i], b[i]) { - return false - } - } - return true -} - -// StringMap compares two string maps for equality -func StringMap(a, b map[string]string) bool { - // If one is nil, the other must also be nil. - if (a == nil) != (b == nil) { - return false - } - if len(a) != len(b) { - return false - } - for k, v := range a { - if w, ok := b[k]; !ok || !strings.EqualFold(v, w) { - return false - } - } - return true -} diff --git a/v2/pkg/protocols/http/cluster.go b/v2/pkg/protocols/http/cluster.go index c50fb49dc..d2d37190c 100644 --- a/v2/pkg/protocols/http/cluster.go +++ b/v2/pkg/protocols/http/cluster.go @@ -1,7 +1,8 @@ package http import ( - "github.com/projectdiscovery/nuclei/v2/pkg/protocols/common/compare" + sliceutil "github.com/projectdiscovery/utils/slice" + "golang.org/x/exp/maps" ) // CanCluster returns true if the request can be clustered. @@ -19,10 +20,10 @@ func (request *Request) CanCluster(other *Request) bool { request.Redirects != other.Redirects { return false } - if !compare.StringSlice(request.Path, other.Path) { + if !sliceutil.Equal(request.Path, other.Path) { return false } - if !compare.StringMap(request.Headers, other.Headers) { + if !maps.Equal(request.Headers, other.Headers) { return false } return true diff --git a/v2/pkg/templates/cluster.go b/v2/pkg/templates/cluster.go index 1619a159c..0e80bf8d3 100644 --- a/v2/pkg/templates/cluster.go +++ b/v2/pkg/templates/cluster.go @@ -14,6 +14,7 @@ import ( "github.com/projectdiscovery/nuclei/v2/pkg/protocols/common/helpers/writer" "github.com/projectdiscovery/nuclei/v2/pkg/templates/types" cryptoutil "github.com/projectdiscovery/utils/crypto" + mapsutil "github.com/projectdiscovery/utils/maps" ) // Cluster clusters a list of templates into a lesser number if possible based @@ -40,19 +41,25 @@ import ( // If multiple requests are identified as identical, they are appended to a slice. // Finally, the engine creates a single executer with a clusteredexecuter for all templates // in a cluster. -func Cluster(list map[string]*Template) [][]*Template { +func Cluster(list []*Template) [][]*Template { final := [][]*Template{} + skip := mapsutil.NewSyncLockMap[string, struct{}]() - // Each protocol that can be clustered should be handled here. - for key, template := range list { - // We only cluster http and dns requests as of now. + for _, template := range list { + key := template.Path + + if skip.Has(key) { + continue + } + + // We only cluster http, dns and ssl requests as of now. // Take care of requests that can't be clustered first. if len(template.RequestsHTTP) == 0 && len(template.RequestsDNS) == 0 && len(template.RequestsSSL) == 0 { - delete(list, key) + _ = skip.Set(key, struct{}{}) final = append(final, []*Template{template}) continue } - delete(list, key) // delete element first so it's not found later. + _ = skip.Set(key, struct{}{}) var templateType types.ProtocolType switch { @@ -67,27 +74,33 @@ func Cluster(list map[string]*Template) [][]*Template { // Find any/all similar matching request that is identical to // this one and cluster them together for http protocol only. cluster := []*Template{} - for otherKey, other := range list { + for _, other := range list { + otherKey := other.Path + + if skip.Has(otherKey) { + continue + } + switch templateType { case types.DNSProtocol: - if len(other.RequestsDNS) == 0 || len(other.RequestsDNS) > 1 { + if len(other.RequestsDNS) != 1 { continue } else if template.RequestsDNS[0].CanCluster(other.RequestsDNS[0]) { - delete(list, otherKey) + _ = skip.Set(otherKey, struct{}{}) cluster = append(cluster, other) } case types.HTTPProtocol: - if len(other.RequestsHTTP) == 0 || len(other.RequestsHTTP) > 1 { + if len(other.RequestsHTTP) != 1 { continue } else if template.RequestsHTTP[0].CanCluster(other.RequestsHTTP[0]) { - delete(list, otherKey) + _ = skip.Set(otherKey, struct{}{}) cluster = append(cluster, other) } case types.SSLProtocol: - if len(other.RequestsSSL) == 0 || len(other.RequestsSSL) > 1 { + if len(other.RequestsSSL) != 1 { continue } else if template.RequestsSSL[0].CanCluster(other.RequestsSSL[0]) { - delete(list, otherKey) + _ = skip.Set(otherKey, struct{}{}) cluster = append(cluster, other) } } @@ -95,9 +108,9 @@ func Cluster(list map[string]*Template) [][]*Template { if len(cluster) > 0 { cluster = append(cluster, template) final = append(final, cluster) - continue + } else { + final = append(final, []*Template{template}) } - final = append(final, []*Template{template}) } return final } @@ -118,14 +131,10 @@ func ClusterTemplates(templatesList []*Template, options protocols.ExecutorOptio return templatesList, 0 } - templatesMap := make(map[string]*Template) - for _, v := range templatesList { - templatesMap[v.Path] = v - } - clusterCount := 0 + var clusterCount int finalTemplatesList := make([]*Template, 0, len(templatesList)) - clusters := Cluster(templatesMap) + clusters := Cluster(templatesList) for _, cluster := range clusters { if len(cluster) > 1 { executerOpts := options diff --git a/v2/pkg/templates/cluster_test.go b/v2/pkg/templates/cluster_test.go index 31e54a6a3..bca59e835 100644 --- a/v2/pkg/templates/cluster_test.go +++ b/v2/pkg/templates/cluster_test.go @@ -9,50 +9,30 @@ import ( ) func TestClusterTemplates(t *testing.T) { - tests := []struct { - name string - templates map[string]*Template - expected [][]*Template - }{ - { - name: "http-cluster-get", - templates: map[string]*Template{ - "first.yaml": {RequestsHTTP: []*http.Request{{Path: []string{"{{BaseURL}}"}}}}, - "second.yaml": {RequestsHTTP: []*http.Request{{Path: []string{"{{BaseURL}}"}}}}, - }, - expected: [][]*Template{{ - {RequestsHTTP: []*http.Request{{Path: []string{"{{BaseURL}}"}}}}, - {RequestsHTTP: []*http.Request{{Path: []string{"{{BaseURL}}"}}}}, - }}, - }, - { - name: "no-http-cluster", - templates: map[string]*Template{ - "first.yaml": {RequestsHTTP: []*http.Request{{Path: []string{"{{BaseURL}}/random"}}}}, - "second.yaml": {RequestsHTTP: []*http.Request{{Path: []string{"{{BaseURL}}/another"}}}}, - }, - expected: [][]*Template{ - {{RequestsHTTP: []*http.Request{{Path: []string{"{{BaseURL}}/random"}}}}}, - {{RequestsHTTP: []*http.Request{{Path: []string{"{{BaseURL}}/another"}}}}}, - }, - }, - { - name: "dns-cluster", - templates: map[string]*Template{ - "first.yaml": {RequestsDNS: []*dns.Request{{Name: "{{Hostname}}"}}}, - "second.yaml": {RequestsDNS: []*dns.Request{{Name: "{{Hostname}}"}}}, - }, - expected: [][]*Template{{ - {RequestsDNS: []*dns.Request{{Name: "{{Hostname}}"}}}, - {RequestsDNS: []*dns.Request{{Name: "{{Hostname}}"}}}, - }}, - }, - } - - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - returned := Cluster(test.templates) - require.ElementsMatch(t, returned, test.expected, "could not get cluster results") - }) - } + t.Run("http-cluster-get", func(t *testing.T) { + tp1 := &Template{Path: "first.yaml", RequestsHTTP: []*http.Request{{Path: []string{"{{BaseURL}}"}}}} + tp2 := &Template{Path: "second.yaml", RequestsHTTP: []*http.Request{{Path: []string{"{{BaseURL}}"}}}} + tpls := []*Template{tp1, tp2} + // cluster 0 + expected := []*Template{tp1, tp2} + got := Cluster(tpls)[0] + require.ElementsMatchf(t, expected, got, "different %v %v", len(expected), len(got)) + }) + t.Run("no-http-cluster", func(t *testing.T) { + tp1 := &Template{Path: "first.yaml", RequestsHTTP: []*http.Request{{Path: []string{"{{BaseURL}}/random"}}}} + tp2 := &Template{Path: "second.yaml", RequestsHTTP: []*http.Request{{Path: []string{"{{BaseURL}}/another"}}}} + tpls := []*Template{tp1, tp2} + expected := [][]*Template{{tp1}, {tp2}} + got := Cluster(tpls) + require.ElementsMatch(t, expected, got) + }) + t.Run("dns-cluster", func(t *testing.T) { + tp1 := &Template{Path: "first.yaml", RequestsDNS: []*dns.Request{{Name: "{{Hostname}}"}}} + tp2 := &Template{Path: "second.yaml", RequestsDNS: []*dns.Request{{Name: "{{Hostname}}"}}} + tpls := []*Template{tp1, tp2} + // cluster 0 + expected := []*Template{tp1, tp2} + got := Cluster(tpls)[0] + require.ElementsMatch(t, got, expected) + }) }