From 5d175314fe9a53a426c19077f53cb61c2cc40d65 Mon Sep 17 00:00:00 2001 From: Mzack9999 Date: Tue, 6 May 2025 10:48:56 +0200 Subject: [PATCH] . --- pkg/js/compiler/pool.go | 6 +++- pkg/js/gojs/gojs.go | 72 +++++++++++++++++++++++++++++++++++++---- pkg/js/gojs/set.go | 59 ++++++++++++++++++++++++++++++++- 3 files changed, 129 insertions(+), 8 deletions(-) diff --git a/pkg/js/compiler/pool.go b/pkg/js/compiler/pool.go index b271cd329..585099508 100644 --- a/pkg/js/compiler/pool.go +++ b/pkg/js/compiler/pool.go @@ -85,6 +85,7 @@ func executeWithRuntime(runtime *goja.Runtime, p *goja.Program, args *ExecuteArg opts.Cleanup(runtime) } _ = runtime.GlobalObject().Delete("executionId") + _ = runtime.GlobalObject().Delete("context") }() // TODO(dwisiswant0): remove this once we get the RCA. @@ -111,8 +112,11 @@ func executeWithRuntime(runtime *goja.Runtime, p *goja.Program, args *ExecuteArg } } - // inject execution id + // inject execution id and context _ = runtime.Set("executionId", opts.ExecutionId) + if opts.Context != nil { + _ = runtime.Set("context", opts.Context) + } // execute the script return runtime.RunProgram(p) diff --git a/pkg/js/gojs/gojs.go b/pkg/js/gojs/gojs.go index 3b43fe13f..3ec2a2dee 100644 --- a/pkg/js/gojs/gojs.go +++ b/pkg/js/gojs/gojs.go @@ -1,6 +1,8 @@ package gojs import ( + "context" + "reflect" "sync" "github.com/dop251/goja" @@ -47,21 +49,79 @@ func (p *GojaModule) Name() string { return p.name } -func (p *GojaModule) Set(objects Objects) Module { - - for k, v := range objects { - p.sets[k] = v +// wrapModuleFunc wraps a Go function with context injection for modules +func wrapModuleFunc(runtime *goja.Runtime, fn interface{}) interface{} { + fnType := reflect.TypeOf(fn) + if fnType.Kind() != reflect.Func { + return fn } + // Only wrap if first parameter is context.Context + if fnType.NumIn() == 0 || fnType.In(0) != reflect.TypeOf((*context.Context)(nil)).Elem() { + return fn // Return original function unchanged if it doesn't have context.Context as first arg + } + + // Create input and output type slices + inTypes := make([]reflect.Type, fnType.NumIn()) + for i := 0; i < fnType.NumIn(); i++ { + inTypes[i] = fnType.In(i) + } + outTypes := make([]reflect.Type, fnType.NumOut()) + for i := 0; i < fnType.NumOut(); i++ { + outTypes[i] = fnType.Out(i) + } + + // Create a new function with same signature + newFnType := reflect.FuncOf(inTypes, outTypes, fnType.IsVariadic()) + newFn := reflect.MakeFunc(newFnType, func(args []reflect.Value) []reflect.Value { + // Get context from runtime + var ctx context.Context + if ctxVal := runtime.Get("context"); ctxVal != nil { + if ctxObj, ok := ctxVal.Export().(context.Context); ok { + ctx = ctxObj + } + } + if ctx == nil { + ctx = context.Background() + } + + // Add execution ID to context if available + if execID := runtime.Get("executionId"); execID != nil { + ctx = context.WithValue(ctx, "executionId", execID.String()) + } + + // Replace first argument (context) with our context + args[0] = reflect.ValueOf(ctx) + + // Call original function with modified arguments + return reflect.ValueOf(fn).Call(args) + }) + + return newFn.Interface() +} + +func (p *GojaModule) Set(objects Objects) Module { + for k, v := range objects { + // If the value is a function, wrap it with context injection + if fnType := reflect.TypeOf(v); fnType != nil && fnType.Kind() == reflect.Func { + p.sets[k] = wrapModuleFunc(nil, v) // We'll inject the runtime later in Require + } else { + p.sets[k] = v + } + } return p } func (p *GojaModule) Require(runtime *goja.Runtime, module *goja.Object) { - o := module.Get("exports").(*goja.Object) for k, v := range p.sets { - _ = o.Set(k, v) + // If the value is a function, wrap it with context injection + if fnType := reflect.TypeOf(v); fnType != nil && fnType.Kind() == reflect.Func { + _ = o.Set(k, wrapModuleFunc(runtime, v)) + } else { + _ = o.Set(k, v) + } } } diff --git a/pkg/js/gojs/set.go b/pkg/js/gojs/set.go index 9703a3c6e..b18b91c0c 100644 --- a/pkg/js/gojs/set.go +++ b/pkg/js/gojs/set.go @@ -1,6 +1,9 @@ package gojs import ( + "context" + "reflect" + "github.com/dop251/goja" errorutil "github.com/projectdiscovery/utils/errors" ) @@ -22,6 +25,57 @@ func (f *FuncOpts) valid() bool { return f.Name != "" && f.FuncDecl != nil && len(f.Signatures) > 0 && f.Description != "" } +// wrapWithContext wraps a Go function with context injection +func wrapWithContext(runtime *goja.Runtime, fn interface{}) interface{} { + fnType := reflect.TypeOf(fn) + if fnType.Kind() != reflect.Func { + return fn + } + + // Only wrap if first parameter is context.Context + if fnType.NumIn() == 0 || fnType.In(0) != reflect.TypeOf((*context.Context)(nil)).Elem() { + return fn // Return original function unchanged if it doesn't have context.Context as first arg + } + + // Create input and output type slices + inTypes := make([]reflect.Type, fnType.NumIn()) + for i := 0; i < fnType.NumIn(); i++ { + inTypes[i] = fnType.In(i) + } + outTypes := make([]reflect.Type, fnType.NumOut()) + for i := 0; i < fnType.NumOut(); i++ { + outTypes[i] = fnType.Out(i) + } + + // Create a new function with same signature + newFnType := reflect.FuncOf(inTypes, outTypes, fnType.IsVariadic()) + newFn := reflect.MakeFunc(newFnType, func(args []reflect.Value) []reflect.Value { + // Get context from runtime + var ctx context.Context + if ctxVal := runtime.Get("context"); ctxVal != nil { + if ctxObj, ok := ctxVal.Export().(context.Context); ok { + ctx = ctxObj + } + } + if ctx == nil { + ctx = context.Background() + } + + // Add execution ID to context if available + if execID := runtime.Get("executionId"); execID != nil { + ctx = context.WithValue(ctx, "executionId", execID.String()) + } + + // Replace first argument (context) with our context + args[0] = reflect.ValueOf(ctx) + + // Call original function with modified arguments + return reflect.ValueOf(fn).Call(args) + }) + + return newFn.Interface() +} + // RegisterFunc registers a function with given name, signatures and description func RegisterFuncWithSignature(runtime *goja.Runtime, opts FuncOpts) error { if runtime == nil { @@ -30,5 +84,8 @@ func RegisterFuncWithSignature(runtime *goja.Runtime, opts FuncOpts) error { if !opts.valid() { return ErrInvalidFuncOpts.Msgf("name: %s, signatures: %v, description: %s", opts.Name, opts.Signatures, opts.Description) } - return runtime.Set(opts.Name, opts.FuncDecl) + + // Wrap the function with context injection + wrappedFn := wrapWithContext(runtime, opts.FuncDecl) + return runtime.Set(opts.Name, wrappedFn) }