nuclei/pkg/fuzz/dataformat/graphql.go
2024-10-13 02:50:42 +05:30

279 lines
6.7 KiB
Go

package dataformat
import (
"encoding/json"
"fmt"
"log"
"strings"
"github.com/graphql-go/graphql/language/kinds"
"github.com/graphql-go/graphql/language/parser"
"github.com/graphql-go/graphql/language/printer"
"github.com/graphql-go/graphql/language/source"
jsoniter "github.com/json-iterator/go"
"github.com/pkg/errors"
"github.com/projectdiscovery/nuclei/v3/pkg/types"
"github.com/graphql-go/graphql/language/ast"
)
type Graphql struct{}
var (
_ DataFormat = &Graphql{}
)
// NewGraphql returns a new GraphQL encoder
func NewGraphql() *Graphql {
return &Graphql{}
}
// IsType returns true if the data is Graqhql encoded
func (m *Graphql) IsType(data string) bool {
_, isGraphql, _ := isGraphQLOperation([]byte(data), true)
return isGraphql
}
// consider it container type of our graphql representation
type graphQLRequest struct {
Query string `json:"query,omitempty"`
OperationName string `json:"operationName,omitempty"`
Variables map[string]interface{} `json:"variables,omitempty"`
}
func isGraphQLOperation(jsonData []byte, validate bool) (graphQLRequest, bool, error) {
jsonStr := string(jsonData)
if !strings.HasPrefix(jsonStr, "{") && !strings.HasSuffix(jsonStr, "}") {
return graphQLRequest{}, false, nil
}
var request graphQLRequest
if err := json.Unmarshal(jsonData, &request); err != nil {
return graphQLRequest{}, false, errors.Wrap(err, "could not unmarshal json")
}
if request.Query == "" && request.OperationName == "" && len(request.Variables) == 0 {
return graphQLRequest{}, false, nil
}
// Validate if query actually is a graphql
// query and not just some random json
if !validate {
return request, true, nil
}
doc, err := parser.Parse(parser.ParseParams{
Source: &source.Source{
Body: []byte(request.Query),
},
})
if err != nil {
return graphQLRequest{}, false, err
}
if len(doc.Definitions) == 0 {
return graphQLRequest{}, false, nil
}
return request, true, nil
}
// Encode encodes the data into MultiPartForm format
func (m *Graphql) Encode(data KV) (string, error) {
parsedRequest := data.Get("#_parsedReq")
if parsedRequest == nil {
return "", fmt.Errorf("parsed request not found")
}
parsedRequestStruct, ok := parsedRequest.(graphQLRequest)
if !ok {
return "", fmt.Errorf("parsed request is not of type graphQLRequest")
}
_, astDoc, err := m.parseGraphQLRequest(parsedRequestStruct.Query, false)
if err != nil {
return "", fmt.Errorf("error parsing graphql request: %v", err)
}
var hasVariables bool
if hasVariablesItem := data.Get("#_hasVariables"); hasVariablesItem != nil {
hasVariables, _ = hasVariablesItem.(bool)
}
data.Iterate(func(key string, value any) bool {
if strings.HasPrefix(key, "#_") {
return true
}
if hasVariables {
parsedRequestStruct.Variables[key] = value
return true
}
if err := m.modifyASTWithKeyValue(astDoc, key, value); err != nil {
log.Printf("error modifying ast with key value: %v", err)
return false
}
return true
})
modifiedQuery := printer.Print(astDoc)
parsedRequestStruct.Query = types.ToString(modifiedQuery)
marshalled, err := jsoniter.Marshal(parsedRequestStruct)
if err != nil {
return "", fmt.Errorf("error marshalling parsed request: %v", err)
}
return string(marshalled), nil
}
func (m *Graphql) modifyASTWithKeyValue(astDoc *ast.Document, key string, value any) error {
for _, def := range astDoc.Definitions {
switch v := def.(type) {
case *ast.OperationDefinition:
if v.SelectionSet == nil {
continue
}
for _, selection := range v.SelectionSet.Selections {
switch field := selection.(type) {
case *ast.Field:
for _, arg := range field.Arguments {
if arg.Name.Value == key {
arg.Value = convertGoValueToASTValue(value)
}
}
}
}
}
}
return nil
}
// Decode decodes the data from Graphql format
func (m *Graphql) Decode(data string) (KV, error) {
parsedReq, astDoc, err := m.parseGraphQLRequest(data, true)
if err != nil {
return KV{}, fmt.Errorf("error parsing graphql request: %v", err)
}
kv := KVMap(map[string]interface{}{})
kv.Set("#_parsedReq", parsedReq)
for k, v := range parsedReq.Variables {
kv.Set(k, v)
}
if len(parsedReq.Variables) > 0 {
kv.Set("#_hasVariables", true)
}
if err := m.populateGraphQLKV(astDoc, kv); err != nil {
return KV{}, fmt.Errorf("error populating graphql kv: %v", err)
}
return kv, nil
}
func (m *Graphql) populateGraphQLKV(astDoc *ast.Document, kv KV) error {
for _, def := range astDoc.Definitions {
switch def := def.(type) {
case *ast.OperationDefinition:
args, err := getSelectionSetArguments(def)
if err != nil {
return fmt.Errorf("error getting selection set arguments: %v", err)
}
for k, v := range args {
if item := kv.Get(k); item != nil {
continue
}
kv.Set(k, v)
}
}
}
return nil
}
func (m *Graphql) parseGraphQLRequest(query string, unmarshal bool) (graphQLRequest, *ast.Document, error) {
var parsedReq graphQLRequest
var err error
if unmarshal {
parsedReq, _, err = isGraphQLOperation([]byte(query), false)
if err != nil {
return graphQLRequest{}, nil, fmt.Errorf("error parsing query: %v", err)
}
} else {
parsedReq.Query = query
}
astDoc, err := parser.Parse(parser.ParseParams{
Source: &source.Source{
Body: []byte(parsedReq.Query),
},
})
if err != nil {
return graphQLRequest{}, nil, fmt.Errorf("error parsing ast: %v", err)
}
return parsedReq, astDoc, nil
}
func getSelectionSetArguments(def *ast.OperationDefinition) (map[string]interface{}, error) {
args := make(map[string]interface{})
if def.SelectionSet == nil {
return args, nil
}
for _, selection := range def.SelectionSet.Selections {
switch field := selection.(type) {
case *ast.Field:
for _, arg := range field.Arguments {
args[arg.Name.Value] = convertValueToGoType(arg.Value)
}
}
}
return args, nil
}
func convertGoValueToASTValue(value any) ast.Value {
switch v := value.(type) {
case string:
newValue := &ast.StringValue{
Kind: kinds.StringValue,
Value: v,
}
return newValue
}
return nil
}
func convertValueToGoType(value ast.Value) interface{} {
switch value := value.(type) {
case *ast.StringValue:
return value.Value
case *ast.IntValue:
return value.Value
case *ast.FloatValue:
return value.Value
case *ast.BooleanValue:
return value.Value
case *ast.EnumValue:
return value.Value
case *ast.ListValue:
var list []interface{}
for _, v := range value.Values {
list = append(list, convertValueToGoType(v))
}
return list
case *ast.ObjectValue:
obj := make(map[string]interface{})
for _, v := range value.Fields {
obj[v.Name.Value] = convertValueToGoType(v.Value)
}
return obj
}
return nil
}
// Name returns the name of the encoder
func (m *Graphql) Name() string {
return "graphql"
}