diff --git a/integration_tests/http/dsl-functions.yaml b/integration_tests/http/dsl-functions.yaml index 8545c58f9..f311394a9 100644 --- a/integration_tests/http/dsl-functions.yaml +++ b/integration_tests/http/dsl-functions.yaml @@ -7,6 +7,7 @@ info: requests: - raw: + # Note for the integration test: dsl expression should not contain commas - | GET / HTTP/1.1 Host: {{Hostname}} @@ -90,9 +91,11 @@ requests: 78: {{line_starts_with("Hi\nHello", "He")}} 79: {{line_ends_with("Hello\nHi", "lo")}} 80: {{sort("a1b2c3d4e5")}} - 81: {{join(" ", sort("b", "a", "2", "c", "3", "1", "d", "4"))}} - 82: {{uniq("abcabdaabbccd")}} + 81: {{uniq("abcabdaabbccd")}} + 82: {{join(" ", sort("b", "a", "2", "c", "3", "1", "d", "4"))}} 83: {{join(" ", uniq("ab", "cd", "12", "34", "12", "cd"))}} + 84: {{split("ab,cd,efg", ",")}} + 85: {{split("ab,cd,efg", ",", 2)}} extractors: - type: regex diff --git a/v2/cmd/integration-test/http.go b/v2/cmd/integration-test/http.go index 9f1f55486..6df7aceed 100644 --- a/v2/cmd/integration-test/http.go +++ b/v2/cmd/integration-test/http.go @@ -8,7 +8,6 @@ import ( "net/http" "net/http/httptest" "net/http/httputil" - "regexp" "strconv" "strings" "time" @@ -16,6 +15,7 @@ import ( "github.com/julienschmidt/httprouter" "github.com/projectdiscovery/nuclei/v2/pkg/testutils" + stringsutil "github.com/projectdiscovery/utils/strings" ) var httpTestcases = map[string]testutils.TestCase{ @@ -284,19 +284,22 @@ func (h *httpDSLFunctions) Execute(filePath string) error { return err } - resultPattern := regexp.MustCompile(`\[[^]]+] \[[^]]+] \[[^]]+] [^]]+ \[([^]]+)]`) - submatch := resultPattern.FindStringSubmatch(results[0]) - if len(submatch) != 2 { - return errors.New("could not parse the result") + // get result part + resultPart, err := stringsutil.After(results[0], ts.URL) + if err != nil { + return err } - totalExtracted := strings.Split(submatch[1], ",") - numberOfDslFunctions := 83 - if len(totalExtracted) != numberOfDslFunctions { + // remove additional characters till the first valid result and ignore last ] which doesn't alter the total count + resultPart = stringsutil.TrimPrefixAny(resultPart, "/", " ", "[") + + extracted := strings.Split(resultPart, ",") + numberOfDslFunctions := 85 + if len(extracted) != numberOfDslFunctions { return errors.New("incorrect number of results") } - for _, header := range totalExtracted { + for _, header := range extracted { parts := strings.Split(header, ": ") index, err := strconv.Atoi(parts[0]) if err != nil { diff --git a/v2/pkg/operators/common/dsl/dsl.go b/v2/pkg/operators/common/dsl/dsl.go index 7e7e365ce..26978c221 100644 --- a/v2/pkg/operators/common/dsl/dsl.go +++ b/v2/pkg/operators/common/dsl/dsl.go @@ -62,7 +62,7 @@ var functionSignaturePattern = regexp.MustCompile(`(\w+)\s*\((?:([\w\d,\s]+)\s+( var dateFormatRegex = regexp.MustCompile("%([A-Za-z])") type dslFunction struct { - signature string + signatures []string expressFunc govaluate.ExpressionFunction } @@ -88,8 +88,10 @@ func init() { "to_lower": makeDslFunction(1, func(args ...interface{}) (interface{}, error) { return strings.ToLower(types.ToString(args[0])), nil }), - "sort": makeDslWithOptionalArgsFunction( - "(args ...interface{}) interface{}", + "sort": makeMultiSignatureDslFunction([]string{ + "(input string) string", + "(input number) string", + "(elements ...interface{}) []interface{}"}, func(args ...interface{}) (interface{}, error) { argCount := len(args) if argCount == 0 { @@ -110,8 +112,10 @@ func init() { } }, ), - "uniq": makeDslWithOptionalArgsFunction( - "(args ...interface{}) interface{}", + "uniq": makeMultiSignatureDslFunction([]string{ + "(input string) string", + "(input number) string", + "(elements ...interface{}) []interface{}"}, func(args ...interface{}) (interface{}, error) { argCount := len(args) if argCount == 0 { @@ -410,12 +414,40 @@ func init() { return builder.String(), nil }, ), - "join": makeDslWithOptionalArgsFunction( + "split": makeMultiSignatureDslFunction([]string{ + "(input string, n int) []string", + "(input string, separator string, optionalChunkSize) []string"}, + func(arguments ...interface{}) (interface{}, error) { + argumentsSize := len(arguments) + if argumentsSize == 2 { + input := types.ToString(arguments[0]) + separatorOrCount := types.ToString(arguments[1]) + + count, err := strconv.Atoi(separatorOrCount) + if err != nil { + return strings.SplitN(input, separatorOrCount, -1), nil + } + return toChunks(input, count), nil + } else if argumentsSize == 3 { + input := types.ToString(arguments[0]) + separator := types.ToString(arguments[1]) + count, err := strconv.Atoi(types.ToString(arguments[2])) + if err != nil { + return nil, invalidDslFunctionError + } + return strings.SplitN(input, separator, count), nil + } else { + return nil, invalidDslFunctionError + } + }, + ), + "join": makeMultiSignatureDslFunction([]string{ "(separator string, elements ...interface{}) string", + "(separator string, elements []interface{}) string"}, func(arguments ...interface{}) (interface{}, error) { argumentsSize := len(arguments) if argumentsSize < 2 { - return nil, errors.New("incorrect number of arguments received") + return nil, invalidDslFunctionError } else if argumentsSize == 2 { separator := types.ToString(arguments[0]) elements, ok := arguments[1].([]string) @@ -431,7 +463,6 @@ func init() { stringElements := make([]string, 0, argumentsSize) for _, element := range elements { - if _, ok := element.([]string); ok { return nil, errors.New("cannot use join on more than one slice element") } @@ -794,9 +825,18 @@ func init() { } func makeDslWithOptionalArgsFunction(signaturePart string, dslFunctionLogic govaluate.ExpressionFunction) func(functionName string) dslFunction { + return makeMultiSignatureDslFunction([]string{signaturePart}, dslFunctionLogic) +} + +func makeMultiSignatureDslFunction(signatureParts []string, dslFunctionLogic govaluate.ExpressionFunction) func(functionName string) dslFunction { return func(functionName string) dslFunction { + methodSignatures := make([]string, 0, len(signatureParts)) + for _, signaturePart := range signatureParts { + methodSignatures = append(methodSignatures, functionName+signaturePart) + } + return dslFunction{ - functionName + signaturePart, + methodSignatures, dslFunctionLogic, } } @@ -806,7 +846,7 @@ func makeDslFunction(numberOfParameters int, dslFunctionLogic govaluate.Expressi return func(functionName string) dslFunction { signature := functionName + createSignaturePart(numberOfParameters) return dslFunction{ - signature, + []string{signature}, func(args ...interface{}) (interface{}, error) { if len(args) != numberOfParameters { return nil, fmt.Errorf(invalidDslFunctionMessageTemplate, invalidDslFunctionError, signature) @@ -843,7 +883,7 @@ func helperFunctions() map[string]govaluate.ExpressionFunction { func AddHelperFunction(key string, value func(args ...interface{}) (interface{}, error)) error { if _, ok := dslFunctions[key]; !ok { dslFunction := dslFunctions[key] - dslFunction.signature = "(args ...interface{}) interface{}" + dslFunction.signatures = []string{"(args ...interface{}) interface{}"} dslFunction.expressFunc = value return nil } @@ -873,7 +913,7 @@ func getDslFunctionSignatures() []string { result := make([]string, 0, len(dslFunctions)) for _, dslFunction := range dslFunctions { - result = append(result, dslFunction.signature) + result = append(result, dslFunction.signatures...) } return result @@ -1028,6 +1068,25 @@ func stringNumberToDecimal(args []interface{}, prefix string, base int) (interfa return nil, fmt.Errorf("invalid number: %s", input) } +func toChunks(input string, chunkSize int) []string { + if chunkSize <= 0 || chunkSize >= len(input) { + return []string{input} + } + var chunks = make([]string, 0, (len(input)-1)/chunkSize+1) + currentLength := 0 + currentStart := 0 + for i := range input { + if currentLength == chunkSize { + chunks = append(chunks, input[currentStart:i]) + currentLength = 0 + currentStart = i + } + currentLength++ + } + chunks = append(chunks, input[currentStart:]) + return chunks +} + type CompilationError struct { DslSignature string WrappedError error diff --git a/v2/pkg/operators/common/dsl/dsl_test.go b/v2/pkg/operators/common/dsl/dsl_test.go index ba302952d..74fe9f58b 100644 --- a/v2/pkg/operators/common/dsl/dsl_test.go +++ b/v2/pkg/operators/common/dsl/dsl_test.go @@ -4,7 +4,6 @@ import ( "fmt" "math" "regexp" - "strconv" "testing" "time" @@ -118,6 +117,7 @@ func TestGetPrintableDslFunctionSignatures(t *testing.T) { html_escape(arg1 interface{}) interface{} html_unescape(arg1 interface{}) interface{} join(separator string, elements ...interface{}) string + join(separator string, elements []interface{}) string len(arg1 interface{}) interface{} line_ends_with(str string, suffix ...string) bool line_starts_with(str string, prefix ...string) bool @@ -141,7 +141,11 @@ func TestGetPrintableDslFunctionSignatures(t *testing.T) { sha1(arg1 interface{}) interface{} sha256(arg1 interface{}) interface{} sha512(arg1 interface{}) interface{} - sort(args ...interface{}) interface{} + sort(elements ...interface{}) []interface{} + sort(input number) string + sort(input string) string + split(input string, n int) []string + split(input string, separator string, optionalChunkSize) []string starts_with(str string, prefix ...string) bool substr(str string, start int, optionalEnd int) to_lower(arg1 interface{}) interface{} @@ -155,7 +159,9 @@ func TestGetPrintableDslFunctionSignatures(t *testing.T) { trim_right(arg1, arg2 interface{}) interface{} trim_space(arg1 interface{}) interface{} trim_suffix(arg1, arg2 interface{}) interface{} - uniq(args ...interface{}) interface{} + uniq(elements ...interface{}) []interface{} + uniq(input number) string + uniq(input string) string unix_time(optionalSeconds uint) float64 url_decode(arg1 interface{}) interface{} url_encode(arg1 interface{}) interface{} @@ -172,7 +178,6 @@ func TestGetPrintableDslFunctionSignatures(t *testing.T) { } func TestDslExpressions(t *testing.T) { - dslExpressions := map[string]interface{}{ `base64("Hello")`: "SGVsbG8=", `base64(1234)`: "MTIzNA==", @@ -244,27 +249,27 @@ func TestDslExpressions(t *testing.T) { `substr('xxtestxxx',2)`: "testxxx", `substr('xxtestxxx',2,-2)`: "testx", `substr('xxtestxxx',2,6)`: "test", + `sort(12453)`: "12345", `sort("a1b2c3d4e5")`: "12345abcde", `sort("b", "a", "2", "c", "3", "1", "d", "4")`: []string{"1", "2", "3", "4", "a", "b", "c", "d"}, + `split("abcdefg", 2)`: []string{"ab", "cd", "ef", "g"}, + `split("ab,cd,efg", ",", 1)`: []string{"ab,cd,efg"}, + `split("ab,cd,efg", ",", 2)`: []string{"ab", "cd,efg"}, + `split("ab,cd,efg", ",", "3")`: []string{"ab", "cd", "efg"}, + `split("ab,cd,efg", ",", -1)`: []string{"ab", "cd", "efg"}, + `split("ab,cd,efg", ",")`: []string{"ab", "cd", "efg"}, `join(" ", sort("b", "a", "2", "c", "3", "1", "d", "4"))`: "1 2 3 4 a b c d", + `uniq(123123231)`: "123", `uniq("abcabdaabbccd")`: "abcd", `uniq("ab", "cd", "12", "34", "12", "cd")`: []string{"ab", "cd", "12", "34"}, `join(" ", uniq("ab", "cd", "12", "34", "12", "cd"))`: "ab cd 12 34", + `join(", ", split(hex_encode("abcdefg"), 2))`: "61, 62, 63, 64, 65, 66, 67", } testDslExpressionScenarios(t, dslExpressions) } -func Test(t *testing.T) { - if number, err := strconv.ParseInt("0o1234567", 0, 64); err == nil { - fmt.Println(number) - } else { - fmt.Println(err) - } -} - func TestDateTimeDSLFunction(t *testing.T) { - testDateTimeFormat := func(t *testing.T, dateTimeFormat string, dateTimeFunction *govaluate.EvaluableExpression, expectedFormattedTime string, currentUnixTime int64) { dslFunctionParameters := map[string]interface{}{"dateTimeFormat": dateTimeFormat} @@ -302,7 +307,6 @@ func TestDateTimeDSLFunction(t *testing.T) { } func TestDateTimeDslExpressions(t *testing.T) { - t.Run("date_time", func(t *testing.T) { now := time.Now()