diff --git a/cmd/integration-test/interactsh.go b/cmd/integration-test/interactsh.go index 536945fa4..8d737909d 100644 --- a/cmd/integration-test/interactsh.go +++ b/cmd/integration-test/interactsh.go @@ -1,9 +1,11 @@ package main +import osutils "github.com/projectdiscovery/utils/os" + // All Interactsh related testcases var interactshTestCases = []TestCaseInfo{ - {Path: "protocols/http/interactsh.yaml", TestCase: &httpInteractshRequest{}, DisableOn: func() bool { return false }}, - {Path: "protocols/http/interactsh-stop-at-first-match.yaml", TestCase: &httpInteractshStopAtFirstMatchRequest{}, DisableOn: func() bool { return false }}, // disable this test for now - {Path: "protocols/http/default-matcher-condition.yaml", TestCase: &httpDefaultMatcherCondition{}, DisableOn: func() bool { return false }}, + {Path: "protocols/http/interactsh.yaml", TestCase: &httpInteractshRequest{}, DisableOn: func() bool { return osutils.IsWindows() || osutils.IsOSX() }}, + {Path: "protocols/http/interactsh-stop-at-first-match.yaml", TestCase: &httpInteractshStopAtFirstMatchRequest{}, DisableOn: func() bool { return true }}, // disable this test for now + {Path: "protocols/http/default-matcher-condition.yaml", TestCase: &httpDefaultMatcherCondition{}, DisableOn: func() bool { return true }}, {Path: "protocols/http/interactsh-requests-mc-and.yaml", TestCase: &httpInteractshRequestsWithMCAnd{}}, } diff --git a/pkg/js/compiler/compiler.go b/pkg/js/compiler/compiler.go index d601b5122..265bbaf04 100644 --- a/pkg/js/compiler/compiler.go +++ b/pkg/js/compiler/compiler.go @@ -10,6 +10,7 @@ import ( "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/generators" contextutil "github.com/projectdiscovery/utils/context" + stringsutil "github.com/projectdiscovery/utils/strings" ) // Compiler provides a runtime to execute goja runtime @@ -33,6 +34,11 @@ type ExecuteOptions struct { /// Timeout for this script execution Timeout int + // Source is original source of the script + Source *string + + // Manually exported objects + exports map[string]interface{} } // ExecuteArgs is the arguments to pass to the script. @@ -67,7 +73,7 @@ func (e ExecuteResult) GetSuccess() bool { // Execute executes a script with the default options. func (c *Compiler) Execute(code string, args *ExecuteArgs) (ExecuteResult, error) { - p, err := goja.Compile("", code, false) + p, err := WrapScriptNCompile(code, false) if err != nil { return nil, err } @@ -108,10 +114,33 @@ func (c *Compiler) ExecuteWithOptions(program *goja.Program, args *ExecuteArgs, err = fmt.Errorf("panic: %v", r) } }() - return executeProgram(program, args, opts) + return ExecuteProgram(program, args, opts) }) if err != nil { return nil, err } - return ExecuteResult{"response": results.Export(), "success": results.ToBoolean()}, nil + var res ExecuteResult + if opts.exports != nil { + res = ExecuteResult(opts.exports) + opts.exports = nil + } else { + res = NewExecuteResult() + } + res["response"] = results.Export() + res["success"] = results.ToBoolean() + return res, nil +} + +// Wraps a script in a function and compiles it. +func WrapScriptNCompile(script string, strict bool) (*goja.Program, error) { + if !stringsutil.ContainsAny(script, exportAsToken, exportToken) { + // this will not be run in a pooled runtime + return goja.Compile("", script, strict) + } + val := fmt.Sprintf(` + (function() { + %s + })() + `, script) + return goja.Compile("", val, strict) } diff --git a/pkg/js/compiler/init.go b/pkg/js/compiler/init.go index 87f319c56..8a6d4a04b 100644 --- a/pkg/js/compiler/init.go +++ b/pkg/js/compiler/init.go @@ -6,8 +6,9 @@ import "github.com/projectdiscovery/nuclei/v3/pkg/types" var ( // Per Execution Javascript timeout in seconds - JsProtocolTimeout = 10 - JsVmConcurrency = 500 + JsProtocolTimeout = 10 + PoolingJsVmConcurrency = 100 + NonPoolingVMConcurrency = 20 ) // Init initializes the javascript protocol @@ -21,6 +22,7 @@ func Init(opts *types.Options) error { opts.JsConcurrency = 100 } JsProtocolTimeout = opts.Timeout - JsVmConcurrency = opts.JsConcurrency + PoolingJsVmConcurrency = opts.JsConcurrency + PoolingJsVmConcurrency -= NonPoolingVMConcurrency return nil } diff --git a/pkg/js/compiler/non-pool.go b/pkg/js/compiler/non-pool.go new file mode 100644 index 000000000..8057c4960 --- /dev/null +++ b/pkg/js/compiler/non-pool.go @@ -0,0 +1,23 @@ +package compiler + +import ( + "sync" + + "github.com/dop251/goja" + "github.com/remeh/sizedwaitgroup" +) + +var ( + ephemeraljsc = sizedwaitgroup.New(NonPoolingVMConcurrency) + lazyFixedSgInit = sync.OnceFunc(func() { + ephemeraljsc = sizedwaitgroup.New(NonPoolingVMConcurrency) + }) +) + +func executeWithoutPooling(p *goja.Program, args *ExecuteArgs, opts *ExecuteOptions) (result goja.Value, err error) { + lazyFixedSgInit() + ephemeraljsc.Add() + defer ephemeraljsc.Done() + runtime := createNewRuntime() + return executeWithRuntime(runtime, p, args, opts) +} diff --git a/pkg/js/compiler/pool.go b/pkg/js/compiler/pool.go index 5d600e68a..f1b65d310 100644 --- a/pkg/js/compiler/pool.go +++ b/pkg/js/compiler/pool.go @@ -1,7 +1,10 @@ package compiler import ( + "bytes" + "encoding/json" "fmt" + "reflect" "sync" "github.com/dop251/goja" @@ -29,11 +32,18 @@ import ( _ "github.com/projectdiscovery/nuclei/v3/pkg/js/generated/go/libtelnet" _ "github.com/projectdiscovery/nuclei/v3/pkg/js/generated/go/libvnc" "github.com/projectdiscovery/nuclei/v3/pkg/js/global" + "github.com/projectdiscovery/nuclei/v3/pkg/js/gojs" "github.com/projectdiscovery/nuclei/v3/pkg/js/libs/goconsole" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate" + stringsutil "github.com/projectdiscovery/utils/strings" "github.com/remeh/sizedwaitgroup" ) +const ( + exportToken = "Export" + exportAsToken = "ExportAs" +) + var ( r *require.Registry lazyRegistryInit = sync.OnceFunc(func() { @@ -41,40 +51,19 @@ var ( // autoregister console node module with default printer it uses gologger backend require.RegisterNativeModule(console.ModuleName, console.RequireWithPrinter(goconsole.NewGoConsolePrinter())) }) - sg sizedwaitgroup.SizedWaitGroup + pooljsc sizedwaitgroup.SizedWaitGroup lazySgInit = sync.OnceFunc(func() { - sg = sizedwaitgroup.New(JsVmConcurrency) + pooljsc = sizedwaitgroup.New(PoolingJsVmConcurrency) }) ) -func getRegistry() *require.Registry { - lazyRegistryInit() - return r -} - var gojapool = &sync.Pool{ New: func() interface{} { - runtime := protocolstate.NewJSRuntime() - _ = getRegistry().Enable(runtime) - // by default import below modules every time - _ = runtime.Set("console", require.Require(runtime, console.ModuleName)) - - // Register embedded javacript helpers - if err := global.RegisterNativeScripts(runtime); err != nil { - gologger.Error().Msgf("Could not register scripts: %s\n", err) - } - return runtime + return createNewRuntime() }, } -// executes the actual js program -func executeProgram(p *goja.Program, args *ExecuteArgs, opts *ExecuteOptions) (result goja.Value, err error) { - // its unknown (most likely cannot be done) to limit max js runtimes at a moment without making it static - // unlike sync.Pool which reacts to GC and its purposes is to reuse objects rather than creating new ones - lazySgInit() - sg.Add() - defer sg.Done() - runtime := gojapool.Get().(*goja.Runtime) +func executeWithRuntime(runtime *goja.Runtime, p *goja.Program, args *ExecuteArgs, opts *ExecuteOptions) (result goja.Value, err error) { defer func() { // reset before putting back to pool _ = runtime.GlobalObject().Delete("template") // template ctx @@ -85,7 +74,6 @@ func executeProgram(p *goja.Program, args *ExecuteArgs, opts *ExecuteOptions) (r if opts != nil && opts.Cleanup != nil { opts.Cleanup(runtime) } - gojapool.Put(runtime) }() defer func() { if r := recover(); r != nil { @@ -109,8 +97,126 @@ func executeProgram(p *goja.Program, args *ExecuteArgs, opts *ExecuteOptions) (r return runtime.RunProgram(p) } +// ExecuteProgram executes a compiled program with the default options. +// it deligates if a particular program should run in a pooled or non-pooled runtime +func ExecuteProgram(p *goja.Program, args *ExecuteArgs, opts *ExecuteOptions) (result goja.Value, err error) { + if opts.Source == nil { + // not-recommended anymore + return executeWithoutPooling(p, args, opts) + } + if !stringsutil.ContainsAny(*opts.Source, exportAsToken, exportToken) { + // not-recommended anymore + return executeWithoutPooling(p, args, opts) + } + return executeWithPoolingProgram(p, args, opts) +} + +// executes the actual js program +func executeWithPoolingProgram(p *goja.Program, args *ExecuteArgs, opts *ExecuteOptions) (result goja.Value, err error) { + // its unknown (most likely cannot be done) to limit max js runtimes at a moment without making it static + // unlike sync.Pool which reacts to GC and its purposes is to reuse objects rather than creating new ones + lazySgInit() + pooljsc.Add() + defer pooljsc.Done() + runtime := gojapool.Get().(*goja.Runtime) + defer gojapool.Put(runtime) + var buff bytes.Buffer + opts.exports = make(map[string]interface{}) + + defer func() { + // remove below functions from runtime + _ = runtime.GlobalObject().Delete(exportAsToken) + _ = runtime.GlobalObject().Delete(exportToken) + }() + // register export functions + _ = gojs.RegisterFuncWithSignature(runtime, gojs.FuncOpts{ + Name: "Export", // we use string instead of const for documentation generation + Signatures: []string{"Export(value any)"}, + Description: "Converts a given value to a string and is appended to output of script", + FuncDecl: func(call goja.FunctionCall, runtime *goja.Runtime) goja.Value { + if len(call.Arguments) == 0 { + return goja.Null() + } + for _, arg := range call.Arguments { + value := arg.Export() + if out := stringify(value); out != "" { + buff.WriteString(out) + } + } + return goja.Null() + }, + }) + // register exportAs function + _ = gojs.RegisterFuncWithSignature(runtime, gojs.FuncOpts{ + Name: "ExportAs", // Export + Signatures: []string{"ExportAs(key string,value any)"}, + Description: "Exports given value with specified key and makes it available in DSL and response", + FuncDecl: func(call goja.FunctionCall, runtime *goja.Runtime) goja.Value { + if len(call.Arguments) != 2 { + // this is how goja expects errors to be returned + // and internally it is done same way for all errors + panic(runtime.ToValue("ExportAs expects 2 arguments")) + } + key := call.Argument(0).String() + value := call.Argument(1).Export() + opts.exports[key] = stringify(value) + return goja.Null() + }, + }) + val, err := executeWithRuntime(runtime, p, args, opts) + if err != nil { + return nil, err + } + if val.Export() != nil { + // append last value to output + buff.WriteString(stringify(val.Export())) + } + // and return it as result + return runtime.ToValue(buff.String()), nil +} + // Internal purposes i.e generating bindings func InternalGetGeneratorRuntime() *goja.Runtime { runtime := gojapool.Get().(*goja.Runtime) return runtime } + +func getRegistry() *require.Registry { + lazyRegistryInit() + return r +} + +func createNewRuntime() *goja.Runtime { + runtime := protocolstate.NewJSRuntime() + _ = getRegistry().Enable(runtime) + // by default import below modules every time + _ = runtime.Set("console", require.Require(runtime, console.ModuleName)) + + // Register embedded javacript helpers + if err := global.RegisterNativeScripts(runtime); err != nil { + gologger.Error().Msgf("Could not register scripts: %s\n", err) + } + return runtime +} + +// stringify converts a given value to string +// if its a struct it will be marshalled to json +func stringify(value interface{}) string { + if value == nil { + return "" + } + kind := reflect.TypeOf(value).Kind() + if kind == reflect.Struct || kind == reflect.Ptr && reflect.ValueOf(value).Elem().Kind() == reflect.Struct { + // marshal structs or struct pointers to json automatically + val := value + if kind == reflect.Ptr { + val = reflect.ValueOf(value).Elem().Interface() + } + bin, err := json.Marshal(val) + if err == nil { + return string(bin) + } + } + // for everything else stringify + return fmt.Sprintf("%v", value) +} diff --git a/pkg/js/libs/mysql/mysql.go b/pkg/js/libs/mysql/mysql.go index 9dd49c3bd..b5911ae70 100644 --- a/pkg/js/libs/mysql/mysql.go +++ b/pkg/js/libs/mysql/mysql.go @@ -7,13 +7,12 @@ import ( "io" "log" "net" - "net/url" "time" "github.com/go-sql-driver/mysql" "github.com/praetorian-inc/fingerprintx/pkg/plugins" mysqlplugin "github.com/praetorian-inc/fingerprintx/pkg/plugins/services/mysql" - utils "github.com/projectdiscovery/nuclei/v3/pkg/js/utils" + "github.com/projectdiscovery/nuclei/v3/pkg/js/utils" "github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/protocolstate" ) @@ -22,16 +21,6 @@ import ( // Internally client uses go-sql-driver/mysql driver. type MySQLClient struct{} -// Connect connects to MySQL database using given credentials. -// -// If connection is successful, it returns true. -// If connection is unsuccessful, it returns false and error. -// -// The connection is closed after the function returns. -func (c *MySQLClient) Connect(host string, port int, username, password string) (bool, error) { - return connect(host, port, username, password, "INFORMATION_SCHEMA") -} - // IsMySQL checks if the given host is running MySQL database. // // If the host is running MySQL database, it returns true. @@ -58,83 +47,96 @@ func (c *MySQLClient) IsMySQL(host string, port int) (bool, error) { return true, nil } -// ConnectWithDB connects to MySQL database using given credentials and database name. +// Connect connects to MySQL database using given credentials. // // If connection is successful, it returns true. // If connection is unsuccessful, it returns false and error. -// // The connection is closed after the function returns. -func (c *MySQLClient) ConnectWithDB(host string, port int, username, password, dbName string) (bool, error) { - return connect(host, port, username, password, dbName) +func (c *MySQLClient) Connect(host string, port int, username, password string) (bool, error) { + if !protocolstate.IsHostAllowed(host) { + // host is not valid according to network policy + return false, protocolstate.ErrHostDenied.Msgf(host) + } + dsn, err := BuildDSN(MySQLOptions{ + Host: host, + Port: port, + DbName: "INFORMATION_SCHEMA", + Protocol: "tcp", + Username: username, + Password: password, + }) + if err != nil { + return false, err + } + return connectWithDSN(dsn) +} + +type MySQLInfo struct { + Host string `json:"host,omitempty"` + IP string `json:"ip"` + Port int `json:"port"` + Protocol string `json:"protocol"` + TLS bool `json:"tls"` + Transport string `json:"transport"` + Version string `json:"version,omitempty"` + Debug plugins.ServiceMySQL `json:"debug,omitempty"` + Raw string `json:"metadata"` +} + +// returns MySQLInfo when fingerpint is successful +func (c *MySQLClient) FingerprintMySQL(host string, port int) (MySQLInfo, error) { + info := MySQLInfo{} + if !protocolstate.IsHostAllowed(host) { + // host is not valid according to network policy + return info, protocolstate.ErrHostDenied.Msgf(host) + } + conn, err := protocolstate.Dialer.Dial(context.TODO(), "tcp", net.JoinHostPort(host, fmt.Sprintf("%d", port))) + if err != nil { + return info, err + } + defer conn.Close() + + plugin := &mysqlplugin.MYSQLPlugin{} + service, err := plugin.Run(conn, 5*time.Second, plugins.Target{Host: host}) + if err != nil { + return info, err + } + if service == nil { + return info, fmt.Errorf("something went wrong got null output") + } + // fill all fields + info.Host = service.Host + info.IP = service.IP + info.Port = service.Port + info.Protocol = service.Protocol + info.TLS = service.TLS + info.Transport = service.Transport + info.Version = service.Version + info.Debug = service.Metadata().(plugins.ServiceMySQL) + bin, _ := service.Raw.MarshalJSON() + info.Raw = string(bin) + return info, nil } // ConnectWithDSN connects to MySQL database using given DSN. // we override mysql dialer with fastdialer so it respects network policy func (c *MySQLClient) ConnectWithDSN(dsn string) (bool, error) { + return connectWithDSN(dsn) +} + +func (c *MySQLClient) ExecuteQueryWithOpts(opts MySQLOptions, query string) (*utils.SQLResult, error) { + if !protocolstate.IsHostAllowed(opts.Host) { + // host is not valid according to network policy + return nil, protocolstate.ErrHostDenied.Msgf(opts.Host) + } + dsn, err := BuildDSN(opts) + if err != nil { + return nil, err + } + db, err := sql.Open("mysql", dsn) if err != nil { - return false, err - } - defer db.Close() - db.SetMaxOpenConns(1) - db.SetMaxIdleConns(0) - - _, err = db.Exec("select 1") - if err != nil { - return false, err - } - return true, nil -} - -func connect(host string, port int, username, password, dbName string) (bool, error) { - if host == "" || port <= 0 { - return false, fmt.Errorf("invalid host or port") - } - - if !protocolstate.IsHostAllowed(host) { - // host is not valid according to network policy - return false, protocolstate.ErrHostDenied.Msgf(host) - } - - target := net.JoinHostPort(host, fmt.Sprintf("%d", port)) - - db, err := sql.Open("mysql", fmt.Sprintf("%v:%v@tcp(%v)/%s?allowOldPasswords=1", - url.PathEscape(username), - url.PathEscape(password), - target, - dbName)) - if err != nil { - return false, err - } - defer db.Close() - db.SetMaxOpenConns(1) - db.SetMaxIdleConns(0) - - _, err = db.Exec("select 1") - if err != nil { - return false, err - } - return true, nil -} - -// ExecuteQuery connects to Mysql database using given credentials and database name. -// and executes a query on the db. -func (c *MySQLClient) ExecuteQuery(host string, port int, username, password, dbName, query string) (string, error) { - - if !protocolstate.IsHostAllowed(host) { - // host is not valid according to network policy - return "", protocolstate.ErrHostDenied.Msgf(host) - } - - target := net.JoinHostPort(host, fmt.Sprintf("%d", port)) - - db, err := sql.Open("mysql", fmt.Sprintf("%v:%v@tcp(%v)/%s", - url.PathEscape(username), - url.PathEscape(password), - target, - dbName)) - if err != nil { - return "", err + return nil, err } defer db.Close() db.SetMaxOpenConns(1) @@ -142,13 +144,43 @@ func (c *MySQLClient) ExecuteQuery(host string, port int, username, password, db rows, err := db.Query(query) if err != nil { - return "", err + return nil, err } - resp, err := utils.UnmarshalSQLRows(rows) + + data, err := utils.UnmarshalSQLRows(rows) if err != nil { - return "", err + if len(data.Rows) > 0 { + // allow partial results + return data, nil + } + return nil, err } - return string(resp), nil + return data, nil +} + +// ExecuteQuery connects to Mysql database using given credentials +// and executes a query on the db. +func (c *MySQLClient) ExecuteQuery(host string, port int, username, password, query string) (*utils.SQLResult, error) { + return c.ExecuteQueryWithOpts(MySQLOptions{ + Host: host, + Port: port, + Protocol: "tcp", + Username: username, + Password: password, + }, query) +} + +// ExecuteQuery connects to Mysql database using given credentials +// and executes a query on the db. +func (c *MySQLClient) ExecuteQueryOnDB(host string, port int, username, password, dbname, query string) (*utils.SQLResult, error) { + return c.ExecuteQueryWithOpts(MySQLOptions{ + Host: host, + Port: port, + Protocol: "tcp", + Username: username, + Password: password, + DbName: dbname, + }, query) } func init() { diff --git a/pkg/js/libs/mysql/mysql_private.go b/pkg/js/libs/mysql/mysql_private.go new file mode 100644 index 000000000..c0e47c04c --- /dev/null +++ b/pkg/js/libs/mysql/mysql_private.go @@ -0,0 +1,65 @@ +package mysql + +import ( + "database/sql" + "fmt" + "net" + "net/url" + "strings" +) + +// MySQLOptions defines the data source name (DSN) options required to connect to a MySQL database. +// along with other options like Timeout etc +type MySQLOptions struct { + Host string // Host is the host name or IP address of the MySQL server. + Port int // Port is the port number on which the MySQL server is listening. + Protocol string // Protocol is the protocol used to connect to the MySQL server (ex: "tcp"). + Username string // Username is the user name used to authenticate with the MySQL server. + Password string // Password is the password used to authenticate with the MySQL server. + DbName string // DbName is the name of the database to connect to on the MySQL server. + RawQuery string // QueryStr is the query string to append to the DSN (ex: "?tls=skip-verify"). + Timeout int // Timeout is the timeout in seconds for the connection to the MySQL server. +} + +// BuildDSN builds a MySQL data source name (DSN) from the given options. +func BuildDSN(opts MySQLOptions) (string, error) { + if opts.Host == "" || opts.Port <= 0 { + return "", fmt.Errorf("invalid host or port") + } + if opts.Protocol == "" { + opts.Protocol = "tcp" + } + if opts.DbName == "" { + opts.DbName = "/" + } else { + opts.DbName = "/" + opts.DbName + } + target := net.JoinHostPort(opts.Host, fmt.Sprintf("%d", opts.Port)) + var dsn strings.Builder + dsn.WriteString(fmt.Sprintf("%v:%v", url.QueryEscape(opts.Username), url.QueryEscape(opts.Password))) + dsn.WriteString("@") + dsn.WriteString(fmt.Sprintf("%v(%v)", opts.Protocol, target)) + if opts.DbName != "" { + dsn.WriteString(opts.DbName) + } + if opts.RawQuery != "" { + dsn.WriteString(opts.RawQuery) + } + return dsn.String(), nil +} + +func connectWithDSN(dsn string) (bool, error) { + db, err := sql.Open("mysql", dsn) + if err != nil { + return false, err + } + defer db.Close() + db.SetMaxOpenConns(1) + db.SetMaxIdleConns(0) + + _, err = db.Exec("select 1") + if err != nil { + return false, err + } + return true, nil +} diff --git a/pkg/js/libs/postgres/postgres.go b/pkg/js/libs/postgres/postgres.go index 8d309f3a3..e3c21bf60 100644 --- a/pkg/js/libs/postgres/postgres.go +++ b/pkg/js/libs/postgres/postgres.go @@ -59,11 +59,10 @@ func (c *PGClient) Connect(host string, port int, username, password string) (bo // ExecuteQuery connects to Postgres database using given credentials and database name. // and executes a query on the db. -func (c *PGClient) ExecuteQuery(host string, port int, username, password, dbName, query string) (string, error) { - +func (c *PGClient) ExecuteQuery(host string, port int, username, password, dbName, query string) (*utils.SQLResult, error) { if !protocolstate.IsHostAllowed(host) { // host is not valid according to network policy - return "", protocolstate.ErrHostDenied.Msgf(host) + return nil, protocolstate.ErrHostDenied.Msgf(host) } target := net.JoinHostPort(host, fmt.Sprintf("%d", port)) @@ -71,18 +70,18 @@ func (c *PGClient) ExecuteQuery(host string, port int, username, password, dbNam connStr := fmt.Sprintf("postgres://%s:%s@%s/%s?sslmode=disable", username, password, target, dbName) db, err := sql.Open("postgres", connStr) if err != nil { - return "", err + return nil, err } rows, err := db.Query(query) if err != nil { - return "", err + return nil, err } resp, err := utils.UnmarshalSQLRows(rows) if err != nil { - return "", err + return nil, err } - return string(resp), nil + return resp, nil } // ConnectWithDB connects to Postgres database using given credentials and database name. diff --git a/pkg/js/utils/util.go b/pkg/js/utils/util.go index d801a5174..df08fb414 100644 --- a/pkg/js/utils/util.go +++ b/pkg/js/utils/util.go @@ -2,29 +2,41 @@ package utils import ( "database/sql" - - jsoniter "github.com/json-iterator/go" ) -// UnmarshalSQLRows unmarshals sql rows to json +// SQLResult holds the result of a SQL query. // -// This function provides a way to unmarshal arbitrary sql rows -// to json. -func UnmarshalSQLRows(rows *sql.Rows) ([]byte, error) { +// It contains the count of rows, the columns present, and the actual row data. +type SQLResult struct { + Count int // Count is the number of rows returned. + Columns []string // Columns is the slice of column names. + Rows []interface{} // Rows is a slice of row data, where each row is a map of column name to value. +} + +// UnmarshalSQLRows converts sql.Rows into a more structured SQLResult. +// +// This function takes *sql.Rows as input and attempts to unmarshal the data into +// a SQLResult struct. It handles different SQL data types by using the appropriate +// sql.Null* types during scanning. It returns a pointer to a SQLResult or an error. +// +// The function closes the sql.Rows when finished. +func UnmarshalSQLRows(rows *sql.Rows) (*SQLResult, error) { + defer rows.Close() columnTypes, err := rows.ColumnTypes() if err != nil { return nil, err } + result := &SQLResult{} + result.Columns, err = rows.Columns() + if err != nil { + return nil, err + } count := len(columnTypes) - finalRows := []interface{}{} - for rows.Next() { - + result.Count++ scanArgs := make([]interface{}, count) - for i, v := range columnTypes { - switch v.DatabaseTypeName() { case "VARCHAR", "TEXT", "UUID", "TIMESTAMP": scanArgs[i] = new(sql.NullString) @@ -36,17 +48,13 @@ func UnmarshalSQLRows(rows *sql.Rows) ([]byte, error) { scanArgs[i] = new(sql.NullString) } } - err := rows.Scan(scanArgs...) - if err != nil { - return nil, err + // Return the result accumulated so far along with the error. + return result, err } - - masterData := map[string]interface{}{} - + masterData := make(map[string]interface{}) for i, v := range columnTypes { - if z, ok := (scanArgs[i]).(*sql.NullBool); ok { masterData[v.Name()] = z.Bool continue @@ -74,8 +82,7 @@ func UnmarshalSQLRows(rows *sql.Rows) ([]byte, error) { masterData[v.Name()] = scanArgs[i] } - - finalRows = append(finalRows, masterData) + result.Rows = append(result.Rows, masterData) } - return jsoniter.Marshal(finalRows) + return result, nil } diff --git a/pkg/protocols/javascript/js.go b/pkg/protocols/javascript/js.go index 0fa43213c..3ca587e49 100644 --- a/pkg/protocols/javascript/js.go +++ b/pkg/protocols/javascript/js.go @@ -151,6 +151,7 @@ func (request *Request) Compile(options *protocols.ExecutorOptions) error { opts := &compiler.ExecuteOptions{ Timeout: request.Timeout, + Source: &request.Init, } // register 'export' function to export variables from init code // these are saved in args and are available in pre-condition and request code @@ -212,7 +213,7 @@ func (request *Request) Compile(options *protocols.ExecutorOptions) error { // proceed with whatever args we have args.Args, _ = request.evaluateArgs(allVars, options, true) - initCompiled, err := goja.Compile("", request.Init, false) + initCompiled, err := compiler.WrapScriptNCompile(request.Init, false) if err != nil { return errorutil.NewWithTag(request.TemplateID, "could not compile init code: %s", err) } @@ -233,7 +234,7 @@ func (request *Request) Compile(options *protocols.ExecutorOptions) error { // compile pre-condition if any if request.PreCondition != "" { - preConditionCompiled, err := goja.Compile("", request.PreCondition, false) + preConditionCompiled, err := compiler.WrapScriptNCompile(request.PreCondition, false) if err != nil { return errorutil.NewWithTag(request.TemplateID, "could not compile pre-condition: %s", err) } @@ -242,7 +243,7 @@ func (request *Request) Compile(options *protocols.ExecutorOptions) error { // compile actual source code if request.Code != "" { - scriptCompiled, err := goja.Compile("", request.Code, false) + scriptCompiled, err := compiler.WrapScriptNCompile(request.Code, false) if err != nil { return errorutil.NewWithTag(request.TemplateID, "could not compile javascript code: %s", err) } @@ -339,7 +340,8 @@ func (request *Request) ExecuteWithResults(target *contextargs.Context, dynamicV } argsCopy.TemplateCtx = templateCtx.GetAll() - result, err := request.options.JsCompiler.ExecuteWithOptions(request.preConditionCompiled, argsCopy, &compiler.ExecuteOptions{Timeout: request.Timeout}) + result, err := request.options.JsCompiler.ExecuteWithOptions(request.preConditionCompiled, argsCopy, + &compiler.ExecuteOptions{Timeout: request.Timeout, Source: &request.PreCondition}) if err != nil { return errorutil.NewWithTag(request.TemplateID, "could not execute pre-condition: %s", err) } @@ -471,7 +473,8 @@ func (request *Request) executeRequestWithPayloads(hostPort string, input *conte } } - results, err := request.options.JsCompiler.ExecuteWithOptions(request.scriptCompiled, argsCopy, &compiler.ExecuteOptions{Timeout: request.Timeout}) + results, err := request.options.JsCompiler.ExecuteWithOptions(request.scriptCompiled, argsCopy, + &compiler.ExecuteOptions{Timeout: request.Timeout, Source: &request.Code}) if err != nil { // shouldn't fail even if it returned error instead create a failure event results = compiler.ExecuteResult{"success": false, "error": err.Error()} diff --git a/pkg/tmplexec/exec.go b/pkg/tmplexec/exec.go index ca9b54bf5..a2facd108 100644 --- a/pkg/tmplexec/exec.go +++ b/pkg/tmplexec/exec.go @@ -8,6 +8,7 @@ import ( "github.com/dop251/goja" "github.com/projectdiscovery/gologger" + "github.com/projectdiscovery/nuclei/v3/pkg/js/compiler" "github.com/projectdiscovery/nuclei/v3/pkg/operators/common/dsl" "github.com/projectdiscovery/nuclei/v3/pkg/output" "github.com/projectdiscovery/nuclei/v3/pkg/protocols" @@ -48,7 +49,7 @@ func NewTemplateExecuter(requests []protocols.Request, options *protocols.Execut // we use a dummy input here because goal of flow executor at this point is to just check // syntax and other things are correct before proceeding to actual execution // during execution new instance of flow will be created as it is tightly coupled with lot of executor options - p, err := goja.Compile("flow.js", options.Flow, false) + p, err := compiler.WrapScriptNCompile(options.Flow, false) if err != nil { return nil, fmt.Errorf("could not compile flow: %s", err) }