Multiple bug fixes in query param fuzzing (#4925)

* fuzz: check and handle typed slice

* do not query encode params + fuzz/allow duplicates params

* sometimes order matters ~query params

* component: fix broken iterator

* result upload add meta params
This commit is contained in:
Tarun Koyalwar 2024-03-25 10:08:26 +05:30 committed by GitHub
parent bc268174ab
commit c1bd4f82ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 421 additions and 130 deletions

View File

@ -13,10 +13,12 @@ import (
"time" "time"
"github.com/projectdiscovery/gologger" "github.com/projectdiscovery/gologger"
"github.com/projectdiscovery/nuclei/v3/pkg/catalog/config"
"github.com/projectdiscovery/nuclei/v3/pkg/output" "github.com/projectdiscovery/nuclei/v3/pkg/output"
"github.com/projectdiscovery/retryablehttp-go" "github.com/projectdiscovery/retryablehttp-go"
pdcpauth "github.com/projectdiscovery/utils/auth/pdcp" pdcpauth "github.com/projectdiscovery/utils/auth/pdcp"
errorutil "github.com/projectdiscovery/utils/errors" errorutil "github.com/projectdiscovery/utils/errors"
updateutils "github.com/projectdiscovery/utils/update"
urlutil "github.com/projectdiscovery/utils/url" urlutil "github.com/projectdiscovery/utils/url"
) )
@ -217,6 +219,8 @@ func (u *UploadWriter) getRequest(bin []byte) (*retryablehttp.Request, error) {
if err != nil { if err != nil {
return nil, errorutil.NewWithErr(err).Msgf("could not create cloud upload request") return nil, errorutil.NewWithErr(err).Msgf("could not create cloud upload request")
} }
// add pdtm meta params
req.URL.RawQuery = updateutils.GetpdtmParams(config.Version)
req.Header.Set(pdcpauth.ApiKeyHeaderName, u.creds.APIKey) req.Header.Set(pdcpauth.ApiKeyHeaderName, u.creds.APIKey)
req.Header.Set("Content-Type", "application/octet-stream") req.Header.Set("Content-Type", "application/octet-stream")
req.Header.Set("Accept", "application/json") req.Header.Set("Accept", "application/json")

View File

@ -55,16 +55,17 @@ func (b *Body) Parse(req *retryablehttp.Request) (bool, error) {
} }
b.value = NewValue(dataStr) b.value = NewValue(dataStr)
if b.value.Parsed() != nil { tmp := b.value.Parsed()
if !tmp.IsNIL() {
return true, nil return true, nil
} }
switch { switch {
case strings.Contains(contentType, "application/json") && b.value.Parsed() == nil: case strings.Contains(contentType, "application/json") && tmp.IsNIL():
return b.parseBody(dataformat.JSONDataFormat, req) return b.parseBody(dataformat.JSONDataFormat, req)
case strings.Contains(contentType, "application/xml") && b.value.Parsed() == nil: case strings.Contains(contentType, "application/xml") && tmp.IsNIL():
return b.parseBody(dataformat.XMLDataFormat, req) return b.parseBody(dataformat.XMLDataFormat, req)
case strings.Contains(contentType, "multipart/form-data") && b.value.Parsed() == nil: case strings.Contains(contentType, "multipart/form-data") && tmp.IsNIL():
return b.parseBody(dataformat.MultiPartFormDataFormat, req) return b.parseBody(dataformat.MultiPartFormDataFormat, req)
} }
parsed, err := b.parseBody(dataformat.FormDataFormat, req) parsed, err := b.parseBody(dataformat.FormDataFormat, req)
@ -93,16 +94,18 @@ func (b *Body) parseBody(decoderName string, req *retryablehttp.Request) (bool,
} }
// Iterate iterates through the component // Iterate iterates through the component
func (b *Body) Iterate(callback func(key string, value interface{}) error) error { func (b *Body) Iterate(callback func(key string, value interface{}) error) (errx error) {
for key, value := range b.value.Parsed() { b.value.parsed.Iterate(func(key string, value any) bool {
if strings.HasPrefix(key, "#_") { if strings.HasPrefix(key, "#_") {
continue return true
} }
if err := callback(key, value); err != nil { if err := callback(key, value); err != nil {
return err errx = err
return false
} }
} return true
return nil })
return
} }
// SetValue sets a value in the component // SetValue sets a value in the component

View File

