From a87b310e11ce79418a16cb8927f77eccd9ec6fb1 Mon Sep 17 00:00:00 2001 From: Mzack9999 Date: Mon, 5 May 2025 22:15:44 +0200 Subject: [PATCH] introducing execution id --- pkg/protocols/common/protocolinit/init.go | 10 +-- pkg/protocols/common/protocolstate/context.go | 46 ++++++++++++++ .../common/protocolstate/headless.go | 63 +++++++++++++------ pkg/protocols/common/protocolstate/state.go | 57 ++++++++++------- 4 files changed, 132 insertions(+), 44 deletions(-) create mode 100644 pkg/protocols/common/protocolstate/context.go diff --git a/pkg/protocols/common/protocolinit/init.go b/pkg/protocols/common/protocolinit/init.go index 20b7b7a10..f0ba77177 100644 --- a/pkg/protocols/common/protocolinit/init.go +++ b/pkg/protocols/common/protocolinit/init.go @@ -1,6 +1,8 @@ package protocolinit import ( + "context" + "github.com/projectdiscovery/nuclei/v3/pkg/js/compiler" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/dns/dnsclientpool" @@ -13,8 +15,8 @@ import ( ) // Init initializes the client pools for the protocols -func Init(options *types.Options) error { - if err := protocolstate.Init(options); err != nil { +func Init(ctx context.Context, options *types.Options) error { + if err := protocolstate.Init(ctx, options); err != nil { return err } if err := dnsclientpool.Init(options); err != nil { @@ -38,6 +40,6 @@ func Init(options *types.Options) error { return nil } -func Close() { - protocolstate.Close() +func Close(ctx context.Context) { + protocolstate.Close(ctx) } diff --git a/pkg/protocols/common/protocolstate/context.go b/pkg/protocols/common/protocolstate/context.go new file mode 100644 index 000000000..a6dbb46fb --- /dev/null +++ b/pkg/protocols/common/protocolstate/context.go @@ -0,0 +1,46 @@ +package protocolstate + +import ( + "context" + + "github.com/rs/xid" +) + +// contextKey is a type for context keys +type ContextKey string + +type ExecutionContext struct { + ExecutionID string +} + +// executionIDKey is the key used to store execution ID in context +const executionIDKey ContextKey = "execution_id" + +// WithExecutionID adds an execution ID to the context +func WithExecutionID(ctx context.Context, executionContext *ExecutionContext) context.Context { + return context.WithValue(ctx, executionIDKey, executionContext) +} + +// HasExecutionID checks if the context has an execution ID +func HasExecutionContext(ctx context.Context) bool { + _, ok := ctx.Value(executionIDKey).(*ExecutionContext) + return ok +} + +// GetExecutionID retrieves the execution ID from the context +// Returns empty string if no execution ID is set +func GetExecutionContext(ctx context.Context) *ExecutionContext { + if id, ok := ctx.Value(executionIDKey).(*ExecutionContext); ok { + return id + } + return nil +} + +// WithAutoExecutionContext creates a new context with an automatically generated execution ID +// If the input context already has an execution ID, it will be preserved +func WithAutoExecutionContext(ctx context.Context) context.Context { + if HasExecutionContext(ctx) { + return ctx + } + return WithExecutionID(ctx, &ExecutionContext{ExecutionID: xid.New().String()}) +} diff --git a/pkg/protocols/common/protocolstate/headless.go b/pkg/protocols/common/protocolstate/headless.go index 755d367b9..1b4e7b932 100644 --- a/pkg/protocols/common/protocolstate/headless.go +++ b/pkg/protocols/common/protocolstate/headless.go @@ -1,6 +1,7 @@ package protocolstate import ( + "context" "net" "strings" @@ -8,6 +9,7 @@ import ( "github.com/go-rod/rod/lib/proto" "github.com/projectdiscovery/networkpolicy" errorutil "github.com/projectdiscovery/utils/errors" + mapsutil "github.com/projectdiscovery/utils/maps" stringsutil "github.com/projectdiscovery/utils/strings" urlutil "github.com/projectdiscovery/utils/url" "go.uber.org/multierr" @@ -18,13 +20,25 @@ import ( var ( ErrURLDenied = errorutil.NewWithFmt("headless: url %v dropped by rule: %v") ErrHostDenied = errorutil.NewWithFmt("host %v dropped by network policy") - NetworkPolicy *networkpolicy.NetworkPolicy + networkPolicies = mapsutil.NewSyncLockMap[string, *networkpolicy.NetworkPolicy]() allowLocalFileAccess bool ) +func GetNetworkPolicy(ctx context.Context) *networkpolicy.NetworkPolicy { + execCtx := GetExecutionContext(ctx) + if execCtx == nil { + return nil + } + np, ok := networkPolicies.Get(execCtx.ExecutionID) + if !ok || np == nil { + return nil + } + return np +} + // ValidateNFailRequest validates and fails request // if the request does not respect the rules, it will be canceled with reason -func ValidateNFailRequest(page *rod.Page, e *proto.FetchRequestPaused) error { +func ValidateNFailRequest(ctx context.Context, page *rod.Page, e *proto.FetchRequestPaused) error { reqURL := e.Request.URL normalized := strings.ToLower(reqURL) // normalize url to lowercase normalized = strings.TrimSpace(normalized) // trim leading & trailing whitespaces @@ -36,7 +50,7 @@ func ValidateNFailRequest(page *rod.Page, e *proto.FetchRequestPaused) error { if stringsutil.HasPrefixAnyI(normalized, "ftp:", "externalfile:", "chrome:", "chrome-extension:") { return multierr.Combine(FailWithReason(page, e), ErrURLDenied.Msgf(reqURL, "protocol blocked by network policy")) } - if !isValidHost(reqURL) { + if !isValidHost(ctx, reqURL) { return multierr.Combine(FailWithReason(page, e), ErrURLDenied.Msgf(reqURL, "address blocked by network policy")) } return nil @@ -52,54 +66,67 @@ func FailWithReason(page *rod.Page, e *proto.FetchRequestPaused) error { } // InitHeadless initializes headless protocol state -func InitHeadless(localFileAccess bool, np *networkpolicy.NetworkPolicy) { +func InitHeadless(ctx context.Context, localFileAccess bool, np *networkpolicy.NetworkPolicy) { allowLocalFileAccess = localFileAccess if np != nil { - NetworkPolicy = np + execCtx := GetExecutionContext(ctx) + if execCtx != nil { + networkPolicies.Set(execCtx.ExecutionID, np) + } } } // isValidHost checks if the host is valid (only limited to http/https protocols) -func isValidHost(targetUrl string) bool { +func isValidHost(ctx context.Context, targetUrl string) bool { if !stringsutil.HasPrefixAny(targetUrl, "http:", "https:") { return true } - if NetworkPolicy == nil { + + execCtx := GetExecutionContext(ctx) + if execCtx == nil { return true } + + np, ok := networkPolicies.Get(execCtx.ExecutionID) + if !ok || np == nil { + return true + } + urlx, err := urlutil.Parse(targetUrl) if err != nil { // not a valid url return false } targetUrl = urlx.Hostname() - _, ok := NetworkPolicy.ValidateHost(targetUrl) + _, ok = np.ValidateHost(targetUrl) return ok } // IsHostAllowed checks if the host is allowed by network policy -func IsHostAllowed(targetUrl string) bool { - if NetworkPolicy == nil { +func IsHostAllowed(ctx context.Context, targetUrl string) bool { + execCtx := GetExecutionContext(ctx) + if execCtx == nil { return true } + + np, ok := networkPolicies.Get(execCtx.ExecutionID) + if !ok || np == nil { + return true + } + sepCount := strings.Count(targetUrl, ":") if sepCount > 1 { // most likely a ipv6 address (parse url and validate host) - return NetworkPolicy.Validate(targetUrl) + return np.Validate(targetUrl) } if sepCount == 1 { host, _, _ := net.SplitHostPort(targetUrl) - if _, ok := NetworkPolicy.ValidateHost(host); !ok { + if _, ok := np.ValidateHost(host); !ok { return false } return true - // portInt, _ := strconv.Atoi(port) - // fixme: broken port validation logic in networkpolicy - // if !NetworkPolicy.ValidatePort(portInt) { - // return false - // } } // just a hostname or ip without port - _, ok := NetworkPolicy.ValidateHost(targetUrl) + _, ok = np.ValidateHost(targetUrl) return ok } diff --git a/pkg/protocols/common/protocolstate/state.go b/pkg/protocols/common/protocolstate/state.go index 89c5eb355..aef024e3c 100644 --- a/pkg/protocols/common/protocolstate/state.go +++ b/pkg/protocols/common/protocolstate/state.go @@ -5,7 +5,6 @@ import ( "fmt" "net" "net/url" - "sync" "github.com/go-sql-driver/mysql" "github.com/pkg/errors" @@ -16,28 +15,36 @@ import ( "github.com/projectdiscovery/networkpolicy" "github.com/projectdiscovery/nuclei/v3/pkg/types" "github.com/projectdiscovery/nuclei/v3/pkg/utils/expand" + mapsutil "github.com/projectdiscovery/utils/maps" ) // Dialer is a shared fastdialer instance for host DNS resolution var ( - muDialer sync.RWMutex - Dialer *fastdialer.Dialer + dialers *mapsutil.SyncLockMap[string, *fastdialer.Dialer] ) -func GetDialer() *fastdialer.Dialer { - muDialer.RLock() - defer muDialer.RUnlock() - - return Dialer +func GetDialer(ctx context.Context) *fastdialer.Dialer { + executionContext := GetExecutionContext(ctx) + dialer, ok := dialers.Get(executionContext.ExecutionID) + if !ok { + return nil + } + return dialer } -func ShouldInit() bool { - return Dialer == nil +func ShouldInit(ctx context.Context) bool { + executionContext := GetExecutionContext(ctx) + dialer, ok := dialers.Get(executionContext.ExecutionID) + if !ok { + return false + } + return dialer == nil } // Init creates the Dialer instance based on user configuration -func Init(options *types.Options) error { - if Dialer != nil { +func Init(ctx context.Context, options *types.Options) error { + executionContext := GetExecutionContext(ctx) + if GetDialer(ctx) != nil { return nil } @@ -66,8 +73,8 @@ func Init(options *types.Options) error { DenyList: expandedDenyList, } opts.WithNetworkPolicyOptions = npOptions - NetworkPolicy, _ = networkpolicy.New(*npOptions) - InitHeadless(options.AllowLocalFileAccess, NetworkPolicy) + networkPolicy, _ := networkpolicy.New(*npOptions) + InitHeadless(ctx, options.AllowLocalFileAccess, networkPolicy) switch { case options.SourceIP != "" && options.Interface != "": @@ -152,7 +159,7 @@ func Init(options *types.Options) error { if err != nil { return errors.Wrap(err, "could not create dialer") } - Dialer = dialer + dialers.Set(executionContext.ExecutionID, dialer) // Set a custom dialer for the "nucleitcp" protocol. This is just plain TCP, but it's registered // with a different name so that we do not clobber the "tcp" dialer in the event that nuclei is @@ -164,6 +171,7 @@ func Init(options *types.Options) error { addr += ":3306" } + // TODO: find a way to get dialer from context return Dialer.Dial(ctx, "tcp", addr) }) @@ -226,13 +234,18 @@ func interfaceAddresses(interfaceName string) ([]net.Addr, error) { } // Close closes the global shared fastdialer -func Close() { - muDialer.Lock() - defer muDialer.Unlock() - - if Dialer != nil { - Dialer.Close() - Dialer = nil +func Close(ctx context.Context) { + executionContext := GetExecutionContext(ctx) + dialer, ok := dialers.Get(executionContext.ExecutionID) + if !ok { + return } + + if dialer != nil { + dialer.Close() + } + + dialers.Delete(executionContext.ExecutionID) + StopActiveMemGuardian() }