diff --git a/go.mod b/go.mod index 55b96ac5f..0203e57fb 100644 --- a/go.mod +++ b/go.mod @@ -81,6 +81,7 @@ require ( github.com/kitabisa/go-ci v1.0.3 github.com/labstack/echo/v4 v4.13.3 github.com/leslie-qiwa/flat v0.0.0-20230424180412-f9d1cf014baa + github.com/lib/pq v1.10.9 github.com/mattn/go-sqlite3 v1.14.22 github.com/mholt/archives v0.1.0 github.com/microsoft/go-mssqldb v1.6.0 @@ -198,7 +199,6 @@ require ( github.com/klauspost/compress v1.17.11 // indirect github.com/klauspost/pgzip v1.2.6 // indirect github.com/kylelemons/godebug v1.1.0 // indirect - github.com/lib/pq v1.10.9 // indirect github.com/logrusorgru/aurora/v4 v4.0.0 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/mackerelio/go-osstat v0.2.4 // indirect diff --git a/pkg/catalog/loader/remote_loader.go b/pkg/catalog/loader/remote_loader.go index 4b27b6b01..ccd5c27f0 100644 --- a/pkg/catalog/loader/remote_loader.go +++ b/pkg/catalog/loader/remote_loader.go @@ -5,13 +5,16 @@ import ( "fmt" "net/url" "strings" + "sync" "github.com/pkg/errors" "github.com/projectdiscovery/nuclei/v3/pkg/templates/extensions" "github.com/projectdiscovery/nuclei/v3/pkg/utils" "github.com/projectdiscovery/retryablehttp-go" + sliceutil "github.com/projectdiscovery/utils/slice" stringsutil "github.com/projectdiscovery/utils/strings" + syncutil "github.com/projectdiscovery/utils/sync" ) type ContentType string @@ -28,67 +31,73 @@ type RemoteContent struct { } func getRemoteTemplatesAndWorkflows(templateURLs, workflowURLs, remoteTemplateDomainList []string) ([]string, []string, error) { - remoteContentChannel := make(chan RemoteContent) + var ( + err error + muErr sync.Mutex + ) + remoteTemplateList := sliceutil.NewSyncSlice[string]() + remoteWorkFlowList := sliceutil.NewSyncSlice[string]() - for _, templateURL := range templateURLs { - go getRemoteContent(templateURL, remoteTemplateDomainList, remoteContentChannel, Template) - } - for _, workflowURL := range workflowURLs { - go getRemoteContent(workflowURL, remoteTemplateDomainList, remoteContentChannel, Workflow) + awg, errAwg := syncutil.New(syncutil.WithSize(50)) + if errAwg != nil { + return nil, nil, errAwg } - var remoteTemplateList []string - var remoteWorkFlowList []string - var err error - for i := 0; i < (len(templateURLs) + len(workflowURLs)); i++ { - remoteContent := <-remoteContentChannel + loadItem := func(URL string, contentType ContentType) { + defer awg.Done() + + remoteContent := getRemoteContent(URL, remoteTemplateDomainList, contentType) if remoteContent.Error != nil { + muErr.Lock() if err != nil { err = errors.New(remoteContent.Error.Error() + ": " + err.Error()) } else { err = remoteContent.Error } + muErr.Unlock() } else { switch remoteContent.Type { case Template: - remoteTemplateList = append(remoteTemplateList, remoteContent.Content...) + remoteTemplateList.Append(remoteContent.Content...) case Workflow: - remoteWorkFlowList = append(remoteWorkFlowList, remoteContent.Content...) + remoteWorkFlowList.Append(remoteContent.Content...) } } } - return remoteTemplateList, remoteWorkFlowList, err + + for _, templateURL := range templateURLs { + awg.Add() + go loadItem(templateURL, Template) + } + for _, workflowURL := range workflowURLs { + awg.Add() + go loadItem(workflowURL, Workflow) + } + + awg.Wait() + + return remoteTemplateList.Slice, remoteWorkFlowList.Slice, err } -func getRemoteContent(URL string, remoteTemplateDomainList []string, remoteContentChannel chan<- RemoteContent, contentType ContentType) { +func getRemoteContent(URL string, remoteTemplateDomainList []string, contentType ContentType) RemoteContent { if err := validateRemoteTemplateURL(URL, remoteTemplateDomainList); err != nil { - remoteContentChannel <- RemoteContent{ - Error: err, - } - return + return RemoteContent{Error: err} } if strings.HasPrefix(URL, "http") && stringsutil.HasSuffixAny(URL, extensions.YAML) { - remoteContentChannel <- RemoteContent{ + return RemoteContent{ Content: []string{URL}, Type: contentType, } - return } response, err := retryablehttp.DefaultClient().Get(URL) if err != nil { - remoteContentChannel <- RemoteContent{ - Error: err, - } - return + return RemoteContent{Error: err} } defer func() { _ = response.Body.Close() }() if response.StatusCode < 200 || response.StatusCode > 299 { - remoteContentChannel <- RemoteContent{ - Error: fmt.Errorf("get \"%s\": unexpect status %d", URL, response.StatusCode), - } - return + return RemoteContent{Error: fmt.Errorf("get \"%s\": unexpect status %d", URL, response.StatusCode)} } scanner := bufio.NewScanner(response.Body) @@ -100,23 +109,17 @@ func getRemoteContent(URL string, remoteTemplateDomainList []string, remoteConte } if utils.IsURL(text) { if err := validateRemoteTemplateURL(text, remoteTemplateDomainList); err != nil { - remoteContentChannel <- RemoteContent{ - Error: err, - } - return + return RemoteContent{Error: err} } } templateList = append(templateList, text) } if err := scanner.Err(); err != nil { - remoteContentChannel <- RemoteContent{ - Error: errors.Wrap(err, "get \"%s\""), - } - return + return RemoteContent{Error: errors.Wrap(err, "get \"%s\"")} } - remoteContentChannel <- RemoteContent{ + return RemoteContent{ Content: templateList, Type: contentType, }