introducing execution id

This commit is contained in:
Mzack9999 2025-05-05 22:15:44 +02:00
parent b9d0f2585f
commit a87b310e11
4 changed files with 132 additions and 44 deletions

View File

@ -1,6 +1,8 @@
package protocolinit
import (
"context"
"github.com/projectdiscovery/nuclei/v3/pkg/js/compiler"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/dns/dnsclientpool"
@ -13,8 +15,8 @@ import (
)
// Init initializes the client pools for the protocols
func Init(options *types.Options) error {
if err := protocolstate.Init(options); err != nil {
func Init(ctx context.Context, options *types.Options) error {
if err := protocolstate.Init(ctx, options); err != nil {
return err
}
if err := dnsclientpool.Init(options); err != nil {
@ -38,6 +40,6 @@ func Init(options *types.Options) error {
return nil
}
func Close() {
protocolstate.Close()
func Close(ctx context.Context) {
protocolstate.Close(ctx)
}

View File

@ -0,0 +1,46 @@
package protocolstate
import (
"context"
"github.com/rs/xid"
)
// contextKey is a type for context keys
type ContextKey string
type ExecutionContext struct {
ExecutionID string
}
// executionIDKey is the key used to store execution ID in context
const executionIDKey ContextKey = "execution_id"
// WithExecutionID adds an execution ID to the context
func WithExecutionID(ctx context.Context, executionContext *ExecutionContext) context.Context {
return context.WithValue(ctx, executionIDKey, executionContext)
}
// HasExecutionID checks if the context has an execution ID
func HasExecutionContext(ctx context.Context) bool {
_, ok := ctx.Value(executionIDKey).(*ExecutionContext)
return ok
}
// GetExecutionID retrieves the execution ID from the context
// Returns empty string if no execution ID is set
func GetExecutionContext(ctx context.Context) *ExecutionContext {
if id, ok := ctx.Value(executionIDKey).(*ExecutionContext); ok {
return id
}
return nil
}
// WithAutoExecutionContext creates a new context with an automatically generated execution ID
// If the input context already has an execution ID, it will be preserved
func WithAutoExecutionContext(ctx context.Context) context.Context {
if HasExecutionContext(ctx) {
return ctx
}
return WithExecutionID(ctx, &ExecutionContext{ExecutionID: xid.New().String()})
}

View File

@ -1,6 +1,7 @@
package protocolstate
import (
"context"
"net"
"strings"
@ -8,6 +9,7 @@ import (
"github.com/go-rod/rod/lib/proto"
"github.com/projectdiscovery/networkpolicy"
errorutil "github.com/projectdiscovery/utils/errors"
mapsutil "github.com/projectdiscovery/utils/maps"
stringsutil "github.com/projectdiscovery/utils/strings"
urlutil "github.com/projectdiscovery/utils/url"
"go.uber.org/multierr"
@ -18,13 +20,25 @@ import (
var (
ErrURLDenied = errorutil.NewWithFmt("headless: url %v dropped by rule: %v")
ErrHostDenied = errorutil.NewWithFmt("host %v dropped by network policy")
NetworkPolicy *networkpolicy.NetworkPolicy
networkPolicies = mapsutil.NewSyncLockMap[string, *networkpolicy.NetworkPolicy]()
allowLocalFileAccess bool
)
func GetNetworkPolicy(ctx context.Context) *networkpolicy.NetworkPolicy {
execCtx := GetExecutionContext(ctx)
if execCtx == nil {
return nil
}
np, ok := networkPolicies.Get(execCtx.ExecutionID)
if !ok || np == nil {
return nil
}
return np
}
// ValidateNFailRequest validates and fails request
// if the request does not respect the rules, it will be canceled with reason
func ValidateNFailRequest(page *rod.Page, e *proto.FetchRequestPaused) error {
func ValidateNFailRequest(ctx context.Context, page *rod.Page, e *proto.FetchRequestPaused) error {
reqURL := e.Request.URL
normalized := strings.ToLower(reqURL) // normalize url to lowercase
normalized = strings.TrimSpace(normalized) // trim leading & trailing whitespaces
@ -36,7 +50,7 @@ func ValidateNFailRequest(page *rod.Page, e *proto.FetchRequestPaused) error {
if stringsutil.HasPrefixAnyI(normalized, "ftp:", "externalfile:", "chrome:", "chrome-extension:") {
return multierr.Combine(FailWithReason(page, e), ErrURLDenied.Msgf(reqURL, "protocol blocked by network policy"))
}
if !isValidHost(reqURL) {
if !isValidHost(ctx, reqURL) {
return multierr.Combine(FailWithReason(page, e), ErrURLDenied.Msgf(reqURL, "address blocked by network policy"))
}
return nil
@ -52,54 +66,67 @@ func FailWithReason(page *rod.Page, e *proto.FetchRequestPaused) error {
}
// InitHeadless initializes headless protocol state
func InitHeadless(localFileAccess bool, np *networkpolicy.NetworkPolicy) {
func InitHeadless(ctx context.Context, localFileAccess bool, np *networkpolicy.NetworkPolicy) {
allowLocalFileAccess = localFileAccess
if np != nil {
NetworkPolicy = np
execCtx := GetExecutionContext(ctx)
if execCtx != nil {
networkPolicies.Set(execCtx.ExecutionID, np)
}
}
}
// isValidHost checks if the host is valid (only limited to http/https protocols)
func isValidHost(targetUrl string) bool {
func isValidHost(ctx context.Context, targetUrl string) bool {
if !stringsutil.HasPrefixAny(targetUrl, "http:", "https:") {
return true
}
if NetworkPolicy == nil {
execCtx := GetExecutionContext(ctx)
if execCtx == nil {
return true
}
np, ok := networkPolicies.Get(execCtx.ExecutionID)
if !ok || np == nil {
return true
}
urlx, err := urlutil.Parse(targetUrl)
if err != nil {
// not a valid url
return false
}
targetUrl = urlx.Hostname()
_, ok := NetworkPolicy.ValidateHost(targetUrl)
_, ok = np.ValidateHost(targetUrl)
return ok
}
// IsHostAllowed checks if the host is allowed by network policy
func IsHostAllowed(targetUrl string) bool {
if NetworkPolicy == nil {
func IsHostAllowed(ctx context.Context, targetUrl string) bool {
execCtx := GetExecutionContext(ctx)
if execCtx == nil {
return true
}
np, ok := networkPolicies.Get(execCtx.ExecutionID)
if !ok || np == nil {
return true
}
sepCount := strings.Count(targetUrl, ":")
if sepCount > 1 {
// most likely a ipv6 address (parse url and validate host)
return NetworkPolicy.Validate(targetUrl)
return np.Validate(targetUrl)
}
if sepCount == 1 {
host, _, _ := net.SplitHostPort(targetUrl)
if _, ok := NetworkPolicy.ValidateHost(host); !ok {
if _, ok := np.ValidateHost(host); !ok {
return false
}
return true
// portInt, _ := strconv.Atoi(port)
// fixme: broken port validation logic in networkpolicy
// if !NetworkPolicy.ValidatePort(portInt) {
// return false
// }
}
// just a hostname or ip without port
_, ok := NetworkPolicy.ValidateHost(targetUrl)
_, ok = np.ValidateHost(targetUrl)
return ok
}

View File

@ -5,7 +5,6 @@ import (
"fmt"
"net"
"net/url"
"sync"
"github.com/go-sql-driver/mysql"
"github.com/pkg/errors"
@ -16,28 +15,36 @@ import (
"github.com/projectdiscovery/networkpolicy"
"github.com/projectdiscovery/nuclei/v3/pkg/types"
"github.com/projectdiscovery/nuclei/v3/pkg/utils/expand"
mapsutil "github.com/projectdiscovery/utils/maps"
)
// Dialer is a shared fastdialer instance for host DNS resolution
var (
muDialer sync.RWMutex
Dialer *fastdialer.Dialer
dialers *mapsutil.SyncLockMap[string, *fastdialer.Dialer]
)
func GetDialer() *fastdialer.Dialer {
muDialer.RLock()
defer muDialer.RUnlock()
return Dialer
func GetDialer(ctx context.Context) *fastdialer.Dialer {
executionContext := GetExecutionContext(ctx)
dialer, ok := dialers.Get(executionContext.ExecutionID)
if !ok {
return nil
}
return dialer
}
func ShouldInit() bool {
return Dialer == nil
func ShouldInit(ctx context.Context) bool {
executionContext := GetExecutionContext(ctx)
dialer, ok := dialers.Get(executionContext.ExecutionID)
if !ok {
return false
}
return dialer == nil
}
// Init creates the Dialer instance based on user configuration
func Init(options *types.Options) error {
if Dialer != nil {
func Init(ctx context.Context, options *types.Options) error {
executionContext := GetExecutionContext(ctx)
if GetDialer(ctx) != nil {
return nil
}
@ -66,8 +73,8 @@ func Init(options *types.Options) error {
DenyList: expandedDenyList,
}
opts.WithNetworkPolicyOptions = npOptions
NetworkPolicy, _ = networkpolicy.New(*npOptions)
InitHeadless(options.AllowLocalFileAccess, NetworkPolicy)
networkPolicy, _ := networkpolicy.New(*npOptions)
InitHeadless(ctx, options.AllowLocalFileAccess, networkPolicy)
switch {
case options.SourceIP != "" && options.Interface != "":
@ -152,7 +159,7 @@ func Init(options *types.Options) error {
if err != nil {
return errors.Wrap(err, "could not create dialer")
}
Dialer = dialer
dialers.Set(executionContext.ExecutionID, dialer)
// Set a custom dialer for the "nucleitcp" protocol. This is just plain TCP, but it's registered
// with a different name so that we do not clobber the "tcp" dialer in the event that nuclei is
@ -164,6 +171,7 @@ func Init(options *types.Options) error {
addr += ":3306"
}
// TODO: find a way to get dialer from context
return Dialer.Dial(ctx, "tcp", addr)
})
@ -226,13 +234,18 @@ func interfaceAddresses(interfaceName string) ([]net.Addr, error) {
}
// Close closes the global shared fastdialer
func Close() {
muDialer.Lock()
defer muDialer.Unlock()
if Dialer != nil {
Dialer.Close()
Dialer = nil
func Close(ctx context.Context) {
executionContext := GetExecutionContext(ctx)
dialer, ok := dialers.Get(executionContext.ExecutionID)
if !ok {
return
}
if dialer != nil {
dialer.Close()
}
dialers.Delete(executionContext.ExecutionID)
StopActiveMemGuardian()
}