@ -2,9 +2,12 @@ package component
import ( import (
"context" "context"
"fmt"
"net/http" "net/http"
"github.com/projectdiscovery/nuclei/v3/pkg/fuzz/dataformat"
"github.com/projectdiscovery/retryablehttp-go" "github.com/projectdiscovery/retryablehttp-go"
mapsutil "github.com/projectdiscovery/utils/maps"
) )
// Cookie is a component for a request cookie // Cookie is a component for a request cookie
@ -35,29 +38,31 @@ func (c *Cookie) Parse(req *retryablehttp.Request) (bool, error) {
c.req = req c.req = req
c.value = NewValue("") c.value = NewValue("")
parsedCookies := make(map[string]interface{}) parsedCookies := mapsutil.NewOrderedMap[string, any]()
for _, cookie := range req.Cookies() { for _, cookie := range req.Cookies() {
parsedCookies[cookie.Name] = cookie.Value parsedCookies.Set(cookie.Name, cookie.Value)
} }
if len(parsedCookies) == 0 { if parsedCookies.Len() == 0 {
return false, nil return false, nil
} }
c.value.SetParsed(parsedCookies, "") c.value.SetParsed(dataformat.KVOrderedMap(&parsedCookies), "")
return true, nil return true, nil
} }
// Iterate iterates through the component // Iterate iterates through the component
func (c *Cookie) Iterate(callback func(key string, value interface{}) error) error { func (c *Cookie) Iterate(callback func(key string, value interface{}) error) (err error) {
for key, value := range c.value.Parsed() { c.value.parsed.Iterate(func(key string, value any) bool {
// Skip ignored cookies // Skip ignored cookies
if _, ok := defaultIgnoredCookieKeys[key]; ok { if _, ok := defaultIgnoredCookieKeys[key]; ok {
continue return ok
} }
if err := callback(key, value); err != nil { if errx := callback(key, value); errx != nil {
return err err = errx
return false
} }
} return true
return nil })
return
} }
// SetValue sets a value in the component // SetValue sets a value in the component
@ -83,13 +88,14 @@ func (c *Cookie) Rebuild() (*retryablehttp.Request, error) {
cloned := c.req.Clone(context.Background()) cloned := c.req.Clone(context.Background())
cloned.Header.Del("Cookie") cloned.Header.Del("Cookie")
for key, value := range c.value.Parsed() { c.value.parsed.Iterate(func(key string, value any) bool {
cookie := &http.Cookie{ cookie := &http.Cookie{
Name: key, Name: key,
Value: value.(string), // Assume the value is always a string for cookies Value: fmt.Sprint(value), // Assume the value is always a string for cookies
} }
cloned.AddCookie(cookie) cloned.AddCookie(cookie)
} return true
})
return cloned, nil return cloned, nil
} }

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"strings" "strings"
"github.com/projectdiscovery/nuclei/v3/pkg/fuzz/dataformat"
"github.com/projectdiscovery/retryablehttp-go" "github.com/projectdiscovery/retryablehttp-go"
) )
@ -40,22 +41,24 @@ func (q *Header) Parse(req *retryablehttp.Request) (bool, error) {
} }
parsedHeaders[key] = value parsedHeaders[key] = value
} }
q.value.SetParsed(parsedHeaders, "") q.value.SetParsed(dataformat.KVMap(parsedHeaders), "")
return true, nil return true, nil
} }
// Iterate iterates through the component // Iterate iterates through the component
func (q *Header) Iterate(callback func(key string, value interface{}) error) error { func (q *Header) Iterate(callback func(key string, value interface{}) error) (errx error) {
for key, value := range q.value.Parsed() { q.value.parsed.Iterate(func(key string, value any) bool {
// Skip ignored headers // Skip ignored headers
if _, ok := defaultIgnoredHeaderKeys[key]; ok { if _, ok := defaultIgnoredHeaderKeys[key]; ok {
continue return ok
} }
if err := callback(key, value); err != nil { if err := callback(key, value); err != nil {
return err errx = err
return false
} }
} return true
return nil })
return
} }
// SetValue sets a value in the component // SetValue sets a value in the component
@ -79,22 +82,23 @@ func (q *Header) Delete(key string) error {
// component rebuilt // component rebuilt
func (q *Header) Rebuild() (*retryablehttp.Request, error) { func (q *Header) Rebuild() (*retryablehttp.Request, error) {
cloned := q.req.Clone(context.Background()) cloned := q.req.Clone(context.Background())
for key, value := range q.value.parsed { q.value.parsed.Iterate(func(key string, value any) bool {
if strings.EqualFold(key, "Host") { if strings.EqualFold(key, "Host") {
cloned.Host = value.(string) return true
} }
switch v := value.(type) { if vx, ok := IsTypedSlice(value); ok {
case []interface{}: // convert to []interface{}
value = vx
}
if v, ok := value.([]interface{}); ok {
for _, vv := range v { for _, vv := range v {
if cloned.Header[key] == nil { cloned.Header.Add(key, vv.(string))
cloned.Header[key] = make([]string, 0)
}
cloned.Header[key] = append(cloned.Header[key], vv.(string))
}
case string:
cloned.Header[key] = []string{v}
} }
return true
} }
cloned.Header.Set(key, value.(string))
return true
})
return cloned, nil return cloned, nil
} }

View File

@ -42,13 +42,15 @@ func (q *Path) Parse(req *retryablehttp.Request) (bool, error) {
} }
// Iterate iterates through the component // Iterate iterates through the component
func (q *Path) Iterate(callback func(key string, value interface{}) error) error { func (q *Path) Iterate(callback func(key string, value interface{}) error) (err error) {
for key, value := range q.value.Parsed() { q.value.parsed.Iterate(func(key string, value any) bool {
if err := callback(key, value); err != nil { if errx := callback(key, value); errx != nil {
return err err = errx
return false
} }
} return true
return nil })
return
} }
// SetValue sets a value in the component // SetValue sets a value in the component

