diff --git a/internal/runner/inputs.go b/internal/runner/inputs.go index 3d51ca7e8..5efa55d90 100644 --- a/internal/runner/inputs.go +++ b/internal/runner/inputs.go @@ -38,7 +38,10 @@ func (r *Runner) initializeTemplatesHTTPInput() (*hybrid.HybridMap, error) { } httpxOptions.RetryMax = r.options.Retries httpxOptions.Timeout = time.Duration(r.options.Timeout) * time.Second - httpxOptions.NetworkPolicy = protocolstate.NetworkPolicy + + dialers := protocolstate.GetDialersWithId(r.options.ExecutionId) + + httpxOptions.NetworkPolicy = dialers.NetworkPolicy httpxClient, err := httpx.New(&httpxOptions) if err != nil { return nil, errors.Wrap(err, "could not create httpx client") diff --git a/internal/runner/runner.go b/internal/runner/runner.go index 85fe0ea75..e4817848b 100644 --- a/internal/runner/runner.go +++ b/internal/runner/runner.go @@ -413,7 +413,7 @@ func (r *Runner) Close() { if r.inputProvider != nil { r.inputProvider.Close() } - protocolinit.Close() + protocolinit.Close(r.options.ExecutionId) if r.pprofServer != nil { r.pprofServer.Stop() } @@ -655,7 +655,7 @@ func (r *Runner) RunEnumeration() error { } ret := uncover.GetUncoverTargetsFromMetadata(context.TODO(), store.Templates(), r.options.UncoverField, uncoverOpts) for host := range ret { - _ = r.inputProvider.SetWithExclusions(host) + _ = r.inputProvider.SetWithExclusions(r.options.ExecutionId, host) } } // display execution info like version , templates used etc diff --git a/lib/multi.go b/lib/multi.go index 1aa870836..2a2ef52df 100644 --- a/lib/multi.go +++ b/lib/multi.go @@ -154,7 +154,7 @@ func (e *ThreadSafeNucleiEngine) ExecuteNucleiWithOptsCtx(ctx context.Context, t } store.Load() - inputProvider := provider.NewSimpleInputProviderWithUrls(targets...) + inputProvider := provider.NewSimpleInputProviderWithUrls(e.eng.opts.ExecutionId, targets...) if len(store.Templates()) == 0 && len(store.Workflows()) == 0 { return ErrNoTemplatesAvailable diff --git a/lib/sdk.go b/lib/sdk.go index 7f2ec5bcc..9b0017f54 100644 --- a/lib/sdk.go +++ b/lib/sdk.go @@ -124,9 +124,9 @@ func (e *NucleiEngine) GetWorkflows() []*templates.Template { func (e *NucleiEngine) LoadTargets(targets []string, probeNonHttp bool) { for _, target := range targets { if probeNonHttp { - _ = e.inputProvider.SetWithProbe(target, e.httpxClient) + _ = e.inputProvider.SetWithProbe(e.opts.ExecutionId, target, e.httpxClient) } else { - e.inputProvider.Set(target) + e.inputProvider.Set(e.opts.ExecutionId, target) } } } @@ -136,9 +136,9 @@ func (e *NucleiEngine) LoadTargetsFromReader(reader io.Reader, probeNonHttp bool buff := bufio.NewScanner(reader) for buff.Scan() { if probeNonHttp { - _ = e.inputProvider.SetWithProbe(buff.Text(), e.httpxClient) + _ = e.inputProvider.SetWithProbe(e.opts.ExecutionId, buff.Text(), e.httpxClient) } else { - e.inputProvider.Set(buff.Text()) + e.inputProvider.Set(e.opts.ExecutionId, buff.Text()) } } } @@ -229,7 +229,7 @@ func (e *NucleiEngine) closeInternal() { // Close all resources used by nuclei engine func (e *NucleiEngine) Close() { e.closeInternal() - protocolinit.Close() + protocolinit.Close(e.opts.ExecutionId) } // ExecuteCallbackWithCtx executes templates on targets and calls callback on each result(only if results are found) diff --git a/lib/sdk_private.go b/lib/sdk_private.go index c0d394acc..781280b7f 100644 --- a/lib/sdk_private.go +++ b/lib/sdk_private.go @@ -112,7 +112,7 @@ func (e *NucleiEngine) init(ctx context.Context) error { e.parser = templates.NewParser() - if sharedInit == nil || protocolstate.ShouldInit() { + if sharedInit == nil || protocolstate.ShouldInit(e.opts.ExecutionId) { sharedInit = &sync.Once{} } diff --git a/pkg/input/provider/http/multiformat.go b/pkg/input/provider/http/multiformat.go index a534879c1..c86a17b84 100644 --- a/pkg/input/provider/http/multiformat.go +++ b/pkg/input/provider/http/multiformat.go @@ -115,17 +115,17 @@ func (i *HttpInputProvider) Iterate(callback func(value *contextargs.MetaInput) // Set adds item to input provider // No-op for this provider -func (i *HttpInputProvider) Set(value string) {} +func (i *HttpInputProvider) Set(_ string, value string) {} // SetWithProbe adds item to input provider with http probing // No-op for this provider -func (i *HttpInputProvider) SetWithProbe(value string, probe types.InputLivenessProbe) error { +func (i *HttpInputProvider) SetWithProbe(_ string, value string, probe types.InputLivenessProbe) error { return nil } // SetWithExclusions adds item to input provider if it doesn't match any of the exclusions // No-op for this provider -func (i *HttpInputProvider) SetWithExclusions(value string) error { +func (i *HttpInputProvider) SetWithExclusions(_ string, value string) error { return nil } diff --git a/pkg/input/provider/interface.go b/pkg/input/provider/interface.go index e6d5da14a..1ac068514 100644 --- a/pkg/input/provider/interface.go +++ b/pkg/input/provider/interface.go @@ -59,11 +59,11 @@ type InputProvider interface { // Iterate over all inputs in order Iterate(callback func(value *contextargs.MetaInput) bool) // Set adds item to input provider - Set(value string) + Set(executionId string, value string) // SetWithProbe adds item to input provider with http probing - SetWithProbe(value string, probe types.InputLivenessProbe) error + SetWithProbe(executionId string, value string, probe types.InputLivenessProbe) error // SetWithExclusions adds item to input provider if it doesn't match any of the exclusions - SetWithExclusions(value string) error + SetWithExclusions(executionId string, value string) error // InputType returns the type of input provider InputType() string // Close the input provider and cleanup any resources diff --git a/pkg/input/provider/list/hmap.go b/pkg/input/provider/list/hmap.go index 6f41920cd..b79c6e922 100644 --- a/pkg/input/provider/list/hmap.go +++ b/pkg/input/provider/list/hmap.go @@ -139,7 +139,7 @@ func (i *ListInputProvider) Iterate(callback func(value *contextargs.MetaInput) } // Set normalizes and stores passed input values -func (i *ListInputProvider) Set(value string) { +func (i *ListInputProvider) Set(executionId string, value string) { URL := strings.TrimSpace(value) if URL == "" { return @@ -169,7 +169,8 @@ func (i *ListInputProvider) Set(value string) { if i.ipOptions.ScanAllIPs { // scan all ips - dnsData, err := protocolstate.Dialer.GetDNSData(urlx.Hostname()) + dialers := protocolstate.GetDialersWithId(executionId) + dnsData, err := dialers.Fastdialer.GetDNSData(urlx.Hostname()) if err == nil { if (len(dnsData.A) + len(dnsData.AAAA)) > 0 { var ips []string @@ -201,7 +202,8 @@ func (i *ListInputProvider) Set(value string) { ips := []string{} // only scan the target but ipv6 if it has one if i.ipOptions.IPV6 { - dnsData, err := protocolstate.Dialer.GetDNSData(urlx.Hostname()) + dialers := protocolstate.GetDialersWithId(executionId) + dnsData, err := dialers.Fastdialer.GetDNSData(urlx.Hostname()) if err == nil && len(dnsData.AAAA) > 0 { // pick/ prefer 1st ips = append(ips, dnsData.AAAA[0]) @@ -228,17 +230,17 @@ func (i *ListInputProvider) Set(value string) { } // SetWithProbe only sets the input if it is live -func (i *ListInputProvider) SetWithProbe(value string, probe providerTypes.InputLivenessProbe) error { +func (i *ListInputProvider) SetWithProbe(executionId string, value string, probe providerTypes.InputLivenessProbe) error { probedValue, err := probe.ProbeURL(value) if err != nil { return err } - i.Set(probedValue) + i.Set(executionId, probedValue) return nil } // SetWithExclusions normalizes and stores passed input values if not excluded -func (i *ListInputProvider) SetWithExclusions(value string) error { +func (i *ListInputProvider) SetWithExclusions(executionId string, value string) error { URL := strings.TrimSpace(value) if URL == "" { return nil @@ -247,7 +249,7 @@ func (i *ListInputProvider) SetWithExclusions(value string) error { i.skippedCount++ return nil } - i.Set(URL) + i.Set(executionId, URL) return nil } @@ -273,18 +275,20 @@ func (i *ListInputProvider) initializeInputSources(opts *Options) error { switch { case iputil.IsCIDR(target): ips := expand.CIDR(target) - i.addTargets(ips) + i.addTargets(options.ExecutionId, ips) case asn.IsASN(target): ips := expand.ASN(target) - i.addTargets(ips) + i.addTargets(options.ExecutionId, ips) default: - i.Set(target) + i.Set(options.ExecutionId, target) } } // Handle stdin if options.Stdin { - i.scanInputFromReader(readerutil.TimeoutReader{Reader: os.Stdin, Timeout: time.Duration(options.InputReadTimeout)}) + i.scanInputFromReader( + options.ExecutionId, + readerutil.TimeoutReader{Reader: os.Stdin, Timeout: time.Duration(options.InputReadTimeout)}) } // Handle target file @@ -297,7 +301,7 @@ func (i *ListInputProvider) initializeInputSources(opts *Options) error { } } if input != nil { - i.scanInputFromReader(input) + i.scanInputFromReader(options.ExecutionId, input) input.Close() } } @@ -317,7 +321,7 @@ func (i *ListInputProvider) initializeInputSources(opts *Options) error { return err } for c := range ch { - i.Set(c) + i.Set(options.ExecutionId, c) } } @@ -331,7 +335,7 @@ func (i *ListInputProvider) initializeInputSources(opts *Options) error { ips := expand.ASN(target) i.removeTargets(ips) default: - i.Del(target) + i.Del(options.ExecutionId, target) } } } @@ -340,19 +344,19 @@ func (i *ListInputProvider) initializeInputSources(opts *Options) error { } // scanInputFromReader scans a line of input from reader and passes it for storage -func (i *ListInputProvider) scanInputFromReader(reader io.Reader) { +func (i *ListInputProvider) scanInputFromReader(executionId string, reader io.Reader) { scanner := bufio.NewScanner(reader) for scanner.Scan() { item := scanner.Text() switch { case iputil.IsCIDR(item): ips := expand.CIDR(item) - i.addTargets(ips) + i.addTargets(executionId, ips) case asn.IsASN(item): ips := expand.ASN(item) - i.addTargets(ips) + i.addTargets(executionId, ips) default: - i.Set(item) + i.Set(executionId, item) } } } @@ -371,7 +375,7 @@ func (i *ListInputProvider) isExcluded(URL string) bool { return exists } -func (i *ListInputProvider) Del(value string) { +func (i *ListInputProvider) Del(executionId string, value string) { URL := strings.TrimSpace(value) if URL == "" { return @@ -401,7 +405,8 @@ func (i *ListInputProvider) Del(value string) { if i.ipOptions.ScanAllIPs { // scan all ips - dnsData, err := protocolstate.Dialer.GetDNSData(urlx.Hostname()) + dialers := protocolstate.GetDialersWithId(executionId) + dnsData, err := dialers.Fastdialer.GetDNSData(urlx.Hostname()) if err == nil { if (len(dnsData.A) + len(dnsData.AAAA)) > 0 { var ips []string @@ -433,7 +438,8 @@ func (i *ListInputProvider) Del(value string) { ips := []string{} // only scan the target but ipv6 if it has one if i.ipOptions.IPV6 { - dnsData, err := protocolstate.Dialer.GetDNSData(urlx.Hostname()) + dialers := protocolstate.GetDialersWithId(executionId) + dnsData, err := dialers.Fastdialer.GetDNSData(urlx.Hostname()) if err == nil && len(dnsData.AAAA) > 0 { // pick/ prefer 1st ips = append(ips, dnsData.AAAA[0]) @@ -519,9 +525,9 @@ func (i *ListInputProvider) setHostMapStream(data string) { } } -func (i *ListInputProvider) addTargets(targets []string) { +func (i *ListInputProvider) addTargets(executionId string, targets []string) { for _, target := range targets { - i.Set(target) + i.Set(executionId, target) } } diff --git a/pkg/input/provider/simple.go b/pkg/input/provider/simple.go index c85f7871b..ac1b854df 100644 --- a/pkg/input/provider/simple.go +++ b/pkg/input/provider/simple.go @@ -19,10 +19,10 @@ func NewSimpleInputProvider() *SimpleInputProvider { } // NewSimpleInputProviderWithUrls creates a new simple input provider with the given urls -func NewSimpleInputProviderWithUrls(urls ...string) *SimpleInputProvider { +func NewSimpleInputProviderWithUrls(executionId string, urls ...string) *SimpleInputProvider { provider := NewSimpleInputProvider() for _, url := range urls { - provider.Set(url) + provider.Set(executionId, url) } return provider } @@ -42,14 +42,14 @@ func (s *SimpleInputProvider) Iterate(callback func(value *contextargs.MetaInput } // Set adds an item to the input provider -func (s *SimpleInputProvider) Set(value string) { +func (s *SimpleInputProvider) Set(_ string, value string) { metaInput := contextargs.NewMetaInput() metaInput.Input = value s.Inputs = append(s.Inputs, metaInput) } // SetWithProbe adds an item to the input provider with HTTP probing -func (s *SimpleInputProvider) SetWithProbe(value string, probe types.InputLivenessProbe) error { +func (s *SimpleInputProvider) SetWithProbe(_ string, value string, probe types.InputLivenessProbe) error { probedValue, err := probe.ProbeURL(value) if err != nil { return err @@ -61,7 +61,7 @@ func (s *SimpleInputProvider) SetWithProbe(value string, probe types.InputLivene } // SetWithExclusions adds an item to the input provider if it doesn't match any of the exclusions -func (s *SimpleInputProvider) SetWithExclusions(value string) error { +func (s *SimpleInputProvider) SetWithExclusions(_ string, value string) error { metaInput := contextargs.NewMetaInput() metaInput.Input = value s.Inputs = append(s.Inputs, metaInput) diff --git a/pkg/js/compiler/compiler.go b/pkg/js/compiler/compiler.go index b13e7f9ec..42b0b9da9 100644 --- a/pkg/js/compiler/compiler.go +++ b/pkg/js/compiler/compiler.go @@ -32,6 +32,9 @@ func New() *Compiler { // ExecuteOptions provides options for executing a script. type ExecuteOptions struct { + // ExecutionId is the id of the execution + ExecutionId string + // Callback can be used to register new runtime helper functions // ex: export etc Callback func(runtime *goja.Runtime) error diff --git a/pkg/js/compiler/pool.go b/pkg/js/compiler/pool.go index ac6a3dada..b271cd329 100644 --- a/pkg/js/compiler/pool.go +++ b/pkg/js/compiler/pool.go @@ -84,6 +84,7 @@ func executeWithRuntime(runtime *goja.Runtime, p *goja.Program, args *ExecuteArg if opts != nil && opts.Cleanup != nil { opts.Cleanup(runtime) } + _ = runtime.GlobalObject().Delete("executionId") }() // TODO(dwisiswant0): remove this once we get the RCA. @@ -108,8 +109,11 @@ func executeWithRuntime(runtime *goja.Runtime, p *goja.Program, args *ExecuteArg if err := opts.Callback(runtime); err != nil { return nil, err } - } + + // inject execution id + _ = runtime.Set("executionId", opts.ExecutionId) + // execute the script return runtime.RunProgram(p) } diff --git a/pkg/js/libs/kerberos/kerberosx.go b/pkg/js/libs/kerberos/kerberosx.go index ea3e5921d..131c9f905 100644 --- a/pkg/js/libs/kerberos/kerberosx.go +++ b/pkg/js/libs/kerberos/kerberosx.go @@ -109,7 +109,8 @@ func NewKerberosClient(call goja.ConstructorCall, runtime *goja.Runtime) *goja.O if controller != "" { // validate controller hostport - if !protocolstate.IsHostAllowed(controller) { + executionId := c.nj.ExecutionId() + if !protocolstate.IsHostAllowed(executionId, controller) { c.nj.Throw("domain controller address blacklisted by network policy") } @@ -246,16 +247,18 @@ func (c *Client) GetServiceTicket(User, Pass, SPN string) (TGS, error) { c.nj.Require(Pass != "", "Pass cannot be empty") c.nj.Require(SPN != "", "SPN cannot be empty") + executionId := c.nj.ExecutionId() + if len(c.Krb5Config.Realms) > 0 { // this means dc address was given for _, r := range c.Krb5Config.Realms { for _, kdc := range r.KDC { - if !protocolstate.IsHostAllowed(kdc) { + if !protocolstate.IsHostAllowed(executionId, kdc) { c.nj.Throw("KDC address %v blacklisted by network policy", kdc) } } for _, kpasswd := range r.KPasswdServer { - if !protocolstate.IsHostAllowed(kpasswd) { + if !protocolstate.IsHostAllowed(executionId, kpasswd) { c.nj.Throw("Kpasswd address %v blacklisted by network policy", kpasswd) } } @@ -265,7 +268,7 @@ func (c *Client) GetServiceTicket(User, Pass, SPN string) (TGS, error) { // and check if they are allowed by network policy _, kdcs, _ := c.Krb5Config.GetKDCs(c.Realm, true) for _, v := range kdcs { - if !protocolstate.IsHostAllowed(v) { + if !protocolstate.IsHostAllowed(executionId, v) { c.nj.Throw("KDC address %v blacklisted by network policy", v) } } diff --git a/pkg/js/libs/kerberos/sendtokdc.go b/pkg/js/libs/kerberos/sendtokdc.go index 7e14386a7..52d277fa4 100644 --- a/pkg/js/libs/kerberos/sendtokdc.go +++ b/pkg/js/libs/kerberos/sendtokdc.go @@ -68,6 +68,9 @@ func sendToKDCTcp(kclient *Client, msg string) ([]byte, error) { kclient.nj.HandleError(err, "error getting KDCs") kclient.nj.Require(len(kdcs) > 0, "no KDCs found") + executionId := kclient.nj.ExecutionId() + dialers := protocolstate.GetDialersWithId(executionId) + var errs []string for i := 1; i <= len(kdcs); i++ { host, port, err := net.SplitHostPort(kdcs[i]) @@ -75,7 +78,7 @@ func sendToKDCTcp(kclient *Client, msg string) ([]byte, error) { // use that ip address instead of realm/domain for resolving host = kclient.config.ip } - tcpConn, err := protocolstate.Dialer.Dial(context.TODO(), "tcp", net.JoinHostPort(host, port)) + tcpConn, err := dialers.Fastdialer.Dial(context.TODO(), "tcp", net.JoinHostPort(host, port)) if err != nil { errs = append(errs, fmt.Sprintf("error establishing connection to %s: %v", kdcs[i], err)) continue @@ -101,6 +104,9 @@ func sendToKDCUdp(kclient *Client, msg string) ([]byte, error) { kclient.nj.HandleError(err, "error getting KDCs") kclient.nj.Require(len(kdcs) > 0, "no KDCs found") + executionId := kclient.nj.ExecutionId() + dialers := protocolstate.GetDialersWithId(executionId) + var errs []string for i := 1; i <= len(kdcs); i++ { host, port, err := net.SplitHostPort(kdcs[i]) @@ -108,7 +114,7 @@ func sendToKDCUdp(kclient *Client, msg string) ([]byte, error) { // use that ip address instead of realm/domain for resolving host = kclient.config.ip } - udpConn, err := protocolstate.Dialer.Dial(context.TODO(), "udp", net.JoinHostPort(host, port)) + udpConn, err := dialers.Fastdialer.Dial(context.TODO(), "udp", net.JoinHostPort(host, port)) if err != nil { errs = append(errs, fmt.Sprintf("error establishing connection to %s: %v", kdcs[i], err)) continue diff --git a/pkg/js/libs/ldap/ldap.go b/pkg/js/libs/ldap/ldap.go index 463f86b6a..80819feba 100644 --- a/pkg/js/libs/ldap/ldap.go +++ b/pkg/js/libs/ldap/ldap.go @@ -86,12 +86,15 @@ func NewClient(call goja.ConstructorCall, runtime *goja.Runtime) *goja.Object { u, err := url.Parse(ldapUrl) c.nj.HandleError(err, "invalid ldap url supported schemas are ldap://, ldaps://, ldapi://, and cldap://") + executionId := c.nj.ExecutionId() + dialers := protocolstate.GetDialersWithId(executionId) + var conn net.Conn if u.Scheme == "ldapi" { if u.Path == "" || u.Path == "/" { u.Path = "/var/run/slapd/ldapi" } - conn, err = protocolstate.Dialer.Dial(context.TODO(), "unix", u.Path) + conn, err = dialers.Fastdialer.Dial(context.TODO(), "unix", u.Path) c.nj.HandleError(err, "failed to connect to ldap server") } else { host, port, err := net.SplitHostPort(u.Host) @@ -110,12 +113,12 @@ func NewClient(call goja.ConstructorCall, runtime *goja.Runtime) *goja.Object { if port == "" { port = ldap.DefaultLdapPort } - conn, err = protocolstate.Dialer.Dial(context.TODO(), "udp", net.JoinHostPort(host, port)) + conn, err = dialers.Fastdialer.Dial(context.TODO(), "udp", net.JoinHostPort(host, port)) case "ldap": if port == "" { port = ldap.DefaultLdapPort } - conn, err = protocolstate.Dialer.Dial(context.TODO(), "tcp", net.JoinHostPort(host, port)) + conn, err = dialers.Fastdialer.Dial(context.TODO(), "tcp", net.JoinHostPort(host, port)) case "ldaps": if port == "" { port = ldap.DefaultLdapsPort @@ -124,7 +127,7 @@ func NewClient(call goja.ConstructorCall, runtime *goja.Runtime) *goja.Object { if c.cfg.ServerName != "" { serverName = c.cfg.ServerName } - conn, err = protocolstate.Dialer.DialTLSWithConfig(context.TODO(), "tcp", net.JoinHostPort(host, port), + conn, err = dialers.Fastdialer.DialTLSWithConfig(context.TODO(), "tcp", net.JoinHostPort(host, port), &tls.Config{InsecureSkipVerify: true, MinVersion: tls.VersionTLS10, ServerName: serverName}) default: err = fmt.Errorf("unsupported ldap url schema %v", u.Scheme) @@ -331,7 +334,7 @@ func (c *Client) CollectMetadata() Metadata { // ``` func (c *Client) GetVersion() []string { c.nj.Require(c.conn != nil, "no existing connection") - + // Query root DSE for supported LDAP versions sr := ldap.NewSearchRequest( "", @@ -341,18 +344,17 @@ func (c *Client) GetVersion() []string { "(objectClass=*)", []string{"supportedLDAPVersion"}, nil) - + res, err := c.conn.Search(sr) c.nj.HandleError(err, "failed to get LDAP version") - + if len(res.Entries) > 0 { return res.Entries[0].GetAttributeValues("supportedLDAPVersion") } - + return []string{"unknown"} } - // close the ldap connection // @example // ```javascript diff --git a/pkg/js/utils/nucleijs.go b/pkg/js/utils/nucleijs.go index 9d9e3f4ec..44497ed0c 100644 --- a/pkg/js/utils/nucleijs.go +++ b/pkg/js/utils/nucleijs.go @@ -42,6 +42,10 @@ func (j *NucleiJS) runtime() *goja.Runtime { return j.vm } +func (j *NucleiJS) ExecutionId() string { + return j.runtime().Get("executionId").String() +} + // see: https://arc.net/l/quote/wpenftpc for throwing docs // ThrowError throws an error in goja runtime if is not nil diff --git a/pkg/protocols/code/code.go b/pkg/protocols/code/code.go index b3344d08d..2ff664238 100644 --- a/pkg/protocols/code/code.go +++ b/pkg/protocols/code/code.go @@ -201,6 +201,7 @@ func (request *Request) ExecuteWithResults(input *contextargs.Context, dynamicVa result, err := request.options.JsCompiler.ExecuteWithOptions(request.preConditionCompiled, args, &compiler.ExecuteOptions{ + ExecutionId: request.options.Options.ExecutionId, TimeoutVariants: request.options.Options.GetTimeouts(), Source: &request.PreCondition, Callback: registerPreConditionFunctions, diff --git a/pkg/protocols/common/protocolinit/init.go b/pkg/protocols/common/protocolinit/init.go index f0ba77177..bdb6a6f3c 100644 --- a/pkg/protocols/common/protocolinit/init.go +++ b/pkg/protocols/common/protocolinit/init.go @@ -1,8 +1,6 @@ package protocolinit import ( - "context" - "github.com/projectdiscovery/nuclei/v3/pkg/js/compiler" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/dns/dnsclientpool" @@ -15,8 +13,8 @@ import ( ) // Init initializes the client pools for the protocols -func Init(ctx context.Context, options *types.Options) error { - if err := protocolstate.Init(ctx, options); err != nil { +func Init(options *types.Options) error { + if err := protocolstate.Init(options); err != nil { return err } if err := dnsclientpool.Init(options); err != nil { @@ -40,6 +38,6 @@ func Init(ctx context.Context, options *types.Options) error { return nil } -func Close(ctx context.Context) { - protocolstate.Close(ctx) +func Close(executionId string) { + protocolstate.Close(executionId) } diff --git a/pkg/protocols/common/protocolstate/dialers.go b/pkg/protocols/common/protocolstate/dialers.go new file mode 100644 index 000000000..ad10a7298 --- /dev/null +++ b/pkg/protocols/common/protocolstate/dialers.go @@ -0,0 +1,18 @@ +package protocolstate + +import ( + "github.com/projectdiscovery/fastdialer/fastdialer" + "github.com/projectdiscovery/networkpolicy" + "github.com/projectdiscovery/rawhttp" + "github.com/projectdiscovery/retryablehttp-go" + mapsutil "github.com/projectdiscovery/utils/maps" +) + +type Dialers struct { + Fastdialer *fastdialer.Dialer + RawHTTPClient *rawhttp.Client + DefaultHTTPClient *retryablehttp.Client + HTTPClientPool *mapsutil.SyncLockMap[string, *retryablehttp.Client] + + NetworkPolicy *networkpolicy.NetworkPolicy +} diff --git a/pkg/protocols/common/protocolstate/headless.go b/pkg/protocols/common/protocolstate/headless.go index 1b4e7b932..267c9b63f 100644 --- a/pkg/protocols/common/protocolstate/headless.go +++ b/pkg/protocols/common/protocolstate/headless.go @@ -8,8 +8,8 @@ import ( "github.com/go-rod/rod" "github.com/go-rod/rod/lib/proto" "github.com/projectdiscovery/networkpolicy" + "github.com/projectdiscovery/nuclei/v3/pkg/types" errorutil "github.com/projectdiscovery/utils/errors" - mapsutil "github.com/projectdiscovery/utils/maps" stringsutil "github.com/projectdiscovery/utils/strings" urlutil "github.com/projectdiscovery/utils/url" "go.uber.org/multierr" @@ -18,9 +18,9 @@ import ( // initalize state of headless protocol var ( - ErrURLDenied = errorutil.NewWithFmt("headless: url %v dropped by rule: %v") - ErrHostDenied = errorutil.NewWithFmt("host %v dropped by network policy") - networkPolicies = mapsutil.NewSyncLockMap[string, *networkpolicy.NetworkPolicy]() + ErrURLDenied = errorutil.NewWithFmt("headless: url %v dropped by rule: %v") + ErrHostDenied = errorutil.NewWithFmt("host %v dropped by network policy") + allowLocalFileAccess bool ) @@ -29,16 +29,16 @@ func GetNetworkPolicy(ctx context.Context) *networkpolicy.NetworkPolicy { if execCtx == nil { return nil } - np, ok := networkPolicies.Get(execCtx.ExecutionID) - if !ok || np == nil { + dialers, ok := dialers.Get(execCtx.ExecutionID) + if !ok || dialers == nil { return nil } - return np + return dialers.NetworkPolicy } // ValidateNFailRequest validates and fails request // if the request does not respect the rules, it will be canceled with reason -func ValidateNFailRequest(ctx context.Context, page *rod.Page, e *proto.FetchRequestPaused) error { +func ValidateNFailRequest(options *types.Options, page *rod.Page, e *proto.FetchRequestPaused) error { reqURL := e.Request.URL normalized := strings.ToLower(reqURL) // normalize url to lowercase normalized = strings.TrimSpace(normalized) // trim leading & trailing whitespaces @@ -50,7 +50,7 @@ func ValidateNFailRequest(ctx context.Context, page *rod.Page, e *proto.FetchReq if stringsutil.HasPrefixAnyI(normalized, "ftp:", "externalfile:", "chrome:", "chrome-extension:") { return multierr.Combine(FailWithReason(page, e), ErrURLDenied.Msgf(reqURL, "protocol blocked by network policy")) } - if !isValidHost(ctx, reqURL) { + if !isValidHost(options, reqURL) { return multierr.Combine(FailWithReason(page, e), ErrURLDenied.Msgf(reqURL, "address blocked by network policy")) } return nil @@ -66,28 +66,22 @@ func FailWithReason(page *rod.Page, e *proto.FetchRequestPaused) error { } // InitHeadless initializes headless protocol state -func InitHeadless(ctx context.Context, localFileAccess bool, np *networkpolicy.NetworkPolicy) { +func InitHeadless(localFileAccess bool) { allowLocalFileAccess = localFileAccess - if np != nil { - execCtx := GetExecutionContext(ctx) - if execCtx != nil { - networkPolicies.Set(execCtx.ExecutionID, np) - } - } } // isValidHost checks if the host is valid (only limited to http/https protocols) -func isValidHost(ctx context.Context, targetUrl string) bool { +func isValidHost(options *types.Options, targetUrl string) bool { if !stringsutil.HasPrefixAny(targetUrl, "http:", "https:") { return true } - execCtx := GetExecutionContext(ctx) - if execCtx == nil { + dialers, ok := dialers.Get(options.ExecutionId) + if !ok { return true } - np, ok := networkPolicies.Get(execCtx.ExecutionID) + np := dialers.NetworkPolicy if !ok || np == nil { return true } @@ -103,13 +97,13 @@ func isValidHost(ctx context.Context, targetUrl string) bool { } // IsHostAllowed checks if the host is allowed by network policy -func IsHostAllowed(ctx context.Context, targetUrl string) bool { - execCtx := GetExecutionContext(ctx) - if execCtx == nil { +func IsHostAllowed(executionId string, targetUrl string) bool { + dialers, ok := dialers.Get(executionId) + if !ok { return true } - np, ok := networkPolicies.Get(execCtx.ExecutionID) + np := dialers.NetworkPolicy if !ok || np == nil { return true } diff --git a/pkg/protocols/common/protocolstate/state.go b/pkg/protocols/common/protocolstate/state.go index aef024e3c..5120a1eb2 100644 --- a/pkg/protocols/common/protocolstate/state.go +++ b/pkg/protocols/common/protocolstate/state.go @@ -20,31 +20,37 @@ import ( // Dialer is a shared fastdialer instance for host DNS resolution var ( - dialers *mapsutil.SyncLockMap[string, *fastdialer.Dialer] + dialers *mapsutil.SyncLockMap[string, *Dialers] ) -func GetDialer(ctx context.Context) *fastdialer.Dialer { +func GetDialers(ctx context.Context) *Dialers { executionContext := GetExecutionContext(ctx) - dialer, ok := dialers.Get(executionContext.ExecutionID) + dialers, ok := dialers.Get(executionContext.ExecutionID) if !ok { return nil } - return dialer + return dialers } -func ShouldInit(ctx context.Context) bool { - executionContext := GetExecutionContext(ctx) - dialer, ok := dialers.Get(executionContext.ExecutionID) +func GetDialersWithId(id string) *Dialers { + dialers, ok := dialers.Get(id) + if !ok { + return nil + } + return dialers +} + +func ShouldInit(id string) bool { + dialer, ok := dialers.Get(id) if !ok { return false } return dialer == nil } -// Init creates the Dialer instance based on user configuration -func Init(ctx context.Context, options *types.Options) error { - executionContext := GetExecutionContext(ctx) - if GetDialer(ctx) != nil { +// Init creates the Dialers instance based on user configuration +func Init(options *types.Options) error { + if GetDialersWithId(options.ExecutionId) != nil { return nil } @@ -73,8 +79,7 @@ func Init(ctx context.Context, options *types.Options) error { DenyList: expandedDenyList, } opts.WithNetworkPolicyOptions = npOptions - networkPolicy, _ := networkpolicy.New(*npOptions) - InitHeadless(ctx, options.AllowLocalFileAccess, networkPolicy) + InitHeadless(options.AllowLocalFileAccess) switch { case options.SourceIP != "" && options.Interface != "": @@ -159,7 +164,15 @@ func Init(ctx context.Context, options *types.Options) error { if err != nil { return errors.Wrap(err, "could not create dialer") } - dialers.Set(executionContext.ExecutionID, dialer) + + networkPolicy, _ := networkpolicy.New(*npOptions) + + dialersInstance := &Dialers{ + Fastdialer: dialer, + NetworkPolicy: networkPolicy, + } + + dialers.Set(options.ExecutionId, dialersInstance) // Set a custom dialer for the "nucleitcp" protocol. This is just plain TCP, but it's registered // with a different name so that we do not clobber the "tcp" dialer in the event that nuclei is @@ -234,18 +247,17 @@ func interfaceAddresses(interfaceName string) ([]net.Addr, error) { } // Close closes the global shared fastdialer -func Close(ctx context.Context) { - executionContext := GetExecutionContext(ctx) - dialer, ok := dialers.Get(executionContext.ExecutionID) +func Close(executionId string) { + dialersInstance, ok := dialers.Get(executionId) if !ok { return } - if dialer != nil { - dialer.Close() + if dialersInstance != nil { + dialersInstance.Fastdialer.Close() } - dialers.Delete(executionContext.ExecutionID) + dialers.Delete(executionId) StopActiveMemGuardian() } diff --git a/pkg/protocols/headless/engine/http_client.go b/pkg/protocols/headless/engine/http_client.go index 5ecddf700..4ee74f2b9 100644 --- a/pkg/protocols/headless/engine/http_client.go +++ b/pkg/protocols/headless/engine/http_client.go @@ -19,7 +19,7 @@ import ( // newHttpClient creates a new http client for headless communication with a timeout func newHttpClient(options *types.Options) (*http.Client, error) { - dialer := protocolstate.Dialer + dialers := protocolstate.GetDialersWithId(options.ExecutionId) // Set the base TLS configuration definition tlsConfig := &tls.Config{ @@ -41,15 +41,15 @@ func newHttpClient(options *types.Options) (*http.Client, error) { transport := &http.Transport{ ForceAttemptHTTP2: options.ForceAttemptHTTP2, - DialContext: dialer.Dial, + DialContext: dialers.Fastdialer.Dial, DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { if options.TlsImpersonate { - return dialer.DialTLSWithConfigImpersonate(ctx, network, addr, tlsConfig, impersonate.Random, nil) + return dialers.Fastdialer.DialTLSWithConfigImpersonate(ctx, network, addr, tlsConfig, impersonate.Random, nil) } if options.HasClientCertificates() || options.ForceAttemptHTTP2 { - return dialer.DialTLSWithConfig(ctx, network, addr, tlsConfig) + return dialers.Fastdialer.DialTLSWithConfig(ctx, network, addr, tlsConfig) } - return dialer.DialTLS(ctx, network, addr) + return dialers.Fastdialer.DialTLS(ctx, network, addr) }, MaxIdleConns: 500, MaxIdleConnsPerHost: 500, diff --git a/pkg/protocols/headless/engine/rules.go b/pkg/protocols/headless/engine/rules.go index cf7fd3d4f..0ff933aea 100644 --- a/pkg/protocols/headless/engine/rules.go +++ b/pkg/protocols/headless/engine/rules.go @@ -110,7 +110,7 @@ func (p *Page) routingRuleHandlerNative(e *proto.FetchRequestPaused) error { // ValidateNFailRequest validates if Local file access is enabled // and local network access is enables if not it will fail the request // that don't match the rules - if err := protocolstate.ValidateNFailRequest(p.page, e); err != nil { + if err := protocolstate.ValidateNFailRequest(p.options.Options, p.page, e); err != nil { return err } body, _ := FetchGetResponseBody(p.page, e) diff --git a/pkg/protocols/http/httpclientpool/clientpool.go b/pkg/protocols/http/httpclientpool/clientpool.go index 3f10fcbab..609c04e99 100644 --- a/pkg/protocols/http/httpclientpool/clientpool.go +++ b/pkg/protocols/http/httpclientpool/clientpool.go @@ -25,36 +25,19 @@ import ( "github.com/projectdiscovery/nuclei/v3/pkg/types/scanstrategy" "github.com/projectdiscovery/rawhttp" "github.com/projectdiscovery/retryablehttp-go" - mapsutil "github.com/projectdiscovery/utils/maps" urlutil "github.com/projectdiscovery/utils/url" ) var ( - rawHttpClient *rawhttp.Client - rawHttpClientOnce sync.Once forceMaxRedirects int - normalClient *retryablehttp.Client - clientPool *mapsutil.SyncLockMap[string, *retryablehttp.Client] ) // Init initializes the clientpool implementation func Init(options *types.Options) error { - // Don't create clients if already created in the past. - if normalClient != nil { - return nil - } if options.ShouldFollowHTTPRedirects() { forceMaxRedirects = options.MaxRedirects } - clientPool = &mapsutil.SyncLockMap[string, *retryablehttp.Client]{ - Map: make(mapsutil.Map[string, *retryablehttp.Client]), - } - client, err := wrappedGet(options, &Configuration{}) - if err != nil { - return err - } - normalClient = client return nil } @@ -158,25 +141,30 @@ func (c *Configuration) HasStandardOptions() bool { // GetRawHTTP returns the rawhttp request client func GetRawHTTP(options *protocols.ExecutorOptions) *rawhttp.Client { - rawHttpClientOnce.Do(func() { - rawHttpOptions := rawhttp.DefaultOptions - if options.Options.AliveHttpProxy != "" { - rawHttpOptions.Proxy = options.Options.AliveHttpProxy - } else if options.Options.AliveSocksProxy != "" { - rawHttpOptions.Proxy = options.Options.AliveSocksProxy - } else if protocolstate.Dialer != nil { - rawHttpOptions.FastDialer = protocolstate.Dialer - } - rawHttpOptions.Timeout = options.Options.GetTimeouts().HttpTimeout - rawHttpClient = rawhttp.NewClient(rawHttpOptions) - }) - return rawHttpClient + dialers := protocolstate.GetDialersWithId(options.Options.ExecutionId) + + if dialers.RawHTTPClient != nil { + return dialers.RawHTTPClient + } + + rawHttpOptions := rawhttp.DefaultOptions + if options.Options.AliveHttpProxy != "" { + rawHttpOptions.Proxy = options.Options.AliveHttpProxy + } else if options.Options.AliveSocksProxy != "" { + rawHttpOptions.Proxy = options.Options.AliveSocksProxy + } else if dialers.Fastdialer != nil { + rawHttpOptions.FastDialer = dialers.Fastdialer + } + rawHttpOptions.Timeout = options.Options.GetTimeouts().HttpTimeout + dialers.RawHTTPClient = rawhttp.NewClient(rawHttpOptions) + return dialers.RawHTTPClient } // Get creates or gets a client for the protocol based on custom configuration func Get(options *types.Options, configuration *Configuration) (*retryablehttp.Client, error) { if configuration.HasStandardOptions() { - return normalClient, nil + dialers := protocolstate.GetDialersWithId(options.ExecutionId) + return dialers.DefaultHTTPClient, nil } return wrappedGet(options, configuration) } @@ -185,8 +173,10 @@ func Get(options *types.Options, configuration *Configuration) (*retryablehttp.C func wrappedGet(options *types.Options, configuration *Configuration) (*retryablehttp.Client, error) { var err error + dialers := protocolstate.GetDialersWithId(options.ExecutionId) + hash := configuration.Hash() - if client, ok := clientPool.Get(hash); ok { + if client, ok := dialers.HTTPClientPool.Get(hash); ok { return client, nil } @@ -263,15 +253,15 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl transport := &http.Transport{ ForceAttemptHTTP2: options.ForceAttemptHTTP2, - DialContext: protocolstate.GetDialer().Dial, + DialContext: dialers.Fastdialer.Dial, DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) { if options.TlsImpersonate { - return protocolstate.Dialer.DialTLSWithConfigImpersonate(ctx, network, addr, tlsConfig, impersonate.Random, nil) + return dialers.Fastdialer.DialTLSWithConfigImpersonate(ctx, network, addr, tlsConfig, impersonate.Random, nil) } if options.HasClientCertificates() || options.ForceAttemptHTTP2 { - return protocolstate.Dialer.DialTLSWithConfig(ctx, network, addr, tlsConfig) + return dialers.Fastdialer.DialTLSWithConfig(ctx, network, addr, tlsConfig) } - return protocolstate.GetDialer().DialTLS(ctx, network, addr) + return dialers.Fastdialer.DialTLS(ctx, network, addr) }, MaxIdleConns: maxIdleConns, MaxIdleConnsPerHost: maxIdleConnsPerHost, @@ -338,7 +328,7 @@ func wrappedGet(options *types.Options, configuration *Configuration) (*retryabl // Only add to client pool if we don't have a cookie jar in place. if jar == nil { - if err := clientPool.Set(hash, client); err != nil { + if err := dialers.HTTPClientPool.Set(hash, client); err != nil { return nil, err } } diff --git a/pkg/protocols/http/request.go b/pkg/protocols/http/request.go index 2cc32f5bf..cfa5c1d57 100644 --- a/pkg/protocols/http/request.go +++ b/pkg/protocols/http/request.go @@ -817,6 +817,8 @@ func (request *Request) executeRequest(input *contextargs.Context, generatedRequ } } + dialers := protocolstate.GetDialersWithId(request.options.Options.ExecutionId) + if err != nil { // rawhttp doesn't support draining response bodies. if resp != nil && resp.Body != nil && generatedRequest.rawRequest == nil && !generatedRequest.original.Pipeline { @@ -837,7 +839,7 @@ func (request *Request) executeRequest(input *contextargs.Context, generatedRequ if input.MetaInput.CustomIP != "" { outputEvent["ip"] = input.MetaInput.CustomIP } else { - outputEvent["ip"] = protocolstate.Dialer.GetDialedIP(hostname) + outputEvent["ip"] = dialers.Fastdialer.GetDialedIP(hostname) // try getting cname request.addCNameIfAvailable(hostname, outputEvent) } @@ -957,7 +959,7 @@ func (request *Request) executeRequest(input *contextargs.Context, generatedRequ if input.MetaInput.CustomIP != "" { outputEvent["ip"] = input.MetaInput.CustomIP } else { - dialer := protocolstate.GetDialer() + dialer := dialers.Fastdialer if dialer != nil { outputEvent["ip"] = dialer.GetDialedIP(hostname) } @@ -1081,11 +1083,13 @@ func (request *Request) validateNFixEvent(input *contextargs.Context, gr *genera // addCNameIfAvailable adds the cname to the event if available func (request *Request) addCNameIfAvailable(hostname string, outputEvent map[string]interface{}) { - if protocolstate.Dialer == nil { + dialers := protocolstate.GetDialersWithId(request.options.Options.ExecutionId) + + if dialers.Fastdialer == nil { return } - data, err := protocolstate.Dialer.GetDNSData(hostname) + data, err := dialers.Fastdialer.GetDNSData(hostname) if err == nil { switch len(data.CNAME) { case 0: diff --git a/pkg/protocols/javascript/js.go b/pkg/protocols/javascript/js.go index 344199fb3..68326840c 100644 --- a/pkg/protocols/javascript/js.go +++ b/pkg/protocols/javascript/js.go @@ -611,6 +611,8 @@ func (request *Request) executeRequestWithPayloads(hostPort string, input *conte // generateEventData generates event data for the request func (request *Request) generateEventData(input *contextargs.Context, values map[string]interface{}, matched string) map[string]interface{} { + dialers := protocolstate.GetDialersWithId(request.options.Options.ExecutionId) + data := make(map[string]interface{}) for k, v := range values { data[k] = v @@ -643,7 +645,7 @@ func (request *Request) generateEventData(input *contextargs.Context, values map } } } - data["ip"] = protocolstate.Dialer.GetDialedIP(hostname) + data["ip"] = dialers.Fastdialer.GetDialedIP(hostname) // if input itself was an ip, use it if iputil.IsIP(hostname) { data["ip"] = hostname @@ -651,7 +653,7 @@ func (request *Request) generateEventData(input *contextargs.Context, values map // if ip is not found,this is because ssh and other protocols do not use fastdialer // although its not perfect due to its use case dial and get ip - dnsData, err := protocolstate.Dialer.GetDNSData(hostname) + dnsData, err := dialers.Fastdialer.GetDNSData(hostname) if err == nil { for _, v := range dnsData.A { data["ip"] = v diff --git a/pkg/protocols/network/networkclientpool/clientpool.go b/pkg/protocols/network/networkclientpool/clientpool.go index a67cee296..936d4211f 100644 --- a/pkg/protocols/network/networkclientpool/clientpool.go +++ b/pkg/protocols/network/networkclientpool/clientpool.go @@ -6,17 +6,8 @@ import ( "github.com/projectdiscovery/nuclei/v3/pkg/types" ) -var ( - normalClient *fastdialer.Dialer -) - // Init initializes the clientpool implementation func Init(options *types.Options) error { - // Don't create clients if already created in the past. - if normalClient != nil { - return nil - } - normalClient = protocolstate.Dialer return nil } @@ -29,6 +20,7 @@ func (c *Configuration) Hash() string { } // Get creates or gets a client for the protocol based on custom configuration -func Get(options *types.Options, configuration *Configuration /*TODO review unused parameters*/) (*fastdialer.Dialer, error) { - return normalClient, nil +func Get(options *types.Options, configuration *Configuration) (*fastdialer.Dialer, error) { + dialers := protocolstate.GetDialersWithId(options.ExecutionId) + return dialers.Fastdialer, nil } diff --git a/pkg/protocols/network/request.go b/pkg/protocols/network/request.go index f7b11fbb5..fbdf76493 100644 --- a/pkg/protocols/network/request.go +++ b/pkg/protocols/network/request.go @@ -64,7 +64,8 @@ func (request *Request) getOpenPorts(target *contextargs.Context) ([]string, err errs = append(errs, err) continue } - conn, err := protocolstate.Dialer.Dial(target.Context(), "tcp", addr) + dialers := protocolstate.GetDialersWithId(request.options.Options.ExecutionId) + conn, err := dialers.Fastdialer.Dial(target.Context(), "tcp", addr) if err != nil { errs = append(errs, err) continue diff --git a/pkg/types/types.go b/pkg/types/types.go index 41c95ef68..0e529e975 100644 --- a/pkg/types/types.go +++ b/pkg/types/types.go @@ -449,6 +449,9 @@ type Options struct { // This is internally managed and does not need to be set by user by explicitly setting // this overrides the default/derived one timeouts *Timeouts + + // Unique identifier of the execution session + ExecutionId string } // SetTimeouts sets the timeout variants to use for the executor