From 085589a9383ab38192d7d808b772760a55d04989 Mon Sep 17 00:00:00 2001 From: andrey Date: Thu, 25 Jan 2024 20:49:37 +0300 Subject: [PATCH] add stop watch --- error.go | 1 + provider.go | 2 +- provider/watcher/provider.go | 9 ++++++++- provider/watcher/provider_test.go | 15 ++++++++++++--- 4 files changed, 22 insertions(+), 5 deletions(-) diff --git a/error.go b/error.go index 93c30be..fbd9a95 100644 --- a/error.go +++ b/error.go @@ -7,4 +7,5 @@ var ( ErrInvalidValue = errors.New("invalid value") ErrUnknowType = errors.New("unknow type") ErrInitFactory = errors.New("init factory") + ErrStopWatch = errors.New("stop watch") ) diff --git a/provider.go b/provider.go index d7fa9c0..aec566f 100644 --- a/provider.go +++ b/provider.go @@ -11,7 +11,7 @@ type NamedProvider interface { Provider } -type WatchCallback func(ctx context.Context, oldVar, newVar Value) +type WatchCallback func(ctx context.Context, oldVar, newVar Value) error type WatchProvider interface { Watch(ctx context.Context, callback WatchCallback, path ...string) error diff --git a/provider/watcher/provider.go b/provider/watcher/provider.go index 37e538b..4ab7561 100644 --- a/provider/watcher/provider.go +++ b/provider/watcher/provider.go @@ -2,6 +2,7 @@ package watcher import ( "context" + "errors" "fmt" "log/slog" "time" @@ -61,7 +62,13 @@ func (p *Provider) Watch(ctx context.Context, callback config.WatchCallback, key if err != nil { p.logger(ctx, "get value%v:%v", key, err.Error()) } else if !newVar.IsEquals(oldVar) { - callback(ctx, oldVar, newVar) + if err := callback(ctx, oldVar, newVar); err != nil { + if errors.Is(err, config.ErrStopWatch) { + return + } + p.logger(ctx, "callback %v:%v", key, err) + + } oldVar = newVar } case <-ctx.Done(): diff --git a/provider/watcher/provider_test.go b/provider/watcher/provider_test.go index b36fa96..ba5d8ba 100644 --- a/provider/watcher/provider_test.go +++ b/provider/watcher/provider_test.go @@ -31,7 +31,11 @@ func (p *provider) Value(context.Context, ...string) (config.Value, error) { func TestWatcher(t *testing.T) { t.Parallel() - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) + defer func() { + cancel() + }() + prov := &provider{} w := watcher.New(time.Second, prov) @@ -42,14 +46,19 @@ func TestWatcher(t *testing.T) { err := w.Watch( ctx, - func(ctx context.Context, oldVar, newVar config.Value) { + func(ctx context.Context, oldVar, newVar config.Value) error { atomic.AddInt32(&cnt, 1) wg.Done() + if atomic.LoadInt32(&cnt) == 2 { + return config.ErrStopWatch + } + + return nil }, "tmpname", ) - require.NoError(t, err) wg.Wait() + require.NoError(t, err) require.Equal(t, int32(2), cnt) }