diff --git a/.gitignore b/.gitignore index ed539aa9b..9a9771c71 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ v2/cmd/docgen/docgen v2/pkg/protocols/common/helpers/deserialization/testdata/Deserialize.class v2/pkg/protocols/common/helpers/deserialization/testdata/ValueObject.class v2/pkg/protocols/common/helpers/deserialization/testdata/ValueObject2.ser +.vscode \ No newline at end of file diff --git a/v2/pkg/protocols/http/request_annotations.go b/v2/pkg/protocols/http/request_annotations.go index 46c7adc4e..b14831cb7 100644 --- a/v2/pkg/protocols/http/request_annotations.go +++ b/v2/pkg/protocols/http/request_annotations.go @@ -24,8 +24,30 @@ var ( reSniAnnotation = regexp.MustCompile(`(?m)^@tls-sni:\s*(.+)\s*$`) // @timeout:duration overrides the input timout with a custom duration reTimeoutAnnotation = regexp.MustCompile(`(?m)^@timeout:\s*(.+)\s*$`) + // @once sets the request to be executed only once for a specific URL + reOnceAnnotation = regexp.MustCompile(`(?m)^@once\s*$`) ) +type flowMark int + +const ( + Once flowMark = iota +) + +// parseFlowAnnotations and override requests flow +func parseFlowAnnotations(rawRequest string) (flowMark, bool) { + var fm flowMark + // parse request for known ovverride annotations + var hasFlowOveride bool + // @once + if reOnceAnnotation.MatchString(rawRequest) { + fm = Once + hasFlowOveride = true + } + + return fm, hasFlowOveride +} + // parseAnnotations and override requests settings func (r *Request) parseAnnotations(rawRequest string, request *http.Request) (*http.Request, bool) { // parse request for known ovverride annotations diff --git a/v2/pkg/protocols/http/request_generator.go b/v2/pkg/protocols/http/request_generator.go index 58e31fbfb..3fb32f89b 100644 --- a/v2/pkg/protocols/http/request_generator.go +++ b/v2/pkg/protocols/http/request_generator.go @@ -19,6 +19,7 @@ type requestGenerator struct { options *protocols.ExecuterOptions payloadIterator *generators.Iterator interactshURLs []string + onceFlow map[string]struct{} } // LeaveDefaultPorts skips normalization of default standard ports @@ -26,7 +27,11 @@ var LeaveDefaultPorts = false // newGenerator creates a new request generator instance func (request *Request) newGenerator() *requestGenerator { - generator := &requestGenerator{request: request, options: request.options} + generator := &requestGenerator{ + request: request, + options: request.options, + onceFlow: make(map[string]struct{}), + } if len(request.Payloads) > 0 { generator.payloadIterator = request.generator.NewIterator() @@ -53,29 +58,68 @@ func (r *requestGenerator) nextValue() (value string, payloads map[string]interf } hasPayloadIterator := r.payloadIterator != nil - hasInitializedPayloads := r.currentPayloads != nil - if r.currentIndex == 0 && hasPayloadIterator && !hasInitializedPayloads { + if hasPayloadIterator && r.currentPayloads == nil { r.currentPayloads, r.okCurrentPayload = r.payloadIterator.Value() } - if r.currentIndex < len(sequence) { - currentRequest := sequence[r.currentIndex] - r.currentIndex++ - return currentRequest, r.currentPayloads, true + + var request string + var shouldContinue bool + if nextRequest, nextIndex, found := r.findNextIteration(sequence, r.currentIndex); found { + r.currentIndex = nextIndex + 1 + request = nextRequest + shouldContinue = true + } else if nextRequest, nextIndex, found := r.findNextIteration(sequence, 0); found && hasPayloadIterator { + r.currentIndex = nextIndex + 1 + request = nextRequest + shouldContinue = true } - if r.currentIndex == len(sequence) { - if r.okCurrentPayload { - r.currentIndex = 0 - currentRequest := sequence[r.currentIndex] - if hasPayloadIterator { - r.currentPayloads, r.okCurrentPayload = r.payloadIterator.Value() - if r.okCurrentPayload { - r.currentIndex++ - return currentRequest, r.currentPayloads, true - } - } + + if shouldContinue { + if r.hasMarker(request, Once) { + r.applyMark(request, Once) + } + if hasPayloadIterator { + return request, r.currentPayloads, r.okCurrentPayload + } + return request, r.currentPayloads, true + } else { + return "", nil, false + } +} + +func (r *requestGenerator) findNextIteration(sequence []string, index int) (string, int, bool) { + for i, request := range sequence[index:] { + if !r.wasMarked(request, Once) { + return request, index + i, true } } - return "", nil, false + if r.payloadIterator != nil { + r.currentPayloads, r.okCurrentPayload = r.payloadIterator.Value() + } + + return "", 0, false +} + +func (r *requestGenerator) applyMark(request string, mark flowMark) { + switch mark { + case Once: + r.onceFlow[request] = struct{}{} + } + +} + +func (r *requestGenerator) wasMarked(request string, mark flowMark) bool { + switch mark { + case Once: + _, ok := r.onceFlow[request] + return ok + } + return false +} + +func (r *requestGenerator) hasMarker(request string, mark flowMark) bool { + fo, hasOverrides := parseFlowAnnotations(request) + return hasOverrides && fo == mark }