From 5b89811b906f086e47fdf5357a5ee367f8f1daef Mon Sep 17 00:00:00 2001 From: HD Moore Date: Fri, 18 Jul 2025 13:40:58 -0500 Subject: [PATCH] Support concurrent Nuclei engines in the same process (#6322) * support for concurrent nuclei engines * clarify LfaAllowed race * remove unused mutex * update LfaAllowed logic to prevent races until it can be reworked for per-execution ID * Update pkg/templates/parser.go Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> * debug tests * debug gh action * fixig gh template test * using atomic * using synclockmap * restore tests concurrency * lint * wiring executionId in js fs --------- Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Co-authored-by: Mzack9999 --- lib/config.go | 2 +- lib/sdk_private.go | 19 ++++ pkg/external/customtemplates/github_test.go | 21 ++-- pkg/installer/template.go | 2 +- pkg/js/libs/fs/fs.go | 21 ++-- pkg/protocols/common/protocolstate/file.go | 51 ++++++++- .../common/protocolstate/headless.go | 12 --- pkg/protocols/common/protocolstate/state.go | 4 +- pkg/protocols/dns/dnsclientpool/clientpool.go | 29 ++--- .../http/httpclientpool/clientpool.go | 12 +-- pkg/protocols/protocols.go | 61 ++++++++--- .../whois/rdapclientpool/clientpool.go | 10 +- pkg/templates/compile.go | 101 ++++++++++++------ pkg/templates/parser.go | 98 +++-------------- pkg/utils/capture_writer.go | 16 +++ 15 files changed, 279 insertions(+), 180 deletions(-) create mode 100644 pkg/utils/capture_writer.go diff --git a/lib/config.go b/lib/config.go index 5e96352b5..125442898 100644 --- a/lib/config.go +++ b/lib/config.go @@ -537,7 +537,7 @@ func WithResumeFile(file string) NucleiSDKOptions { } } -// WithLogger allows setting gologger instance +// WithLogger allows setting a shared gologger instance func WithLogger(logger *gologger.Logger) NucleiSDKOptions { return func(e *NucleiEngine) error { e.Logger = logger diff --git a/lib/sdk_private.go b/lib/sdk_private.go index 659187b20..d80a0fd06 100644 --- a/lib/sdk_private.go +++ b/lib/sdk_private.go @@ -231,6 +231,25 @@ func (e *NucleiEngine) init(ctx context.Context) error { } } + // Handle the case where the user passed an existing parser that we can use as a cache + if e.opts.Parser != nil { + if cachedParser, ok := e.opts.Parser.(*templates.Parser); ok { + e.parser = cachedParser + e.opts.Parser = cachedParser + e.executerOpts.Parser = cachedParser + e.executerOpts.Options.Parser = cachedParser + } + } + + // Create a new parser if necessary + if e.parser == nil { + op := templates.NewParser() + e.parser = op + e.opts.Parser = op + e.executerOpts.Parser = op + e.executerOpts.Options.Parser = op + } + e.engine = core.New(e.opts) e.engine.SetExecuterOptions(e.executerOpts) diff --git a/pkg/external/customtemplates/github_test.go b/pkg/external/customtemplates/github_test.go index 972706af1..4e429c020 100644 --- a/pkg/external/customtemplates/github_test.go +++ b/pkg/external/customtemplates/github_test.go @@ -1,23 +1,25 @@ package customtemplates import ( + "bytes" "context" "path/filepath" + "strings" "testing" "github.com/projectdiscovery/gologger" + "github.com/projectdiscovery/gologger/levels" "github.com/projectdiscovery/nuclei/v3/pkg/catalog/config" "github.com/projectdiscovery/nuclei/v3/pkg/testutils" - osutils "github.com/projectdiscovery/utils/os" + "github.com/projectdiscovery/nuclei/v3/pkg/utils" "github.com/stretchr/testify/require" ) func TestDownloadCustomTemplatesFromGitHub(t *testing.T) { - if osutils.IsOSX() { - t.Skip("skipping on macos due to unknown failure (works locally)") - } - - gologger.DefaultLogger.SetWriter(&testutils.NoopWriter{}) + // Capture output to check for rate limit errors + outputBuffer := &bytes.Buffer{} + gologger.DefaultLogger.SetWriter(&utils.CaptureWriter{Buffer: outputBuffer}) + gologger.DefaultLogger.SetMaxLevel(levels.LevelDebug) templatesDirectory := t.TempDir() config.DefaultConfig.SetTemplatesDir(templatesDirectory) @@ -29,5 +31,12 @@ func TestDownloadCustomTemplatesFromGitHub(t *testing.T) { require.Nil(t, err, "could not create custom templates manager") ctm.Download(context.Background()) + + // Check if output contains rate limit error and skip test if so + output := outputBuffer.String() + if strings.Contains(output, "API rate limit exceeded") { + t.Skip("GitHub API rate limit exceeded, skipping test") + } + require.DirExists(t, filepath.Join(templatesDirectory, "github", "projectdiscovery", "nuclei-templates-test"), "cloned directory does not exists") } diff --git a/pkg/installer/template.go b/pkg/installer/template.go index 4ee784477..9e56f12a1 100644 --- a/pkg/installer/template.go +++ b/pkg/installer/template.go @@ -53,7 +53,7 @@ func (t *templateUpdateResults) String() string { }, } table := tablewriter.NewWriter(&buff) - table.Header("Total", "Added", "Modified", "Removed") + table.Header([]string{"Total", "Added", "Modified", "Removed"}) for _, v := range data { _ = table.Append(v) } diff --git a/pkg/js/libs/fs/fs.go b/pkg/js/libs/fs/fs.go index e3a3fd7bd..a5f77e875 100644 --- a/pkg/js/libs/fs/fs.go +++ b/pkg/js/libs/fs/fs.go @@ -1,6 +1,7 @@ package fs import ( + "context" "os" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate" @@ -27,8 +28,9 @@ import ( // // when no itemType is provided, it will return both files and directories // const items = fs.ListDir('/tmp'); // ``` -func ListDir(path string, itemType string) ([]string, error) { - finalPath, err := protocolstate.NormalizePath(path) +func ListDir(ctx context.Context, path string, itemType string) ([]string, error) { + executionId := ctx.Value("executionId").(string) + finalPath, err := protocolstate.NormalizePathWithExecutionId(executionId, path) if err != nil { return nil, err } @@ -57,8 +59,9 @@ func ListDir(path string, itemType string) ([]string, error) { // // here permitted directories are $HOME/nuclei-templates/* // const content = fs.ReadFile('helpers/usernames.txt'); // ``` -func ReadFile(path string) ([]byte, error) { - finalPath, err := protocolstate.NormalizePath(path) +func ReadFile(ctx context.Context, path string) ([]byte, error) { + executionId := ctx.Value("executionId").(string) + finalPath, err := protocolstate.NormalizePathWithExecutionId(executionId, path) if err != nil { return nil, err } @@ -74,8 +77,8 @@ func ReadFile(path string) ([]byte, error) { // // here permitted directories are $HOME/nuclei-templates/* // const content = fs.ReadFileAsString('helpers/usernames.txt'); // ``` -func ReadFileAsString(path string) (string, error) { - bin, err := ReadFile(path) +func ReadFileAsString(ctx context.Context, path string) (string, error) { + bin, err := ReadFile(ctx, path) if err != nil { return "", err } @@ -91,14 +94,14 @@ func ReadFileAsString(path string) (string, error) { // const contents = fs.ReadFilesFromDir('helpers/ssh-keys'); // log(contents); // ``` -func ReadFilesFromDir(dir string) ([]string, error) { - files, err := ListDir(dir, "file") +func ReadFilesFromDir(ctx context.Context, dir string) ([]string, error) { + files, err := ListDir(ctx, dir, "file") if err != nil { return nil, err } var results []string for _, file := range files { - content, err := ReadFileAsString(dir + "/" + file) + content, err := ReadFileAsString(ctx, dir+"/"+file) if err != nil { return nil, err } diff --git a/pkg/protocols/common/protocolstate/file.go b/pkg/protocols/common/protocolstate/file.go index 9475aac0f..180d5a0b5 100644 --- a/pkg/protocols/common/protocolstate/file.go +++ b/pkg/protocols/common/protocolstate/file.go @@ -4,22 +4,65 @@ import ( "strings" "github.com/projectdiscovery/nuclei/v3/pkg/catalog/config" + "github.com/projectdiscovery/nuclei/v3/pkg/types" errorutil "github.com/projectdiscovery/utils/errors" fileutil "github.com/projectdiscovery/utils/file" + mapsutil "github.com/projectdiscovery/utils/maps" ) var ( // LfaAllowed means local file access is allowed - LfaAllowed bool + LfaAllowed *mapsutil.SyncLockMap[string, bool] ) +func init() { + LfaAllowed = mapsutil.NewSyncLockMap[string, bool]() +} + +// IsLfaAllowed returns whether local file access is allowed +func IsLfaAllowed(options *types.Options) bool { + if GetLfaAllowed(options) { + return true + } + + // Otherwise look into dialers + dialers, ok := dialers.Get(options.ExecutionId) + if ok && dialers != nil { + dialers.Lock() + defer dialers.Unlock() + + return dialers.LocalFileAccessAllowed + } + + // otherwise just return option value + return options.AllowLocalFileAccess +} + +func SetLfaAllowed(options *types.Options) { + _ = LfaAllowed.Set(options.ExecutionId, options.AllowLocalFileAccess) +} + +func GetLfaAllowed(options *types.Options) bool { + allowed, ok := LfaAllowed.Get(options.ExecutionId) + + return ok && allowed +} + +func NormalizePathWithExecutionId(executionId string, filePath string) (string, error) { + options := &types.Options{ + ExecutionId: executionId, + } + return NormalizePath(options, filePath) +} + // Normalizepath normalizes path and returns absolute path // it returns error if path is not allowed // this respects the sandbox rules and only loads files from // allowed directories -func NormalizePath(filePath string) (string, error) { - // TODO: this should be tied to executionID - if LfaAllowed { +func NormalizePath(options *types.Options, filePath string) (string, error) { + // TODO: this should be tied to executionID using *types.Options + if IsLfaAllowed(options) { + // if local file access is allowed, we can return the absolute path return filePath, nil } cleaned, err := fileutil.ResolveNClean(filePath, config.DefaultConfig.GetTemplateDir()) diff --git a/pkg/protocols/common/protocolstate/headless.go b/pkg/protocols/common/protocolstate/headless.go index 4012e2da6..1d9970119 100644 --- a/pkg/protocols/common/protocolstate/headless.go +++ b/pkg/protocols/common/protocolstate/headless.go @@ -74,18 +74,6 @@ func InitHeadless(options *types.Options) { } } -// AllowLocalFileAccess returns whether local file access is allowed -func IsLfaAllowed(options *types.Options) bool { - dialers, ok := dialers.Get(options.ExecutionId) - if ok && dialers != nil { - dialers.Lock() - defer dialers.Unlock() - - return dialers.LocalFileAccessAllowed - } - return false -} - func IsRestrictLocalNetworkAccess(options *types.Options) bool { dialers, ok := dialers.Get(options.ExecutionId) if ok && dialers != nil { diff --git a/pkg/protocols/common/protocolstate/state.go b/pkg/protocols/common/protocolstate/state.go index 9f9a96a06..f72122c19 100644 --- a/pkg/protocols/common/protocolstate/state.go +++ b/pkg/protocols/common/protocolstate/state.go @@ -200,9 +200,7 @@ func initDialers(options *types.Options) error { StartActiveMemGuardian(context.Background()) - // TODO: this should be tied to executionID - // overidde global settings with latest options - LfaAllowed = options.AllowLocalFileAccess + SetLfaAllowed(options) return nil } diff --git a/pkg/protocols/dns/dnsclientpool/clientpool.go b/pkg/protocols/dns/dnsclientpool/clientpool.go index c1805be1c..ccbc1bc5d 100644 --- a/pkg/protocols/dns/dnsclientpool/clientpool.go +++ b/pkg/protocols/dns/dnsclientpool/clientpool.go @@ -8,12 +8,14 @@ import ( "github.com/pkg/errors" "github.com/projectdiscovery/nuclei/v3/pkg/types" "github.com/projectdiscovery/retryabledns" + mapsutil "github.com/projectdiscovery/utils/maps" ) var ( - poolMutex *sync.RWMutex + clientPool *mapsutil.SyncLockMap[string, *retryabledns.Client] + normalClient *retryabledns.Client - clientPool map[string]*retryabledns.Client + m sync.Mutex ) // defaultResolvers contains the list of resolvers known to be trusted. @@ -26,12 +28,14 @@ var defaultResolvers = []string{ // Init initializes the client pool implementation func Init(options *types.Options) error { + m.Lock() + defer m.Unlock() + // Don't create clients if already created in the past. if normalClient != nil { return nil } - poolMutex = &sync.RWMutex{} - clientPool = make(map[string]*retryabledns.Client) + clientPool = mapsutil.NewSyncLockMap[string, *retryabledns.Client]() resolvers := defaultResolvers if len(options.InternalResolversList) > 0 { @@ -45,6 +49,12 @@ func Init(options *types.Options) error { return nil } +func getNormalClient() *retryabledns.Client { + m.Lock() + defer m.Unlock() + return normalClient +} + // Configuration contains the custom configuration options for a client type Configuration struct { // Retries contains the retries for the dns client @@ -71,15 +81,12 @@ func (c *Configuration) Hash() string { // Get creates or gets a client for the protocol based on custom configuration func Get(options *types.Options, configuration *Configuration) (*retryabledns.Client, error) { if (configuration.Retries <= 1) && len(configuration.Resolvers) == 0 { - return normalClient, nil + return getNormalClient(), nil } hash := configuration.Hash() - poolMutex.RLock() - if client, ok := clientPool[hash]; ok { - poolMutex.RUnlock() + if client, ok := clientPool.Get(hash); ok { return client, nil } - poolMutex.RUnlock() resolvers := defaultResolvers if len(options.InternalResolversList) > 0 { @@ -95,9 +102,7 @@ func Get(options *types.Options, configuration *Configuration) (*retryabledns.Cl if err != nil { return nil, errors.Wrap(err, "could not create dns client") } + _ = clientPool.Set(hash, client) - poolMutex.Lock() - clientPool[hash] = client - poolMutex.Unlock() return client, nil } diff --git a/pkg/protocols/http/httpclientpool/clientpool.go b/pkg/protocols/http/httpclientpool/clientpool.go index 940ac3886..4fa0790a5 100644 --- a/pkg/protocols/http/httpclientpool/clientpool.go +++ b/pkg/protocols/http/httpclientpool/clientpool.go @@ -154,16 +154,16 @@ func GetRawHTTP(options *protocols.ExecutorOptions) *rawhttp.Client { return dialers.RawHTTPClient } - rawHttpOptions := rawhttp.DefaultOptions + rawHttpOptionsCopy := *rawhttp.DefaultOptions if options.Options.AliveHttpProxy != "" { - rawHttpOptions.Proxy = options.Options.AliveHttpProxy + rawHttpOptionsCopy.Proxy = options.Options.AliveHttpProxy } else if options.Options.AliveSocksProxy != "" { - rawHttpOptions.Proxy = options.Options.AliveSocksProxy + rawHttpOptionsCopy.Proxy = options.Options.AliveSocksProxy } else if dialers.Fastdialer != nil { - rawHttpOptions.FastDialer = dialers.Fastdialer + rawHttpOptionsCopy.FastDialer = dialers.Fastdialer } - rawHttpOptions.Timeout = options.Options.GetTimeouts().HttpTimeout - dialers.RawHTTPClient = rawhttp.NewClient(rawHttpOptions) + rawHttpOptionsCopy.Timeout = options.Options.GetTimeouts().HttpTimeout + dialers.RawHTTPClient = rawhttp.NewClient(&rawHttpOptionsCopy) return dialers.RawHTTPClient } diff --git a/pkg/protocols/protocols.go b/pkg/protocols/protocols.go index 30443eee6..197d79e0a 100644 --- a/pkg/protocols/protocols.go +++ b/pkg/protocols/protocols.go @@ -3,7 +3,6 @@ package protocols import ( "context" "encoding/base64" - "sync" "sync/atomic" "github.com/projectdiscovery/fastdialer/fastdialer" @@ -139,8 +138,6 @@ type ExecutorOptions struct { Logger *gologger.Logger // CustomFastdialer is a fastdialer dialer instance CustomFastdialer *fastdialer.Dialer - - m sync.Mutex } // todo: centralizing components is not feasible with current clogged architecture @@ -198,6 +195,11 @@ func (e *ExecutorOptions) HasTemplateCtx(input *contextargs.MetaInput) bool { // GetTemplateCtx returns template context for given input func (e *ExecutorOptions) GetTemplateCtx(input *contextargs.MetaInput) *contextargs.Context { scanId := input.GetScanHash(e.TemplateID) + if e.templateCtxStore == nil { + // if template context store is not initialized create it + e.CreateTemplateCtxStore() + } + // get template context from store templateCtx, ok := e.templateCtxStore.Get(scanId) if !ok { // if template context does not exist create new and add it to store and return it @@ -444,14 +446,49 @@ func (e *ExecutorOptions) ApplyNewEngineOptions(n *ExecutorOptions) { if e == nil || n == nil || n.Options == nil { return } - execID := n.Options.GetExecutionID() - e.SetExecutionID(execID) -} -// ApplyNewEngineOptions updates an existing ExecutorOptions with options from a new engine. This -// handles things like the ExecutionID that need to be updated. -func (e *ExecutorOptions) SetExecutionID(executorId string) { - e.m.Lock() - defer e.m.Unlock() - e.Options.SetExecutionID(executorId) + // The types.Options include the ExecutionID among other things + e.Options = n.Options.Copy() + + // Keep the template-specific fields, but replace the rest + /* + e.TemplateID = n.TemplateID + e.TemplatePath = n.TemplatePath + e.TemplateInfo = n.TemplateInfo + e.TemplateVerifier = n.TemplateVerifier + e.RawTemplate = n.RawTemplate + e.Variables = n.Variables + e.Constants = n.Constants + */ + e.Output = n.Output + e.Options = n.Options + e.IssuesClient = n.IssuesClient + e.Progress = n.Progress + e.RateLimiter = n.RateLimiter + e.Catalog = n.Catalog + e.ProjectFile = n.ProjectFile + e.Browser = n.Browser + e.Interactsh = n.Interactsh + e.HostErrorsCache = n.HostErrorsCache + e.StopAtFirstMatch = n.StopAtFirstMatch + e.ExcludeMatchers = n.ExcludeMatchers + e.InputHelper = n.InputHelper + e.FuzzParamsFrequency = n.FuzzParamsFrequency + e.FuzzStatsDB = n.FuzzStatsDB + e.DoNotCache = n.DoNotCache + e.Colorizer = n.Colorizer + e.WorkflowLoader = n.WorkflowLoader + e.ResumeCfg = n.ResumeCfg + e.ProtocolType = n.ProtocolType + e.Flow = n.Flow + e.IsMultiProtocol = n.IsMultiProtocol + e.templateCtxStore = n.templateCtxStore + e.JsCompiler = n.JsCompiler + e.AuthProvider = n.AuthProvider + e.TemporaryDirectory = n.TemporaryDirectory + e.Parser = n.Parser + e.ExportReqURLPattern = n.ExportReqURLPattern + e.GlobalMatchers = n.GlobalMatchers + e.Logger = n.Logger + e.CustomFastdialer = n.CustomFastdialer } diff --git a/pkg/protocols/whois/rdapclientpool/clientpool.go b/pkg/protocols/whois/rdapclientpool/clientpool.go index 81da1c578..f2d4f2316 100644 --- a/pkg/protocols/whois/rdapclientpool/clientpool.go +++ b/pkg/protocols/whois/rdapclientpool/clientpool.go @@ -30,6 +30,12 @@ func Init(options *types.Options) error { return nil } +func getNormalClient() *rdap.Client { + m.Lock() + defer m.Unlock() + return normalClient +} + // Configuration contains the custom configuration options for a client - placeholder type Configuration struct{} @@ -40,7 +46,5 @@ func (c *Configuration) Hash() string { // Get creates or gets a client for the protocol based on custom configuration func Get(options *types.Options, configuration *Configuration) (*rdap.Client, error) { - m.Lock() - defer m.Unlock() - return normalClient, nil + return getNormalClient(), nil } diff --git a/pkg/templates/compile.go b/pkg/templates/compile.go index a0d99a768..fdb612a96 100644 --- a/pkg/templates/compile.go +++ b/pkg/templates/compile.go @@ -56,49 +56,90 @@ func Parse(filePath string, preprocessor Preprocessor, options *protocols.Execut } if !options.DoNotCache { if value, _, _ := parser.compiledTemplatesCache.Has(filePath); value != nil { - // Update the template to use the current options for the calling engine - // TODO: This may be require additional work for robustness - t := *value - t.Options.ApplyNewEngineOptions(options) - if t.CompiledWorkflow != nil { - t.CompiledWorkflow.Options.ApplyNewEngineOptions(options) - for _, w := range t.CompiledWorkflow.Workflows { + // Copy the template, apply new options, and recompile requests + tplCopy := *value + newBase := options.Copy() + newBase.TemplateID = tplCopy.Options.TemplateID + newBase.TemplatePath = tplCopy.Options.TemplatePath + newBase.TemplateInfo = tplCopy.Options.TemplateInfo + newBase.TemplateVerifier = tplCopy.Options.TemplateVerifier + newBase.RawTemplate = tplCopy.Options.RawTemplate + tplCopy.Options = newBase + + tplCopy.Options.ApplyNewEngineOptions(options) + if tplCopy.CompiledWorkflow != nil { + tplCopy.CompiledWorkflow.Options.ApplyNewEngineOptions(options) + for _, w := range tplCopy.CompiledWorkflow.Workflows { for _, ex := range w.Executers { ex.Options.ApplyNewEngineOptions(options) } } } - for _, r := range t.RequestsDNS { - r.UpdateOptions(t.Options) + + // TODO: Reconsider whether to recompile requests. Compiling these is just as slow + // as not using a cache at all, but may be necessary. + + for i, r := range tplCopy.RequestsDNS { + rCopy := *r + rCopy.UpdateOptions(tplCopy.Options) + // rCopy.Compile(tplCopy.Options) + tplCopy.RequestsDNS[i] = &rCopy } - for _, r := range t.RequestsHTTP { - r.UpdateOptions(t.Options) + for i, r := range tplCopy.RequestsHTTP { + rCopy := *r + rCopy.UpdateOptions(tplCopy.Options) + // rCopy.Compile(tplCopy.Options) + tplCopy.RequestsHTTP[i] = &rCopy } - for _, r := range t.RequestsCode { - r.UpdateOptions(t.Options) + for i, r := range tplCopy.RequestsCode { + rCopy := *r + rCopy.UpdateOptions(tplCopy.Options) + // rCopy.Compile(tplCopy.Options) + tplCopy.RequestsCode[i] = &rCopy } - for _, r := range t.RequestsFile { - r.UpdateOptions(t.Options) + for i, r := range tplCopy.RequestsFile { + rCopy := *r + rCopy.UpdateOptions(tplCopy.Options) + // rCopy.Compile(tplCopy.Options) + tplCopy.RequestsFile[i] = &rCopy } - for _, r := range t.RequestsHeadless { - r.UpdateOptions(t.Options) + for i, r := range tplCopy.RequestsHeadless { + rCopy := *r + rCopy.UpdateOptions(tplCopy.Options) + // rCopy.Compile(tplCopy.Options) + tplCopy.RequestsHeadless[i] = &rCopy } - for _, r := range t.RequestsNetwork { - r.UpdateOptions(t.Options) + for i, r := range tplCopy.RequestsNetwork { + rCopy := *r + rCopy.UpdateOptions(tplCopy.Options) + // rCopy.Compile(tplCopy.Options) + tplCopy.RequestsNetwork[i] = &rCopy } - for _, r := range t.RequestsJavascript { - r.UpdateOptions(t.Options) + for i, r := range tplCopy.RequestsJavascript { + rCopy := *r + rCopy.UpdateOptions(tplCopy.Options) + //rCopy.Compile(tplCopy.Options) + tplCopy.RequestsJavascript[i] = &rCopy } - for _, r := range t.RequestsSSL { - r.UpdateOptions(t.Options) + for i, r := range tplCopy.RequestsSSL { + rCopy := *r + rCopy.UpdateOptions(tplCopy.Options) + // rCopy.Compile(tplCopy.Options) + tplCopy.RequestsSSL[i] = &rCopy } - for _, r := range t.RequestsWHOIS { - r.UpdateOptions(t.Options) + for i, r := range tplCopy.RequestsWHOIS { + rCopy := *r + rCopy.UpdateOptions(tplCopy.Options) + // rCopy.Compile(tplCopy.Options) + tplCopy.RequestsWHOIS[i] = &rCopy } - for _, r := range t.RequestsWebsocket { - r.UpdateOptions(t.Options) + for i, r := range tplCopy.RequestsWebsocket { + rCopy := *r + rCopy.UpdateOptions(tplCopy.Options) + // rCopy.Compile(tplCopy.Options) + tplCopy.RequestsWebsocket[i] = &rCopy } - template := t + template := &tplCopy if template.isGlobalMatchersEnabled() { item := &globalmatchers.Item{ @@ -119,8 +160,8 @@ func Parse(filePath string, preprocessor Preprocessor, options *protocols.Execut template.CompiledWorkflow = compiled template.CompiledWorkflow.Options = options } - - return &template, nil + // options.Logger.Error().Msgf("returning cached template %s after recompiling %d requests", tplCopy.Options.TemplateID, tplCopy.Requests()) + return template, nil } } diff --git a/pkg/templates/parser.go b/pkg/templates/parser.go index b99529916..56b64c237 100644 --- a/pkg/templates/parser.go +++ b/pkg/templates/parser.go @@ -49,6 +49,23 @@ func (p *Parser) Cache() *Cache { return p.parsedTemplatesCache } +// CompiledCache returns the compiled templates cache +func (p *Parser) CompiledCache() *Cache { + return p.compiledTemplatesCache +} + +func (p *Parser) ParsedCount() int { + p.Lock() + defer p.Unlock() + return len(p.parsedTemplatesCache.items.Map) +} + +func (p *Parser) CompiledCount() int { + p.Lock() + defer p.Unlock() + return len(p.compiledTemplatesCache.items.Map) +} + func checkOpenFileError(err error) bool { if err != nil && strings.Contains(err.Error(), "too many open files") { panic(err) @@ -171,84 +188,3 @@ func (p *Parser) LoadWorkflow(templatePath string, catalog catalog.Catalog) (boo return false, nil } - -// CloneForExecutionId creates a clone with updated execution IDs -func (p *Parser) CloneForExecutionId(xid string) *Parser { - p.Lock() - defer p.Unlock() - - newParser := &Parser{ - ShouldValidate: p.ShouldValidate, - NoStrictSyntax: p.NoStrictSyntax, - parsedTemplatesCache: NewCache(), - compiledTemplatesCache: NewCache(), - } - - for k, tpl := range p.parsedTemplatesCache.items.Map { - newTemplate := templateUpdateExecutionId(tpl.template, xid) - newParser.parsedTemplatesCache.Store(k, newTemplate, []byte(tpl.raw), tpl.err) - } - - for k, tpl := range p.compiledTemplatesCache.items.Map { - newTemplate := templateUpdateExecutionId(tpl.template, xid) - newParser.compiledTemplatesCache.Store(k, newTemplate, []byte(tpl.raw), tpl.err) - } - - return newParser -} - -func templateUpdateExecutionId(tpl *Template, xid string) *Template { - // TODO: This is a no-op today since options are patched in elsewhere, but we're keeping this - // for future work where we may need additional tweaks per template instance. - return tpl - - /* - templateBase := *tpl - var newOpts *protocols.ExecutorOptions - // Swap out the types.Options execution ID attached to the template - if templateBase.Options != nil { - optionsBase := *templateBase.Options //nolint - templateBase.Options = &optionsBase - if templateBase.Options.Options != nil { - optionsOptionsBase := *templateBase.Options.Options //nolint - templateBase.Options.Options = &optionsOptionsBase - templateBase.Options.Options.ExecutionId = xid - newOpts = templateBase.Options - } - } - if newOpts == nil { - return &templateBase - } - for _, r := range templateBase.RequestsDNS { - r.UpdateOptions(newOpts) - } - for _, r := range templateBase.RequestsHTTP { - r.UpdateOptions(newOpts) - } - for _, r := range templateBase.RequestsCode { - r.UpdateOptions(newOpts) - } - for _, r := range templateBase.RequestsFile { - r.UpdateOptions(newOpts) - } - for _, r := range templateBase.RequestsHeadless { - r.UpdateOptions(newOpts) - } - for _, r := range templateBase.RequestsNetwork { - r.UpdateOptions(newOpts) - } - for _, r := range templateBase.RequestsJavascript { - r.UpdateOptions(newOpts) - } - for _, r := range templateBase.RequestsSSL { - r.UpdateOptions(newOpts) - } - for _, r := range templateBase.RequestsWHOIS { - r.UpdateOptions(newOpts) - } - for _, r := range templateBase.RequestsWebsocket { - r.UpdateOptions(newOpts) - } - return &templateBase - */ -} diff --git a/pkg/utils/capture_writer.go b/pkg/utils/capture_writer.go new file mode 100644 index 000000000..29986a5aa --- /dev/null +++ b/pkg/utils/capture_writer.go @@ -0,0 +1,16 @@ +package utils + +import ( + "bytes" + + "github.com/projectdiscovery/gologger/levels" +) + +// CaptureWriter captures log output for testing +type CaptureWriter struct { + Buffer *bytes.Buffer +} + +func (w *CaptureWriter) Write(data []byte, level levels.Level) { + w.Buffer.Write(data) +}