nuclei/pkg/protocols/common/hosterrorscache/hosterrorscache_test.go
Dwi Siswanto 052fd8b79a
feat(hosterrorscache): add Remove and MarkFailedOrRemove methods (#5984)
* feat(hosterrorscache): add `Remove` and `MarkFailedOrRemove` methods

and also deprecating `MarkFailed`

Signed-off-by: Dwi Siswanto <git@dw1.io>

* refactor(*): unwraps `hosterrorscache\.MarkFailed` invocation

Signed-off-by: Dwi Siswanto <git@dw1.io>

* feat(hosterrorscache): add sync in `Check` and `MarkFailedOrRemove` methods

* test(hosterrorscache): add concurrent test for `Check` method

* refactor(hosterrorscache): do NOT change `MarkFailed` behavior

Signed-off-by: Dwi Siswanto <git@dw1.io>

* feat(*): use `MarkFailedOrRemove` explicitly

Signed-off-by: Dwi Siswanto <git@dw1.io>

---------

Signed-off-by: Dwi Siswanto <git@dw1.io>
2025-01-31 15:46:57 +05:30

197 lines
4.8 KiB
Go

package hosterrorscache
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"github.com/projectdiscovery/nuclei/v3/pkg/protocols/common/contextargs"
"github.com/stretchr/testify/require"
)
const (
protoType = "http"
)
func TestCacheCheck(t *testing.T) {
cache := New(3, DefaultMaxHostsCount, nil)
err := errors.New("net/http: timeout awaiting response headers")
t.Run("increment host error", func(t *testing.T) {
ctx := newCtxArgs(t.Name())
for i := 1; i < 3; i++ {
cache.MarkFailed(protoType, ctx, err)
got := cache.Check(protoType, ctx)
require.Falsef(t, got, "got %v in iteration %d", got, i)
}
})
t.Run("flagged", func(t *testing.T) {
ctx := newCtxArgs(t.Name())
for i := 1; i <= 3; i++ {
cache.MarkFailed(protoType, ctx, err)
}
got := cache.Check(protoType, ctx)
require.True(t, got)
})
t.Run("mark failed or remove", func(t *testing.T) {
ctx := newCtxArgs(t.Name())
cache.MarkFailedOrRemove(protoType, ctx, nil) // nil error should remove the host from cache
got := cache.Check(protoType, ctx)
require.False(t, got)
})
}
func TestTrackErrors(t *testing.T) {
cache := New(3, DefaultMaxHostsCount, []string{"custom error"})
for i := 0; i < 100; i++ {
cache.MarkFailed(protoType, newCtxArgs("custom"), errors.New("got: nested: custom error"))
got := cache.Check(protoType, newCtxArgs("custom"))
if i < 2 {
// till 3 the host is not flagged to skip
require.False(t, got)
} else {
// above 3 it must remain flagged to skip
require.True(t, got)
}
}
value := cache.Check(protoType, newCtxArgs("custom"))
require.Equal(t, true, value, "could not get checked value")
}
func TestCacheItemDo(t *testing.T) {
var (
count int
item cacheItem
)
wg := sync.WaitGroup{}
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
item.Do(func() {
count++
})
}()
}
wg.Wait()
// ensures the increment happened only once regardless of the multiple call
require.Equal(t, count, 1)
}
func TestRemove(t *testing.T) {
cache := New(3, DefaultMaxHostsCount, nil)
ctx := newCtxArgs(t.Name())
err := errors.New("net/http: timeout awaiting response headers")
for i := 0; i < 100; i++ {
cache.MarkFailed(protoType, ctx, err)
}
require.True(t, cache.Check(protoType, ctx))
cache.Remove(ctx)
require.False(t, cache.Check(protoType, ctx))
}
func TestCacheMarkFailed(t *testing.T) {
cache := New(3, DefaultMaxHostsCount, nil)
tests := []struct {
host string
expected int32
}{
{"http://example.com:80", 1},
{"example.com:80", 2},
// earlier if port is not provided then port was omitted
// but from now it will default to appropriate http scheme based port with 80 as default
{"example.com:443", 1},
}
for _, test := range tests {
normalizedCacheValue := cache.GetKeyFromContext(newCtxArgs(test.host), nil)
cache.MarkFailed(protoType, newCtxArgs(test.host), errors.New("no address found for host"))
failedTarget, err := cache.failedTargets.Get(normalizedCacheValue)
require.Nil(t, err)
require.NotNil(t, failedTarget)
require.EqualValues(t, test.expected, failedTarget.errors.Load())
}
}
func TestCacheMarkFailedConcurrent(t *testing.T) {
cache := New(3, DefaultMaxHostsCount, nil)
tests := []struct {
host string
expected int32
}{
{"http://example.com:80", 200},
{"example.com:80", 200},
{"example.com:443", 100},
}
// the cache is not atomic during items creation, so we pre-create them with counter to zero
for _, test := range tests {
normalizedValue := cache.NormalizeCacheValue(test.host)
newItem := &cacheItem{errors: atomic.Int32{}}
newItem.errors.Store(0)
_ = cache.failedTargets.Set(normalizedValue, newItem)
}
wg := sync.WaitGroup{}
for _, test := range tests {
currentTest := test
for i := 0; i < 100; i++ {
wg.Add(1)
go func() {
defer wg.Done()
cache.MarkFailed(protoType, newCtxArgs(currentTest.host), errors.New("net/http: timeout awaiting response headers"))
}()
}
}
wg.Wait()
for _, test := range tests {
require.True(t, cache.Check(protoType, newCtxArgs(test.host)))
normalizedCacheValue := cache.NormalizeCacheValue(test.host)
failedTarget, err := cache.failedTargets.Get(normalizedCacheValue)
require.Nil(t, err)
require.NotNil(t, failedTarget)
require.EqualValues(t, test.expected, failedTarget.errors.Load())
}
}
func TestCacheCheckConcurrent(t *testing.T) {
cache := New(3, DefaultMaxHostsCount, nil)
ctx := newCtxArgs(t.Name())
wg := sync.WaitGroup{}
for i := 1; i <= 100; i++ {
wg.Add(1)
i := i
go func() {
defer wg.Done()
cache.MarkFailed(protoType, ctx, errors.New("no address found for host"))
if i >= 3 {
got := cache.Check(protoType, ctx)
require.True(t, got)
}
}()
}
wg.Wait()
}
func newCtxArgs(value string) *contextargs.Context {
ctx := contextargs.NewWithInput(context.TODO(), value)
return ctx
}