diff --git a/pkg/ai/openai.go b/pkg/ai/openai.go index d9d84adea8..dca391e270 100644 --- a/pkg/ai/openai.go +++ b/pkg/ai/openai.go @@ -73,27 +73,22 @@ func (c *OpenAIClient) GetCompletion(ctx context.Context, prompt string) (string func (a *OpenAIClient) Parse(ctx context.Context, prompt []string, cache cache.ICache) (string, error) { inputKey := strings.Join(prompt, " ") // Check for cached data - sEnc := base64.StdEncoding.EncodeToString([]byte(inputKey)) - cacheKey := util.GetCacheKey(a.GetName(), a.language, sEnc) - // find in viper cache - if cache.Exists(cacheKey) { - // retrieve data from cache - response, err := cache.Load(cacheKey) + cacheKey := util.GetCacheKey(a.GetName(), a.language, inputKey) + if !cache.IsCacheDisabled() && cache.Exists(cacheKey) { + response, err := cache.Load(cacheKey) if err != nil { return "", err } - if response == "" { - color.Red("error retrieving cached data") - return "", nil - } - output, err := base64.StdEncoding.DecodeString(response) - if err != nil { - color.Red("error decoding cached data: %v", err) - return "", nil + if response != "" { + output, err := base64.StdEncoding.DecodeString(response) + if err != nil { + color.Red("error decoding cached data: %v", err) + return "", nil + } + return string(output), nil } - return string(output), nil } response, err := a.GetCompletion(ctx, inputKey) diff --git a/pkg/analysis/analysis.go b/pkg/analysis/analysis.go index 075cfea32c..3487f304c5 100644 --- a/pkg/analysis/analysis.go +++ b/pkg/analysis/analysis.go @@ -34,14 +34,14 @@ import ( ) type Analysis struct { - Context context.Context - Filters []string - Client *kubernetes.Client - AIClient ai.IAI - Results []common.Result - Namespace string - Cache cache.ICache - Explain bool + Context context.Context + Filters []string + Client *kubernetes.Client + AIClient ai.IAI + Results []common.Result + Namespace string + Cache cache.ICache + Explain bool MaxConcurrency int } @@ -102,13 +102,13 @@ func NewAnalysis(backend string, language string, filters []string, namespace st } return &Analysis{ - Context: ctx, - Filters: filters, - Client: client, - AIClient: aiClient, - Namespace: namespace, - Cache: cache.New(noCache), - Explain: explain, + Context: ctx, + Filters: filters, + Client: client, + AIClient: aiClient, + Namespace: namespace, + Cache: cache.New(noCache), + Explain: explain, MaxConcurrency: maxConcurrency, }, nil } diff --git a/pkg/cache/cache.go b/pkg/cache/cache.go index 6a6924425a..a81cd46f89 100644 --- a/pkg/cache/cache.go +++ b/pkg/cache/cache.go @@ -4,12 +4,11 @@ type ICache interface { Store(key string, data string) error Load(key string) (string, error) Exists(key string) bool + IsCacheDisabled() bool } func New(noCache bool) ICache { - if noCache { - return &NoopCache{} + return &FileBasedCache{ + noCache: noCache, } - - return &FileBasedCache{} } diff --git a/pkg/cache/file_based.go b/pkg/cache/file_based.go index 5c1b62123b..a4acaf8abe 100644 --- a/pkg/cache/file_based.go +++ b/pkg/cache/file_based.go @@ -11,7 +11,13 @@ import ( var _ (ICache) = (*FileBasedCache)(nil) -type FileBasedCache struct{} +type FileBasedCache struct { + noCache bool +} + +func (f *FileBasedCache) IsCacheDisabled() bool { + return f.noCache +} func (*FileBasedCache) Exists(key string) bool { path, err := xdg.CacheFile(filepath.Join("k8sgpt", key)) diff --git a/pkg/cache/noop.go b/pkg/cache/noop.go deleted file mode 100644 index 4a0b88f1e0..0000000000 --- a/pkg/cache/noop.go +++ /dev/null @@ -1,17 +0,0 @@ -package cache - -var _ (ICache) = (*NoopCache)(nil) - -type NoopCache struct{} - -func (c *NoopCache) Store(key string, data string) error { - return nil -} - -func (c *NoopCache) Load(key string) (string, error) { - return "", nil -} - -func (c *NoopCache) Exists(key string) bool { - return false -} diff --git a/pkg/util/util.go b/pkg/util/util.go index 889fd3c859..d6b2f975ea 100644 --- a/pkg/util/util.go +++ b/pkg/util/util.go @@ -15,7 +15,9 @@ package util import ( "context" + "crypto/sha256" "encoding/base64" + "encoding/hex" "errors" "fmt" "math/rand" @@ -148,7 +150,11 @@ func ReplaceIfMatch(text string, pattern string, replacement string) string { } func GetCacheKey(provider string, language string, sEnc string) string { - return fmt.Sprintf("%s-%s-%s", provider, language, sEnc) + data := fmt.Sprintf("%s-%s-%s", provider, language, sEnc) + + hash := sha256.Sum256([]byte(data)) + + return hex.EncodeToString(hash[:]) } func GetPodListByLabels(client k.Interface,