Support concurrent Nuclei engines in the same process (#6322)

* support for concurrent nuclei engines

* clarify LfaAllowed race

* remove unused mutex

* update LfaAllowed logic to prevent races until it can be reworked for per-execution ID

* Update pkg/templates/parser.go

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* debug tests

* debug gh action

* fixig gh template test

* using atomic

* using synclockmap

* restore tests concurrency

* lint

* wiring executionId in js fs

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: Mzack9999 <mzack9999@protonmail.com>
This commit is contained in:
HD Moore 2025-07-18 13:40:58 -05:00 committed by GitHub
parent 3e9bee7400
commit 5b89811b90
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 279 additions and 180 deletions

View File

@ -537,7 +537,7 @@ func WithResumeFile(file string) NucleiSDKOptions {
} }
} }
// WithLogger allows setting gologger instance // WithLogger allows setting a shared gologger instance
func WithLogger(logger *gologger.Logger) NucleiSDKOptions { func WithLogger(logger *gologger.Logger) NucleiSDKOptions {
return func(e *NucleiEngine) error { return func(e *NucleiEngine) error {
e.Logger = logger e.Logger = logger

View File

@ -231,6 +231,25 @@ func (e *NucleiEngine) init(ctx context.Context) error {
} }
} }
// Handle the case where the user passed an existing parser that we can use as a cache
if e.opts.Parser != nil {
if cachedParser, ok := e.opts.Parser.(*templates.Parser); ok {
e.parser = cachedParser
e.opts.Parser = cachedParser
e.executerOpts.Parser = cachedParser
e.executerOpts.Options.Parser = cachedParser
}
}
// Create a new parser if necessary
if e.parser == nil {
op := templates.NewParser()
e.parser = op
e.opts.Parser = op
e.executerOpts.Parser = op
e.executerOpts.Options.Parser = op
}
e.engine = core.New(e.opts) e.engine = core.New(e.opts)
e.engine.SetExecuterOptions(e.executerOpts) e.engine.SetExecuterOptions(e.executerOpts)

View File

@ -1,23 +1,25 @@
package customtemplates package customtemplates
import ( import (
"bytes"
"context" "context"
"path/filepath" "path/filepath"
"strings"
"testing" "testing"
"github.com/projectdiscovery/gologger" "github.com/projectdiscovery/gologger"
"github.com/projectdiscovery/gologger/levels"
"github.com/projectdiscovery/nuclei/v3/pkg/catalog/config" "github.com/projectdiscovery/nuclei/v3/pkg/catalog/config"
"github.com/projectdiscovery/nuclei/v3/pkg/testutils" "github.com/projectdiscovery/nuclei/v3/pkg/testutils"
osutils "github.com/projectdiscovery/utils/os" "github.com/projectdiscovery/nuclei/v3/pkg/utils"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func TestDownloadCustomTemplatesFromGitHub(t *testing.T) { func TestDownloadCustomTemplatesFromGitHub(t *testing.T) {
if osutils.IsOSX() { // Capture output to check for rate limit errors
t.Skip("skipping on macos due to unknown failure (works locally)") outputBuffer := &bytes.Buffer{}
} gologger.DefaultLogger.SetWriter(&utils.CaptureWriter{Buffer: outputBuffer})
gologger.DefaultLogger.SetMaxLevel(levels.LevelDebug)
gologger.DefaultLogger.SetWriter(&testutils.NoopWriter{})
templatesDirectory := t.TempDir() templatesDirectory := t.TempDir()
config.DefaultConfig.SetTemplatesDir(templatesDirectory) config.DefaultConfig.SetTemplatesDir(templatesDirectory)
@ -29,5 +31,12 @@ func TestDownloadCustomTemplatesFromGitHub(t *testing.T) {
require.Nil(t, err, "could not create custom templates manager") require.Nil(t, err, "could not create custom templates manager")
ctm.Download(context.Background()) ctm.Download(context.Background())
// Check if output contains rate limit error and skip test if so
output := outputBuffer.String()
if strings.Contains(output, "API rate limit exceeded") {
t.Skip("GitHub API rate limit exceeded, skipping test")
}
require.DirExists(t, filepath.Join(templatesDirectory, "github", "projectdiscovery", "nuclei-templates-test"), "cloned directory does not exists") require.DirExists(t, filepath.Join(templatesDirectory, "github", "projectdiscovery", "nuclei-templates-test"), "cloned directory does not exists")
} }

View File

@ -53,7 +53,7 @@ func (t *templateUpdateResults) String() string {
}, },
} }
table := tablewriter.NewWriter(&buff) table := tablewriter.NewWriter(&buff)
table.Header("Total", "Added", "Modified", "Removed") table.Header([]string{"Total", "Added", "Modified", "Removed"})
for _, v := range data { for _, v := range data {
_ = table.Append(v) _ = table.Append(v)
} }

View File

@ -1,6 +1,7 @@
package fs package fs
import ( import (
"context"
"os" "os"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate"
@ -27,8 +28,9 @@ import (
// // when no itemType is provided, it will return both files and directories // // when no itemType is provided, it will return both files and directories
// const items = fs.ListDir('/tmp'); // const items = fs.ListDir('/tmp');
// ``` // ```
func ListDir(path string, itemType string) ([]string, error) { func ListDir(ctx context.Context, path string, itemType string) ([]string, error) {
finalPath, err := protocolstate.NormalizePath(path) executionId := ctx.Value("executionId").(string)
finalPath, err := protocolstate.NormalizePathWithExecutionId(executionId, path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -57,8 +59,9 @@ func ListDir(path string, itemType string) ([]string, error) {
// // here permitted directories are $HOME/nuclei-templates/* // // here permitted directories are $HOME/nuclei-templates/*
// const content = fs.ReadFile('helpers/usernames.txt'); // const content = fs.ReadFile('helpers/usernames.txt');
// ``` // ```
func ReadFile(path string) ([]byte, error) { func ReadFile(ctx context.Context, path string) ([]byte, error) {
finalPath, err := protocolstate.NormalizePath(path) executionId := ctx.Value("executionId").(string)
finalPath, err := protocolstate.NormalizePathWithExecutionId(executionId, path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -74,8 +77,8 @@ func ReadFile(path string) ([]byte, error) {
// // here permitted directories are $HOME/nuclei-templates/* // // here permitted directories are $HOME/nuclei-templates/*
// const content = fs.ReadFileAsString('helpers/usernames.txt'); // const content = fs.ReadFileAsString('helpers/usernames.txt');
// ``` // ```
func ReadFileAsString(path string) (string, error) { func ReadFileAsString(ctx context.Context, path string) (string, error) {
bin, err := ReadFile(path) bin, err := ReadFile(ctx, path)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -91,14 +94,14 @@ func ReadFileAsString(path string) (string, error) {
// const contents = fs.ReadFilesFromDir('helpers/ssh-keys'); // const contents = fs.ReadFilesFromDir('helpers/ssh-keys');
// log(contents); // log(contents);
// ``` // ```
func ReadFilesFromDir(dir string) ([]string, error) { func ReadFilesFromDir(ctx context.Context, dir string) ([]string, error) {
files, err := ListDir(dir, "file") files, err := ListDir(ctx, dir, "file")
if err != nil { if err != nil {
return nil, err return nil, err
} }
var results []string var results []string
for _, file := range files { for _, file := range files {
content, err := ReadFileAsString(dir + "/" + file) content, err := ReadFileAsString(ctx, dir+"/"+file)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -4,22 +4,65 @@ import (
"strings" "strings"
"github.com/projectdiscovery/nuclei/v3/pkg/catalog/config" "github.com/projectdiscovery/nuclei/v3/pkg/catalog/config"
"github.com/projectdiscovery/nuclei/v3/pkg/types"
errorutil "github.com/projectdiscovery/utils/errors" errorutil "github.com/projectdiscovery/utils/errors"
fileutil "github.com/projectdiscovery/utils/file" fileutil "github.com/projectdiscovery/utils/file"
mapsutil "github.com/projectdiscovery/utils/maps"
) )
var ( var (
// LfaAllowed means local file access is allowed // LfaAllowed means local file access is allowed
LfaAllowed bool LfaAllowed *mapsutil.SyncLockMap[string, bool]
) )
func init() {
LfaAllowed = mapsutil.NewSyncLockMap[string, bool]()
}
// IsLfaAllowed returns whether local file access is allowed
func IsLfaAllowed(options *types.Options) bool {
if GetLfaAllowed(options) {
return true
}
// Otherwise look into dialers
dialers, ok := dialers.Get(options.ExecutionId)
if ok && dialers != nil {
dialers.Lock()
defer dialers.Unlock()
return dialers.LocalFileAccessAllowed
}
// otherwise just return option value
return options.AllowLocalFileAccess
}
func SetLfaAllowed(options *types.Options) {
_ = LfaAllowed.Set(options.ExecutionId, options.AllowLocalFileAccess)
}
func GetLfaAllowed(options *types.Options) bool {
allowed, ok := LfaAllowed.Get(options.ExecutionId)
return ok && allowed
}
func NormalizePathWithExecutionId(executionId string, filePath string) (string, error) {
options := &types.Options{
ExecutionId: executionId,
}
return NormalizePath(options, filePath)
}
// Normalizepath normalizes path and returns absolute path // Normalizepath normalizes path and returns absolute path
// it returns error if path is not allowed // it returns error if path is not allowed
// this respects the sandbox rules and only loads files from // this respects the sandbox rules and only loads files from
// allowed directories // allowed directories
func NormalizePath(filePath string) (string, error) { func NormalizePath(options *types.Options, filePath string) (string, error) {
// TODO: this should be tied to executionID // TODO: this should be tied to executionID using *types.Options
if LfaAllowed { if IsLfaAllowed(options) {
// if local file access is allowed, we can return the absolute path
return filePath, nil return filePath, nil
} }
cleaned, err := fileutil.ResolveNClean(filePath, config.DefaultConfig.GetTemplateDir()) cleaned, err := fileutil.ResolveNClean(filePath, config.DefaultConfig.GetTemplateDir())

View File

@ -74,18 +74,6 @@ func InitHeadless(options *types.Options) {
} }
} }
// AllowLocalFileAccess returns whether local file access is allowed
func IsLfaAllowed(options *types.Options) bool {
dialers, ok := dialers.Get(options.ExecutionId)
if ok && dialers != nil {
dialers.Lock()
defer dialers.Unlock()
return dialers.LocalFileAccessAllowed
}
return false
}
func IsRestrictLocalNetworkAccess(options *types.Options) bool { func IsRestrictLocalNetworkAccess(options *types.Options) bool {
dialers, ok := dialers.Get(options.ExecutionId) dialers, ok := dialers.Get(options.ExecutionId)
if ok && dialers != nil { if ok && dialers != nil {

View File

@ -200,9 +200,7 @@ func initDialers(options *types.Options) error {
StartActiveMemGuardian(context.Background()) StartActiveMemGuardian(context.Background())
// TODO: this should be tied to executionID SetLfaAllowed(options)
// overidde global settings with latest options
LfaAllowed = options.AllowLocalFileAccess
return nil return nil
} }

View File

@ -8,12 +8,14 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/projectdiscovery/nuclei/v3/pkg/types" "github.com/projectdiscovery/nuclei/v3/pkg/types"
"github.com/projectdiscovery/retryabledns" "github.com/projectdiscovery/retryabledns"
mapsutil "github.com/projectdiscovery/utils/maps"
) )
var ( var (
poolMutex *sync.RWMutex clientPool *mapsutil.SyncLockMap[string, *retryabledns.Client]
normalClient *retryabledns.Client normalClient *retryabledns.Client
clientPool map[string]*retryabledns.Client m sync.Mutex
) )
// defaultResolvers contains the list of resolvers known to be trusted. // defaultResolvers contains the list of resolvers known to be trusted.
@ -26,12 +28,14 @@ var defaultResolvers = []string{
// Init initializes the client pool implementation // Init initializes the client pool implementation
func Init(options *types.Options) error { func Init(options *types.Options) error {
m.Lock()
defer m.Unlock()
// Don't create clients if already created in the past. // Don't create clients if already created in the past.
if normalClient != nil { if normalClient != nil {
return nil return nil
} }
poolMutex = &sync.RWMutex{} clientPool = mapsutil.NewSyncLockMap[string, *retryabledns.Client]()
clientPool = make(map[string]*retryabledns.Client)
resolvers := defaultResolvers resolvers := defaultResolvers
if len(options.InternalResolversList) > 0 { if len(options.InternalResolversList) > 0 {
@ -45,6 +49,12 @@ func Init(options *types.Options) error {
return nil return nil
} }
func getNormalClient() *retryabledns.Client {
m.Lock()
defer m.Unlock()
return normalClient
}
// Configuration contains the custom configuration options for a client // Configuration contains the custom configuration options for a client
type Configuration struct { type Configuration struct {
// Retries contains the retries for the dns client // Retries contains the retries for the dns client
@ -71,15 +81,12 @@ func (c *Configuration) Hash() string {
// Get creates or gets a client for the protocol based on custom configuration // Get creates or gets a client for the protocol based on custom configuration
func Get(options *types.Options, configuration *Configuration) (*retryabledns.Client, error) { func Get(options *types.Options, configuration *Configuration) (*retryabledns.Client, error) {
if (configuration.Retries <= 1) && len(configuration.Resolvers) == 0 { if (configuration.Retries <= 1) && len(configuration.Resolvers) == 0 {
return normalClient, nil return getNormalClient(), nil
} }
hash := configuration.Hash() hash := configuration.Hash()
poolMutex.RLock() if client, ok := clientPool.Get(hash); ok {
if client, ok := clientPool[hash]; ok {
poolMutex.RUnlock()
return client, nil return client, nil
} }
poolMutex.RUnlock()
resolvers := defaultResolvers resolvers := defaultResolvers
if len(options.InternalResolversList) > 0 { if len(options.InternalResolversList) > 0 {
@ -95,9 +102,7 @@ func Get(options *types.Options, configuration *Configuration) (*retryabledns.Cl
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not create dns client") return nil, errors.Wrap(err, "could not create dns client")
} }
_ = clientPool.Set(hash, client)
poolMutex.Lock()
clientPool[hash] = client
poolMutex.Unlock()
return client, nil return client, nil
} }

View File

@ -154,16 +154,16 @@ func GetRawHTTP(options *protocols.ExecutorOptions) *rawhttp.Client {
return dialers.RawHTTPClient return dialers.RawHTTPClient
} }
rawHttpOptions := rawhttp.DefaultOptions rawHttpOptionsCopy := *rawhttp.DefaultOptions
if options.Options.AliveHttpProxy != "" { if options.Options.AliveHttpProxy != "" {
rawHttpOptions.Proxy = options.Options.AliveHttpProxy rawHttpOptionsCopy.Proxy = options.Options.AliveHttpProxy
} else if options.Options.AliveSocksProxy != "" { } else if options.Options.AliveSocksProxy != "" {
rawHttpOptions.Proxy = options.Options.AliveSocksProxy rawHttpOptionsCopy.Proxy = options.Options.AliveSocksProxy
} else if dialers.Fastdialer != nil { } else if dialers.Fastdialer != nil {
rawHttpOptions.FastDialer = dialers.Fastdialer rawHttpOptionsCopy.FastDialer = dialers.Fastdialer
} }
rawHttpOptions.Timeout = options.Options.GetTimeouts().HttpTimeout rawHttpOptionsCopy.Timeout = options.Options.GetTimeouts().HttpTimeout
dialers.RawHTTPClient = rawhttp.NewClient(rawHttpOptions) dialers.RawHTTPClient = rawhttp.NewClient(&rawHttpOptionsCopy)
return dialers.RawHTTPClient return dialers.RawHTTPClient
} }

View File

@ -3,7 +3,6 @@ package protocols
import ( import (
"context" "context"
"encoding/base64" "encoding/base64"
"sync"
"sync/atomic" "sync/atomic"
"github.com/projectdiscovery/fastdialer/fastdialer" "github.com/projectdiscovery/fastdialer/fastdialer"
@ -139,8 +138,6 @@ type ExecutorOptions struct {
Logger *gologger.Logger Logger *gologger.Logger
// CustomFastdialer is a fastdialer dialer instance // CustomFastdialer is a fastdialer dialer instance
CustomFastdialer *fastdialer.Dialer CustomFastdialer *fastdialer.Dialer
m sync.Mutex
} }
// todo: centralizing components is not feasible with current clogged architecture // todo: centralizing components is not feasible with current clogged architecture
@ -198,6 +195,11 @@ func (e *ExecutorOptions) HasTemplateCtx(input *contextargs.MetaInput) bool {
// GetTemplateCtx returns template context for given input // GetTemplateCtx returns template context for given input
func (e *ExecutorOptions) GetTemplateCtx(input *contextargs.MetaInput) *contextargs.Context { func (e *ExecutorOptions) GetTemplateCtx(input *contextargs.MetaInput) *contextargs.Context {
scanId := input.GetScanHash(e.TemplateID) scanId := input.GetScanHash(e.TemplateID)
if e.templateCtxStore == nil {
// if template context store is not initialized create it
e.CreateTemplateCtxStore()
}
// get template context from store
templateCtx, ok := e.templateCtxStore.Get(scanId) templateCtx, ok := e.templateCtxStore.Get(scanId)
if !ok { if !ok {
// if template context does not exist create new and add it to store and return it // if template context does not exist create new and add it to store and return it
@ -444,14 +446,49 @@ func (e *ExecutorOptions) ApplyNewEngineOptions(n *ExecutorOptions) {
if e == nil || n == nil || n.Options == nil { if e == nil || n == nil || n.Options == nil {
return return
} }
execID := n.Options.GetExecutionID()
e.SetExecutionID(execID)
}
// ApplyNewEngineOptions updates an existing ExecutorOptions with options from a new engine. This // The types.Options include the ExecutionID among other things
// handles things like the ExecutionID that need to be updated. e.Options = n.Options.Copy()
func (e *ExecutorOptions) SetExecutionID(executorId string) {
e.m.Lock() // Keep the template-specific fields, but replace the rest
defer e.m.Unlock() /*
e.Options.SetExecutionID(executorId) e.TemplateID = n.TemplateID
e.TemplatePath = n.TemplatePath
e.TemplateInfo = n.TemplateInfo
e.TemplateVerifier = n.TemplateVerifier
e.RawTemplate = n.RawTemplate
e.Variables = n.Variables
e.Constants = n.Constants
*/
e.Output = n.Output
e.Options = n.Options
e.IssuesClient = n.IssuesClient
e.Progress = n.Progress
e.RateLimiter = n.RateLimiter
e.Catalog = n.Catalog
e.ProjectFile = n.ProjectFile
e.Browser = n.Browser
e.Interactsh = n.Interactsh
e.HostErrorsCache = n.HostErrorsCache
e.StopAtFirstMatch = n.StopAtFirstMatch
e.ExcludeMatchers = n.ExcludeMatchers
e.InputHelper = n.InputHelper
e.FuzzParamsFrequency = n.FuzzParamsFrequency
e.FuzzStatsDB = n.FuzzStatsDB
e.DoNotCache = n.DoNotCache
e.Colorizer = n.Colorizer
e.WorkflowLoader = n.WorkflowLoader
e.ResumeCfg = n.ResumeCfg
e.ProtocolType = n.ProtocolType
e.Flow = n.Flow
e.IsMultiProtocol = n.IsMultiProtocol
e.templateCtxStore = n.templateCtxStore
e.JsCompiler = n.JsCompiler
e.AuthProvider = n.AuthProvider
e.TemporaryDirectory = n.TemporaryDirectory
e.Parser = n.Parser
e.ExportReqURLPattern = n.ExportReqURLPattern
e.GlobalMatchers = n.GlobalMatchers
e.Logger = n.Logger
e.CustomFastdialer = n.CustomFastdialer
} }

View File

@ -30,6 +30,12 @@ func Init(options *types.Options) error {
return nil return nil
} }
func getNormalClient() *rdap.Client {
m.Lock()
defer m.Unlock()
return normalClient
}
// Configuration contains the custom configuration options for a client - placeholder // Configuration contains the custom configuration options for a client - placeholder
type Configuration struct{} type Configuration struct{}
@ -40,7 +46,5 @@ func (c *Configuration) Hash() string {
// Get creates or gets a client for the protocol based on custom configuration // Get creates or gets a client for the protocol based on custom configuration
func Get(options *types.Options, configuration *Configuration) (*rdap.Client, error) { func Get(options *types.Options, configuration *Configuration) (*rdap.Client, error) {
m.Lock() return getNormalClient(), nil
defer m.Unlock()
return normalClient, nil
} }

View File

@ -56,49 +56,90 @@ func Parse(filePath string, preprocessor Preprocessor, options *protocols.Execut
} }
if !options.DoNotCache { if !options.DoNotCache {
if value, _, _ := parser.compiledTemplatesCache.Has(filePath); value != nil { if value, _, _ := parser.compiledTemplatesCache.Has(filePath); value != nil {
// Update the template to use the current options for the calling engine // Copy the template, apply new options, and recompile requests
// TODO: This may be require additional work for robustness tplCopy := *value
t := *value newBase := options.Copy()
t.Options.ApplyNewEngineOptions(options) newBase.TemplateID = tplCopy.Options.TemplateID
if t.CompiledWorkflow != nil { newBase.TemplatePath = tplCopy.Options.TemplatePath
t.CompiledWorkflow.Options.ApplyNewEngineOptions(options) newBase.TemplateInfo = tplCopy.Options.TemplateInfo
for _, w := range t.CompiledWorkflow.Workflows { newBase.TemplateVerifier = tplCopy.Options.TemplateVerifier
newBase.RawTemplate = tplCopy.Options.RawTemplate
tplCopy.Options = newBase
tplCopy.Options.ApplyNewEngineOptions(options)
if tplCopy.CompiledWorkflow != nil {
tplCopy.CompiledWorkflow.Options.ApplyNewEngineOptions(options)
for _, w := range tplCopy.CompiledWorkflow.Workflows {
for _, ex := range w.Executers { for _, ex := range w.Executers {
ex.Options.ApplyNewEngineOptions(options) ex.Options.ApplyNewEngineOptions(options)
} }
} }
} }
for _, r := range t.RequestsDNS {
r.UpdateOptions(t.Options) // TODO: Reconsider whether to recompile requests. Compiling these is just as slow
// as not using a cache at all, but may be necessary.
for i, r := range tplCopy.RequestsDNS {
rCopy := *r
rCopy.UpdateOptions(tplCopy.Options)
// rCopy.Compile(tplCopy.Options)
tplCopy.RequestsDNS[i] = &rCopy
} }
for _, r := range t.RequestsHTTP { for i, r := range tplCopy.RequestsHTTP {
r.UpdateOptions(t.Options) rCopy := *r
rCopy.UpdateOptions(tplCopy.Options)
// rCopy.Compile(tplCopy.Options)
tplCopy.RequestsHTTP[i] = &rCopy
} }
for _, r := range t.RequestsCode { for i, r := range tplCopy.RequestsCode {
r.UpdateOptions(t.Options) rCopy := *r
rCopy.UpdateOptions(tplCopy.Options)
// rCopy.Compile(tplCopy.Options)
tplCopy.RequestsCode[i] = &rCopy
} }
for _, r := range t.RequestsFile { for i, r := range tplCopy.RequestsFile {
r.UpdateOptions(t.Options) rCopy := *r
rCopy.UpdateOptions(tplCopy.Options)
// rCopy.Compile(tplCopy.Options)
tplCopy.RequestsFile[i] = &rCopy
} }
for _, r := range t.RequestsHeadless { for i, r := range tplCopy.RequestsHeadless {
r.UpdateOptions(t.Options) rCopy := *r
rCopy.UpdateOptions(tplCopy.Options)
// rCopy.Compile(tplCopy.Options)
tplCopy.RequestsHeadless[i] = &rCopy
} }
for _, r := range t.RequestsNetwork { for i, r := range tplCopy.RequestsNetwork {
r.UpdateOptions(t.Options) rCopy := *r
rCopy.UpdateOptions(tplCopy.Options)
// rCopy.Compile(tplCopy.Options)
tplCopy.RequestsNetwork[i] = &rCopy
} }
for _, r := range t.RequestsJavascript { for i, r := range tplCopy.RequestsJavascript {
r.UpdateOptions(t.Options) rCopy := *r
rCopy.UpdateOptions(tplCopy.Options)
//rCopy.Compile(tplCopy.Options)
tplCopy.RequestsJavascript[i] = &rCopy
} }
for _, r := range t.RequestsSSL { for i, r := range tplCopy.RequestsSSL {
r.UpdateOptions(t.Options) rCopy := *r
rCopy.UpdateOptions(tplCopy.Options)
// rCopy.Compile(tplCopy.Options)
tplCopy.RequestsSSL[i] = &rCopy
} }
for _, r := range t.RequestsWHOIS { for i, r := range tplCopy.RequestsWHOIS {
r.UpdateOptions(t.Options) rCopy := *r
rCopy.UpdateOptions(tplCopy.Options)
// rCopy.Compile(tplCopy.Options)
tplCopy.RequestsWHOIS[i] = &rCopy
} }
for _, r := range t.RequestsWebsocket { for i, r := range tplCopy.RequestsWebsocket {
r.UpdateOptions(t.Options) rCopy := *r
rCopy.UpdateOptions(tplCopy.Options)
// rCopy.Compile(tplCopy.Options)
tplCopy.RequestsWebsocket[i] = &rCopy
} }
template := t template := &tplCopy
if template.isGlobalMatchersEnabled() { if template.isGlobalMatchersEnabled() {
item := &globalmatchers.Item{ item := &globalmatchers.Item{
@ -119,8 +160,8 @@ func Parse(filePath string, preprocessor Preprocessor, options *protocols.Execut
template.CompiledWorkflow = compiled template.CompiledWorkflow = compiled
template.CompiledWorkflow.Options = options template.CompiledWorkflow.Options = options
} }
// options.Logger.Error().Msgf("returning cached template %s after recompiling %d requests", tplCopy.Options.TemplateID, tplCopy.Requests())
return &template, nil return template, nil
} }
} }

View File

@ -49,6 +49,23 @@ func (p *Parser) Cache() *Cache {
return p.parsedTemplatesCache return p.parsedTemplatesCache
} }
// CompiledCache returns the compiled templates cache
func (p *Parser) CompiledCache() *Cache {
return p.compiledTemplatesCache
}
func (p *Parser) ParsedCount() int {
p.Lock()
defer p.Unlock()
return len(p.parsedTemplatesCache.items.Map)
}
func (p *Parser) CompiledCount() int {
p.Lock()
defer p.Unlock()
return len(p.compiledTemplatesCache.items.Map)
}
func checkOpenFileError(err error) bool { func checkOpenFileError(err error) bool {
if err != nil && strings.Contains(err.Error(), "too many open files") { if err != nil && strings.Contains(err.Error(), "too many open files") {
panic(err) panic(err)
@ -171,84 +188,3 @@ func (p *Parser) LoadWorkflow(templatePath string, catalog catalog.Catalog) (boo
return false, nil return false, nil
} }
// CloneForExecutionId creates a clone with updated execution IDs
func (p *Parser) CloneForExecutionId(xid string) *Parser {
p.Lock()
defer p.Unlock()
newParser := &Parser{
ShouldValidate: p.ShouldValidate,
NoStrictSyntax: p.NoStrictSyntax,
parsedTemplatesCache: NewCache(),
compiledTemplatesCache: NewCache(),
}
for k, tpl := range p.parsedTemplatesCache.items.Map {
newTemplate := templateUpdateExecutionId(tpl.template, xid)
newParser.parsedTemplatesCache.Store(k, newTemplate, []byte(tpl.raw), tpl.err)
}
for k, tpl := range p.compiledTemplatesCache.items.Map {
newTemplate := templateUpdateExecutionId(tpl.template, xid)
newParser.compiledTemplatesCache.Store(k, newTemplate, []byte(tpl.raw), tpl.err)
}
return newParser
}
func templateUpdateExecutionId(tpl *Template, xid string) *Template {
// TODO: This is a no-op today since options are patched in elsewhere, but we're keeping this
// for future work where we may need additional tweaks per template instance.
return tpl
/*
templateBase := *tpl
var newOpts *protocols.ExecutorOptions
// Swap out the types.Options execution ID attached to the template
if templateBase.Options != nil {
optionsBase := *templateBase.Options //nolint
templateBase.Options = &optionsBase
if templateBase.Options.Options != nil {
optionsOptionsBase := *templateBase.Options.Options //nolint
templateBase.Options.Options = &optionsOptionsBase
templateBase.Options.Options.ExecutionId = xid
newOpts = templateBase.Options
}
}
if newOpts == nil {
return &templateBase
}
for _, r := range templateBase.RequestsDNS {
r.UpdateOptions(newOpts)
}
for _, r := range templateBase.RequestsHTTP {
r.UpdateOptions(newOpts)
}
for _, r := range templateBase.RequestsCode {
r.UpdateOptions(newOpts)
}
for _, r := range templateBase.RequestsFile {
r.UpdateOptions(newOpts)
}
for _, r := range templateBase.RequestsHeadless {
r.UpdateOptions(newOpts)
}
for _, r := range templateBase.RequestsNetwork {
r.UpdateOptions(newOpts)
}
for _, r := range templateBase.RequestsJavascript {
r.UpdateOptions(newOpts)
}
for _, r := range templateBase.RequestsSSL {
r.UpdateOptions(newOpts)
}
for _, r := range templateBase.RequestsWHOIS {
r.UpdateOptions(newOpts)
}
for _, r := range templateBase.RequestsWebsocket {
r.UpdateOptions(newOpts)
}
return &templateBase
*/
}

View File

@ -0,0 +1,16 @@
package utils
import (
"bytes"
"github.com/projectdiscovery/gologger/levels"
)
// CaptureWriter captures log output for testing
type CaptureWriter struct {
Buffer *bytes.Buffer
}
func (w *CaptureWriter) Write(data []byte, level levels.Level) {
w.Buffer.Write(data)
}