Sandeep Singh b4644af80a
Lint + test fixes after utils dep update (#6393)
* fix: remove undefined errorutil.ShowStackTrace

* feat: add make lint support and integrate with test

* refactor: migrate errorutil to errkit across codebase

- Replace deprecated errorutil with modern errkit
- Convert error declarations from var to func for better compatibility
- Fix all SA1019 deprecation warnings
- Maintain error chain support and stack traces

* fix: improve DNS test reliability using Google DNS

- Configure test to use Google DNS (8.8.8.8) for stability
- Fix nil pointer issue in DNS client initialization
- Keep production defaults unchanged

* fixing logic

* removing unwanted branches in makefile

---------

Co-authored-by: Mzack9999 <mzack9999@protonmail.com>
2025-08-20 05:28:23 +05:30

148 lines
4.2 KiB
Go

package protocolstate
import (
"context"
"fmt"
"net"
"strings"
"github.com/go-rod/rod"
"github.com/go-rod/rod/lib/proto"
"github.com/projectdiscovery/networkpolicy"
"github.com/projectdiscovery/nuclei/v3/pkg/types"
"github.com/projectdiscovery/utils/errkit"
stringsutil "github.com/projectdiscovery/utils/strings"
urlutil "github.com/projectdiscovery/utils/url"
"go.uber.org/multierr"
)
// initialize state of headless protocol
// ErrURLDenied returns an error when a URL is denied by network policy
func ErrURLDenied(url, rule string) error {
return errkit.New(fmt.Sprintf("headless: url %v dropped by rule: %v", url, rule)).Build()
}
// ErrHostDenied returns an error when a host is denied by network policy
func ErrHostDenied(host string) error {
return errkit.New(fmt.Sprintf("host %v dropped by network policy", host)).Build()
}
func GetNetworkPolicy(ctx context.Context) *networkpolicy.NetworkPolicy {
execCtx := GetExecutionContext(ctx)
if execCtx == nil {
return nil
}
dialers, ok := dialers.Get(execCtx.ExecutionID)
if !ok || dialers == nil {
return nil
}
return dialers.NetworkPolicy
}
// ValidateNFailRequest validates and fails request
// if the request does not respect the rules, it will be canceled with reason
func ValidateNFailRequest(options *types.Options, page *rod.Page, e *proto.FetchRequestPaused) error {
reqURL := e.Request.URL
normalized := strings.ToLower(reqURL) // normalize url to lowercase
normalized = strings.TrimSpace(normalized) // trim leading & trailing whitespaces
if !IsLfaAllowed(options) && stringsutil.HasPrefixI(normalized, "file:") {
return multierr.Combine(FailWithReason(page, e), ErrURLDenied(reqURL, "use of file:// protocol disabled use '-lfa' to enable"))
}
// validate potential invalid schemes
// javascript protocol is allowed for xss fuzzing
if stringsutil.HasPrefixAnyI(normalized, "ftp:", "externalfile:", "chrome:", "chrome-extension:") {
return multierr.Combine(FailWithReason(page, e), ErrURLDenied(reqURL, "protocol blocked by network policy"))
}
if !isValidHost(options, reqURL) {
return multierr.Combine(FailWithReason(page, e), ErrURLDenied(reqURL, "address blocked by network policy"))
}
return nil
}
// FailWithReason fails request with AccessDenied reason
func FailWithReason(page *rod.Page, e *proto.FetchRequestPaused) error {
m := proto.FetchFailRequest{
RequestID: e.RequestID,
ErrorReason: proto.NetworkErrorReasonAccessDenied,
}
return m.Call(page)
}
// InitHeadless initializes headless protocol state
func InitHeadless(options *types.Options) {
dialers, ok := dialers.Get(options.ExecutionId)
if ok && dialers != nil {
dialers.Lock()
dialers.LocalFileAccessAllowed = options.AllowLocalFileAccess
dialers.RestrictLocalNetworkAccess = options.RestrictLocalNetworkAccess
dialers.Unlock()
}
}
func IsRestrictLocalNetworkAccess(options *types.Options) bool {
dialers, ok := dialers.Get(options.ExecutionId)
if ok && dialers != nil {
dialers.Lock()
defer dialers.Unlock()
return dialers.RestrictLocalNetworkAccess
}
return false
}
// isValidHost checks if the host is valid (only limited to http/https protocols)
func isValidHost(options *types.Options, targetUrl string) bool {
if !stringsutil.HasPrefixAny(targetUrl, "http:", "https:") {
return true
}
dialers, ok := dialers.Get(options.ExecutionId)
if !ok {
return true
}
np := dialers.NetworkPolicy
if !ok || np == nil {
return true
}
urlx, err := urlutil.Parse(targetUrl)
if err != nil {
// not a valid url
return false
}
targetUrl = urlx.Hostname()
_, ok = np.ValidateHost(targetUrl)
return ok
}
// IsHostAllowed checks if the host is allowed by network policy
func IsHostAllowed(executionId string, targetUrl string) bool {
dialers, ok := dialers.Get(executionId)
if !ok {
return true
}
np := dialers.NetworkPolicy
if !ok || np == nil {
return true
}
sepCount := strings.Count(targetUrl, ":")
if sepCount > 1 {
// most likely a ipv6 address (parse url and validate host)
return np.Validate(targetUrl)
}
if sepCount == 1 {
host, _, _ := net.SplitHostPort(targetUrl)
if _, ok := np.ValidateHost(host); !ok {
return false
}
return true
}
// just a hostname or ip without port
_, ok = np.ValidateHost(targetUrl)
return ok
}