Add split DSL function (#2838)

* Add support for showing overloaded DSL method signatures

* Add `split` DSL function #2837

* fixing lint warnings

* replacing faulty regex with strings methods

Co-authored-by: Mzack9999 <mzack9999@protonmail.com>
Co-authored-by: mzack <marco.rivoli.nvh@gmail.com>
This commit is contained in:
forgedhallpass 2022-11-14 02:38:12 +02:00 committed by GitHub
parent 2403c50c36
commit 0295ca19bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 106 additions and 37 deletions

View File

@ -7,6 +7,7 @@ info:
requests: requests:
- raw: - raw:
# Note for the integration test: dsl expression should not contain commas
- | - |
GET / HTTP/1.1 GET / HTTP/1.1
Host: {{Hostname}} Host: {{Hostname}}
@ -90,9 +91,11 @@ requests:
78: {{line_starts_with("Hi\nHello", "He")}} 78: {{line_starts_with("Hi\nHello", "He")}}
79: {{line_ends_with("Hello\nHi", "lo")}} 79: {{line_ends_with("Hello\nHi", "lo")}}
80: {{sort("a1b2c3d4e5")}} 80: {{sort("a1b2c3d4e5")}}
81: {{join(" ", sort("b", "a", "2", "c", "3", "1", "d", "4"))}} 81: {{uniq("abcabdaabbccd")}}
82: {{uniq("abcabdaabbccd")}} 82: {{join(" ", sort("b", "a", "2", "c", "3", "1", "d", "4"))}}
83: {{join(" ", uniq("ab", "cd", "12", "34", "12", "cd"))}} 83: {{join(" ", uniq("ab", "cd", "12", "34", "12", "cd"))}}
84: {{split("ab,cd,efg", ",")}}
85: {{split("ab,cd,efg", ",", 2)}}
extractors: extractors:
- type: regex - type: regex

View File

@ -8,7 +8,6 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"net/http/httputil" "net/http/httputil"
"regexp"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -16,6 +15,7 @@ import (
"github.com/julienschmidt/httprouter" "github.com/julienschmidt/httprouter"
"github.com/projectdiscovery/nuclei/v2/pkg/testutils" "github.com/projectdiscovery/nuclei/v2/pkg/testutils"
stringsutil "github.com/projectdiscovery/utils/strings"
) )
var httpTestcases = map[string]testutils.TestCase{ var httpTestcases = map[string]testutils.TestCase{
@ -284,19 +284,22 @@ func (h *httpDSLFunctions) Execute(filePath string) error {
return err return err
} }
resultPattern := regexp.MustCompile(`\[[^]]+] \[[^]]+] \[[^]]+] [^]]+ \[([^]]+)]`) // get result part
submatch := resultPattern.FindStringSubmatch(results[0]) resultPart, err := stringsutil.After(results[0], ts.URL)
if len(submatch) != 2 { if err != nil {
return errors.New("could not parse the result") return err
} }
totalExtracted := strings.Split(submatch[1], ",") // remove additional characters till the first valid result and ignore last ] which doesn't alter the total count
numberOfDslFunctions := 83 resultPart = stringsutil.TrimPrefixAny(resultPart, "/", " ", "[")
if len(totalExtracted) != numberOfDslFunctions {
extracted := strings.Split(resultPart, ",")
numberOfDslFunctions := 85
if len(extracted) != numberOfDslFunctions {
return errors.New("incorrect number of results") return errors.New("incorrect number of results")
} }
for _, header := range totalExtracted { for _, header := range extracted {
parts := strings.Split(header, ": ") parts := strings.Split(header, ": ")
index, err := strconv.Atoi(parts[0]) index, err := strconv.Atoi(parts[0])
if err != nil { if err != nil {

View File

@ -62,7 +62,7 @@ var functionSignaturePattern = regexp.MustCompile(`(\w+)\s*\((?:([\w\d,\s]+)\s+(
var dateFormatRegex = regexp.MustCompile("%([A-Za-z])") var dateFormatRegex = regexp.MustCompile("%([A-Za-z])")
type dslFunction struct { type dslFunction struct {
signature string signatures []string
expressFunc govaluate.ExpressionFunction expressFunc govaluate.ExpressionFunction
} }
@ -88,8 +88,10 @@ func init() {
"to_lower": makeDslFunction(1, func(args ...interface{}) (interface{}, error) { "to_lower": makeDslFunction(1, func(args ...interface{}) (interface{}, error) {
return strings.ToLower(types.ToString(args[0])), nil return strings.ToLower(types.ToString(args[0])), nil
}), }),
"sort": makeDslWithOptionalArgsFunction( "sort": makeMultiSignatureDslFunction([]string{
"(args ...interface{}) interface{}", "(input string) string",
"(input number) string",
"(elements ...interface{}) []interface{}"},
func(args ...interface{}) (interface{}, error) { func(args ...interface{}) (interface{}, error) {
argCount := len(args) argCount := len(args)
if argCount == 0 { if argCount == 0 {
@ -110,8 +112,10 @@ func init() {
} }
}, },
), ),
"uniq": makeDslWithOptionalArgsFunction( "uniq": makeMultiSignatureDslFunction([]string{
"(args ...interface{}) interface{}", "(input string) string",
"(input number) string",
"(elements ...interface{}) []interface{}"},
func(args ...interface{}) (interface{}, error) { func(args ...interface{}) (interface{}, error) {
argCount := len(args) argCount := len(args)
if argCount == 0 { if argCount == 0 {
@ -410,12 +414,40 @@ func init() {
return builder.String(), nil return builder.String(), nil
}, },
), ),
"join": makeDslWithOptionalArgsFunction( "split": makeMultiSignatureDslFunction([]string{
"(input string, n int) []string",
"(input string, separator string, optionalChunkSize) []string"},
func(arguments ...interface{}) (interface{}, error) {
argumentsSize := len(arguments)
if argumentsSize == 2 {
input := types.ToString(arguments[0])
separatorOrCount := types.ToString(arguments[1])
count, err := strconv.Atoi(separatorOrCount)
if err != nil {
return strings.SplitN(input, separatorOrCount, -1), nil
}
return toChunks(input, count), nil
} else if argumentsSize == 3 {
input := types.ToString(arguments[0])
separator := types.ToString(arguments[1])
count, err := strconv.Atoi(types.ToString(arguments[2]))
if err != nil {
return nil, invalidDslFunctionError
}
return strings.SplitN(input, separator, count), nil
} else {
return nil, invalidDslFunctionError
}
},
),
"join": makeMultiSignatureDslFunction([]string{
"(separator string, elements ...interface{}) string", "(separator string, elements ...interface{}) string",
"(separator string, elements []interface{}) string"},
func(arguments ...interface{}) (interface{}, error) { func(arguments ...interface{}) (interface{}, error) {
argumentsSize := len(arguments) argumentsSize := len(arguments)
if argumentsSize < 2 { if argumentsSize < 2 {
return nil, errors.New("incorrect number of arguments received") return nil, invalidDslFunctionError
} else if argumentsSize == 2 { } else if argumentsSize == 2 {
separator := types.ToString(arguments[0]) separator := types.ToString(arguments[0])
elements, ok := arguments[1].([]string) elements, ok := arguments[1].([]string)
@ -431,7 +463,6 @@ func init() {
stringElements := make([]string, 0, argumentsSize) stringElements := make([]string, 0, argumentsSize)
for _, element := range elements { for _, element := range elements {
if _, ok := element.([]string); ok { if _, ok := element.([]string); ok {
return nil, errors.New("cannot use join on more than one slice element") return nil, errors.New("cannot use join on more than one slice element")
} }
@ -794,9 +825,18 @@ func init() {
} }
func makeDslWithOptionalArgsFunction(signaturePart string, dslFunctionLogic govaluate.ExpressionFunction) func(functionName string) dslFunction { func makeDslWithOptionalArgsFunction(signaturePart string, dslFunctionLogic govaluate.ExpressionFunction) func(functionName string) dslFunction {
return makeMultiSignatureDslFunction([]string{signaturePart}, dslFunctionLogic)
}
func makeMultiSignatureDslFunction(signatureParts []string, dslFunctionLogic govaluate.ExpressionFunction) func(functionName string) dslFunction {
return func(functionName string) dslFunction { return func(functionName string) dslFunction {
methodSignatures := make([]string, 0, len(signatureParts))
for _, signaturePart := range signatureParts {
methodSignatures = append(methodSignatures, functionName+signaturePart)
}
return dslFunction{ return dslFunction{
functionName + signaturePart, methodSignatures,
dslFunctionLogic, dslFunctionLogic,
} }
} }
@ -806,7 +846,7 @@ func makeDslFunction(numberOfParameters int, dslFunctionLogic govaluate.Expressi
return func(functionName string) dslFunction { return func(functionName string) dslFunction {
signature := functionName + createSignaturePart(numberOfParameters) signature := functionName + createSignaturePart(numberOfParameters)
return dslFunction{ return dslFunction{
signature, []string{signature},
func(args ...interface{}) (interface{}, error) { func(args ...interface{}) (interface{}, error) {
if len(args) != numberOfParameters { if len(args) != numberOfParameters {
return nil, fmt.Errorf(invalidDslFunctionMessageTemplate, invalidDslFunctionError, signature) return nil, fmt.Errorf(invalidDslFunctionMessageTemplate, invalidDslFunctionError, signature)
@ -843,7 +883,7 @@ func helperFunctions() map[string]govaluate.ExpressionFunction {
func AddHelperFunction(key string, value func(args ...interface{}) (interface{}, error)) error { func AddHelperFunction(key string, value func(args ...interface{}) (interface{}, error)) error {
if _, ok := dslFunctions[key]; !ok { if _, ok := dslFunctions[key]; !ok {
dslFunction := dslFunctions[key] dslFunction := dslFunctions[key]
dslFunction.signature = "(args ...interface{}) interface{}" dslFunction.signatures = []string{"(args ...interface{}) interface{}"}
dslFunction.expressFunc = value dslFunction.expressFunc = value
return nil return nil
} }
@ -873,7 +913,7 @@ func getDslFunctionSignatures() []string {
result := make([]string, 0, len(dslFunctions)) result := make([]string, 0, len(dslFunctions))
for _, dslFunction := range dslFunctions { for _, dslFunction := range dslFunctions {
result = append(result, dslFunction.signature) result = append(result, dslFunction.signatures...)
} }
return result return result
@ -1028,6 +1068,25 @@ func stringNumberToDecimal(args []interface{}, prefix string, base int) (interfa
return nil, fmt.Errorf("invalid number: %s", input) return nil, fmt.Errorf("invalid number: %s", input)
} }
func toChunks(input string, chunkSize int) []string {
if chunkSize <= 0 || chunkSize >= len(input) {
return []string{input}
}
var chunks = make([]string, 0, (len(input)-1)/chunkSize+1)
currentLength := 0
currentStart := 0
for i := range input {
if currentLength == chunkSize {
chunks = append(chunks, input[currentStart:i])
currentLength = 0
currentStart = i
}
currentLength++
}
chunks = append(chunks, input[currentStart:])
return chunks
}
type CompilationError struct { type CompilationError struct {
DslSignature string DslSignature string
WrappedError error WrappedError error

View File

@ -4,7 +4,6 @@ import (
"fmt" "fmt"
"math" "math"
"regexp" "regexp"
"strconv"
"testing" "testing"
"time" "time"
@ -118,6 +117,7 @@ func TestGetPrintableDslFunctionSignatures(t *testing.T) {
html_escape(arg1 interface{}) interface{} html_escape(arg1 interface{}) interface{}
html_unescape(arg1 interface{}) interface{} html_unescape(arg1 interface{}) interface{}
join(separator string, elements ...interface{}) string join(separator string, elements ...interface{}) string
join(separator string, elements []interface{}) string
len(arg1 interface{}) interface{} len(arg1 interface{}) interface{}
line_ends_with(str string, suffix ...string) bool line_ends_with(str string, suffix ...string) bool
line_starts_with(str string, prefix ...string) bool line_starts_with(str string, prefix ...string) bool
@ -141,7 +141,11 @@ func TestGetPrintableDslFunctionSignatures(t *testing.T) {
sha1(arg1 interface{}) interface{} sha1(arg1 interface{}) interface{}
sha256(arg1 interface{}) interface{} sha256(arg1 interface{}) interface{}
sha512(arg1 interface{}) interface{} sha512(arg1 interface{}) interface{}
sort(args ...interface{}) interface{} sort(elements ...interface{}) []interface{}
sort(input number) string
sort(input string) string
split(input string, n int) []string
split(input string, separator string, optionalChunkSize) []string
starts_with(str string, prefix ...string) bool starts_with(str string, prefix ...string) bool
substr(str string, start int, optionalEnd int) substr(str string, start int, optionalEnd int)
to_lower(arg1 interface{}) interface{} to_lower(arg1 interface{}) interface{}
@ -155,7 +159,9 @@ func TestGetPrintableDslFunctionSignatures(t *testing.T) {
trim_right(arg1, arg2 interface{}) interface{} trim_right(arg1, arg2 interface{}) interface{}
trim_space(arg1 interface{}) interface{} trim_space(arg1 interface{}) interface{}
trim_suffix(arg1, arg2 interface{}) interface{} trim_suffix(arg1, arg2 interface{}) interface{}
uniq(args ...interface{}) interface{} uniq(elements ...interface{}) []interface{}
uniq(input number) string
uniq(input string) string
unix_time(optionalSeconds uint) float64 unix_time(optionalSeconds uint) float64
url_decode(arg1 interface{}) interface{} url_decode(arg1 interface{}) interface{}
url_encode(arg1 interface{}) interface{} url_encode(arg1 interface{}) interface{}
@ -172,7 +178,6 @@ func TestGetPrintableDslFunctionSignatures(t *testing.T) {
} }
func TestDslExpressions(t *testing.T) { func TestDslExpressions(t *testing.T) {
dslExpressions := map[string]interface{}{ dslExpressions := map[string]interface{}{
`base64("Hello")`: "SGVsbG8=", `base64("Hello")`: "SGVsbG8=",
`base64(1234)`: "MTIzNA==", `base64(1234)`: "MTIzNA==",
@ -244,27 +249,27 @@ func TestDslExpressions(t *testing.T) {
`substr('xxtestxxx',2)`: "testxxx", `substr('xxtestxxx',2)`: "testxxx",
`substr('xxtestxxx',2,-2)`: "testx", `substr('xxtestxxx',2,-2)`: "testx",
`substr('xxtestxxx',2,6)`: "test", `substr('xxtestxxx',2,6)`: "test",
`sort(12453)`: "12345",
`sort("a1b2c3d4e5")`: "12345abcde", `sort("a1b2c3d4e5")`: "12345abcde",
`sort("b", "a", "2", "c", "3", "1", "d", "4")`: []string{"1", "2", "3", "4", "a", "b", "c", "d"}, `sort("b", "a", "2", "c", "3", "1", "d", "4")`: []string{"1", "2", "3", "4", "a", "b", "c", "d"},
`split("abcdefg", 2)`: []string{"ab", "cd", "ef", "g"},
`split("ab,cd,efg", ",", 1)`: []string{"ab,cd,efg"},
`split("ab,cd,efg", ",", 2)`: []string{"ab", "cd,efg"},
`split("ab,cd,efg", ",", "3")`: []string{"ab", "cd", "efg"},
`split("ab,cd,efg", ",", -1)`: []string{"ab", "cd", "efg"},
`split("ab,cd,efg", ",")`: []string{"ab", "cd", "efg"},
`join(" ", sort("b", "a", "2", "c", "3", "1", "d", "4"))`: "1 2 3 4 a b c d", `join(" ", sort("b", "a", "2", "c", "3", "1", "d", "4"))`: "1 2 3 4 a b c d",
`uniq(123123231)`: "123",
`uniq("abcabdaabbccd")`: "abcd", `uniq("abcabdaabbccd")`: "abcd",
`uniq("ab", "cd", "12", "34", "12", "cd")`: []string{"ab", "cd", "12", "34"}, `uniq("ab", "cd", "12", "34", "12", "cd")`: []string{"ab", "cd", "12", "34"},
`join(" ", uniq("ab", "cd", "12", "34", "12", "cd"))`: "ab cd 12 34", `join(" ", uniq("ab", "cd", "12", "34", "12", "cd"))`: "ab cd 12 34",
`join(", ", split(hex_encode("abcdefg"), 2))`: "61, 62, 63, 64, 65, 66, 67",
} }
testDslExpressionScenarios(t, dslExpressions) testDslExpressionScenarios(t, dslExpressions)
} }
func Test(t *testing.T) {
if number, err := strconv.ParseInt("0o1234567", 0, 64); err == nil {
fmt.Println(number)
} else {
fmt.Println(err)
}
}
func TestDateTimeDSLFunction(t *testing.T) { func TestDateTimeDSLFunction(t *testing.T) {
testDateTimeFormat := func(t *testing.T, dateTimeFormat string, dateTimeFunction *govaluate.EvaluableExpression, expectedFormattedTime string, currentUnixTime int64) { testDateTimeFormat := func(t *testing.T, dateTimeFormat string, dateTimeFunction *govaluate.EvaluableExpression, expectedFormattedTime string, currentUnixTime int64) {
dslFunctionParameters := map[string]interface{}{"dateTimeFormat": dateTimeFormat} dslFunctionParameters := map[string]interface{}{"dateTimeFormat": dateTimeFormat}
@ -302,7 +307,6 @@ func TestDateTimeDSLFunction(t *testing.T) {
} }
func TestDateTimeDslExpressions(t *testing.T) { func TestDateTimeDslExpressions(t *testing.T) {
t.Run("date_time", func(t *testing.T) { t.Run("date_time", func(t *testing.T) {
now := time.Now() now := time.Now()