diff --git a/pkg/filesystem.go b/pkg/filesystem.go new file mode 100644 index 000000000..8e647d97a --- /dev/null +++ b/pkg/filesystem.go @@ -0,0 +1,44 @@ +/* + Copyright 2014 CoreOS, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package pkg + +import ( + "io/ioutil" +) + +// ListDirectory generates a slice of all the file names that both exist in +// the provided directory and pass the filter. +// The returned file names are relative to the directory argument. +// filterFunc is called once for each file found in the directory. If +// filterFunc returns true, the given file will ignored. +func ListDirectory(dir string, filterFunc func(string) bool) ([]string, error) { + fis, err := ioutil.ReadDir(dir) + if err != nil { + return nil, err + } + + units := make([]string, 0) + for _, fi := range fis { + name := fi.Name() + if filterFunc(name) { + continue + } + units = append(units, name) + } + + return units, nil +} diff --git a/pkg/filesystem_test.go b/pkg/filesystem_test.go new file mode 100644 index 000000000..cc8d7dc5b --- /dev/null +++ b/pkg/filesystem_test.go @@ -0,0 +1,55 @@ +/* + Copyright 2014 CoreOS, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package pkg + +import ( + "io/ioutil" + "os" + "path" + "reflect" + "testing" +) + +func TestListDirectory(t *testing.T) { + dir, err := ioutil.TempDir("", "fleet-testing-") + if err != nil { + t.Fatal(err.Error()) + } + + defer os.RemoveAll(dir) + + for _, name := range []string{"ping", "pong", "foo", "bar", "baz"} { + err := ioutil.WriteFile(path.Join(dir, name), []byte{}, 0400) + if err != nil { + t.Fatal(err.Error()) + } + } + + filterFunc := func(name string) bool { + return name == "foo" || name == "bar" + } + + got, err := ListDirectory(dir, filterFunc) + if err != nil { + t.Fatal(err.Error()) + } + + want := []string{"baz", "ping", "pong"} + if !reflect.DeepEqual(want, got) { + t.Fatalf("ListDirectory output incorrect: want=%v, got=%v", want, got) + } +} diff --git a/systemd/manager.go b/systemd/manager.go index 9367f6a2b..123f47791 100644 --- a/systemd/manager.go +++ b/systemd/manager.go @@ -36,7 +36,7 @@ const ( type systemdUnitManager struct { systemd *dbus.Conn - UnitsDir string + unitsDir string hashes map[string]unit.Hash mutex sync.RWMutex @@ -52,15 +52,53 @@ func NewSystemdUnitManager(uDir string) (*systemdUnitManager, error) { return nil, err } + hashes, err := hashUnitFiles(uDir) + if err != nil { + return nil, err + } + mgr := systemdUnitManager{ systemd: systemd, - UnitsDir: uDir, - hashes: make(map[string]unit.Hash), + unitsDir: uDir, + hashes: hashes, mutex: sync.RWMutex{}, } return &mgr, nil } +func hashUnitFiles(dir string) (map[string]unit.Hash, error) { + uNames, err := lsUnitsDir(dir) + if err != nil { + return nil, err + } + + hMap := make(map[string]unit.Hash) + for _, uName := range uNames { + h, err := hashUnitFile(path.Join(dir, uName)) + if err != nil { + return nil, err + } + + hMap[uName] = h + } + + return hMap, nil +} + +func hashUnitFile(loc string) (unit.Hash, error) { + b, err := ioutil.ReadFile(loc) + if err != nil { + return unit.Hash{}, err + } + + uf, err := unit.NewUnitFile(string(b)) + if err != nil { + return unit.Hash{}, err + } + + return uf.Hash(), nil +} + // Load writes the given Unit to disk, subscribing to relevant dbus // events, caching the Unit's Hash, and, if necessary, instructing the systemd // daemon to reload. @@ -163,21 +201,9 @@ func (m *systemdUnitManager) daemonReload() error { // Units enumerates all files recognized as valid systemd units in // this manager's units directory. -func (m *systemdUnitManager) Units() (units []string, err error) { - fis, err := ioutil.ReadDir(m.UnitsDir) - if err != nil { - return - } +func (m *systemdUnitManager) Units() ([]string, error) { + return lsUnitsDir(m.unitsDir) - for _, fi := range fis { - name := fi.Name() - if !unit.RecognizedUnitType(name) { - log.Warningf("Found unrecognized file in %s, ignoring", path.Join(m.UnitsDir, name)) - continue - } - units = append(units, name) - } - return } func (m *systemdUnitManager) GetUnitStates(filter pkg.Set) (map[string]*unit.UnitState, error) { @@ -256,5 +282,18 @@ func (m *systemdUnitManager) removeUnit(name string) { } func (m *systemdUnitManager) getUnitFilePath(name string) string { - return path.Join(m.UnitsDir, name) + return path.Join(m.unitsDir, name) +} + +func lsUnitsDir(dir string) ([]string, error) { + filterFunc := func(name string) bool { + if !unit.RecognizedUnitType(name) { + log.Warningf("Found unrecognized file in %s, ignoring", path.Join(dir, name)) + return true + } + + return false + } + + return pkg.ListDirectory(dir, filterFunc) } diff --git a/systemd/manager_test.go b/systemd/manager_test.go new file mode 100644 index 000000000..c8aa2c6e9 --- /dev/null +++ b/systemd/manager_test.go @@ -0,0 +1,115 @@ +/* + Copyright 2014 CoreOS, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package systemd + +import ( + "io/ioutil" + "os" + "path" + "reflect" + "testing" +) + +func TestHashUnitFile(t *testing.T) { + f, err := ioutil.TempFile("", "fleet-testing-") + if err != nil { + t.Fatalf(err.Error()) + } + + defer os.Remove(f.Name()) + + contents := ` +[Service] +ExecStart=/usr/bin/sleep infinity +` + + if _, err := f.Write([]byte(contents)); err != nil { + t.Fatalf(err.Error()) + } + + if err := f.Close(); err != nil { + t.Fatalf(err.Error()) + } + + hash, err := hashUnitFile(f.Name()) + if err != nil { + t.Fatalf(err.Error()) + } + + want := "40ea6646945809f4b420a50475ee68503088f127" + got := hash.String() + if want != got { + t.Fatalf("unit hash incorrect: want=%s, got=%s", want, got) + } +} + +func TestHashUnitFileDirectory(t *testing.T) { + dir, err := ioutil.TempDir("", "fleet-testing-") + if err != nil { + t.Fatal(err.Error()) + } + + defer os.RemoveAll(dir) + + fixtures := []struct { + name string + contents string + hash string + }{ + { + name: "foo.service", + contents: "[Service]\nExecStart=/usr/bin/sleep infinity", + hash: "40ea6646945809f4b420a50475ee68503088f127", + }, + { + name: "bar.service", + contents: "[Service]\nExecStart=/usr/bin/sleep 10", + hash: "5bf16b98c62f35fcdc723d32989cdeba7a2dd2a8", + }, + { + name: "baz.service", + contents: "[Service]\nExecStart=/usr/bin/sleep 2000", + hash: "5ba5292ab6a82b623ee6086dc90b3354ba004832", + }, + } + + for _, f := range fixtures { + err := ioutil.WriteFile(path.Join(dir, f.name), []byte(f.contents), 0400) + if err != nil { + t.Fatal(err.Error()) + } + } + + hashes, err := hashUnitFileDirectory(dir) + if err != nil { + t.Fatal(err.Error()) + } + + got := make(map[string]string, len(hashes)) + for uName, hash := range hashes { + got[uName] = hash.String() + } + + want := make(map[string]string, len(fixtures)) + for _, f := range fixtures { + want[f.name] = f.hash + } + + if !reflect.DeepEqual(want, got) { + t.Fatalf("hashUnitFileDirectory returned unexpected values: want=%v, got=%v", want, got) + } +}