diff --git a/nats.go b/nats.go index 91b0a46f9..b3a87449a 100644 --- a/nats.go +++ b/nats.go @@ -28,6 +28,8 @@ import ( "math/rand" "net" "net/url" + "os" + "path/filepath" "runtime" "strconv" "strings" @@ -4016,7 +4018,12 @@ func wipeSlice(buf []byte) { } func userFromFile(userFile string) (string, error) { - contents, err := ioutil.ReadFile(userFile) + path, err := expandPath(userFile) + if err != nil { + return _EMPTY_, fmt.Errorf("nats: %v", err) + } + + contents, err := ioutil.ReadFile(path) if err != nil { return _EMPTY_, fmt.Errorf("nats: %v", err) } @@ -4024,6 +4031,46 @@ func userFromFile(userFile string) (string, error) { return jwt.ParseDecoratedJWT(contents) } +func homeDir() (string, error) { + if runtime.GOOS == "windows" { + homeDrive, homePath := os.Getenv("HOMEDRIVE"), os.Getenv("HOMEPATH") + userProfile := os.Getenv("USERPROFILE") + + var home string + if homeDrive == "" || homePath == "" { + if userProfile == "" { + return _EMPTY_, errors.New("nats: failed to get home dir, require %HOMEDRIVE% and %HOMEPATH% or %USERPROFILE%") + } + home = userProfile + } else { + home = filepath.Join(homeDrive, homePath) + } + + return home, nil + } + + home := os.Getenv("HOME") + if home == "" { + return _EMPTY_, errors.New("nats: failed to get home dir, require $HOME") + } + return home, nil +} + +func expandPath(p string) (string, error) { + p = os.ExpandEnv(p) + + if !strings.HasPrefix(p, "~") { + return p, nil + } + + home, err := homeDir() + if err != nil { + return _EMPTY_, err + } + + return filepath.Join(home, p[1:]), nil +} + func nkeyPairFromSeedFile(seedFile string) (nkeys.KeyPair, error) { contents, err := ioutil.ReadFile(seedFile) if err != nil { diff --git a/nats_test.go b/nats_test.go index 94e178459..f62f16f2c 100644 --- a/nats_test.go +++ b/nats_test.go @@ -98,6 +98,103 @@ func TestVersionMatchesTag(t *testing.T) { } } +func TestExpandPath(t *testing.T) { + if runtime.GOOS == "windows" { + origUserProfile := os.Getenv("USERPROFILE") + origHomeDrive, origHomePath := os.Getenv("HOMEDRIVE"), os.Getenv("HOMEPATH") + defer func() { + os.Setenv("USERPROFILE", origUserProfile) + os.Setenv("HOMEDRIVE", origHomeDrive) + os.Setenv("HOMEPATH", origHomePath) + }() + + cases := []struct { + path string + userProfile string + homeDrive string + homePath string + + wantPath string + wantErr bool + }{ + // Missing HOMEDRIVE and HOMEPATH. + {path: "/Foo/Bar", userProfile: `C:\Foo\Bar`, wantPath: "/Foo/Bar"}, + {path: "Foo/Bar", userProfile: `C:\Foo\Bar`, wantPath: "Foo/Bar"}, + {path: "~/Fizz", userProfile: `C:\Foo\Bar`, wantPath: `C:\Foo\Bar\Fizz`}, + {path: `${HOMEDRIVE}${HOMEPATH}\Fizz`, userProfile: `C:\Foo\Bar`, wantPath: `C:\Foo\Bar\Fizz`}, + + // Missing USERPROFILE. + {path: "~/Fizz", homeDrive: "X:", homePath: `\Foo\Bar`, wantPath: `X:\Foo\Bar\Fizz`}, + + // Set all environment variables. HOMEDRIVE and HOMEPATH take + // precedence. + {path: "~/Fizz", userProfile: `C:\Foo\Bar`, + homeDrive: "X:", homePath: `\Foo\Bar`, wantPath: `X:\Foo\Bar\Fizz`}, + + // Missing all environment variables. + {path: "~/Fizz", wantErr: true}, + } + for i, c := range cases { + t.Run(fmt.Sprintf("windows case %d", i), func(t *testing.T) { + os.Setenv("USERPROFILE", c.userProfile) + os.Setenv("HOMEDRIVE", c.homeDrive) + os.Setenv("HOMEPATH", c.homePath) + + gotPath, err := expandPath(c.path) + if !c.wantErr && err != nil { + t.Fatalf("unexpected error: got=%v; want=%v", err, nil) + } else if c.wantErr && err == nil { + t.Fatalf("unexpected success: got=%v; want=%v", nil, "err") + } + + if gotPath != c.wantPath { + t.Fatalf("unexpected path: got=%v; want=%v", gotPath, c.wantPath) + } + }) + } + + return + } + + // Unix tests + + origHome := os.Getenv("HOME") + defer os.Setenv("HOME", origHome) + + cases := []struct { + path string + home string + testEnv string + + wantPath string + wantErr bool + }{ + {path: "/foo/bar", home: "/fizz/buzz", wantPath: "/foo/bar"}, + {path: "foo/bar", home: "/fizz/buzz", wantPath: "foo/bar"}, + {path: "~/fizz", home: "/foo/bar", wantPath: "/foo/bar/fizz"}, + {path: "$HOME/fizz", home: "/foo/bar", wantPath: "/foo/bar/fizz"}, + + // missing HOME env var + {path: "~/fizz", wantErr: true}, + } + for i, c := range cases { + t.Run(fmt.Sprintf("unix case %d", i), func(t *testing.T) { + os.Setenv("HOME", c.home) + + gotPath, err := expandPath(c.path) + if !c.wantErr && err != nil { + t.Fatalf("unexpected error: got=%v; want=%v", err, nil) + } else if c.wantErr && err == nil { + t.Fatalf("unexpected success: got=%v; want=%v", nil, "err") + } + + if gotPath != c.wantPath { + t.Fatalf("unexpected path: got=%v; want=%v", gotPath, c.wantPath) + } + }) + } +} + //////////////////////////////////////////////////////////////////////////////// // Reconnect tests ////////////////////////////////////////////////////////////////////////////////