mirror of
https://github.com/projectdiscovery/nuclei.git
synced 2025-12-24 23:25:24 +00:00
introducing execution id
This commit is contained in:
parent
b9d0f2585f
commit
a87b310e11
@ -1,6 +1,8 @@
|
||||
package protocolinit
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/js/compiler"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/dns/dnsclientpool"
|
||||
@ -13,8 +15,8 @@ import (
|
||||
)
|
||||
|
||||
// Init initializes the client pools for the protocols
|
||||
func Init(options *types.Options) error {
|
||||
if err := protocolstate.Init(options); err != nil {
|
||||
func Init(ctx context.Context, options *types.Options) error {
|
||||
if err := protocolstate.Init(ctx, options); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := dnsclientpool.Init(options); err != nil {
|
||||
@ -38,6 +40,6 @@ func Init(options *types.Options) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func Close() {
|
||||
protocolstate.Close()
|
||||
func Close(ctx context.Context) {
|
||||
protocolstate.Close(ctx)
|
||||
}
|
||||
|
||||
46
pkg/protocols/common/protocolstate/context.go
Normal file
46
pkg/protocols/common/protocolstate/context.go
Normal file
@ -0,0 +1,46 @@
|
||||
package protocolstate
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/rs/xid"
|
||||
)
|
||||
|
||||
// contextKey is a type for context keys
|
||||
type ContextKey string
|
||||
|
||||
type ExecutionContext struct {
|
||||
ExecutionID string
|
||||
}
|
||||
|
||||
// executionIDKey is the key used to store execution ID in context
|
||||
const executionIDKey ContextKey = "execution_id"
|
||||
|
||||
// WithExecutionID adds an execution ID to the context
|
||||
func WithExecutionID(ctx context.Context, executionContext *ExecutionContext) context.Context {
|
||||
return context.WithValue(ctx, executionIDKey, executionContext)
|
||||
}
|
||||
|
||||
// HasExecutionID checks if the context has an execution ID
|
||||
func HasExecutionContext(ctx context.Context) bool {
|
||||
_, ok := ctx.Value(executionIDKey).(*ExecutionContext)
|
||||
return ok
|
||||
}
|
||||
|
||||
// GetExecutionID retrieves the execution ID from the context
|
||||
// Returns empty string if no execution ID is set
|
||||
func GetExecutionContext(ctx context.Context) *ExecutionContext {
|
||||
if id, ok := ctx.Value(executionIDKey).(*ExecutionContext); ok {
|
||||
return id
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// WithAutoExecutionContext creates a new context with an automatically generated execution ID
|
||||
// If the input context already has an execution ID, it will be preserved
|
||||
func WithAutoExecutionContext(ctx context.Context) context.Context {
|
||||
if HasExecutionContext(ctx) {
|
||||
return ctx
|
||||
}
|
||||
return WithExecutionID(ctx, &ExecutionContext{ExecutionID: xid.New().String()})
|
||||
}
|
||||
@ -1,6 +1,7 @@
|
||||
package protocolstate
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
@ -8,6 +9,7 @@ import (
|
||||
"github.com/go-rod/rod/lib/proto"
|
||||
"github.com/projectdiscovery/networkpolicy"
|
||||
errorutil "github.com/projectdiscovery/utils/errors"
|
||||
mapsutil "github.com/projectdiscovery/utils/maps"
|
||||
stringsutil "github.com/projectdiscovery/utils/strings"
|
||||
urlutil "github.com/projectdiscovery/utils/url"
|
||||
"go.uber.org/multierr"
|
||||
@ -18,13 +20,25 @@ import (
|
||||
var (
|
||||
ErrURLDenied = errorutil.NewWithFmt("headless: url %v dropped by rule: %v")
|
||||
ErrHostDenied = errorutil.NewWithFmt("host %v dropped by network policy")
|
||||
NetworkPolicy *networkpolicy.NetworkPolicy
|
||||
networkPolicies = mapsutil.NewSyncLockMap[string, *networkpolicy.NetworkPolicy]()
|
||||
allowLocalFileAccess bool
|
||||
)
|
||||
|
||||
func GetNetworkPolicy(ctx context.Context) *networkpolicy.NetworkPolicy {
|
||||
execCtx := GetExecutionContext(ctx)
|
||||
if execCtx == nil {
|
||||
return nil
|
||||
}
|
||||
np, ok := networkPolicies.Get(execCtx.ExecutionID)
|
||||
if !ok || np == nil {
|
||||
return nil
|
||||
}
|
||||
return np
|
||||
}
|
||||
|
||||
// ValidateNFailRequest validates and fails request
|
||||
// if the request does not respect the rules, it will be canceled with reason
|
||||
func ValidateNFailRequest(page *rod.Page, e *proto.FetchRequestPaused) error {
|
||||
func ValidateNFailRequest(ctx context.Context, page *rod.Page, e *proto.FetchRequestPaused) error {
|
||||
reqURL := e.Request.URL
|
||||
normalized := strings.ToLower(reqURL) // normalize url to lowercase
|
||||
normalized = strings.TrimSpace(normalized) // trim leading & trailing whitespaces
|
||||
@ -36,7 +50,7 @@ func ValidateNFailRequest(page *rod.Page, e *proto.FetchRequestPaused) error {
|
||||
if stringsutil.HasPrefixAnyI(normalized, "ftp:", "externalfile:", "chrome:", "chrome-extension:") {
|
||||
return multierr.Combine(FailWithReason(page, e), ErrURLDenied.Msgf(reqURL, "protocol blocked by network policy"))
|
||||
}
|
||||
if !isValidHost(reqURL) {
|
||||
if !isValidHost(ctx, reqURL) {
|
||||
return multierr.Combine(FailWithReason(page, e), ErrURLDenied.Msgf(reqURL, "address blocked by network policy"))
|
||||
}
|
||||
return nil
|
||||
@ -52,54 +66,67 @@ func FailWithReason(page *rod.Page, e *proto.FetchRequestPaused) error {
|
||||
}
|
||||
|
||||
// InitHeadless initializes headless protocol state
|
||||
func InitHeadless(localFileAccess bool, np *networkpolicy.NetworkPolicy) {
|
||||
func InitHeadless(ctx context.Context, localFileAccess bool, np *networkpolicy.NetworkPolicy) {
|
||||
allowLocalFileAccess = localFileAccess
|
||||
if np != nil {
|
||||
NetworkPolicy = np
|
||||
execCtx := GetExecutionContext(ctx)
|
||||
if execCtx != nil {
|
||||
networkPolicies.Set(execCtx.ExecutionID, np)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isValidHost checks if the host is valid (only limited to http/https protocols)
|
||||
func isValidHost(targetUrl string) bool {
|
||||
func isValidHost(ctx context.Context, targetUrl string) bool {
|
||||
if !stringsutil.HasPrefixAny(targetUrl, "http:", "https:") {
|
||||
return true
|
||||
}
|
||||
if NetworkPolicy == nil {
|
||||
|
||||
execCtx := GetExecutionContext(ctx)
|
||||
if execCtx == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
np, ok := networkPolicies.Get(execCtx.ExecutionID)
|
||||
if !ok || np == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
urlx, err := urlutil.Parse(targetUrl)
|
||||
if err != nil {
|
||||
// not a valid url
|
||||
return false
|
||||
}
|
||||
targetUrl = urlx.Hostname()
|
||||
_, ok := NetworkPolicy.ValidateHost(targetUrl)
|
||||
_, ok = np.ValidateHost(targetUrl)
|
||||
return ok
|
||||
}
|
||||
|
||||
// IsHostAllowed checks if the host is allowed by network policy
|
||||
func IsHostAllowed(targetUrl string) bool {
|
||||
if NetworkPolicy == nil {
|
||||
func IsHostAllowed(ctx context.Context, targetUrl string) bool {
|
||||
execCtx := GetExecutionContext(ctx)
|
||||
if execCtx == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
np, ok := networkPolicies.Get(execCtx.ExecutionID)
|
||||
if !ok || np == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
sepCount := strings.Count(targetUrl, ":")
|
||||
if sepCount > 1 {
|
||||
// most likely a ipv6 address (parse url and validate host)
|
||||
return NetworkPolicy.Validate(targetUrl)
|
||||
return np.Validate(targetUrl)
|
||||
}
|
||||
if sepCount == 1 {
|
||||
host, _, _ := net.SplitHostPort(targetUrl)
|
||||
if _, ok := NetworkPolicy.ValidateHost(host); !ok {
|
||||
if _, ok := np.ValidateHost(host); !ok {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
// portInt, _ := strconv.Atoi(port)
|
||||
// fixme: broken port validation logic in networkpolicy
|
||||
// if !NetworkPolicy.ValidatePort(portInt) {
|
||||
// return false
|
||||
// }
|
||||
}
|
||||
// just a hostname or ip without port
|
||||
_, ok := NetworkPolicy.ValidateHost(targetUrl)
|
||||
_, ok = np.ValidateHost(targetUrl)
|
||||
return ok
|
||||
}
|
||||
|
||||
@ -5,7 +5,6 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"sync"
|
||||
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/pkg/errors"
|
||||
@ -16,28 +15,36 @@ import (
|
||||
"github.com/projectdiscovery/networkpolicy"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/types"
|
||||
"github.com/projectdiscovery/nuclei/v3/pkg/utils/expand"
|
||||
mapsutil "github.com/projectdiscovery/utils/maps"
|
||||
)
|
||||
|
||||
// Dialer is a shared fastdialer instance for host DNS resolution
|
||||
var (
|
||||
muDialer sync.RWMutex
|
||||
Dialer *fastdialer.Dialer
|
||||
dialers *mapsutil.SyncLockMap[string, *fastdialer.Dialer]
|
||||
)
|
||||
|
||||
func GetDialer() *fastdialer.Dialer {
|
||||
muDialer.RLock()
|
||||
defer muDialer.RUnlock()
|
||||
|
||||
return Dialer
|
||||
func GetDialer(ctx context.Context) *fastdialer.Dialer {
|
||||
executionContext := GetExecutionContext(ctx)
|
||||
dialer, ok := dialers.Get(executionContext.ExecutionID)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return dialer
|
||||
}
|
||||
|
||||
func ShouldInit() bool {
|
||||
return Dialer == nil
|
||||
func ShouldInit(ctx context.Context) bool {
|
||||
executionContext := GetExecutionContext(ctx)
|
||||
dialer, ok := dialers.Get(executionContext.ExecutionID)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return dialer == nil
|
||||
}
|
||||
|
||||
// Init creates the Dialer instance based on user configuration
|
||||
func Init(options *types.Options) error {
|
||||
if Dialer != nil {
|
||||
func Init(ctx context.Context, options *types.Options) error {
|
||||
executionContext := GetExecutionContext(ctx)
|
||||
if GetDialer(ctx) != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -66,8 +73,8 @@ func Init(options *types.Options) error {
|
||||
DenyList: expandedDenyList,
|
||||
}
|
||||
opts.WithNetworkPolicyOptions = npOptions
|
||||
NetworkPolicy, _ = networkpolicy.New(*npOptions)
|
||||
InitHeadless(options.AllowLocalFileAccess, NetworkPolicy)
|
||||
networkPolicy, _ := networkpolicy.New(*npOptions)
|
||||
InitHeadless(ctx, options.AllowLocalFileAccess, networkPolicy)
|
||||
|
||||
switch {
|
||||
case options.SourceIP != "" && options.Interface != "":
|
||||
@ -152,7 +159,7 @@ func Init(options *types.Options) error {
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "could not create dialer")
|
||||
}
|
||||
Dialer = dialer
|
||||
dialers.Set(executionContext.ExecutionID, dialer)
|
||||
|
||||
// Set a custom dialer for the "nucleitcp" protocol. This is just plain TCP, but it's registered
|
||||
// with a different name so that we do not clobber the "tcp" dialer in the event that nuclei is
|
||||
@ -164,6 +171,7 @@ func Init(options *types.Options) error {
|
||||
addr += ":3306"
|
||||
}
|
||||
|
||||
// TODO: find a way to get dialer from context
|
||||
return Dialer.Dial(ctx, "tcp", addr)
|
||||
})
|
||||
|
||||
@ -226,13 +234,18 @@ func interfaceAddresses(interfaceName string) ([]net.Addr, error) {
|
||||
}
|
||||
|
||||
// Close closes the global shared fastdialer
|
||||
func Close() {
|
||||
muDialer.Lock()
|
||||
defer muDialer.Unlock()
|
||||
|
||||
if Dialer != nil {
|
||||
Dialer.Close()
|
||||
Dialer = nil
|
||||
func Close(ctx context.Context) {
|
||||
executionContext := GetExecutionContext(ctx)
|
||||
dialer, ok := dialers.Get(executionContext.ExecutionID)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if dialer != nil {
|
||||
dialer.Close()
|
||||
}
|
||||
|
||||
dialers.Delete(executionContext.ExecutionID)
|
||||
|
||||
StopActiveMemGuardian()
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user