View File

@ -47,13 +47,15 @@ func (q *Query) Parse(req *retryablehttp.Request) (bool, error) {
} }
// Iterate iterates through the component // Iterate iterates through the component
func (q *Query) Iterate(callback func(key string, value interface{}) error) error { func (q *Query) Iterate(callback func(key string, value interface{}) error) (errx error) {
for key, value := range q.value.Parsed() { q.value.parsed.Iterate(func(key string, value interface{}) bool {
if err := callback(key, value); err != nil { if err := callback(key, value); err != nil {
return err errx = err
return false
} }
} return true
return nil })
return
} }
// SetValue sets a value in the component // SetValue sets a value in the component

View File

@ -1,9 +1,11 @@
package component package component
import ( import (
"reflect"
"strconv" "strconv"
"github.com/leslie-qiwa/flat" "github.com/leslie-qiwa/flat"
"github.com/logrusorgru/aurora"
"github.com/projectdiscovery/gologger" "github.com/projectdiscovery/gologger"
"github.com/projectdiscovery/nuclei/v3/pkg/fuzz/dataformat" "github.com/projectdiscovery/nuclei/v3/pkg/fuzz/dataformat"
) )
@ -15,7 +17,7 @@ import (
// all the data values that are used in a request. // all the data values that are used in a request.
type Value struct { type Value struct {
data string data string
parsed map[string]interface{} parsed dataformat.KV
dataFormat string dataFormat string
} }
@ -40,34 +42,43 @@ func (v *Value) String() string {
} }
// Parsed returns the parsed value // Parsed returns the parsed value
func (v *Value) Parsed() map[string]interface{} { func (v *Value) Parsed() dataformat.KV {
return v.parsed return v.parsed
} }
// SetParsed sets the parsed value map // SetParsed sets the parsed value map
func (v *Value) SetParsed(parsed map[string]interface{}, dataFormat string) { func (v *Value) SetParsed(data dataformat.KV, dataFormat string) {
v.dataFormat = dataFormat
if data.OrderedMap != nil {
v.parsed = data
return
}
parsed := data.Map
flattened, err := flat.Flatten(parsed, flatOpts) flattened, err := flat.Flatten(parsed, flatOpts)
if err == nil { if err == nil {
v.parsed = flattened v.parsed = dataformat.KVMap(flattened)
} else { } else {
v.parsed = parsed v.parsed = dataformat.KVMap(parsed)
} }
v.dataFormat = dataFormat
} }
// SetParsedValue sets the parsed value for a key // SetParsedValue sets the parsed value for a key
// in the parsed map // in the parsed map
func (v *Value) SetParsedValue(key string, value string) bool { func (v *Value) SetParsedValue(key string, value string) bool {
origValue, ok := v.parsed[key] origValue := v.parsed.Get(key)
if !ok { if origValue == nil {
v.parsed[key] = value v.parsed.Set(key, value)
return true return true
} }
// If the value is a list, append to it // If the value is a list, append to it
// otherwise replace it // otherwise replace it
switch v := origValue.(type) { switch v := origValue.(type) {
case []interface{}: case []interface{}:
origValue = append(v, value) // update last value
if len(v) > 0 {
v[len(v)-1] = value
}
origValue = v
case string: case string:
origValue = value origValue = value
case int, int32, int64, float32, float64: case int, int32, int64, float32, float64:
@ -82,35 +93,49 @@ func (v *Value) SetParsedValue(key string, value string) bool {
return false return false
} }
origValue = parsed origValue = parsed
case nil:
origValue = value
default: default:
gologger.Error().Msgf("unknown type %T for value %s", v, v) // explicitly check for typed slice
if val, ok := IsTypedSlice(v); ok {
if len(val) > 0 {
val[len(val)-1] = value
} }
v.parsed[key] = origValue origValue = val
} else {
// make it default warning instead of error
gologger.DefaultLogger.Print().Msgf("[%v] unknown type %T for value %s", aurora.BrightYellow("WARN"), v, v)
}
}
v.parsed.Set(key, origValue)
return true return true
} }
// Delete removes a key from the parsed value // Delete removes a key from the parsed value
func (v *Value) Delete(key string) bool { func (v *Value) Delete(key string) bool {
if _, ok := v.parsed[key]; !ok { return v.parsed.Delete(key)
return false
}
delete(v.parsed, key)
return true
} }
// Encode encodes the value into a string // Encode encodes the value into a string
// using the dataformat and encoding // using the dataformat and encoding
func (v *Value) Encode() (string, error) { func (v *Value) Encode() (string, error) {
toEncodeStr := v.data toEncodeStr := v.data
if v.parsed.OrderedMap != nil {
// flattening orderedmap not supported
if v.dataFormat != "" {
dataformatStr, err := dataformat.Encode(v.parsed, v.dataFormat)
if err != nil {
return "", err
}
toEncodeStr = dataformatStr
}
return toEncodeStr, nil
}
nested, err := flat.Unflatten(v.parsed, flatOpts) nested, err := flat.Unflatten(v.parsed.Map, flatOpts)
if err != nil { if err != nil {
return "", err return "", err
} }
if v.dataFormat != "" { if v.dataFormat != "" {
dataformatStr, err := dataformat.Encode(nested, v.dataFormat) dataformatStr, err := dataformat.Encode(dataformat.KVMap(nested), v.dataFormat)
if err != nil { if err != nil {
return "", err return "", err
} }
@ -118,3 +143,18 @@ func (v *Value) Encode() (string, error) {
} }
return toEncodeStr, nil return toEncodeStr, nil
} }
// In go, []int, []string are not implictily converted to []interface{}
// when using type assertion and they need to be handled separately.
func IsTypedSlice(v interface{}) ([]interface{}, bool) {
if reflect.ValueOf(v).Kind() == reflect.Slice {
// iterate and convert to []interface{}
slice := reflect.ValueOf(v)
interfaceSlice := make([]interface{}, slice.Len())
for i := 0; i < slice.Len(); i++ {
interfaceSlice[i] = slice.Index(i).Interface()
}
return interfaceSlice, true
}
return nil, false
}

View File

@ -37,3 +37,17 @@ func TestFlatMap_FlattenUnflatten(t *testing.T) {
} }
require.Equal(t, data, nested, "unexpected data") require.Equal(t, data, nested, "unexpected data")
} }
func TestAnySlice(t *testing.T) {
data := []any{}
data = append(data, []int{1, 2, 3})
data = append(data, []string{"foo", "bar"})
data = append(data, []bool{true, false})
data = append(data, []float64{1.1, 2.2, 3.3})
for _, d := range data {
val, ok := IsTypedSlice(d)
require.True(t, ok, "expected slice")
require.True(t, val != nil, "expected value but got nil")
}
}

