From 8013cfcc2d94c6c782aaa34bfbc46f8088c560e7 Mon Sep 17 00:00:00 2001 From: ykadowak Date: Wed, 10 Jan 2024 06:37:45 +0000 Subject: [PATCH] Add isSymlink function and test to gen license to avoid for symlink to become normal file. This commit adds a new function called `isSymlink` to check if a given path is a symbolic link. It also includes a corresponding test in the `main_test.go` file. This functionality will be useful for handling symbolic links in the codebase. --- hack/license/gen/main.go | 17 +++++++++++ hack/license/gen/main_test.go | 55 +++++++++++++++++++++++++++++++++++ 2 files changed, 72 insertions(+) diff --git a/hack/license/gen/main.go b/hack/license/gen/main.go index d970d9b880..6ae6618718 100644 --- a/hack/license/gen/main.go +++ b/hack/license/gen/main.go @@ -226,7 +226,24 @@ func dirwalk(dir string) []string { return paths } +func isSymlink(path string) (bool, error) { + lst, err := os.Lstat(path) + if err != nil { + return false, err + } + return lst.Mode()&os.ModeSymlink != 0, nil +} + func readAndRewrite(path string) error { + // return if it is a symlink + isSym, err := isSymlink(path) + if err != nil { + return err + } + if isSym { + return nil + } + f, err := os.OpenFile(path, os.O_RDWR|os.O_SYNC, fs.ModePerm) if err != nil { return errors.Errorf("filepath %s, could not open", path) diff --git a/hack/license/gen/main_test.go b/hack/license/gen/main_test.go index f14d15f3c0..8952824ca4 100644 --- a/hack/license/gen/main_test.go +++ b/hack/license/gen/main_test.go @@ -15,3 +15,58 @@ // package main + +import ( + "os" + "path/filepath" + "testing" +) + +func TestIsSymlink(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + symlinkPath := filepath.Join(dir, "target") + filePath := filepath.Join(dir, "file") + + _, err := os.Create(filePath) + if err != nil { + t.Error(err) + } + + err = os.Symlink(filePath, symlinkPath) + if err != nil { + t.Error(err) + } + + tests := []struct { + name string + path string + expected bool + }{ + { + name: "return true when it is a symlink", + path: symlinkPath, + expected: true, + }, + { + name: "return false when it is a normal file", + path: filePath, + expected: false, + }, + } + + for _, tc := range tests { + test := tc + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + isSymlink, err := isSymlink(test.path) + if err != nil { + tt.Error(err) + } + if isSymlink != test.expected { + tt.Errorf("expected %v, got %v", test.expected, isSymlink) + } + }) + } +}