Add split DSL function (#2838)

* Add support for showing overloaded DSL method signatures

* Add `split` DSL function #2837

* fixing lint warnings

* replacing faulty regex with strings methods

Co-authored-by: Mzack9999 <mzack9999@protonmail.com>
Co-authored-by: mzack <marco.rivoli.nvh@gmail.com>
This commit is contained in:
forgedhallpass 2022-11-14 02:38:12 +02:00 committed by GitHub
parent 2403c50c36
commit 0295ca19bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 106 additions and 37 deletions

View File

@ -7,6 +7,7 @@ info:
requests:
- raw:
# Note for the integration test: dsl expression should not contain commas
- |
GET / HTTP/1.1
Host: {{Hostname}}
@ -90,9 +91,11 @@ requests:
78: {{line_starts_with("Hi\nHello", "He")}}
79: {{line_ends_with("Hello\nHi", "lo")}}
80: {{sort("a1b2c3d4e5")}}
81: {{join(" ", sort("b", "a", "2", "c", "3", "1", "d", "4"))}}
82: {{uniq("abcabdaabbccd")}}
81: {{uniq("abcabdaabbccd")}}
82: {{join(" ", sort("b", "a", "2", "c", "3", "1", "d", "4"))}}
83: {{join(" ", uniq("ab", "cd", "12", "34", "12", "cd"))}}
84: {{split("ab,cd,efg", ",")}}
85: {{split("ab,cd,efg", ",", 2)}}
extractors:
- type: regex

View File

@ -8,7 +8,6 @@ import (
"net/http"
"net/http/httptest"
"net/http/httputil"
"regexp"
"strconv"
"strings"
"time"
@ -16,6 +15,7 @@ import (
"github.com/julienschmidt/httprouter"
"github.com/projectdiscovery/nuclei/v2/pkg/testutils"
stringsutil "github.com/projectdiscovery/utils/strings"
)
var httpTestcases = map[string]testutils.TestCase{
@ -284,19 +284,22 @@ func (h *httpDSLFunctions) Execute(filePath string) error {
return err
}
resultPattern := regexp.MustCompile(`\[[^]]+] \[[^]]+] \[[^]]+] [^]]+ \[([^]]+)]`)
submatch := resultPattern.FindStringSubmatch(results[0])
if len(submatch) != 2 {
return errors.New("could not parse the result")
// get result part
resultPart, err := stringsutil.After(results[0], ts.URL)
if err != nil {
return err
}
totalExtracted := strings.Split(submatch[1], ",")
numberOfDslFunctions := 83
if len(totalExtracted) != numberOfDslFunctions {
// remove additional characters till the first valid result and ignore last ] which doesn't alter the total count
resultPart = stringsutil.TrimPrefixAny(resultPart, "/", " ", "[")
extracted := strings.Split(resultPart, ",")
numberOfDslFunctions := 85
if len(extracted) != numberOfDslFunctions {
return errors.New("incorrect number of results")
}
for _, header := range totalExtracted {
for _, header := range extracted {
parts := strings.Split(header, ": ")
index, err := strconv.Atoi(parts[0])
if err != nil {

View File

@ -62,7 +62,7 @@ var functionSignaturePattern = regexp.MustCompile(`(\w+)\s*\((?:([\w\d,\s]+)\s+(
var dateFormatRegex = regexp.MustCompile("%([A-Za-z])")
type dslFunction struct {
signature string
signatures []string
expressFunc govaluate.ExpressionFunction
}
@ -88,8 +88,10 @@ func init() {
"to_lower": makeDslFunction(1, func(args ...interface{}) (interface{}, error) {
return strings.ToLower(types.ToString(args[0])), nil
}),
"sort": makeDslWithOptionalArgsFunction(
"(args ...interface{}) interface{}",
"sort": makeMultiSignatureDslFunction([]string{
"(input string) string",
"(input number) string",
"(elements ...interface{}) []interface{}"},
func(args ...interface{}) (interface{}, error) {
argCount := len(args)
if argCount == 0 {
@ -110,8 +112,10 @@ func init() {
}
},
),
"uniq": makeDslWithOptionalArgsFunction(
"(args ...interface{}) interface{}",
"uniq": makeMultiSignatureDslFunction([]string{
"(input string) string",
"(input number) string",
"(elements ...interface{}) []interface{}"},
func(args ...interface{}) (interface{}, error) {
argCount := len(args)
if argCount == 0 {
@ -410,12 +414,40 @@ func init() {
return builder.String(), nil
},
),
"join": makeDslWithOptionalArgsFunction(
"split": makeMultiSignatureDslFunction([]string{
"(input string, n int) []string",
"(input string, separator string, optionalChunkSize) []string"},
func(arguments ...interface{}) (interface{}, error) {
argumentsSize := len(arguments)
if argumentsSize == 2 {
input := types.ToString(arguments[0])
separatorOrCount := types.ToString(arguments[1])
count, err := strconv.Atoi(separatorOrCount)
if err != nil {
return strings.SplitN(input, separatorOrCount, -1), nil
}
return toChunks(input, count), nil
} else if argumentsSize == 3 {
input := types.ToString(arguments[0])
separator := types.ToString(arguments[1])
count, err := strconv.Atoi(types.ToString(arguments[2]))
if err != nil {
return nil, invalidDslFunctionError
}
return strings.SplitN(input, separator, count), nil
} else {
return nil, invalidDslFunctionError
}
},
),
"join": makeMultiSignatureDslFunction([]string{
"(separator string, elements ...interface{}) string",
"(separator string, elements []interface{}) string"},
func(arguments ...interface{}) (interface{}, error) {
argumentsSize := len(arguments)
if argumentsSize < 2 {
return nil, errors.New("incorrect number of arguments received")
return nil, invalidDslFunctionError
} else if argumentsSize == 2 {
separator := types.ToString(arguments[0])
elements, ok := arguments[1].([]string)
@ -431,7 +463,6 @@ func init() {
stringElements := make([]string, 0, argumentsSize)
for _, element := range elements {
if _, ok := element.([]string); ok {
return nil, errors.New("cannot use join on more than one slice element")
}
@ -794,9 +825,18 @@ func init() {
}
func makeDslWithOptionalArgsFunction(signaturePart string, dslFunctionLogic govaluate.ExpressionFunction) func(functionName string) dslFunction {
return makeMultiSignatureDslFunction([]string{signaturePart}, dslFunctionLogic)
}
func makeMultiSignatureDslFunction(signatureParts []string, dslFunctionLogic govaluate.ExpressionFunction) func(functionName string) dslFunction {
return func(functionName string) dslFunction {
methodSignatures := make([]string, 0, len(signatureParts))
for _, signaturePart := range signatureParts {
methodSignatures = append(methodSignatures, functionName+signaturePart)
}
return dslFunction{
functionName + signaturePart,
methodSignatures,
dslFunctionLogic,
}
}
@ -806,7 +846,7 @@ func makeDslFunction(numberOfParameters int, dslFunctionLogic govaluate.Expressi
return func(functionName string) dslFunction {
signature := functionName + createSignaturePart(numberOfParameters)
return dslFunction{
signature,
[]string{signature},
func(args ...interface{}) (interface{}, error) {
if len(args) != numberOfParameters {
return nil, fmt.Errorf(invalidDslFunctionMessageTemplate, invalidDslFunctionError, signature)
@ -843,7 +883,7 @@ func helperFunctions() map[string]govaluate.ExpressionFunction {
func AddHelperFunction(key string, value func(args ...interface{}) (interface{}, error)) error {
if _, ok := dslFunctions[key]; !ok {
dslFunction := dslFunctions[key]
dslFunction.signature = "(args ...interface{}) interface{}"
dslFunction.signatures = []string{"(args ...interface{}) interface{}"}
dslFunction.expressFunc = value
return nil
}
@ -873,7 +913,7 @@ func getDslFunctionSignatures() []string {
result := make([]string, 0, len(dslFunctions))
for _, dslFunction := range dslFunctions {
result = append(result, dslFunction.signature)
result = append(result, dslFunction.signatures...)
}
return result
@ -1028,6 +1068,25 @@ func stringNumberToDecimal(args []interface{}, prefix string, base int) (interfa
return nil, fmt.Errorf("invalid number: %s", input)
}
func toChunks(input string, chunkSize int) []string {
if chunkSize <= 0 || chunkSize >= len(input) {
return []string{input}
}
var chunks = make([]string, 0, (len(input)-1)/chunkSize+1)
currentLength := 0
currentStart := 0
for i := range input {
if currentLength == chunkSize {
chunks = append(chunks, input[currentStart:i])
currentLength = 0
currentStart = i
}
currentLength++
}
chunks = append(chunks, input[currentStart:])
return chunks
}
type CompilationError struct {
DslSignature string
WrappedError error

View File

@ -4,7 +4,6 @@ import (
"fmt"
"math"
"regexp"
"strconv"
"testing"
"time"
@ -118,6 +117,7 @@ func TestGetPrintableDslFunctionSignatures(t *testing.T) {
html_escape(arg1 interface{}) interface{}
html_unescape(arg1 interface{}) interface{}
join(separator string, elements ...interface{}) string
join(separator string, elements []interface{}) string
len(arg1 interface{}) interface{}
line_ends_with(str string, suffix ...string) bool
line_starts_with(str string, prefix ...string) bool
@ -141,7 +141,11 @@ func TestGetPrintableDslFunctionSignatures(t *testing.T) {
sha1(arg1 interface{}) interface{}
sha256(arg1 interface{}) interface{}
sha512(arg1 interface{}) interface{}
sort(args ...interface{}) interface{}
sort(elements ...interface{}) []interface{}
sort(input number) string
sort(input string) string
split(input string, n int) []string
split(input string, separator string, optionalChunkSize) []string
starts_with(str string, prefix ...string) bool
substr(str string, start int, optionalEnd int)
to_lower(arg1 interface{}) interface{}
@ -155,7 +159,9 @@ func TestGetPrintableDslFunctionSignatures(t *testing.T) {
trim_right(arg1, arg2 interface{}) interface{}
trim_space(arg1 interface{}) interface{}
trim_suffix(arg1, arg2 interface{}) interface{}
uniq(args ...interface{}) interface{}
uniq(elements ...interface{}) []interface{}
uniq(input number) string
uniq(input string) string
unix_time(optionalSeconds uint) float64
url_decode(arg1 interface{}) interface{}
url_encode(arg1 interface{}) interface{}
@ -172,7 +178,6 @@ func TestGetPrintableDslFunctionSignatures(t *testing.T) {
}
func TestDslExpressions(t *testing.T) {
dslExpressions := map[string]interface{}{
`base64("Hello")`: "SGVsbG8=",
`base64(1234)`: "MTIzNA==",
@ -244,27 +249,27 @@ func TestDslExpressions(t *testing.T) {
`substr('xxtestxxx',2)`: "testxxx",
`substr('xxtestxxx',2,-2)`: "testx",
`substr('xxtestxxx',2,6)`: "test",
`sort(12453)`: "12345",
`sort("a1b2c3d4e5")`: "12345abcde",
`sort("b", "a", "2", "c", "3", "1", "d", "4")`: []string{"1", "2", "3", "4", "a", "b", "c", "d"},
`split("abcdefg", 2)`: []string{"ab", "cd", "ef", "g"},
`split("ab,cd,efg", ",", 1)`: []string{"ab,cd,efg"},
`split("ab,cd,efg", ",", 2)`: []string{"ab", "cd,efg"},
`split("ab,cd,efg", ",", "3")`: []string{"ab", "cd", "efg"},
`split("ab,cd,efg", ",", -1)`: []string{"ab", "cd", "efg"},
`split("ab,cd,efg", ",")`: []string{"ab", "cd", "efg"},
`join(" ", sort("b", "a", "2", "c", "3", "1", "d", "4"))`: "1 2 3 4 a b c d",
`uniq(123123231)`: "123",
`uniq("abcabdaabbccd")`: "abcd",
`uniq("ab", "cd", "12", "34", "12", "cd")`: []string{"ab", "cd", "12", "34"},
`join(" ", uniq("ab", "cd", "12", "34", "12", "cd"))`: "ab cd 12 34",
`join(", ", split(hex_encode("abcdefg"), 2))`: "61, 62, 63, 64, 65, 66, 67",
}
testDslExpressionScenarios(t, dslExpressions)
}
func Test(t *testing.T) {
if number, err := strconv.ParseInt("0o1234567", 0, 64); err == nil {
fmt.Println(number)
} else {
fmt.Println(err)
}
}
func TestDateTimeDSLFunction(t *testing.T) {
testDateTimeFormat := func(t *testing.T, dateTimeFormat string, dateTimeFunction *govaluate.EvaluableExpression, expectedFormattedTime string, currentUnixTime int64) {
dslFunctionParameters := map[string]interface{}{"dateTimeFormat": dateTimeFormat}
@ -302,7 +307,6 @@ func TestDateTimeDSLFunction(t *testing.T) {
}
func TestDateTimeDslExpressions(t *testing.T) {
t.Run("date_time", func(t *testing.T) {
now := time.Now()