View File

@ -8,6 +8,12 @@ import (
// dataformats is a list of dataformats // dataformats is a list of dataformats
var dataformats map[string]DataFormat var dataformats map[string]DataFormat
const (
// DefaultKey is the key i.e used when given
// data is not of k-v type
DefaultKey = "value"
)
func init() { func init() {
dataformats = make(map[string]DataFormat) dataformats = make(map[string]DataFormat)
@ -49,9 +55,9 @@ type DataFormat interface {
// Name returns the name of the encoder // Name returns the name of the encoder
Name() string Name() string
// Encode encodes the data into a format // Encode encodes the data into a format
Encode(data map[string]interface{}) (string, error) Encode(data KV) (string, error)
// Decode decodes the data from a format // Decode decodes the data from a format
Decode(input string) (map[string]interface{}, error) Decode(input string) (KV, error)
} }
// Decoded is a decoded data format // Decoded is a decoded data format
@ -59,7 +65,7 @@ type Decoded struct {
// DataFormat is the data format // DataFormat is the data format
DataFormat string DataFormat string
// Data is the decoded data // Data is the decoded data
Data map[string]interface{} Data KV
} }
// Decode decodes the data from a format // Decode decodes the data from a format
@ -81,7 +87,7 @@ func Decode(data string) (*Decoded, error) {
} }
// Encode encodes the data into a format // Encode encodes the data into a format
func Encode(data map[string]interface{}, dataformat string) (string, error) { func Encode(data KV, dataformat string) (string, error) {
if dataformat == "" { if dataformat == "" {
return "", errors.New("dataformat is required") return "", errors.New("dataformat is required")
} }

View File

@ -14,7 +14,7 @@ func TestDataformatDecodeEncode_JSON(t *testing.T) {
if decoded.DataFormat != "json" { if decoded.DataFormat != "json" {
t.Fatal("unexpected data format") t.Fatal("unexpected data format")
} }
if decoded.Data["foo"] != "bar" { if decoded.Data.Get("foo") != "bar" {
t.Fatal("unexpected data") t.Fatal("unexpected data")
} }
@ -37,11 +37,19 @@ func TestDataformatDecodeEncode_XML(t *testing.T) {
if decoded.DataFormat != "xml" { if decoded.DataFormat != "xml" {
t.Fatal("unexpected data format") t.Fatal("unexpected data format")
} }
if decoded.Data["foo"].(map[string]interface{})["#text"] != "bar" { fooValue := decoded.Data.Get("foo")
t.Fatal("unexpected data") if fooValue == nil {
t.Fatal("key 'foo' not found")
} }
if decoded.Data["foo"].(map[string]interface{})["-attr"] != "baz" { fooMap, ok := fooValue.(map[string]interface{})
t.Fatal("unexpected data") if !ok {
t.Fatal("type assertion to map[string]interface{} failed")
}
if fooMap["#text"] != "bar" {
t.Fatal("unexpected data for '#text'")
}
if fooMap["-attr"] != "baz" {
t.Fatal("unexpected data for '-attr'")
} }
encoded, err := Encode(decoded.Data, decoded.DataFormat) encoded, err := Encode(decoded.Data, decoded.DataFormat)

View File

@ -1,9 +1,35 @@
package dataformat package dataformat
import ( import (
"fmt"
"net/url" "net/url"
"regexp"
"strconv"
"strings"
"github.com/projectdiscovery/gologger"
mapsutil "github.com/projectdiscovery/utils/maps"
urlutil "github.com/projectdiscovery/utils/url"
) )
const (
normalizedRegex = `_(\d+)$`
)
var (
reNormalized = regexp.MustCompile(normalizedRegex)
)
// == Handling Duplicate Query Parameters / Form Data ==
// Nuclei supports fuzzing duplicate query parameters by internally normalizing
// them and denormalizing them back when creating request this normalization
// can be leveraged to specify custom fuzzing behaviour in template as well
// if a query like `?foo=bar&foo=baz&foo=fuzzz` is provided, it will be normalized to
// foo_1=bar , foo_2=baz , foo=fuzzz (i.e last value is given original key which is usual behaviour in HTTP and its implementations)
// this way this change does not break any existing rules in template given by keys-regex or keys
// At same time if user wants to specify 2nd or 1st duplicate value in template, they can use foo_1 or foo_2 in keys-regex or keys
// Note: By default all duplicate query parameters are fuzzed
type Form struct{} type Form struct{}
var ( var (
@ -21,38 +47,95 @@ func (f *Form) IsType(data string) bool {
} }
// Encode encodes the data into Form format // Encode encodes the data into Form format
func (f *Form) Encode(data map[string]interface{}) (string, error) { func (f *Form) Encode(data KV) (string, error) {
query := url.Values{} params := urlutil.NewOrderedParams()
for key, value := range data {
switch v := value.(type) { data.Iterate(func(key string, value any) bool {
case []interface{}: params.Add(key, fmt.Sprint(value))
for _, val := range v { return true
query.Add(key, val.(string)) })
normalized := map[string]map[string]string{}
// Normalize the data
for _, origKey := range data.OrderedMap.GetKeys() {
// here origKey is base key without _1, _2 etc.
if origKey != "" && !reNormalized.MatchString(origKey) {
params.Iterate(func(key string, value []string) bool {
if strings.HasPrefix(key, origKey) && reNormalized.MatchString(key) {
m := map[string]string{}
if normalized[origKey] != nil {
m = normalized[origKey]
} }
case string: if len(value) == 1 {
query.Set(key, v) m[key] = value[0]
} else {
m[key] = ""
}
normalized[origKey] = m
params.Del(key)
}
return true
})
} }
} }
encoded := query.Encode()
if len(normalized) > 0 {
for k, v := range normalized {
maxIndex := -1
for key := range v {
matches := reNormalized.FindStringSubmatch(key)
if len(matches) == 2 {
dataIdx, err := strconv.Atoi(matches[1])
if err != nil {
gologger.Verbose().Msgf("error converting normalized index(%v) to integer: %v", matches[1], err)
continue
}
if dataIdx > maxIndex {
maxIndex = dataIdx
}
}
}
data := make([]string, maxIndex+1) // Ensure the slice is large enough
for key, value := range v {
matches := reNormalized.FindStringSubmatch(key)
if len(matches) == 2 {
dataIdx, _ := strconv.Atoi(matches[1]) // Error already checked above
data[dataIdx-1] = value // Use dataIdx-1 since slice is 0-indexed
}
}
data[maxIndex] = fmt.Sprint(params.Get(k)) // Use maxIndex which is the last index
// remove existing
params.Del(k)
params.Add(k, data...)
}
}
encoded := params.Encode()
return encoded, nil return encoded, nil
} }
// Decode decodes the data from Form format // Decode decodes the data from Form format
func (f *Form) Decode(data string) (map[string]interface{}, error) { func (f *Form) Decode(data string) (KV, error) {
parsed, err := url.ParseQuery(data) parsed, err := url.ParseQuery(data)
if err != nil { if err != nil {
return nil, err return KV{}, err
} }
values := make(map[string]interface{}) values := mapsutil.NewOrderedMap[string, any]()
for key, value := range parsed { for key, value := range parsed {
if len(value) == 1 { if len(value) == 1 {
values[key] = value[0] values.Set(key, value[0])
} else { } else {
values[key] = value // in case of multiple query params in form data
// last value is considered and previous values are exposed with _1, _2, _3 etc.
// note that last value will not be included in _1, _2, _3 etc.
for i := 0; i < len(value)-1; i++ {
values.Set(key+"_"+strconv.Itoa(i+1), value[i])
}
values.Set(key, value[len(value)-1])
} }
} }
return values, nil return KVOrderedMap(&values), nil
} }
// Name returns the name of the encoder // Name returns the name of the encoder

View File

@ -30,16 +30,16 @@ func (j *JSON) IsType(data string) bool {
} }
// Encode encodes the data into JSON format // Encode encodes the data into JSON format
func (j *JSON) Encode(data map[string]interface{}) (string, error) { func (j *JSON) Encode(data KV) (string, error) {
encoded, err := jsoniter.Marshal(data) encoded, err := jsoniter.Marshal(data.Map)
return string(encoded), err return string(encoded), err
} }
// Decode decodes the data from JSON format // Decode decodes the data from JSON format
func (j *JSON) Decode(data string) (map[string]interface{}, error) { func (j *JSON) Decode(data string) (KV, error) {
var decoded map[string]interface{} var decoded map[string]interface{}
err := jsoniter.Unmarshal([]byte(data), &decoded) err := jsoniter.Unmarshal([]byte(data), &decoded)
return decoded, err return KVMap(decoded), err
} }
// Name returns the name of the encoder // Name returns the name of the encoder

111
pkg/fuzz/dataformat/kv.go Normal file
View File

@ -0,0 +1,111 @@
package dataformat
import mapsutil "github.com/projectdiscovery/utils/maps"
// KV is a key-value struct
// that is implemented or used by fuzzing package
// to represent a key-value pair
// sometimes order or key-value pair is important (query params)
// so we use ordered map to represent the data
// if it's not important/significant (ex: json,xml) we use map
// this also allows us to iteratively implement ordered map
type KV struct {
Map map[string]interface{}
OrderedMap *mapsutil.OrderedMap[string, any]
}
// IsNIL returns true if the KV struct is nil
func (kv *KV) IsNIL() bool {
return kv.Map == nil && kv.OrderedMap == nil
}
// IsOrderedMap returns true if the KV struct is an ordered map
func (kv *KV) IsOrderedMap() bool {
return kv.OrderedMap != nil
}
// Set sets a value in the KV struct
func (kv *KV) Set(key string, value any) {
if kv.OrderedMap != nil {
kv.OrderedMap.Set(key, value)
return
}
if kv.Map == nil {
kv.Map = make(map[string]interface{})
}
kv.Map[key] = value
}
// Get gets a value from the KV struct
func (kv *KV) Get(key string) interface{} {
if kv.OrderedMap != nil {
value, ok := kv.OrderedMap.Get(key)
if !ok {
return nil
}
return value
}
return kv.Map[key]
}
// Iterate iterates over the KV struct in insertion order
func (kv *KV) Iterate(f func(key string, value any) bool) {
if kv.OrderedMap != nil {
kv.OrderedMap.Iterate(func(key string, value any) bool {
return f(key, value)
})
return
}
for key, value := range kv.Map {
if !f(key, value) {
break
}
}
}
// Delete deletes a key from the KV struct
func (kv *KV) Delete(key string) bool {
if kv.OrderedMap != nil {
_, ok := kv.OrderedMap.Get(key)
if !ok {
return false
}
kv.OrderedMap.Delete(key)
return true
}
_, ok := kv.Map[key]
if !ok {
return false
}
delete(kv.Map, key)
return true
}
// KVMap returns a new KV struct with the given map
func KVMap(data map[string]interface{}) KV {
return KV{Map: data}
}
// KVOrderedMap returns a new KV struct with the given ordered map
func KVOrderedMap(data *mapsutil.OrderedMap[string, any]) KV {
return KV{OrderedMap: data}
}
// ToMap converts the ordered map to a map
func ToMap(m *mapsutil.OrderedMap[string, any]) map[string]interface{} {
data := make(map[string]interface{})
m.Iterate(func(key string, value any) bool {
data[key] = value
return true
})
return data
}
// ToOrderedMap converts the map to an ordered map
func ToOrderedMap(data map[string]interface{}) *mapsutil.OrderedMap[string, any] {
m := mapsutil.NewOrderedMap[string, any]()
for key, value := range data {
m.Set(key, value)
}
return &m
}

View File

@ -6,6 +6,8 @@ import (
"io" "io"
"mime" "mime"
"mime/multipart" "mime/multipart"
mapsutil "github.com/projectdiscovery/utils/maps"
) )
type MultiPartForm struct { type MultiPartForm struct {
@ -28,23 +30,30 @@ func (m *MultiPartForm) IsType(data string) bool {
} }
// Encode encodes the data into MultiPartForm format // Encode encodes the data into MultiPartForm format
func (m *MultiPartForm) Encode(data map[string]interface{}) (string, error) { func (m *MultiPartForm) Encode(data KV) (string, error) {
var b bytes.Buffer var b bytes.Buffer
w := multipart.NewWriter(&b) w := multipart.NewWriter(&b)
if err := w.SetBoundary(m.boundary); err != nil { if err := w.SetBoundary(m.boundary); err != nil {
return "", err return "", err
} }
for key, value := range data { var Itererr error
data.Iterate(func(key string, value any) bool {
var fw io.Writer var fw io.Writer
var err error var err error
// Add field // Add field
if fw, err = w.CreateFormField(key); err != nil { if fw, err = w.CreateFormField(key); err != nil {
return "", err Itererr = err
return false
} }
if _, err = fw.Write([]byte(value.(string))); err != nil { if _, err = fw.Write([]byte(value.(string))); err != nil {
return "", err Itererr = err
return false
} }
return true
})
if Itererr != nil {
return "", Itererr
} }
w.Close() w.Close()
@ -65,7 +74,7 @@ func (m *MultiPartForm) ParseBoundary(contentType string) error {
} }
// Decode decodes the data from MultiPartForm format // Decode decodes the data from MultiPartForm format
func (m *MultiPartForm) Decode(data string) (map[string]interface{}, error) { func (m *MultiPartForm) Decode(data string) (KV, error) {
// Create a buffer from the string data // Create a buffer from the string data
b := bytes.NewBufferString(data) b := bytes.NewBufferString(data)
// The boundary parameter should be extracted from the Content-Type header of the HTTP request // The boundary parameter should be extracted from the Content-Type header of the HTTP request
@ -75,18 +84,18 @@ func (m *MultiPartForm) Decode(data string) (map[string]interface{}, error) {
form, err := r.ReadForm(32 << 20) // 32MB is the max memory used to parse the form form, err := r.ReadForm(32 << 20) // 32MB is the max memory used to parse the form
if err != nil { if err != nil {
return nil, err return KV{}, err
} }
defer func() { defer func() {
_ = form.RemoveAll() _ = form.RemoveAll()
}() }()
result := make(map[string]interface{}) result := mapsutil.NewOrderedMap[string, any]()
for key, values := range form.Value { for key, values := range form.Value {
if len(values) > 1 { if len(values) > 1 {
result[key] = values result.Set(key, values)
} else { } else {
result[key] = values[0] result.Set(key, values[0])
} }
} }
for key, files := range form.File { for key, files := range form.File {
@ -94,20 +103,19 @@ func (m *MultiPartForm) Decode(data string) (map[string]interface{}, error) {
for _, fileHeader := range files { for _, fileHeader := range files {
file, err := fileHeader.Open() file, err := fileHeader.Open()
if err != nil { if err != nil {
return nil, err return KV{}, err
} }
defer file.Close() defer file.Close()
buffer := new(bytes.Buffer) buffer := new(bytes.Buffer)
if _, err := buffer.ReadFrom(file); err != nil { if _, err := buffer.ReadFrom(file); err != nil {
return nil, err return KV{}, err
} }
fileContents = append(fileContents, buffer.String()) fileContents = append(fileContents, buffer.String())
} }
result[key] = fileContents result.Set(key, fileContents)
} }
return KVOrderedMap(&result), nil
return result, nil
} }
// Name returns the name of the encoder // Name returns the name of the encoder

View File

@ -17,15 +17,15 @@ func (r *Raw) IsType(data string) bool {
} }
// Encode encodes the data into Raw format // Encode encodes the data into Raw format
func (r *Raw) Encode(data map[string]interface{}) (string, error) { func (r *Raw) Encode(data KV) (string, error) {
return data["value"].(string), nil return data.Get("value").(string), nil
} }
// Decode decodes the data from Raw format // Decode decodes the data from Raw format
func (r *Raw) Decode(data string) (map[string]interface{}, error) { func (r *Raw) Decode(data string) (KV, error) {
return map[string]interface{}{ return KVMap(map[string]interface{}{
"value": data, "value": data,
}, nil }), nil
} }
// Name returns the name of the encoder // Name returns the name of the encoder

View File

@ -22,13 +22,13 @@ func (x *XML) IsType(data string) bool {
} }
// Encode encodes the data into XML format // Encode encodes the data into XML format
func (x *XML) Encode(data map[string]interface{}) (string, error) { func (x *XML) Encode(data KV) (string, error) {
var header string var header string
if value, ok := data["#_xml_header"]; ok && value != nil { if value := data.Get("#_xml_header"); value != nil {
header = value.(string) header = value.(string)
delete(data, "#_xml_header") data.Delete("#_xml_header")
} }
marshalled, err := mxj.Map(data).Xml() marshalled, err := mxj.Map(data.Map).Xml()
if err != nil { if err != nil {
return "", err return "", err
} }
@ -41,7 +41,7 @@ func (x *XML) Encode(data map[string]interface{}) (string, error) {
var xmlHeader = regexp.MustCompile(`\<\?(.*)\?\>`) var xmlHeader = regexp.MustCompile(`\<\?(.*)\?\>`)
// Decode decodes the data from XML format // Decode decodes the data from XML format
func (x *XML) Decode(data string) (map[string]interface{}, error) { func (x *XML) Decode(data string) (KV, error) {
var prefixStr string var prefixStr string
prefix := xmlHeader.FindAllStringSubmatch(data, -1) prefix := xmlHeader.FindAllStringSubmatch(data, -1)
if len(prefix) > 0 { if len(prefix) > 0 {
@ -50,10 +50,10 @@ func (x *XML) Decode(data string) (map[string]interface{}, error) {
decoded, err := mxj.NewMapXml([]byte(data)) decoded, err := mxj.NewMapXml([]byte(data))
if err != nil { if err != nil {
return nil, err return KV{}, err
} }
decoded["#_xml_header"] = prefixStr decoded["#_xml_header"] = prefixStr
return decoded, nil return KVMap(decoded), nil
} }
// Name returns the name of the encoder // Name returns the name of the encoder