This commit is contained in:
Mzack9999 2025-05-06 10:13:46 +02:00
parent a87b310e11
commit e44e3a4c82
28 changed files with 217 additions and 171 deletions

View File

@ -38,7 +38,10 @@ func (r *Runner) initializeTemplatesHTTPInput() (*hybrid.HybridMap, error) {
}
httpxOptions.RetryMax = r.options.Retries
httpxOptions.Timeout = time.Duration(r.options.Timeout) * time.Second
httpxOptions.NetworkPolicy = protocolstate.NetworkPolicy
dialers := protocolstate.GetDialersWithId(r.options.ExecutionId)
httpxOptions.NetworkPolicy = dialers.NetworkPolicy
httpxClient, err := httpx.New(&httpxOptions)
if err != nil {
return nil, errors.Wrap(err, "could not create httpx client")

View File

@ -413,7 +413,7 @@ func (r *Runner) Close() {
if r.inputProvider != nil {
r.inputProvider.Close()
}
protocolinit.Close()
protocolinit.Close(r.options.ExecutionId)
if r.pprofServer != nil {
r.pprofServer.Stop()
}
@ -655,7 +655,7 @@ func (r *Runner) RunEnumeration() error {
}
ret := uncover.GetUncoverTargetsFromMetadata(context.TODO(), store.Templates(), r.options.UncoverField, uncoverOpts)
for host := range ret {
_ = r.inputProvider.SetWithExclusions(host)
_ = r.inputProvider.SetWithExclusions(r.options.ExecutionId, host)
}
}
// display execution info like version , templates used etc

View File

@ -154,7 +154,7 @@ func (e *ThreadSafeNucleiEngine) ExecuteNucleiWithOptsCtx(ctx context.Context, t
}
store.Load()
inputProvider := provider.NewSimpleInputProviderWithUrls(targets...)
inputProvider := provider.NewSimpleInputProviderWithUrls(e.eng.opts.ExecutionId, targets...)
if len(store.Templates()) == 0 && len(store.Workflows()) == 0 {
return ErrNoTemplatesAvailable

View File

@ -124,9 +124,9 @@ func (e *NucleiEngine) GetWorkflows() []*templates.Template {
func (e *NucleiEngine) LoadTargets(targets []string, probeNonHttp bool) {
for _, target := range targets {
if probeNonHttp {
_ = e.inputProvider.SetWithProbe(target, e.httpxClient)
_ = e.inputProvider.SetWithProbe(e.opts.ExecutionId, target, e.httpxClient)
} else {
e.inputProvider.Set(target)
e.inputProvider.Set(e.opts.ExecutionId, target)
}
}
}
@ -136,9 +136,9 @@ func (e *NucleiEngine) LoadTargetsFromReader(reader io.Reader, probeNonHttp bool
buff := bufio.NewScanner(reader)
for buff.Scan() {
if probeNonHttp {
_ = e.inputProvider.SetWithProbe(buff.Text(), e.httpxClient)
_ = e.inputProvider.SetWithProbe(e.opts.ExecutionId, buff.Text(), e.httpxClient)
} else {
e.inputProvider.Set(buff.Text())
e.inputProvider.Set(e.opts.ExecutionId, buff.Text())
}
}
}
@ -229,7 +229,7 @@ func (e *NucleiEngine) closeInternal() {
// Close all resources used by nuclei engine
func (e *NucleiEngine) Close() {
e.closeInternal()
protocolinit.Close()
protocolinit.Close(e.opts.ExecutionId)
}
// ExecuteCallbackWithCtx executes templates on targets and calls callback on each result(only if results are found)

View File

@ -112,7 +112,7 @@ func (e *NucleiEngine) init(ctx context.Context) error {
e.parser = templates.NewParser()
if sharedInit == nil || protocolstate.ShouldInit() {
if sharedInit == nil || protocolstate.ShouldInit(e.opts.ExecutionId) {
sharedInit = &sync.Once{}
}

View File

@ -115,17 +115,17 @@ func (i *HttpInputProvider) Iterate(callback func(value *contextargs.MetaInput)
// Set adds item to input provider
// No-op for this provider
func (i *HttpInputProvider) Set(value string) {}
func (i *HttpInputProvider) Set(_ string, value string) {}
// SetWithProbe adds item to input provider with http probing
// No-op for this provider
func (i *HttpInputProvider) SetWithProbe(value string, probe types.InputLivenessProbe) error {
func (i *HttpInputProvider) SetWithProbe(_ string, value string, probe types.InputLivenessProbe) error {
return nil
}
// SetWithExclusions adds item to input provider if it doesn't match any of the exclusions
// No-op for this provider
func (i *HttpInputProvider) SetWithExclusions(value string) error {
func (i *HttpInputProvider) SetWithExclusions(_ string, value string) error {
return nil
}

View File

@ -59,11 +59,11 @@ type InputProvider interface {
// Iterate over all inputs in order
Iterate(callback func(value *contextargs.MetaInput) bool)
// Set adds item to input provider
Set(value string)
Set(executionId string, value string)
// SetWithProbe adds item to input provider with http probing
SetWithProbe(value string, probe types.InputLivenessProbe) error
SetWithProbe(executionId string, value string, probe types.InputLivenessProbe) error
// SetWithExclusions adds item to input provider if it doesn't match any of the exclusions
SetWithExclusions(value string) error
SetWithExclusions(executionId string, value string) error
// InputType returns the type of input provider
InputType() string
// Close the input provider and cleanup any resources

View File

@ -139,7 +139,7 @@ func (i *ListInputProvider) Iterate(callback func(value *contextargs.MetaInput)
}
// Set normalizes and stores passed input values
func (i *ListInputProvider) Set(value string) {
func (i *ListInputProvider) Set(executionId string, value string) {
URL := strings.TrimSpace(value)
if URL == "" {
return
@ -169,7 +169,8 @@ func (i *ListInputProvider) Set(value string) {
if i.ipOptions.ScanAllIPs {
// scan all ips
dnsData, err := protocolstate.Dialer.GetDNSData(urlx.Hostname())
dialers := protocolstate.GetDialersWithId(executionId)
dnsData, err := dialers.Fastdialer.GetDNSData(urlx.Hostname())
if err == nil {
if (len(dnsData.A) + len(dnsData.AAAA)) > 0 {
var ips []string
@ -201,7 +202,8 @@ func (i *ListInputProvider) Set(value string) {
ips := []string{}
// only scan the target but ipv6 if it has one
if i.ipOptions.IPV6 {
dnsData, err := protocolstate.Dialer.GetDNSData(urlx.Hostname())
dialers := protocolstate.GetDialersWithId(executionId)
dnsData, err := dialers.Fastdialer.GetDNSData(urlx.Hostname())
if err == nil && len(dnsData.AAAA) > 0 {
// pick/ prefer 1st
ips = append(ips, dnsData.AAAA[0])
@ -228,17 +230,17 @@ func (i *ListInputProvider) Set(value string) {
}
// SetWithProbe only sets the input if it is live
func (i *ListInputProvider) SetWithProbe(value string, probe providerTypes.InputLivenessProbe) error {
func (i *ListInputProvider) SetWithProbe(executionId string, value string, probe providerTypes.InputLivenessProbe) error {
probedValue, err := probe.ProbeURL(value)
if err != nil {
return err
}
i.Set(probedValue)
i.Set(executionId, probedValue)
return nil
}
// SetWithExclusions normalizes and stores passed input values if not excluded
func (i *ListInputProvider) SetWithExclusions(value string) error {
func (i *ListInputProvider) SetWithExclusions(executionId string, value string) error {
URL := strings.TrimSpace(value)
if URL == "" {
return nil
@ -247,7 +249,7 @@ func (i *ListInputProvider) SetWithExclusions(value string) error {
i.skippedCount++
return nil
}
i.Set(URL)
i.Set(executionId, URL)
return nil
}
@ -273,18 +275,20 @@ func (i *ListInputProvider) initializeInputSources(opts *Options) error {
switch {
case iputil.IsCIDR(target):
ips := expand.CIDR(target)
i.addTargets(ips)
i.addTargets(options.ExecutionId, ips)
case asn.IsASN(target):
ips := expand.ASN(target)
i.addTargets(ips)
i.addTargets(options.ExecutionId, ips)
default:
i.Set(target)
i.Set(options.ExecutionId, target)
}
}
// Handle stdin
if options.Stdin {
i.scanInputFromReader(readerutil.TimeoutReader{Reader: os.Stdin, Timeout: time.Duration(options.InputReadTimeout)})
i.scanInputFromReader(
options.ExecutionId,
readerutil.TimeoutReader{Reader: os.Stdin, Timeout: time.Duration(options.InputReadTimeout)})
}
// Handle target file
@ -297,7 +301,7 @@ func (i *ListInputProvider) initializeInputSources(opts *Options) error {
}
}
if input != nil {
i.scanInputFromReader(input)
i.scanInputFromReader(options.ExecutionId, input)
input.Close()
}
}
@ -317,7 +321,7 @@ func (i *ListInputProvider) initializeInputSources(opts *Options) error {
return err
}
for c := range ch {
i.Set(c)
i.Set(options.ExecutionId, c)
}
}
@ -331,7 +335,7 @@ func (i *ListInputProvider) initializeInputSources(opts *Options) error {
ips := expand.ASN(target)
i.removeTargets(ips)
default:
i.Del(target)
i.Del(options.ExecutionId, target)
}
}
}
@ -340,19 +344,19 @@ func (i *ListInputProvider) initializeInputSources(opts *Options) error {
}
// scanInputFromReader scans a line of input from reader and passes it for storage
func (i *ListInputProvider) scanInputFromReader(reader io.Reader) {
func (i *ListInputProvider) scanInputFromReader(executionId string, reader io.Reader) {
scanner := bufio.NewScanner(reader)
for scanner.Scan() {
item := scanner.Text()
switch {
case iputil.IsCIDR(item):
ips := expand.CIDR(item)
i.addTargets(ips)
i.addTargets(executionId, ips)
case asn.IsASN(item):
ips := expand.ASN(item)
i.addTargets(ips)
i.addTargets(executionId, ips)
default:
i.Set(item)
i.Set(executionId, item)
}
}
}
@ -371,7 +375,7 @@ func (i *ListInputProvider) isExcluded(URL string) bool {
return exists
}
func (i *ListInputProvider) Del(value string) {
func (i *ListInputProvider) Del(executionId string, value string) {
URL := strings.TrimSpace(value)
if URL == "" {
return
@ -401,7 +405,8 @@ func (i *ListInputProvider) Del(value string) {
if i.ipOptions.ScanAllIPs {
// scan all ips
dnsData, err := protocolstate.Dialer.GetDNSData(urlx.Hostname())
dialers := protocolstate.GetDialersWithId(executionId)
dnsData, err := dialers.Fastdialer.GetDNSData(urlx.Hostname())
if err == nil {
if (len(dnsData.A) + len(dnsData.AAAA)) > 0 {
var ips []string
@ -433,7 +438,8 @@ func (i *ListInputProvider) Del(value string) {
ips := []string{}
// only scan the target but ipv6 if it has one
if i.ipOptions.IPV6 {
dnsData, err := protocolstate.Dialer.GetDNSData(urlx.Hostname())
dialers := protocolstate.GetDialersWithId(executionId)
dnsData, err := dialers.Fastdialer.GetDNSData(urlx.Hostname())
if err == nil && len(dnsData.AAAA) > 0 {
// pick/ prefer 1st
ips = append(ips, dnsData.AAAA[0])
@ -519,9 +525,9 @@ func (i *ListInputProvider) setHostMapStream(data string) {
}
}
func (i *ListInputProvider) addTargets(targets []string) {
func (i *ListInputProvider) addTargets(executionId string, targets []string) {
for _, target := range targets {
i.Set(target)
i.Set(executionId, target)
}
}

View File

@ -19,10 +19,10 @@ func NewSimpleInputProvider() *SimpleInputProvider {
}
// NewSimpleInputProviderWithUrls creates a new simple input provider with the given urls
func NewSimpleInputProviderWithUrls(urls ...string) *SimpleInputProvider {
func NewSimpleInputProviderWithUrls(executionId string, urls ...string) *SimpleInputProvider {
provider := NewSimpleInputProvider()
for _, url := range urls {
provider.Set(url)
provider.Set(executionId, url)
}
return provider
}
@ -42,14 +42,14 @@ func (s *SimpleInputProvider) Iterate(callback func(value *contextargs.MetaInput
}
// Set adds an item to the input provider
func (s *SimpleInputProvider) Set(value string) {
func (s *SimpleInputProvider) Set(_ string, value string) {
metaInput := contextargs.NewMetaInput()
metaInput.Input = value
s.Inputs = append(s.Inputs, metaInput)
}
// SetWithProbe adds an item to the input provider with HTTP probing
func (s *SimpleInputProvider) SetWithProbe(value string, probe types.InputLivenessProbe) error {
func (s *SimpleInputProvider) SetWithProbe(_ string, value string, probe types.InputLivenessProbe) error {
probedValue, err := probe.ProbeURL(value)
if err != nil {
return err
@ -61,7 +61,7 @@ func (s *SimpleInputProvider) SetWithProbe(value string, probe types.InputLivene
}
// SetWithExclusions adds an item to the input provider if it doesn't match any of the exclusions
func (s *SimpleInputProvider) SetWithExclusions(value string) error {
func (s *SimpleInputProvider) SetWithExclusions(_ string, value string) error {
metaInput := contextargs.NewMetaInput()
metaInput.Input = value
s.Inputs = append(s.Inputs, metaInput)

View File

@ -32,6 +32,9 @@ func New() *Compiler {
// ExecuteOptions provides options for executing a script.
type ExecuteOptions struct {
// ExecutionId is the id of the execution
ExecutionId string
// Callback can be used to register new runtime helper functions
// ex: export etc
Callback func(runtime *goja.Runtime) error

View File

@ -84,6 +84,7 @@ func executeWithRuntime(runtime *goja.Runtime, p *goja.Program, args *ExecuteArg
if opts != nil && opts.Cleanup != nil {
opts.Cleanup(runtime)
}
_ = runtime.GlobalObject().Delete("executionId")
}()
// TODO(dwisiswant0): remove this once we get the RCA.
@ -108,8 +109,11 @@ func executeWithRuntime(runtime *goja.Runtime, p *goja.Program, args *ExecuteArg
if err := opts.Callback(runtime); err != nil {
return nil, err
}
}
// inject execution id
_ = runtime.Set("executionId", opts.ExecutionId)
// execute the script
return runtime.RunProgram(p)
}

View File

@ -109,7 +109,8 @@ func NewKerberosClient(call goja.ConstructorCall, runtime *goja.Runtime) *goja.O
if controller != "" {
// validate controller hostport
if !protocolstate.IsHostAllowed(controller) {
executionId := c.nj.ExecutionId()
if !protocolstate.IsHostAllowed(executionId, controller) {
c.nj.Throw("domain controller address blacklisted by network policy")
}
@ -246,16 +247,18 @@ func (c *Client) GetServiceTicket(User, Pass, SPN string) (TGS, error) {
c.nj.Require(Pass != "", "Pass cannot be empty")
c.nj.Require(SPN != "", "SPN cannot be empty")
executionId := c.nj.ExecutionId()
if len(c.Krb5Config.Realms) > 0 {
// this means dc address was given
for _, r := range c.Krb5Config.Realms {
for _, kdc := range r.KDC {
if !protocolstate.IsHostAllowed(kdc) {
if !protocolstate.IsHostAllowed(executionId, kdc) {
c.nj.Throw("KDC address %v blacklisted by network policy", kdc)
}
}
for _, kpasswd := range r.KPasswdServer {
if !protocolstate.IsHostAllowed(kpasswd) {
if !protocolstate.IsHostAllowed(executionId, kpasswd) {
c.nj.Throw("Kpasswd address %v blacklisted by network policy", kpasswd)
}
}
@ -265,7 +268,7 @@ func (c *Client) GetServiceTicket(User, Pass, SPN string) (TGS, error) {
// and check if they are allowed by network policy
_, kdcs, _ := c.Krb5Config.GetKDCs(c.Realm, true)
for _, v := range kdcs {
if !protocolstate.IsHostAllowed(v) {
if !protocolstate.IsHostAllowed(executionId, v) {
c.nj.Throw("KDC address %v blacklisted by network policy", v)
}
}

View File

@ -68,6 +68,9 @@ func sendToKDCTcp(kclient *Client, msg string) ([]byte, error) {
kclient.nj.HandleError(err, "error getting KDCs")
kclient.nj.Require(len(kdcs) > 0, "no KDCs found")
executionId := kclient.nj.ExecutionId()
dialers := protocolstate.GetDialersWithId(executionId)
var errs []string
for i := 1; i <= len(kdcs); i++ {
host, port, err := net.SplitHostPort(kdcs[i])
@ -75,7 +78,7 @@ func sendToKDCTcp(kclient *Client, msg string) ([]byte, error) {
// use that ip address instead of realm/domain for resolving
host = kclient.config.ip
}
tcpConn, err := protocolstate.Dialer.Dial(context.TODO(), "tcp", net.JoinHostPort(host, port))
tcpConn, err := dialers.Fastdialer.Dial(context.TODO(), "tcp", net.JoinHostPort(host, port))
if err != nil {
errs = append(errs, fmt.Sprintf("error establishing connection to %s: %v", kdcs[i], err))
continue
@ -101,6 +104,9 @@ func sendToKDCUdp(kclient *Client, msg string) ([]byte, error) {
kclient.nj.HandleError(err, "error getting KDCs")
kclient.nj.Require(len(kdcs) > 0, "no KDCs found")
executionId := kclient.nj.ExecutionId()
dialers := protocolstate.GetDialersWithId(executionId)
var errs []string
for i := 1; i <= len(kdcs); i++ {
host, port, err := net.SplitHostPort(kdcs[i])
@ -108,7 +114,7 @@ func sendToKDCUdp(kclient *Client, msg string) ([]byte, error) {
// use that ip address instead of realm/domain for resolving
host = kclient.config.ip
}
udpConn, err := protocolstate.Dialer.Dial(context.TODO(), "udp", net.JoinHostPort(host, port))
udpConn, err := dialers.Fastdialer.Dial(context.TODO(), "udp", net.JoinHostPort(host, port))
if err != nil {
errs = append(errs, fmt.Sprintf("error establishing connection to %s: %v", kdcs[i], err))
continue

View File

@ -86,12 +86,15 @@ func NewClient(call goja.ConstructorCall, runtime *goja.Runtime) *goja.Object {
u, err := url.Parse(ldapUrl)
c.nj.HandleError(err, "invalid ldap url supported schemas are ldap://, ldaps://, ldapi://, and cldap://")
executionId := c.nj.ExecutionId()
dialers := protocolstate.GetDialersWithId(executionId)
var conn net.Conn
if u.Scheme == "ldapi" {
if u.Path == "" || u.Path == "/" {
u.Path = "/var/run/slapd/ldapi"
}
conn, err = protocolstate.Dialer.Dial(context.TODO(), "unix", u.Path)
conn, err = dialers.Fastdialer.Dial(context.TODO(), "unix", u.Path)
c.nj.HandleError(err, "failed to connect to ldap server")
} else {
host, port, err := net.SplitHostPort(u.Host)
@ -110,12 +113,12 @@ func NewClient(call goja.ConstructorCall, runtime *goja.Runtime) *goja.Object {
if port == "" {
port = ldap.DefaultLdapPort
}
conn, err = protocolstate.Dialer.Dial(context.TODO(), "udp", net.JoinHostPort(host, port))
conn, err = dialers.Fastdialer.Dial(context.TODO(), "udp", net.JoinHostPort(host, port))
case "ldap":
if port == "" {
port = ldap.DefaultLdapPort
}
conn, err = protocolstate.Dialer.Dial(context.TODO(), "tcp", net.JoinHostPort(host, port))
conn, err = dialers.Fastdialer.Dial(context.TODO(), "tcp", net.JoinHostPort(host, port))
case "ldaps":
if port == "" {
port = ldap.DefaultLdapsPort
@ -124,7 +127,7 @@ func NewClient(call goja.ConstructorCall, runtime *goja.Runtime) *goja.Object {
if c.cfg.ServerName != "" {
serverName = c.cfg.ServerName
}
conn, err = protocolstate.Dialer.DialTLSWithConfig(context.TODO(), "tcp", net.JoinHostPort(host, port),
conn, err = dialers.Fastdialer.DialTLSWithConfig(context.TODO(), "tcp", net.JoinHostPort(host, port),
&tls.Config{InsecureSkipVerify: true, MinVersion: tls.VersionTLS10, ServerName: serverName})
default:
err = fmt.Errorf("unsupported ldap url schema %v", u.Scheme)
@ -331,7 +334,7 @@ func (c *Client) CollectMetadata() Metadata {
// ```
func (c *Client) GetVersion() []string {
c.nj.Require(c.conn != nil, "no existing connection")
// Query root DSE for supported LDAP versions
sr := ldap.NewSearchRequest(
"",
@ -341,18 +344,17 @@ func (c *Client) GetVersion() []string {
"(objectClass=*)",
[]string{"supportedLDAPVersion"},
nil)
res, err := c.conn.Search(sr)
c.nj.HandleError(err, "failed to get LDAP version")
if len(res.Entries) > 0 {
return res.Entries[0].GetAttributeValues("supportedLDAPVersion")
}
return []string{"unknown"}
}
// close the ldap connection
// @example
// ```javascript

View File

@ -42,6 +42,10 @@ func (j *NucleiJS) runtime() *goja.Runtime {
return j.vm
}
func (j *NucleiJS) ExecutionId() string {
return j.runtime().Get("executionId").String()
}
// see: https://arc.net/l/quote/wpenftpc for throwing docs
// ThrowError throws an error in goja runtime if is not nil

View File

@ -201,6 +201,7 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa
result, err := request.options.JsCompiler.ExecuteWithOptions(request.preConditionCompiled, args,
&compiler.ExecuteOptions{
ExecutionId: request.options.Options.ExecutionId,
TimeoutVariants: request.options.Options.GetTimeouts(),
Source: &request.PreCondition,
Callback: registerPreConditionFunctions,

View File

@ -1,8 +1,6 @@
package protocolinit
import (
"context"
"github.com/projectdiscovery/nuclei/v3/pkg/js/compiler"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/dns/dnsclientpool"
@ -15,8 +13,8 @@ import (
)
// Init initializes the client pools for the protocols
func Init(ctx context.Context, options *types.Options) error {
if err := protocolstate.Init(ctx, options); err != nil {
func Init(options *types.Options) error {
if err := protocolstate.Init(options); err != nil {
return err
}
if err := dnsclientpool.Init(options); err != nil {
@ -40,6 +38,6 @@ func Init(ctx context.Context, options *types.Options) error {
return nil
}
func Close(ctx context.Context) {
protocolstate.Close(ctx)
func Close(executionId string) {
protocolstate.Close(executionId)
}

View File

@ -0,0 +1,18 @@
package protocolstate
import (
"github.com/projectdiscovery/fastdialer/fastdialer"
"github.com/projectdiscovery/networkpolicy"
"github.com/projectdiscovery/rawhttp"
"github.com/projectdiscovery/retryablehttp-go"
mapsutil "github.com/projectdiscovery/utils/maps"
)
type Dialers struct {
Fastdialer *fastdialer.Dialer
RawHTTPClient *rawhttp.Client
DefaultHTTPClient *retryablehttp.Client
HTTPClientPool *mapsutil.SyncLockMap[string, *retryablehttp.Client]
NetworkPolicy *networkpolicy.NetworkPolicy
}

View File

@ -8,8 +8,8 @@ import (
"github.com/go-rod/rod"
"github.com/go-rod/rod/lib/proto"
"github.com/projectdiscovery/networkpolicy"
"github.com/projectdiscovery/nuclei/v3/pkg/types"
errorutil "github.com/projectdiscovery/utils/errors"
mapsutil "github.com/projectdiscovery/utils/maps"
stringsutil "github.com/projectdiscovery/utils/strings"
urlutil "github.com/projectdiscovery/utils/url"
"go.uber.org/multierr"
@ -18,9 +18,9 @@ import (
// initalize state of headless protocol
var (
ErrURLDenied = errorutil.NewWithFmt("headless: url %v dropped by rule: %v")
ErrHostDenied = errorutil.NewWithFmt("host %v dropped by network policy")
networkPolicies = mapsutil.NewSyncLockMap[string, *networkpolicy.NetworkPolicy]()
ErrURLDenied = errorutil.NewWithFmt("headless: url %v dropped by rule: %v")
ErrHostDenied = errorutil.NewWithFmt("host %v dropped by network policy")
allowLocalFileAccess bool
)
@ -29,16 +29,16 @@ func GetNetworkPolicy(ctx context.Context) *networkpolicy.NetworkPolicy {
if execCtx == nil {
return nil
}
np, ok := networkPolicies.Get(execCtx.ExecutionID)
if !ok || np == nil {
dialers, ok := dialers.Get(execCtx.ExecutionID)
if !ok || dialers == nil {
return nil
}
return np
return dialers.NetworkPolicy
}
// ValidateNFailRequest validates and fails request
// if the request does not respect the rules, it will be canceled with reason
func ValidateNFailRequest(ctx context.Context, page *rod.Page, e *proto.FetchRequestPaused) error {
func ValidateNFailRequest(options *types.Options, page *rod.Page, e *proto.FetchRequestPaused) error {
reqURL := e.Request.URL
normalized := strings.ToLower(reqURL) // normalize url to lowercase
normalized = strings.TrimSpace(normalized) // trim leading & trailing whitespaces
@ -50,7 +50,7 @@ func ValidateNFailRequest(ctx context.Context, page *rod.Page, e *proto.FetchReq
if stringsutil.HasPrefixAnyI(normalized, "ftp:", "externalfile:", "chrome:", "chrome-extension:") {
return multierr.Combine(FailWithReason(page, e), ErrURLDenied.Msgf(reqURL, "protocol blocked by network policy"))
}
if !isValidHost(ctx, reqURL) {
if !isValidHost(options, reqURL) {
return multierr.Combine(FailWithReason(page, e), ErrURLDenied.Msgf(reqURL, "address blocked by network policy"))
}
return nil
@ -66,28 +66,22 @@ func FailWithReason(page *rod.Page, e *proto.FetchRequestPaused) error {
}
// InitHeadless initializes headless protocol state
func InitHeadless(ctx context.Context, localFileAccess bool, np *networkpolicy.NetworkPolicy) {
func InitHeadless(localFileAccess bool) {
allowLocalFileAccess = localFileAccess
if np != nil {
execCtx := GetExecutionContext(ctx)
if execCtx != nil {
networkPolicies.Set(execCtx.ExecutionID, np)
}
}
}
// isValidHost checks if the host is valid (only limited to http/https protocols)
func isValidHost(ctx context.Context, targetUrl string) bool {
func isValidHost(options *types.Options, targetUrl string) bool {
if !stringsutil.HasPrefixAny(targetUrl, "http:", "https:") {
return true
}
execCtx := GetExecutionContext(ctx)
if execCtx == nil {
dialers, ok := dialers.Get(options.ExecutionId)
if !ok {
return true
}
np, ok := networkPolicies.Get(execCtx.ExecutionID)
np := dialers.NetworkPolicy
if !ok || np == nil {
return true
}
@ -103,13 +97,13 @@ func isValidHost(ctx context.Context, targetUrl string) bool {
}
// IsHostAllowed checks if the host is allowed by network policy
func IsHostAllowed(ctx context.Context, targetUrl string) bool {
execCtx := GetExecutionContext(ctx)
if execCtx == nil {
func IsHostAllowed(executionId string, targetUrl string) bool {
dialers, ok := dialers.Get(executionId)
if !ok {
return true
}
np, ok := networkPolicies.Get(execCtx.ExecutionID)
np := dialers.NetworkPolicy
if !ok || np == nil {
return true
}

View File

@ -20,31 +20,37 @@ import (
// Dialer is a shared fastdialer instance for host DNS resolution
var (
dialers *mapsutil.SyncLockMap[string, *fastdialer.Dialer]
dialers *mapsutil.SyncLockMap[string, *Dialers]
)
func GetDialer(ctx context.Context) *fastdialer.Dialer {
func GetDialers(ctx context.Context) *Dialers {
executionContext := GetExecutionContext(ctx)
dialer, ok := dialers.Get(executionContext.ExecutionID)
dialers, ok := dialers.Get(executionContext.ExecutionID)
if !ok {
return nil
}
return dialer
return dialers
}
func ShouldInit(ctx context.Context) bool {
executionContext := GetExecutionContext(ctx)
dialer, ok := dialers.Get(executionContext.ExecutionID)
func GetDialersWithId(id string) *Dialers {
dialers, ok := dialers.Get(id)
if !ok {
return nil
}
return dialers
}
func ShouldInit(id string) bool {
dialer, ok := dialers.Get(id)
if !ok {
return false
}
return dialer == nil
}
// Init creates the Dialer instance based on user configuration
func Init(ctx context.Context, options *types.Options) error {
executionContext := GetExecutionContext(ctx)
if GetDialer(ctx) != nil {
// Init creates the Dialers instance based on user configuration
func Init(options *types.Options) error {
if GetDialersWithId(options.ExecutionId) != nil {
return nil
}
@ -73,8 +79,7 @@ func Init(ctx context.Context, options *types.Options) error {
DenyList: expandedDenyList,
}
opts.WithNetworkPolicyOptions = npOptions
networkPolicy, _ := networkpolicy.New(*npOptions)
InitHeadless(ctx, options.AllowLocalFileAccess, networkPolicy)
InitHeadless(options.AllowLocalFileAccess)
switch {
case options.SourceIP != "" && options.Interface != "":
@ -159,7 +164,15 @@ func Init(ctx context.Context, options *types.Options) error {
if err != nil {
return errors.Wrap(err, "could not create dialer")
}
dialers.Set(executionContext.ExecutionID, dialer)
networkPolicy, _ := networkpolicy.New(*npOptions)
dialersInstance := &Dialers{
Fastdialer: dialer,
NetworkPolicy: networkPolicy,
}
dialers.Set(options.ExecutionId, dialersInstance)
// Set a custom dialer for the "nucleitcp" protocol. This is just plain TCP, but it's registered
// with a different name so that we do not clobber the "tcp" dialer in the event that nuclei is
@ -234,18 +247,17 @@ func interfaceAddresses(interfaceName string) ([]net.Addr, error) {
}
// Close closes the global shared fastdialer
func Close(ctx context.Context) {
executionContext := GetExecutionContext(ctx)
dialer, ok := dialers.Get(executionContext.ExecutionID)
func Close(executionId string) {
dialersInstance, ok := dialers.Get(executionId)
if !ok {
return
}
if dialer != nil {
dialer.Close()
if dialersInstance != nil {
dialersInstance.Fastdialer.Close()
}
dialers.Delete(executionContext.ExecutionID)
dialers.Delete(executionId)
StopActiveMemGuardian()
}

View File

@ -19,7 +19,7 @@ import (
// newHttpClient creates a new http client for headless communication with a timeout
func newHttpClient(options *types.Options) (*http.Client, error) {
dialer := protocolstate.Dialer
dialers := protocolstate.GetDialersWithId(options.ExecutionId)
// Set the base TLS configuration definition
tlsConfig := &tls.Config{
@ -41,15 +41,15 @@ func newHttpClient(options *types.Options) (*http.Client, error) {
transport := &http.Transport{
ForceAttemptHTTP2: options.ForceAttemptHTTP2,
DialContext: dialer.Dial,
DialContext: dialers.Fastdialer.Dial,
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
if options.TlsImpersonate {
return dialer.DialTLSWithConfigImpersonate(ctx, network, addr, tlsConfig, impersonate.Random, nil)
return dialers.Fastdialer.DialTLSWithConfigImpersonate(ctx, network, addr, tlsConfig, impersonate.Random, nil)
}
if options.HasClientCertificates() || options.ForceAttemptHTTP2 {
return dialer.DialTLSWithConfig(ctx, network, addr, tlsConfig)
return dialers.Fastdialer.DialTLSWithConfig(ctx, network, addr, tlsConfig)
}
return dialer.DialTLS(ctx, network, addr)
return dialers.Fastdialer.DialTLS(ctx, network, addr)
},
MaxIdleConns: 500,
MaxIdleConnsPerHost: 500,

View File

@ -110,7 +110,7 @@ func (p *Page) routingRuleHandlerNative(e *proto.FetchRequestPaused) error {
// ValidateNFailRequest validates if Local file access is enabled
// and local network access is enables if not it will fail the request
// that don't match the rules
if err := protocolstate.ValidateNFailRequest(p.page, e); err != nil {
if err := protocolstate.ValidateNFailRequest(p.options.Options, p.page, e); err != nil {
return err
}
body, _ := FetchGetResponseBody(p.page, e)

View File

@ -25,36 +25,19 @@ import (
"github.com/projectdiscovery/nuclei/v3/pkg/types/scanstrategy"
"github.com/projectdiscovery/rawhttp"
"github.com/projectdiscovery/retryablehttp-go"
mapsutil "github.com/projectdiscovery/utils/maps"
urlutil "github.com/projectdiscovery/utils/url"
)
var (
rawHttpClient *rawhttp.Client
rawHttpClientOnce sync.Once
forceMaxRedirects int
normalClient *retryablehttp.Client
clientPool *mapsutil.SyncLockMap[string, *retryablehttp.Client]
)
// Init initializes the clientpool implementation
func Init(options *types.Options) error {
// Don't create clients if already created in the past.
if normalClient != nil {
return nil
}
if options.ShouldFollowHTTPRedirects() {
forceMaxRedirects = options.MaxRedirects
}
clientPool = &mapsutil.SyncLockMap[string, *retryablehttp.Client]{
Map: make(mapsutil.Map[string, *retryablehttp.Client]),
}
client, err := wrappedGet(options, &Configuration{})
if err != nil {
return err
}
normalClient = client
return nil
}
@ -158,25 +141,30 @@ func (c *Configuration) HasStandardOptions() bool {
// GetRawHTTP returns the rawhttp request client
func GetRawHTTP(options *protocols.ExecutorOptions) *rawhttp.Client {
rawHttpClientOnce.Do(func() {
rawHttpOptions := rawhttp.DefaultOptions
if options.Options.AliveHttpProxy != "" {
rawHttpOptions.Proxy = options.Options.AliveHttpProxy
} else if options.Options.AliveSocksProxy != "" {
rawHttpOptions.Proxy = options.Options.AliveSocksProxy
} else if protocolstate.Dialer != nil {
rawHttpOptions.FastDialer = protocolstate.Dialer
}
rawHttpOptions.Timeout = options.Options.GetTimeouts().HttpTimeout
rawHttpClient = rawhttp.NewClient(rawHttpOptions)
})
return rawHttpClient
dialers := protocolstate.GetDialersWithId(options.Options.ExecutionId)
if dialers.RawHTTPClient != nil {
return dialers.RawHTTPClient
}
rawHttpOptions := rawhttp.DefaultOptions
if options.Options.AliveHttpProxy != "" {
rawHttpOptions.Proxy = options.Options.AliveHttpProxy
} else if options.Options.AliveSocksProxy != "" {
rawHttpOptions.Proxy = options.Options.AliveSocksProxy
} else if dialers.Fastdialer != nil {
rawHttpOptions.FastDialer = dialers.Fastdialer
}
rawHttpOptions.Timeout = options.Options.GetTimeouts().HttpTimeout
dialers.RawHTTPClient = rawhttp.NewClient(rawHttpOptions)
return dialers.RawHTTPClient
}
// Get creates or gets a client for the protocol based on custom configuration
func Get(options *types.Options, configuration *Configuration) (*retryablehttp.Client, error) {
if configuration.HasStandardOptions() {
return normalClient, nil
dialers := protocolstate.GetDialersWithId(options.ExecutionId)
return dialers.DefaultHTTPClient, nil
}
return wrappedGet(options, configuration)
}
@ -185,8 +173,10 @@ func Get(options *types.Options, configuration *Configuration) (*retryablehttp.C
func wrappedGet(options *types.Options, configuration *Configuration) (*retryablehttp.Client, error) {
var err error
dialers := protocolstate.GetDialersWithId(options.ExecutionId)
hash := configuration.Hash()
if client, ok := clientPool.Get(hash); ok {
if client, ok := dialers.HTTPClientPool.Get(hash); ok {
return client, nil
}
@ -263,15 +253,15 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl
transport := &http.Transport{
ForceAttemptHTTP2: options.ForceAttemptHTTP2,
DialContext: protocolstate.GetDialer().Dial,
DialContext: dialers.Fastdialer.Dial,
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
if options.TlsImpersonate {
return protocolstate.Dialer.DialTLSWithConfigImpersonate(ctx, network, addr, tlsConfig, impersonate.Random, nil)
return dialers.Fastdialer.DialTLSWithConfigImpersonate(ctx, network, addr, tlsConfig, impersonate.Random, nil)
}
if options.HasClientCertificates() || options.ForceAttemptHTTP2 {
return protocolstate.Dialer.DialTLSWithConfig(ctx, network, addr, tlsConfig)
return dialers.Fastdialer.DialTLSWithConfig(ctx, network, addr, tlsConfig)
}
return protocolstate.GetDialer().DialTLS(ctx, network, addr)
return dialers.Fastdialer.DialTLS(ctx, network, addr)
},
MaxIdleConns: maxIdleConns,
MaxIdleConnsPerHost: maxIdleConnsPerHost,
@ -338,7 +328,7 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl
// Only add to client pool if we don't have a cookie jar in place.
if jar == nil {
if err := clientPool.Set(hash, client); err != nil {
if err := dialers.HTTPClientPool.Set(hash, client); err != nil {
return nil, err
}
}

View File

@ -817,6 +817,8 @@ func (request *Request) executeRequest(input *contextargs.Context, generatedRequ
}
}
dialers := protocolstate.GetDialersWithId(request.options.Options.ExecutionId)
if err != nil {
// rawhttp doesn't support draining response bodies.
if resp != nil && resp.Body != nil && generatedRequest.rawRequest == nil && !generatedRequest.original.Pipeline {
@ -837,7 +839,7 @@ func (request *Request) executeRequest(input *contextargs.Context, generatedRequ
if input.MetaInput.CustomIP != "" {
outputEvent["ip"] = input.MetaInput.CustomIP
} else {
outputEvent["ip"] = protocolstate.Dialer.GetDialedIP(hostname)
outputEvent["ip"] = dialers.Fastdialer.GetDialedIP(hostname)
// try getting cname
request.addCNameIfAvailable(hostname, outputEvent)
}
@ -957,7 +959,7 @@ func (request *Request) executeRequest(input *contextargs.Context, generatedRequ
if input.MetaInput.CustomIP != "" {
outputEvent["ip"] = input.MetaInput.CustomIP
} else {
dialer := protocolstate.GetDialer()
dialer := dialers.Fastdialer
if dialer != nil {
outputEvent["ip"] = dialer.GetDialedIP(hostname)
}
@ -1081,11 +1083,13 @@ func (request *Request) validateNFixEvent(input *contextargs.Context, gr *genera
// addCNameIfAvailable adds the cname to the event if available
func (request *Request) addCNameIfAvailable(hostname string, outputEvent map[string]interface{}) {
if protocolstate.Dialer == nil {
dialers := protocolstate.GetDialersWithId(request.options.Options.ExecutionId)
if dialers.Fastdialer == nil {
return
}
data, err := protocolstate.Dialer.GetDNSData(hostname)
data, err := dialers.Fastdialer.GetDNSData(hostname)
if err == nil {
switch len(data.CNAME) {
case 0:

View File

@ -611,6 +611,8 @@ func (request *Request) executeRequestWithPayloads(hostPort string, input *conte
// generateEventData generates event data for the request
func (request *Request) generateEventData(input *contextargs.Context, values map[string]interface{}, matched string) map[string]interface{} {
dialers := protocolstate.GetDialersWithId(request.options.Options.ExecutionId)
data := make(map[string]interface{})
for k, v := range values {
data[k] = v
@ -643,7 +645,7 @@ func (request *Request) generateEventData(input *contextargs.Context, values map
}
}
}
data["ip"] = protocolstate.Dialer.GetDialedIP(hostname)
data["ip"] = dialers.Fastdialer.GetDialedIP(hostname)
// if input itself was an ip, use it
if iputil.IsIP(hostname) {
data["ip"] = hostname
@ -651,7 +653,7 @@ func (request *Request) generateEventData(input *contextargs.Context, values map
// if ip is not found,this is because ssh and other protocols do not use fastdialer
// although its not perfect due to its use case dial and get ip
dnsData, err := protocolstate.Dialer.GetDNSData(hostname)
dnsData, err := dialers.Fastdialer.GetDNSData(hostname)
if err == nil {
for _, v := range dnsData.A {
data["ip"] = v

View File

@ -6,17 +6,8 @@ import (
"github.com/projectdiscovery/nuclei/v3/pkg/types"
)
var (
normalClient *fastdialer.Dialer
)
// Init initializes the clientpool implementation
func Init(options *types.Options) error {
// Don't create clients if already created in the past.
if normalClient != nil {
return nil
}
normalClient = protocolstate.Dialer
return nil
}
@ -29,6 +20,7 @@ func (c *Configuration) Hash() string {
}
// Get creates or gets a client for the protocol based on custom configuration
func Get(options *types.Options, configuration *Configuration /*TODO review unused parameters*/) (*fastdialer.Dialer, error) {
return normalClient, nil
func Get(options *types.Options, configuration *Configuration) (*fastdialer.Dialer, error) {
dialers := protocolstate.GetDialersWithId(options.ExecutionId)
return dialers.Fastdialer, nil
}

View File

@ -64,7 +64,8 @@ func (request *Request) getOpenPorts(target *contextargs.Context) ([]string, err
errs = append(errs, err)
continue
}
conn, err := protocolstate.Dialer.Dial(target.Context(), "tcp", addr)
dialers := protocolstate.GetDialersWithId(request.options.Options.ExecutionId)
conn, err := dialers.Fastdialer.Dial(target.Context(), "tcp", addr)
if err != nil {
errs = append(errs, err)
continue

View File

@ -449,6 +449,9 @@ type Options struct {
// This is internally managed and does not need to be set by user by explicitly setting
// this overrides the default/derived one
timeouts *Timeouts
// Unique identifier of the execution session
ExecutionId string
}
// SetTimeouts sets the timeout variants to use for the executor