diff --git a/.github/dependabot.yml b/.github/dependabot.yml index b746f1ac3..7263b5edd 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -45,3 +45,21 @@ updates: commit-message: prefix: "deps" include: "scope" + - package-ecosystem: "gomod" + directory: "/updater" + schedule: + interval: "weekly" + ignore: + # Opentelemetry updates will be done manually + - dependency-name: "github.com/open-telemetry/opentelemetry-collector*" + - dependency-name: "go.opentelemetry.io/collector/*" + commit-message: + prefix: "deps" + include: "scope" + - package-ecosystem: "gomod" + directory: "/packagestate" + schedule: + interval: "weekly" + commit-message: + prefix: "deps" + include: "scope" diff --git a/.github/workflows/manual_msi_build.yml b/.github/workflows/manual_msi_build.yml index 8e28b4013..481ea393b 100644 --- a/.github/workflows/manual_msi_build.yml +++ b/.github/workflows/manual_msi_build.yml @@ -36,8 +36,10 @@ jobs: args: build --single-target --skip-validate --rm-dist --snapshot env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - name: Copy Windows Binary + - name: Copy Windows Collector Binary run: cp dist/collector_windows_amd64_v1/observiq-otel-collector.exe windows/observiq-otel-collector.exe + - name: Copy Windows Updater Binary + run: cp dist/updater_windows_amd64_v1/updater.exe windows/updater.exe - name: Copy Plugins to MSI Build Directory run: cp -r release_deps/plugins windows/ - name: Copy Example Config diff --git a/.github/workflows/multi_build.yml b/.github/workflows/multi_build.yml index b8407291f..f3d0f0f0b 100644 --- a/.github/workflows/multi_build.yml +++ b/.github/workflows/multi_build.yml @@ -30,7 +30,7 @@ jobs: - name: Scan Third Party Dependency Licenses run: | go install github.com/uw-labs/lichen@v0.1.5 - lichen --config=./license.yaml $(find dist/collector_*) + lichen --config=./license.yaml $(find dist/collector_* dist/updater_*) build_darwin: runs-on: macos-11 steps: @@ -55,7 +55,7 @@ jobs: - name: Scan Third Party Dependency Licenses run: | go install github.com/uw-labs/lichen@v0.1.5 - lichen --config=./license.yaml $(find dist/collector_*) + lichen --config=./license.yaml $(find dist/collector_* dist/updater_*) build_windows: runs-on: windows-2019 steps: @@ -80,4 +80,4 @@ jobs: - name: Scan Third Party Dependency Licenses run: | go install github.com/uw-labs/lichen@v0.1.5 - lichen --config=./license.yaml dist/collector_windows_amd64.exe + lichen --config=./license.yaml dist/collector_windows_amd64.exe dist/updater_windows_amd64.exe diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 98439ff2a..3b03e4121 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -33,8 +33,10 @@ jobs: args: build --single-target --skip-validate --rm-dist env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - name: Copy Windows Binary + - name: Copy Windows Collector Binary run: cp dist/collector_windows_amd64_v1/observiq-otel-collector.exe windows/observiq-otel-collector.exe + - name: Copy Windows Updater Binary + run: cp dist/updater_windows_amd64_v1/updater.exe windows/updater.exe - name: Copy Plugins to MSI Build Directory run: cp -r release_deps/plugins windows/ - name: Copy Example Config diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 419d31a72..cd3af13a1 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -51,3 +51,9 @@ jobs: ${{ runner.os }}-go- - name: Run Tests run: make test + - name: Run Updater Integration Tests (non-linux) + if: matrix.os != 'ubuntu-20.04' + run: make test-updater-integration + - name: Run Updater Integration Tests (linux) + if: matrix.os == 'ubuntu-20.04' + run: sudo make test-updater-integration diff --git a/.gitignore b/.gitignore index 528ff7666..c23c8c5b7 100644 --- a/.gitignore +++ b/.gitignore @@ -10,7 +10,7 @@ cosign.* gpg.* *.msi *.exe -*.zip +windows/*.zip windows/**/wix.dynamic.json windows/**/wix windows/config.yaml @@ -20,7 +20,7 @@ windows/plugins opentelemetry-java-contrib-jmx-metrics.jar VERSION.txt release_deps -tmp +/tmp # OpAmp Files collector.yaml diff --git a/.goreleaser.yml b/.goreleaser.yml index 37c710bbe..2b850ce5a 100644 --- a/.goreleaser.yml +++ b/.goreleaser.yml @@ -31,6 +31,32 @@ builds: - -X github.com/observiq/observiq-otel-collector/internal/version.gitHash={{ .FullCommit }} - -X github.com/observiq/observiq-otel-collector/internal/version.date={{ .Date }} no_unique_dist_dir: false + - id: updater + binary: updater + dir: ./updater/ + main: ./cmd/updater + env: + - CGO_ENABLED=0 + mod_timestamp: "{{ .CommitTimestamp }}" + goos: + - windows + - linux + - darwin + goarch: + - amd64 + - arm64 + - arm + ignore: + - goos: windows + goarch: arm + - goos: windows + goarch: arm64 + ldflags: + - -s -w + - -X github.com/observiq/observiq-otel-collector/updater/internal/version.version=v{{ .Version }} + - -X github.com/observiq/observiq-otel-collector/updater/internal/version.gitHash={{ .FullCommit }} + - -X github.com/observiq/observiq-otel-collector/updater/internal/version.date={{ .Date }} + no_unique_dist_dir: false # https://goreleaser.com/customization/archive/ archives: @@ -56,7 +82,12 @@ archives: - src: release_deps/com.observiq.collector.plist dst: "install" strip_parent: true - + - src: release_deps/windows_service.json + dst: install + strip_parent: true + - src: release_deps/observiq-otel-collector.service + dst: "install" + strip_parent: true format_overrides: - goos: windows format: zip diff --git a/Makefile b/Makefile index abbf7f25a..04bb6eddf 100644 --- a/Makefile +++ b/Makefile @@ -10,6 +10,8 @@ OUTDIR=./dist GOOS ?= $(shell go env GOOS) GOARCH ?= $(shell go env GOARCH) +INTEGRATION_TEST_ARGS?=-tags integration + ifeq ($(GOOS), windows) EXT?=.exe else @@ -21,11 +23,23 @@ CURRENT_TAG := $(shell git tag --sort=v:refname --points-at HEAD | grep -E "v[0- # Version will be the tag pointing to the current commit, or the previous version tag if there is no such tag VERSION ?= $(if $(CURRENT_TAG),$(CURRENT_TAG),$(PREVIOUS_TAG)) -# Default build target; making this should build for the current os/arch +# Build binaries for current GOOS/GOARCH by default +.DEFAULT_GOAL := build-binaries + +# Builds just the collector for current GOOS/GOARCH pair .PHONY: collector collector: go build -ldflags "-s -w -X github.com/observiq/observiq-otel-collector/internal/version.version=$(VERSION)" -o $(OUTDIR)/collector_$(GOOS)_$(GOARCH)$(EXT) ./cmd/collector +# Builds just the updater for current GOOS/GOARCH pair +.PHONY: updater +updater: + cd ./updater/; go build -ldflags "-s -w -X github.com/observiq/observiq-otel-collector/internal/version.version=$(VERSION)" -o ../$(OUTDIR)/updater_$(GOOS)_$(GOARCH)$(EXT) ./cmd/updater + +# Builds the updater + collector for current GOOS/GOARCH pair +.PHONY: build-binaries +build-binaries: collector updater + .PHONY: build-all build-all: build-linux build-darwin build-windows @@ -40,27 +54,27 @@ build-windows: build-windows-amd64 .PHONY: build-linux-amd64 build-linux-amd64: - GOOS=linux GOARCH=amd64 $(MAKE) collector + GOOS=linux GOARCH=amd64 $(MAKE) build-binaries -j2 .PHONY: build-linux-arm64 build-linux-arm64: - GOOS=linux GOARCH=arm64 $(MAKE) collector + GOOS=linux GOARCH=arm64 $(MAKE) build-binaries -j2 .PHONY: build-linux-arm build-linux-arm: - GOOS=linux GOARCH=arm $(MAKE) collector + GOOS=linux GOARCH=arm $(MAKE) build-binaries -j2 .PHONY: build-darwin-amd64 build-darwin-amd64: - GOOS=darwin GOARCH=amd64 $(MAKE) collector + GOOS=darwin GOARCH=amd64 $(MAKE) build-binaries -j2 .PHONY: build-darwin-arm64 build-darwin-arm64: - GOOS=darwin GOARCH=arm64 $(MAKE) collector + GOOS=darwin GOARCH=arm64 $(MAKE) build-binaries -j2 .PHONY: build-windows-amd64 build-windows-amd64: - GOOS=windows GOARCH=amd64 $(MAKE) collector + GOOS=windows GOARCH=amd64 $(MAKE) build-binaries -j2 # tool-related commands .PHONY: install-tools @@ -97,6 +111,10 @@ test-with-cover: $(MAKE) for-all CMD="go test -coverprofile=cover.out ./..." $(MAKE) for-all CMD="go tool cover -html=cover.out -o cover.html" +.PHONY: test-updater-integration +test-updater-integration: + cd updater; go test $(INTEGRATION_TEST_ARGS) -race ./... + .PHONY: bench bench: $(MAKE) for-all CMD="go test -benchmem -run=^$$ -bench ^* ./..." @@ -115,7 +133,9 @@ tidy: .PHONY: gosec gosec: - gosec ./... + gosec -exclude-dir updater ./... +# exclude the testdata dir; it contains a go program for testing. + cd updater; gosec -exclude-dir internal/service/testdata ./... # This target performs all checks that CI will do (excluding the build itself) .PHONY: ci-checks @@ -157,6 +177,8 @@ release-prep: @cp config/example.yaml release_deps/config.yaml @cp config/logging.yaml release_deps/logging.yaml @cp service/com.observiq.collector.plist release_deps/com.observiq.collector.plist + @jq ".files[] | select(.service != null)" windows/wix.json >> release_deps/windows_service.json + @cp service/observiq-otel-collector.service release_deps/observiq-otel-collector.service # Build, sign, and release .PHONY: release @@ -184,7 +206,7 @@ clean: .PHONY: scan-licenses scan-licenses: - lichen --config=./license.yaml $$(find dist/collector_* | grep -v 'sig\|json\|CHANGELOG.md\|yaml\|SHA256' | xargs) + lichen --config=./license.yaml $$(find dist/collector_* dist/updater_*) .PHONY: generate generate: diff --git a/go.mod b/go.mod index f790dcee4..599d0d68d 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/GoogleCloudPlatform/opentelemetry-operations-collector v0.0.3-0.20220711143229-08f2752ed367 github.com/google/uuid v1.3.0 github.com/observiq/observiq-otel-collector/exporter/googlecloudexporter v1.3.0 + github.com/observiq/observiq-otel-collector/packagestate v0.0.0 github.com/observiq/observiq-otel-collector/processor/resourceattributetransposerprocessor v1.3.0 github.com/observiq/observiq-otel-collector/receiver/pluginreceiver v1.3.0 github.com/open-telemetry/opamp-go v0.2.0 @@ -162,6 +163,7 @@ require ( github.com/alecthomas/participle/v2 v2.0.0-alpha9 // indirect github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137 // indirect github.com/aliyun/aliyun-log-go-sdk v0.1.37 // indirect + github.com/andybalholm/brotli v1.0.1 // indirect github.com/antonmedv/expr v1.9.0 // indirect github.com/apache/thrift v0.16.0 // indirect github.com/armon/go-metrics v0.3.10 // indirect @@ -190,6 +192,7 @@ require ( github.com/docker/docker v20.10.17+incompatible // indirect github.com/docker/go-connections v0.4.1-0.20210727194412-58542c764a11 // indirect github.com/docker/go-units v0.4.0 // indirect + github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5 // indirect github.com/eapache/go-resiliency v1.3.0 // indirect github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21 // indirect github.com/eapache/queue v1.1.0 // indirect @@ -276,6 +279,7 @@ require ( github.com/jpillora/backoff v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/karrick/godirwalk v1.16.1 // indirect + github.com/klauspost/pgzip v1.2.5 // indirect github.com/knadh/koanf v1.4.2 // indirect github.com/kolo/xmlrpc v0.0.0-20201022064351-38db28db192b // indirect github.com/leoluk/perflib_exporter v0.1.0 // indirect @@ -306,6 +310,7 @@ require ( github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f // indirect github.com/nginxinc/nginx-prometheus-exporter v0.8.1-0.20201110005315-f5a5f8086c19 // indirect + github.com/nwaples/rardecode v1.1.0 // indirect github.com/observiq/ctimefmt v1.0.0 // indirect github.com/open-telemetry/opentelemetry-collector-contrib/exporter/googlecloudexporter v0.57.2 // indirect github.com/open-telemetry/opentelemetry-collector-contrib/internal/aws/awsutil v0.57.2 // indirect @@ -391,6 +396,7 @@ require ( github.com/tklauser/numcpus v0.4.0 // indirect github.com/uber/jaeger-client-go v2.30.0+incompatible // indirect github.com/uber/jaeger-lib v2.4.1+incompatible // indirect + github.com/ulikunitz/xz v0.5.9 // indirect github.com/vishvananda/netlink v1.1.1-0.20210330154013-f5de75959ad5 // indirect github.com/vishvananda/netns v0.0.0-20210104183010-2eb08e3e575f // indirect github.com/vmware/govmomi v0.28.0 // indirect @@ -398,6 +404,7 @@ require ( github.com/xdg-go/pbkdf2 v1.0.0 // indirect github.com/xdg-go/scram v1.1.1 // indirect github.com/xdg-go/stringprep v1.0.3 // indirect + github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 // indirect github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect github.com/yuin/gopher-lua v0.0.0-20220504180219-658193537a64 // indirect github.com/yusufpapurcu/wmi v1.2.2 // indirect @@ -454,6 +461,7 @@ require ( require ( github.com/containerd/containerd v1.6.6 // indirect github.com/klauspost/compress v1.15.9 // indirect + github.com/mholt/archiver/v3 v3.5.1 github.com/shirou/gopsutil/v3 v3.22.7 github.com/spf13/cobra v1.5.0 // indirect golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e @@ -465,4 +473,6 @@ replace github.com/observiq/observiq-otel-collector/receiver/pluginreceiver => . replace github.com/observiq/observiq-otel-collector/exporter/googlecloudexporter => ./exporter/googlecloudexporter +replace github.com/observiq/observiq-otel-collector/packagestate => ./packagestate + replace github.com/GoogleCloudPlatform/opentelemetry-operations-collector v0.0.3-0.20220711143229-08f2752ed367 => github.com/observIQ/opentelemetry-operations-collector v0.0.3-0.20220804143341-7ae64090f52c diff --git a/go.sum b/go.sum index a5f131177..fdef627f0 100644 --- a/go.sum +++ b/go.sum @@ -174,6 +174,8 @@ github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137 h1:s6gZFSlWYmbqAu github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137/go.mod h1:OMCwj8VM1Kc9e19TLln2VL61YJF0x1XFtfdL4JdbSyE= github.com/aliyun/aliyun-log-go-sdk v0.1.37 h1:GvswbgLqVOHNeMWssQ9zA+R7YVDP6arLUP92bKyGZNw= github.com/aliyun/aliyun-log-go-sdk v0.1.37/go.mod h1:1QQ59pEJiVVXqKgbHcU6FWIgxT5RKBt+CT8AiQ2bEts= +github.com/andybalholm/brotli v1.0.1 h1:KqhlKozYbRtJvsPrrEeXcO+N2l6NYT5A2QAFmSULpEc= +github.com/andybalholm/brotli v1.0.1/go.mod h1:loMXtMfwqflxFJPmdbJO0a3KNoPuLBgiu3qAvBg8x/Y= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/antonmedv/expr v1.9.0 h1:j4HI3NHEdgDnN9p6oI6Ndr0G5QryMY0FNxT4ONrFDGU= github.com/antonmedv/expr v1.9.0/go.mod h1:5qsM3oLGDND7sDmQGDXHkYfkjYMUX14qsgqmHhwGEk8= @@ -312,6 +314,9 @@ github.com/docker/go-units v0.4.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDD github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE= github.com/dropbox/godropbox v0.0.0-20180512210157-31879d3884b9 h1:NAvZb7gqQfLSNBPzVsvI7eZMosXtg2g2kxXrei90CtU= github.com/dropbox/godropbox v0.0.0-20180512210157-31879d3884b9/go.mod h1:glr97hP/JuXb+WMYCizc4PIFuzw1lCR97mwbe1VVXhQ= +github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5 h1:iFaUwBSo5Svw6L7HYpRu/0lE3e0BaElwnNO1qkNQxBY= +github.com/dsnet/compress v0.0.2-0.20210315054119-f66993602bf5/go.mod h1:qssHWj60/X5sZFNxpG4HBPDHVqxNm4DfnCKgrbZOT+s= +github.com/dsnet/golib v0.0.0-20171103203638-1ea166775780/go.mod h1:Lj+Z9rebOhdfkVLjJ8T6VcRQv3SXugXy999NBtR9aFY= github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-resiliency v1.2.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= @@ -505,6 +510,7 @@ github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.2/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= @@ -764,10 +770,15 @@ github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvW github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/klauspost/compress v1.4.1/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= +github.com/klauspost/compress v1.11.4/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= github.com/klauspost/compress v1.15.8/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU= github.com/klauspost/compress v1.15.9 h1:wKRjX6JRtDdrE9qwa4b/Cip7ACOshUI4smpCQanqjSY= github.com/klauspost/compress v1.15.9/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU= +github.com/klauspost/cpuid v1.2.0/go.mod h1:Pj4uuM528wm8OyEC2QMXAi2YiTZ96dNQPGgoMS4s3ek= +github.com/klauspost/pgzip v1.2.5 h1:qnWYvvKqedOF2ulHpMG72XQol4ILEJ8k2wwRl/Km8oE= +github.com/klauspost/pgzip v1.2.5/go.mod h1:Ch1tH69qFZu15pkjo5kYi6mth2Zzwzt50oCQKQE9RUs= github.com/knadh/koanf v1.4.2 h1:2itp+cdC6miId4pO4Jw7c/3eiYD26Z/Sz3ATJMwHxIs= github.com/knadh/koanf v1.4.2/go.mod h1:4NCo0q4pmU398vF9vq2jStF9MWQZ8JEDcDMHlDCr4h0= github.com/kolo/xmlrpc v0.0.0-20201022064351-38db28db192b h1:iNjcivnc6lhbvJA3LD622NPrUponluJrBWPIwGG/3Bg= @@ -830,6 +841,8 @@ github.com/mattn/go-runewidth v0.0.8/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 h1:I0XW9+e1XWDxdcEniV4rQAIOPUGDq67JSCiRCgGCZLI= github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4= +github.com/mholt/archiver/v3 v3.5.1 h1:rDjOBX9JSF5BvoJGvjqK479aL70qh9DIpZCl+k7Clwo= +github.com/mholt/archiver/v3 v3.5.1/go.mod h1:e3dqJ7H78uzsRSEACH1joayhuSyhnonssnDhppzS1L4= github.com/microsoft/ApplicationInsights-Go v0.4.4 h1:G4+H9WNs6ygSCe6sUyxRc2U81TI5Es90b2t/MwX5KqY= github.com/microsoft/ApplicationInsights-Go v0.4.4/go.mod h1:fKRUseBqkw6bDiXTs3ESTiU/4YTIHsQS4W3fP2ieF4U= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= @@ -907,6 +920,8 @@ github.com/nginxinc/nginx-prometheus-exporter v0.8.1-0.20201110005315-f5a5f8086c github.com/nginxinc/nginx-prometheus-exporter v0.8.1-0.20201110005315-f5a5f8086c19/go.mod h1:L58Se1nwn3cEyHWlcfdlXgiGbHe/efvDbkbi+psz3lA= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/npillmayer/nestext v0.1.3/go.mod h1:h2lrijH8jpicr25dFY+oAJLyzlya6jhnuG+zWp9L0Uk= +github.com/nwaples/rardecode v1.1.0 h1:vSxaY8vQhOcVr4mm5e8XllHWTiM4JF507A0Katqw7MQ= +github.com/nwaples/rardecode v1.1.0/go.mod h1:5DzqNKiOdpKKBH87u8VlvAnPZMXcGRhxWkRpHbbfGS0= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= github.com/nxadm/tail v1.4.8/go.mod h1:+ncqLTQzXmGhMZNUePPaPqPvBxHAIsmXswZKocGu+AU= @@ -1269,6 +1284,7 @@ github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi github.com/pierrec/lz4 v2.6.0+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= github.com/pierrec/lz4 v2.6.1+incompatible h1:9UY3+iC23yxF0UfGaYrGplQ+79Rg+h/q9FV9ix19jjM= github.com/pierrec/lz4 v2.6.1+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY= +github.com/pierrec/lz4/v4 v4.1.2/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pierrec/lz4/v4 v4.1.15 h1:MO0/ucJhngq7299dKLwIMtgTfbkoSPF6AoMYDd8Q4q0= github.com/pierrec/lz4/v4 v4.1.15/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -1469,6 +1485,9 @@ github.com/uber/jaeger-client-go v2.30.0+incompatible h1:D6wyKGCecFaSRUpo8lCVbaO github.com/uber/jaeger-client-go v2.30.0+incompatible/go.mod h1:WVhlPFC8FDjOFMMWRy2pZqQJSXxYSwNYOkTr/Z6d3Kk= github.com/uber/jaeger-lib v2.4.1+incompatible h1:td4jdvLcExb4cBISKIpHuGoVXh+dVKhn2Um6rjCsSsg= github.com/uber/jaeger-lib v2.4.1+incompatible/go.mod h1:ComeNDZlWwrWnDv8aPp0Ba6+uUTzImX/AauajbLI56U= +github.com/ulikunitz/xz v0.5.8/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= +github.com/ulikunitz/xz v0.5.9 h1:RsKRIA2MO8x56wkkcd3LbtcE/uMszhb6DpRf+3uwa3I= +github.com/ulikunitz/xz v0.5.9/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= github.com/urfave/cli v1.20.0/go.mod h1:70zkFmudgCuE/ngEzBv17Jvp/497gISqfk5gWijbERA= github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= github.com/urfave/cli/v2 v2.3.0/go.mod h1:LJmUH05zAU44vOAcrfzZQKsZbVcdbOG8rtL3/XcUArI= @@ -1492,6 +1511,8 @@ github.com/xdg-go/scram v1.1.1/go.mod h1:RaEWvsqvNKKvBPvcKeFjrG2cJqOkHTiyTpzz23n github.com/xdg-go/stringprep v1.0.2/go.mod h1:8F9zXuvzgwmyT5DUm4GUfZGDdT3W+LCvS6+da4O5kxM= github.com/xdg-go/stringprep v1.0.3 h1:kdwGpVNwPFtjs98xCGkHjQtGKh86rDcRZN17QEMCOIs= github.com/xdg-go/stringprep v1.0.3/go.mod h1:W3f5j4i+9rC0kuIEJL0ky1VpHXQU3ocBgklLGvcBnW8= +github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 h1:nIPpBwaJSVYIxUFsDv3M8ofmx9yWTog9BfvIu0q41lo= +github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8/go.mod h1:HUYIGzjTL3rfEspMxjDjgmT5uz5wzYJKVo23qUhYTos= github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d h1:splanxYIlg+5LfHAM6xpdFEAYOk8iySO56hMFq6uLyA= github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA= diff --git a/internal/service/managed.go b/internal/service/managed.go index 20fafc76a..1500a48e3 100644 --- a/internal/service/managed.go +++ b/internal/service/managed.go @@ -47,6 +47,7 @@ func NewManagedCollectorService(col collector.Collector, logger *zap.Logger, man DefaultLogger: logger, Config: *opampConfig, Collector: col, + TmpPath: "./tmp", ManagerConfigPath: managerConfigPath, CollectorConfigPath: collectorConfigPath, LoggerConfigPath: loggerConfigPath, diff --git a/license.yaml b/license.yaml index 0c53d7c1f..dfaf327da 100644 --- a/license.yaml +++ b/license.yaml @@ -23,6 +23,11 @@ exceptions: # are creative commons. https://github.com/opencontainers/go-digest#copyright-and-license - path: "github.com/opencontainers/go-digest" unresolvableLicense: + # uses a custom license that says we can basically do whatever we want with it + - path: "github.com/xi2/xz" - path: "./processor/resourceattributetransposerprocessor" - path: "./receiver/pluginreceiver" - path: "./exporter/googlecloudexporter" + - path: "./packagestate" + - path: "../packagestate" + - path: "./opamp/observiq/testdata/latest" diff --git a/opamp/downloadable_file_manager.go b/opamp/downloadable_file_manager.go new file mode 100644 index 000000000..dd146138d --- /dev/null +++ b/opamp/downloadable_file_manager.go @@ -0,0 +1,31 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package opamp contains configurations and protocol implementations to handle OpAmp communication. +package opamp + +import ( + "github.com/open-telemetry/opamp-go/protobufs" +) + +// DownloadableFileManager handles DownloadableFile's from a PackagesAvailable message +type DownloadableFileManager interface { + // FetchAndExtractArchive fetches the archive at the specified URL. + // It then checks to see if it matches the expected sha256 sum of the file. + // If it matches, the archive is extracted. + // If the archive cannot be extracted, downloaded, or verified, then an error is returned. + FetchAndExtractArchive(*protobufs.DownloadableFile) error + // CleanupArtifacts removes temporary artifacts from previous download/installs + CleanupArtifacts() +} diff --git a/opamp/mocks/mock_downloadable_file_manager.go b/opamp/mocks/mock_downloadable_file_manager.go new file mode 100644 index 000000000..8ef78ffe1 --- /dev/null +++ b/opamp/mocks/mock_downloadable_file_manager.go @@ -0,0 +1,48 @@ +// Code generated by mockery v2.14.0. DO NOT EDIT. + +package mocks + +import ( + mock "github.com/stretchr/testify/mock" + + protobufs "github.com/open-telemetry/opamp-go/protobufs" +) + +// MockDownloadableFileManager is an autogenerated mock type for the DownloadableFileManager type +type MockDownloadableFileManager struct { + mock.Mock +} + +// CleanupArtifacts provides a mock function with given fields: +func (_m *MockDownloadableFileManager) CleanupArtifacts() { + _m.Called() +} + +// FetchAndExtractArchive provides a mock function with given fields: _a0 +func (_m *MockDownloadableFileManager) FetchAndExtractArchive(_a0 *protobufs.DownloadableFile) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(*protobufs.DownloadableFile) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type mockConstructorTestingTNewMockDownloadableFileManager interface { + mock.TestingT + Cleanup(func()) +} + +// NewMockDownloadableFileManager creates a new instance of MockDownloadableFileManager. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewMockDownloadableFileManager(t mockConstructorTestingTNewMockDownloadableFileManager) *MockDownloadableFileManager { + mock := &MockDownloadableFileManager{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/opamp/mocks/mock_packages_state_provider.go b/opamp/mocks/mock_packages_state_provider.go new file mode 100644 index 000000000..55fa3d9bb --- /dev/null +++ b/opamp/mocks/mock_packages_state_provider.go @@ -0,0 +1,231 @@ +// Code generated by mockery v2.12.3. DO NOT EDIT. + +package mocks + +import ( + context "context" + io "io" + + mock "github.com/stretchr/testify/mock" + + protobufs "github.com/open-telemetry/opamp-go/protobufs" + + types "github.com/open-telemetry/opamp-go/client/types" +) + +// MockPackagesStateProvider is an autogenerated mock type for the PackagesStateProvider type +type MockPackagesStateProvider struct { + mock.Mock +} + +// AllPackagesHash provides a mock function with given fields: +func (_m *MockPackagesStateProvider) AllPackagesHash() ([]byte, error) { + ret := _m.Called() + + var r0 []byte + if rf, ok := ret.Get(0).(func() []byte); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// CreatePackage provides a mock function with given fields: packageName, typ +func (_m *MockPackagesStateProvider) CreatePackage(packageName string, typ protobufs.PackageAvailable_PackageType) error { + ret := _m.Called(packageName, typ) + + var r0 error + if rf, ok := ret.Get(0).(func(string, protobufs.PackageAvailable_PackageType) error); ok { + r0 = rf(packageName, typ) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// DeletePackage provides a mock function with given fields: packageName +func (_m *MockPackagesStateProvider) DeletePackage(packageName string) error { + ret := _m.Called(packageName) + + var r0 error + if rf, ok := ret.Get(0).(func(string) error); ok { + r0 = rf(packageName) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// FileContentHash provides a mock function with given fields: packageName +func (_m *MockPackagesStateProvider) FileContentHash(packageName string) ([]byte, error) { + ret := _m.Called(packageName) + + var r0 []byte + if rf, ok := ret.Get(0).(func(string) []byte); ok { + r0 = rf(packageName) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]byte) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(packageName) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// LastReportedStatuses provides a mock function with given fields: +func (_m *MockPackagesStateProvider) LastReportedStatuses() (*protobufs.PackageStatuses, error) { + ret := _m.Called() + + var r0 *protobufs.PackageStatuses + if rf, ok := ret.Get(0).(func() *protobufs.PackageStatuses); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*protobufs.PackageStatuses) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// PackageState provides a mock function with given fields: packageName +func (_m *MockPackagesStateProvider) PackageState(packageName string) (types.PackageState, error) { + ret := _m.Called(packageName) + + var r0 types.PackageState + if rf, ok := ret.Get(0).(func(string) types.PackageState); ok { + r0 = rf(packageName) + } else { + r0 = ret.Get(0).(types.PackageState) + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(packageName) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Packages provides a mock function with given fields: +func (_m *MockPackagesStateProvider) Packages() ([]string, error) { + ret := _m.Called() + + var r0 []string + if rf, ok := ret.Get(0).(func() []string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// SetAllPackagesHash provides a mock function with given fields: hash +func (_m *MockPackagesStateProvider) SetAllPackagesHash(hash []byte) error { + ret := _m.Called(hash) + + var r0 error + if rf, ok := ret.Get(0).(func([]byte) error); ok { + r0 = rf(hash) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SetLastReportedStatuses provides a mock function with given fields: statuses +func (_m *MockPackagesStateProvider) SetLastReportedStatuses(statuses *protobufs.PackageStatuses) error { + ret := _m.Called(statuses) + + var r0 error + if rf, ok := ret.Get(0).(func(*protobufs.PackageStatuses) error); ok { + r0 = rf(statuses) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SetPackageState provides a mock function with given fields: packageName, state +func (_m *MockPackagesStateProvider) SetPackageState(packageName string, state types.PackageState) error { + ret := _m.Called(packageName, state) + + var r0 error + if rf, ok := ret.Get(0).(func(string, types.PackageState) error); ok { + r0 = rf(packageName, state) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// UpdateContent provides a mock function with given fields: ctx, packageName, data, contentHash +func (_m *MockPackagesStateProvider) UpdateContent(ctx context.Context, packageName string, data io.Reader, contentHash []byte) error { + ret := _m.Called(ctx, packageName, data, contentHash) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, io.Reader, []byte) error); ok { + r0 = rf(ctx, packageName, data, contentHash) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type NewPackagesStateProviderT interface { + mock.TestingT + Cleanup(func()) +} + +// NewMockPackagesStateProvider creates a new instance of PackagesStateProvider. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewMockPackagesStateProvider(t NewPackagesStateProviderT) *MockPackagesStateProvider { + mock := &MockPackagesStateProvider{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/opamp/mocks/mock_updater_manager.go b/opamp/mocks/mock_updater_manager.go new file mode 100644 index 000000000..252598589 --- /dev/null +++ b/opamp/mocks/mock_updater_manager.go @@ -0,0 +1,39 @@ +// Code generated by mockery v2.12.3. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" + +// MockUpdaterManager is an autogenerated mock type for the UpdaterManager type +type MockUpdaterManager struct { + mock.Mock +} + +// StartAndMonitorUpdater provides a mock function with given fields: +func (_m *MockUpdaterManager) StartAndMonitorUpdater() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type NewUpdaterManagerT interface { + mock.TestingT + Cleanup(func()) +} + +// NewMockUpdaterManager creates a new instance of updaterManager. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewMockUpdaterManager(t NewUpdaterManagerT) *MockUpdaterManager { + mock := &MockUpdaterManager{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/opamp/observiq/observiq_client.go b/opamp/observiq/observiq_client.go index 3d9865352..f2f89f25d 100644 --- a/opamp/observiq/observiq_client.go +++ b/opamp/observiq/observiq_client.go @@ -17,14 +17,17 @@ package observiq import ( "context" + "encoding/hex" "errors" "fmt" "net/http" "net/url" + "sync" "github.com/observiq/observiq-otel-collector/collector" "github.com/observiq/observiq-otel-collector/internal/version" "github.com/observiq/observiq-otel-collector/opamp" + "github.com/observiq/observiq-otel-collector/packagestate" "github.com/open-telemetry/opamp-go/client" "github.com/open-telemetry/opamp-go/client/types" "github.com/open-telemetry/opamp-go/protobufs" @@ -41,11 +44,19 @@ var _ opamp.Client = (*Client)(nil) // Client represents a client that is connected to Iris via OpAmp type Client struct { - opampClient client.OpAMPClient - logger *zap.Logger - ident *identity - configManager opamp.ConfigManager - collector collector.Collector + opampClient client.OpAMPClient + logger *zap.Logger + ident *identity + configManager opamp.ConfigManager + downloadableFileManager opamp.DownloadableFileManager + collector collector.Collector + packagesStateProvider types.PackagesStateProvider + updaterManager updaterManager + mutex sync.Mutex + updatingPackage bool + + // To signal if we are disconnecting already and not take any actions on connection failures + disconnecting bool currentConfig opamp.Config } @@ -56,6 +67,7 @@ type NewClientArgs struct { Config opamp.Config Collector collector.Collector + TmpPath string ManagerConfigPath string CollectorConfigPath string LoggerConfigPath string @@ -66,13 +78,20 @@ func NewClient(args *NewClientArgs) (opamp.Client, error) { clientLogger := args.DefaultLogger.Named("opamp") configManager := NewAgentConfigManager(args.DefaultLogger) + updaterManger, err := newUpdaterManager(clientLogger, args.TmpPath) + if err != nil { + return nil, fmt.Errorf("failed to create updaterManager: %w", err) + } observiqClient := &Client{ - logger: clientLogger, - ident: newIdentity(clientLogger, args.Config), - configManager: configManager, - collector: args.Collector, - currentConfig: args.Config, + logger: clientLogger, + ident: newIdentity(clientLogger, args.Config), + configManager: configManager, + downloadableFileManager: newDownloadableFileManager(clientLogger, args.TmpPath), + collector: args.Collector, + currentConfig: args.Config, + packagesStateProvider: newPackagesStateProvider(clientLogger, packagestate.DefaultFileName), + updaterManager: updaterManger, } // Parse URL to determin scheme @@ -125,11 +144,18 @@ func (c *Client) Connect(ctx context.Context) error { // Compose and set the agent description if err := c.opampClient.SetAgentDescription(c.ident.ToAgentDescription()); err != nil { c.logger.Error("Error while setting agent description", zap.Error(err)) + + // Set package status file for error (for Updater to pick up), but do not force send to Server + c.attemptFailedInstall(fmt.Sprintf("Error while setting agent description: %s", err.Error())) + return err } tlsCfg, err := c.currentConfig.ToTLS() if err != nil { + // Set package status file for error (for Updater to pick up), but do not force send to Server + c.attemptFailedInstall(fmt.Sprintf("Failed creating TLS config: %s", err.Error())) + return fmt.Errorf("failed creating TLS config: %w", err) } @@ -157,20 +183,31 @@ func (c *Client) Connect(ctx context.Context) error { // OnCommandFunc // SaveRemoteConfigStatusFunc }, + PackagesStateProvider: c.packagesStateProvider, } // Start the embedded collector // Pass in the background context here so it's clear we need to shutdown the collector instead // of the context shutting it down via a cancel. if err := c.collector.Run(context.Background()); err != nil { + // Set package status file for error (for Updater to pick up), but do not force send to Server + c.attemptFailedInstall(fmt.Sprintf("Collector failed to start: %s", err.Error())) + return fmt.Errorf("collector failed to start: %w", err) } - return c.opampClient.Start(ctx, settings) + err = c.opampClient.Start(ctx, settings) + if err != nil { + // Set package status file for error (for Updater to pick up), but do not force send to Server + c.attemptFailedInstall(fmt.Sprintf("OpAMP client failed to start: %s", err.Error())) + } + + return err } // Disconnect disconnects from the server func (c *Client) Disconnect(ctx context.Context) error { + c.safeSetDisconnecting(true) c.collector.Stop() return c.opampClient.Stop(ctx) } @@ -179,10 +216,59 @@ func (c *Client) Disconnect(ctx context.Context) error { func (c *Client) onConnectHandler() { c.logger.Info("Successfully connected to server") + + // See if we can retrieve the PackageStatuses where the main package is in an installing state + lastPackageStatuses := c.getMainPackageInstallingLastStatuses() + if lastPackageStatuses == nil { + return + } + + lastMainPackageStatus := lastPackageStatuses.Packages[packagestate.CollectorPackageName] + // If in the middle of an install and we just connected, this is most likely becasue the collector was just spun up fresh by the Updater. + // If the current version matches the server offered version, this implies a good install and so we should set the PackageStatuses and + // send it to the OpAMP Server. If the version does not match, just change the PackageStatues JSON so that the Updater can start rollback. + if lastMainPackageStatus.ServerOfferedVersion == version.Version() { + c.logger.Info("Package update was successful", + zap.String("AllPackagesHash", hex.EncodeToString(lastPackageStatuses.ServerProvidedAllPackagesHash)), + zap.String("package", packagestate.CollectorPackageName)) + lastMainPackageStatus.Status = protobufs.PackageStatus_Installed + lastMainPackageStatus.AgentHasVersion = version.Version() + lastMainPackageStatus.AgentHasHash = lastMainPackageStatus.ServerOfferedHash + + if err := c.packagesStateProvider.SetLastReportedStatuses(lastPackageStatuses); err != nil { + c.logger.Error("Failed to set last reported package statuses", zap.Error(err)) + } + + // Only immediately send to server on success. Rollback will send this for failure. + if err := c.opampClient.SetPackageStatuses(lastPackageStatuses); err != nil { + c.logger.Error("OpAMP client failed to set package statuses", zap.Error(err)) + } + } else { + c.logger.Error( + fmt.Sprintf( + "Package update failed because of collector version mismatch: expected %s, actual %s", + lastMainPackageStatus.ServerOfferedVersion, version.Version()), + zap.String("package", packagestate.CollectorPackageName)) + + lastMainPackageStatus.Status = protobufs.PackageStatus_InstallFailed + lastMainPackageStatus.ErrorMessage = + fmt.Sprintf("Failed because of collector version mismatch: expected %s, actual %s", + lastMainPackageStatus.ServerOfferedVersion, version.Version()) + + if err := c.packagesStateProvider.SetLastReportedStatuses(lastPackageStatuses); err != nil { + c.logger.Error("Failed to set last reported package statuses", zap.Error(err)) + } + } } func (c *Client) onConnectFailedHandler(err error) { c.logger.Error("Failed to connect to server", zap.Error(err)) + + // We are currently disconnecting so any Connection failed error is expected and should not affect an install + if !c.safeGetDisconnecting() { + // Set package status file for error (for Updater to pick up), but do not force send to Server + c.attemptFailedInstall(fmt.Sprintf("Failed to connect to BindPlane: %s", err.Error())) + } } func (c *Client) onErrorHandler(errResp *protobufs.ServerErrorResponse) { @@ -196,6 +282,11 @@ func (c *Client) onMessageFuncHandler(ctx context.Context, msg *types.MessageDat c.logger.Error("Error while processing Remote Config Change", zap.Error(err)) } } + if msg.PackagesAvailable != nil { + if err := c.onPackagesAvailableHandler(msg.PackagesAvailable); err != nil { + c.logger.Error("Error while processing Packages Available Change", zap.Error(err)) + } + } } func (c *Client) onRemoteConfigHandler(ctx context.Context, remoteConfig *protobufs.AgentRemoteConfig) error { @@ -229,7 +320,266 @@ func (c *Client) onRemoteConfigHandler(ctx context.Context, remoteConfig *protob return nil } +func (c *Client) onPackagesAvailableHandler(packagesAvailable *protobufs.PackagesAvailable) error { + c.logger.Debug("Packages available handler") + + // Initialize PackageStatuses that will eventually be sent back to server + curPackageStatuses := &protobufs.PackageStatuses{ + ServerProvidedAllPackagesHash: packagesAvailable.GetAllPackagesHash(), + Packages: map[string]*protobufs.PackageStatus{}, + } + + // Don't respond to PackagesAvailable messages while currently installing. We use this in memory data rather than the + // PackageStatuses persistant data in order to ensure that we don't get stuck in a stuck state + if c.safeGetUpdatingPackage() { + c.logger.Warn( + "Not starting new package update as already installing new packages", + zap.String("AllPackagesHash", hex.EncodeToString(packagesAvailable.GetAllPackagesHash()))) + curPackageStatuses.ErrorMessage = "Already installing new packages" + if err := c.opampClient.SetPackageStatuses(curPackageStatuses); err != nil { + c.logger.Error("OpAMP client failed to set package statuses", zap.Error(err)) + } + return fmt.Errorf("failed because already installing packages") + } + + // Retrieve last known status (this should return with minimal info even on first time) + lastPackageStatuses, err := c.packagesStateProvider.LastReportedStatuses() + + // If there is a problem retrieving the last saved PackageStatuses, we will log the error + // but continue on as the only thing missing will be the agent package hash. + if err != nil { + c.logger.Warn("Failed to retrieve last reported package statuses", zap.Error(err)) + } + + var lastPkgStatusMap map[string]*protobufs.PackageStatus + if lastPackageStatuses != nil { + lastPkgStatusMap = lastPackageStatuses.GetPackages() + } + + curPackages, curPackageFiles := c.createPackageMaps(packagesAvailable.GetPackages(), lastPkgStatusMap) + curPackageStatuses.Packages = curPackages + + // This is an error because we need this file for communication during the update + if err = c.packagesStateProvider.SetLastReportedStatuses(curPackageStatuses); err != nil { + return fmt.Errorf("failed to save last reported package statuses: %w", err) + } + + if err = c.opampClient.SetPackageStatuses(curPackageStatuses); err != nil { + return fmt.Errorf("opamp client failed to set package statuses: %w", err) + } + + // Start update if applicable + collectorDownloadableFile := curPackageFiles[packagestate.CollectorPackageName] + if collectorDownloadableFile != nil { + c.safeSetUpdatingPackage(true) + go c.installPackageFromFile(collectorDownloadableFile, curPackageStatuses) + } + + return nil +} + +func (c *Client) createPackageMaps( + pkgAvailMap map[string]*protobufs.PackageAvailable, + lastPkgStatusMap map[string]*protobufs.PackageStatus) (map[string]*protobufs.PackageStatus, map[string]*protobufs.DownloadableFile) { + pkgStatusMap := map[string]*protobufs.PackageStatus{} + pkgFileMap := map[string]*protobufs.DownloadableFile{} + + // Loop through all of the available packages sent from the server + for name, availPkg := range pkgAvailMap { + switch name { + // If it's an expected package, return an installing status + case packagestate.CollectorPackageName: + var agentHash []byte + if lastPkgStatusMap != nil && lastPkgStatusMap[name] != nil { + if lastPkgStatusMap[name].GetAgentHasVersion() != version.Version() { + c.logger.Debug(fmt.Sprintf( + "Version: %s and last reported package status version: %s differ", + version.Version(), + lastPkgStatusMap[name].GetAgentHasVersion())) + } else { + agentHash = lastPkgStatusMap[name].GetAgentHasHash() + } + } + + pkgStatusMap[name] = &protobufs.PackageStatus{ + Name: name, + AgentHasVersion: version.Version(), + AgentHasHash: agentHash, + ServerOfferedVersion: availPkg.GetVersion(), + ServerOfferedHash: availPkg.GetHash(), + Status: protobufs.PackageStatus_Installed, + } + + if version.Version() == availPkg.GetVersion() { + c.logger.Info("Package update ignored because no new version offered", + zap.String("package", name)) + if agentHash == nil { + pkgStatusMap[name].AgentHasHash = availPkg.GetHash() + } + break + } + + if availPkg.GetVersion() != "" { + if availPkg.File != nil { + pkgStatusMap[name].Status = protobufs.PackageStatus_Installing + pkgFileMap[name] = availPkg.GetFile() + } else { + c.logger.Error( + "Package update failed to determine valid downloadable file", + zap.String("package", name)) + pkgStatusMap[name].Status = protobufs.PackageStatus_InstallFailed + pkgStatusMap[name].ErrorMessage = fmt.Sprintf("Package %s does not have a valid downloadable file", name) + } + } + // If it's not an expected package, return a failed status + default: + c.logger.Error( + "Package update failed because it is not supported", + zap.String("package", name)) + pkgStatusMap[name] = &protobufs.PackageStatus{ + Name: name, + ServerOfferedVersion: availPkg.GetVersion(), + ServerOfferedHash: availPkg.GetHash(), + Status: protobufs.PackageStatus_InstallFailed, + ErrorMessage: fmt.Sprintf("Package %s not supported", name), + } + } + } + + return pkgStatusMap, pkgFileMap +} + +// installPackageFromFile tries to download and extract the given tarball and then start up the new Updater binary that was +// inside of it +func (c *Client) installPackageFromFile(file *protobufs.DownloadableFile, curPackageStatuses *protobufs.PackageStatuses) { + c.logger.Info("Package update started", + zap.String("AllPackagesHash", hex.EncodeToString(curPackageStatuses.ServerProvidedAllPackagesHash)), + zap.String("package", packagestate.CollectorPackageName)) + // There should be no reason for us to exit this function unless we detected a problem with the Updater's installation + defer c.safeSetUpdatingPackage(false) + + if fileManagerErr := c.downloadableFileManager.FetchAndExtractArchive(file); fileManagerErr != nil { + c.logger.Error( + fmt.Sprintf( + "Package update failed to download and verify downloadable file: %s", fileManagerErr.Error()), + zap.String("package", packagestate.CollectorPackageName)) + // Remove the update artifacts that may exist, depending on where FetchAndExtractArchive failed. + c.downloadableFileManager.CleanupArtifacts() + + // Change existing status to show that install failed and get ready to send + curPackageStatuses.Packages[packagestate.CollectorPackageName].Status = protobufs.PackageStatus_InstallFailed + curPackageStatuses.Packages[packagestate.CollectorPackageName].ErrorMessage = + fmt.Sprintf("Failed to download and verify downloadable file: %s", fileManagerErr.Error()) + + if err := c.packagesStateProvider.SetLastReportedStatuses(curPackageStatuses); err != nil { + c.logger.Error("Failed to save last reported package statuses", zap.Error(err)) + } + + if err := c.opampClient.SetPackageStatuses(curPackageStatuses); err != nil { + c.logger.Error("OpAMP client failed to set package statuses", zap.Error(err)) + } + + return + } + + if monitorErr := c.updaterManager.StartAndMonitorUpdater(); monitorErr != nil { + c.logger.Error( + fmt.Sprintf("Package update failed because of issue with latest Updater: %s", monitorErr), + zap.String("package", packagestate.CollectorPackageName)) + // Remove the update artifacts + c.downloadableFileManager.CleanupArtifacts() + + // Reread package statuses in case Updater changed anything + newPackageStatuses, err := c.packagesStateProvider.LastReportedStatuses() + if err != nil { + c.logger.Error("Failed to read last reported package statuses", zap.Error(err)) + } + + // Change existing status to show that install failed and get ready to send + newPackageStatuses.Packages[packagestate.CollectorPackageName].Status = protobufs.PackageStatus_InstallFailed + if newPackageStatuses.Packages[packagestate.CollectorPackageName].ErrorMessage == "" { + newPackageStatuses.Packages[packagestate.CollectorPackageName].ErrorMessage = fmt.Sprintf("Failed to run the latest Updater: %s", monitorErr) + } + + if err := c.packagesStateProvider.SetLastReportedStatuses(newPackageStatuses); err != nil { + c.logger.Error("Failed to save last reported package statuses", zap.Error(err)) + } + + if err := c.opampClient.SetPackageStatuses(newPackageStatuses); err != nil { + c.logger.Error("OpAMP client failed to set package statuses", zap.Error(err)) + } + } +} + func (c *Client) onGetEffectiveConfigHandler(_ context.Context) (*protobufs.EffectiveConfig, error) { c.logger.Debug("Remote Compose Effective config handler") return c.configManager.ComposeEffectiveConfig() } + +// attemptFailedInstall sets PackageStatuses status to failed and error message if we are in the middle of an install. +// This should allow the updater to pick this up and start the rollback process +func (c *Client) attemptFailedInstall(errMsg string) { + // See if we can retrieve the PackageStatuses where the main package is in an installing state + lastPackageStatuses := c.getMainPackageInstallingLastStatuses() + if lastPackageStatuses == nil { + return + } + + c.logger.Error(fmt.Sprintf("Package update failed: %s", errMsg), + zap.String("package", packagestate.CollectorPackageName)) + + lastMainPackageStatus := lastPackageStatuses.Packages[packagestate.CollectorPackageName] + lastMainPackageStatus.Status = protobufs.PackageStatus_InstallFailed + lastMainPackageStatus.ErrorMessage = errMsg + + if err := c.packagesStateProvider.SetLastReportedStatuses(lastPackageStatuses); err != nil { + c.logger.Error("Failed to set last reported package statuses", zap.Error(err)) + } +} + +func (c *Client) getMainPackageInstallingLastStatuses() *protobufs.PackageStatuses { + lastPackageStatuses, err := c.packagesStateProvider.LastReportedStatuses() + if err != nil { + c.logger.Error("Failed to retrieve last reported package statuses", zap.Error(err)) + return nil + } + + // If we have no info on our main package, nothing else to do + if lastPackageStatuses == nil || lastPackageStatuses.Packages == nil || lastPackageStatuses.Packages[packagestate.CollectorPackageName] == nil { + c.logger.Warn("Failed to retrieve last reported package statuses for main package") + return nil + } + + lastMainPackageStatus := lastPackageStatuses.Packages[packagestate.CollectorPackageName] + + // If we were not installing before the connection, nothing else to do + if lastMainPackageStatus.Status != protobufs.PackageStatus_Installing { + return nil + } + + return lastPackageStatuses +} + +func (c *Client) safeSetUpdatingPackage(value bool) { + c.mutex.Lock() + defer c.mutex.Unlock() + c.updatingPackage = value +} + +func (c *Client) safeGetUpdatingPackage() bool { + c.mutex.Lock() + defer c.mutex.Unlock() + return c.updatingPackage +} + +func (c *Client) safeSetDisconnecting(value bool) { + c.mutex.Lock() + defer c.mutex.Unlock() + c.disconnecting = value +} + +func (c *Client) safeGetDisconnecting() bool { + c.mutex.Lock() + defer c.mutex.Unlock() + return c.disconnecting +} diff --git a/opamp/observiq/observiq_client_test.go b/opamp/observiq/observiq_client_test.go index 07c063d54..f3ea3f5cb 100644 --- a/opamp/observiq/observiq_client_test.go +++ b/opamp/observiq/observiq_client_test.go @@ -21,12 +21,15 @@ import ( "net/http" "os" "path/filepath" + "sync" "testing" + "time" colmocks "github.com/observiq/observiq-otel-collector/collector/mocks" "github.com/observiq/observiq-otel-collector/internal/version" "github.com/observiq/observiq-otel-collector/opamp" "github.com/observiq/observiq-otel-collector/opamp/mocks" + "github.com/observiq/observiq-otel-collector/packagestate" "github.com/open-telemetry/opamp-go/client/types" "github.com/open-telemetry/opamp-go/protobufs" "github.com/stretchr/testify/assert" @@ -121,10 +124,13 @@ func TestNewClient(t *testing.T) { // Do a shallow check on all fields to assert they exist and are equal to passed in params were possible assert.NotNil(t, observiqClient.opampClient) assert.NotNil(t, observiqClient.configManager) + assert.NotNil(t, observiqClient.packagesStateProvider) assert.Equal(t, testLogger.Named("opamp"), observiqClient.logger) assert.Equal(t, mockCollector, observiqClient.collector) assert.NotNil(t, observiqClient.ident) assert.Equal(t, observiqClient.currentConfig, tc.config) + assert.False(t, observiqClient.safeGetDisconnecting()) + assert.False(t, observiqClient.safeGetUpdatingPackage()) } }) @@ -144,6 +150,8 @@ func TestClientConnect(t *testing.T) { mockOpAmpClient := new(mocks.MockOpAMPClient) mockOpAmpClient.On("SetAgentDescription", mock.Anything).Return(expectedErr) + mockStateProvider := new(mocks.MockPackagesStateProvider) + mockStateProvider.On("LastReportedStatuses").Return(nil, nil) c := &Client{ opampClient: mockOpAmpClient, @@ -155,6 +163,7 @@ func TestClientConnect(t *testing.T) { Endpoint: "ws://localhost:1234", SecretKey: &secretKeyContents, }, + packagesStateProvider: mockStateProvider, } err := c.Connect(context.Background()) @@ -167,6 +176,8 @@ func TestClientConnect(t *testing.T) { mockOpAmpClient := new(mocks.MockOpAMPClient) mockOpAmpClient.On("SetAgentDescription", mock.Anything).Return(nil) + mockStateProvider := new(mocks.MockPackagesStateProvider) + mockStateProvider.On("LastReportedStatuses").Return(nil, nil) badCAFile := "bad-ca.cert" c := &Client{ @@ -182,6 +193,7 @@ func TestClientConnect(t *testing.T) { CAFile: &badCAFile, }, }, + packagesStateProvider: mockStateProvider, } err := c.Connect(context.Background()) @@ -196,6 +208,8 @@ func TestClientConnect(t *testing.T) { mockOpAmpClient := mocks.NewMockOpAMPClient(t) mockOpAmpClient.On("SetAgentDescription", mock.Anything).Return(nil) mockOpAmpClient.On("Start", mock.Anything, mock.Anything).Return(expectedErr) + mockStateProvider := new(mocks.MockPackagesStateProvider) + mockStateProvider.On("LastReportedStatuses").Return(nil, nil) mockCollector := colmocks.NewMockCollector(t) mockCollector.On("Run", mock.Anything).Return(nil) @@ -210,6 +224,7 @@ func TestClientConnect(t *testing.T) { Endpoint: "ws://localhost:1234", SecretKey: &secretKeyContents, }, + packagesStateProvider: mockStateProvider, } err := c.Connect(context.Background()) @@ -221,6 +236,8 @@ func TestClientConnect(t *testing.T) { testFunc: func(*testing.T) { mockOpAmpClient := mocks.NewMockOpAMPClient(t) mockOpAmpClient.On("SetAgentDescription", mock.Anything).Return(nil) + mockStateProvider := new(mocks.MockPackagesStateProvider) + mockStateProvider.On("LastReportedStatuses").Return(nil, nil) expectedErr := errors.New("oops") @@ -237,6 +254,7 @@ func TestClientConnect(t *testing.T) { Endpoint: "ws://localhost:1234", SecretKey: &secretKeyContents, }, + packagesStateProvider: mockStateProvider, } err := c.Connect(context.Background()) @@ -252,6 +270,8 @@ func TestClientConnect(t *testing.T) { mockCollector := colmocks.NewMockCollector(t) mockCollector.On("Run", mock.Anything).Return(nil) + mockPackagesStateProvider := mocks.NewMockPackagesStateProvider(t) + c := &Client{ opampClient: mockOpAmpClient, logger: zap.NewNop(), @@ -265,6 +285,7 @@ func TestClientConnect(t *testing.T) { Endpoint: "ws://localhost:1234", SecretKey: &secretKeyContents, }, + packagesStateProvider: mockPackagesStateProvider, } expectedSettings := types.StartSettings{ @@ -286,6 +307,7 @@ func TestClientConnect(t *testing.T) { OnMessageFunc: c.onMessageFuncHandler, GetEffectiveConfigFunc: c.onGetEffectiveConfigHandler, }, + PackagesStateProvider: c.packagesStateProvider, } mockOpAmpClient.On("Start", mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) { settings := args.Get(1).(types.StartSettings) @@ -293,6 +315,7 @@ func TestClientConnect(t *testing.T) { assert.Equal(t, expectedSettings.Header, settings.Header) assert.Equal(t, expectedSettings.TLSConfig, settings.TLSConfig) assert.Equal(t, expectedSettings.InstanceUid, settings.InstanceUid) + assert.Equal(t, expectedSettings.PackagesStateProvider, settings.PackagesStateProvider) // assert is unable to compare function pointers }) @@ -300,6 +323,104 @@ func TestClientConnect(t *testing.T) { assert.NoError(t, err) }, }, + { + desc: "Problem connecting & not installing", + testFunc: func(*testing.T) { + statuses := map[string]*protobufs.PackageStatus{ + packagestate.CollectorPackageName: { + Name: packagestate.CollectorPackageName, + AgentHasVersion: version.Version(), + ServerOfferedVersion: version.Version(), + Status: protobufs.PackageStatus_Installed, + }, + } + packageStatuses := &protobufs.PackageStatuses{ + Packages: statuses, + } + + expectedErr := errors.New("oops") + + mockOpAmpClient := new(mocks.MockOpAMPClient) + mockOpAmpClient.On("SetAgentDescription", mock.Anything).Return(expectedErr) + mockStateProvider := new(mocks.MockPackagesStateProvider) + mockStateProvider.On("LastReportedStatuses").Return(packageStatuses, nil) + + c := &Client{ + opampClient: mockOpAmpClient, + logger: zap.NewNop(), + ident: &identity{}, + configManager: nil, + collector: nil, + currentConfig: opamp.Config{ + Endpoint: "ws://localhost:1234", + SecretKey: &secretKeyContents, + }, + packagesStateProvider: mockStateProvider, + } + + c.Connect(context.Background()) + }, + }, + { + desc: "Problem connecting & installing", + testFunc: func(*testing.T) { + allHash := []byte("allHash") + hash := []byte("hash") + newHash := []byte("newHash") + newVersion := "99.99.99" + statuses := map[string]*protobufs.PackageStatus{ + packagestate.CollectorPackageName: { + Name: packagestate.CollectorPackageName, + AgentHasVersion: version.Version(), + AgentHasHash: hash, + ServerOfferedVersion: newVersion, + ServerOfferedHash: newHash, + Status: protobufs.PackageStatus_Installing, + }, + } + packageStatuses := &protobufs.PackageStatuses{ + ServerProvidedAllPackagesHash: allHash, + Packages: statuses, + } + + expectedErr := errors.New("oops") + + mockOpAmpClient := new(mocks.MockOpAMPClient) + mockOpAmpClient.On("SetAgentDescription", mock.Anything).Return(expectedErr) + mockStateProvider := new(mocks.MockPackagesStateProvider) + mockStateProvider.On("LastReportedStatuses").Return(packageStatuses, nil) + mockStateProvider.On("SetLastReportedStatuses", mock.Anything).Return(nil).Run(func(args mock.Arguments) { + status := args.Get(0).(*protobufs.PackageStatuses) + + assert.NotNil(t, status) + assert.Equal(t, "", status.ErrorMessage) + assert.Equal(t, allHash, status.ServerProvidedAllPackagesHash) + assert.Equal(t, 1, len(status.Packages)) + assert.Equal(t, packagestate.CollectorPackageName, status.Packages[packagestate.CollectorPackageName].Name) + assert.Equal(t, version.Version(), status.Packages[packagestate.CollectorPackageName].AgentHasVersion) + assert.Equal(t, hash, status.Packages[packagestate.CollectorPackageName].AgentHasHash) + assert.Equal(t, newVersion, status.Packages[packagestate.CollectorPackageName].ServerOfferedVersion) + assert.Equal(t, newHash, status.Packages[packagestate.CollectorPackageName].ServerOfferedHash) + assert.Equal(t, fmt.Sprintf("Error while setting agent description: %s", expectedErr), status.Packages[packagestate.CollectorPackageName].ErrorMessage) + assert.Equal(t, protobufs.PackageStatus_InstallFailed, status.Packages[packagestate.CollectorPackageName].Status) + }) + + c := &Client{ + opampClient: mockOpAmpClient, + logger: zap.NewNop(), + ident: &identity{}, + configManager: nil, + collector: nil, + currentConfig: opamp.Config{ + Endpoint: "ws://localhost:1234", + SecretKey: &secretKeyContents, + }, + packagesStateProvider: mockStateProvider, + } + + c.Connect(context.Background()) + }, + }, } for _, tc := range testCases { @@ -320,9 +441,336 @@ func TestClientDisconnect(t *testing.T) { } c.Disconnect(ctx) + assert.True(t, c.safeGetDisconnecting()) mockOpAmpClient.AssertExpectations(t) } +func TestClient_onConnectHandler(t *testing.T) { + testCases := []struct { + desc string + testFunc func(*testing.T) + }{ + { + desc: "LastReportedStatus error", + testFunc: func(*testing.T) { + expectedErr := errors.New("oops") + + mockStateProvider := new(mocks.MockPackagesStateProvider) + mockStateProvider.On("LastReportedStatuses").Return(nil, expectedErr) + + c := &Client{ + logger: zap.NewNop(), + packagesStateProvider: mockStateProvider, + } + + c.onConnectHandler() + }, + }, + { + desc: "LastReportedStatus no main package info", + testFunc: func(*testing.T) { + packageStatuses := &protobufs.PackageStatuses{ + ServerProvidedAllPackagesHash: []byte("allHash"), + Packages: make(map[string]*protobufs.PackageStatus), + } + + mockStateProvider := new(mocks.MockPackagesStateProvider) + mockStateProvider.On("LastReportedStatuses").Return(packageStatuses, nil) + + c := &Client{ + logger: zap.NewNop(), + packagesStateProvider: mockStateProvider, + } + + c.onConnectHandler() + }, + }, + { + desc: "Good LastReportedStatus but not installing", + testFunc: func(*testing.T) { + allHash := []byte("allHash") + hash := []byte("hash") + newHash := []byte("newHash") + newVersion := "99.99.99" + errorMessage := "problem" + statuses := map[string]*protobufs.PackageStatus{ + packagestate.CollectorPackageName: { + Name: packagestate.CollectorPackageName, + AgentHasVersion: version.Version(), + AgentHasHash: hash, + ServerOfferedVersion: newVersion, + ServerOfferedHash: newHash, + Status: protobufs.PackageStatus_InstallFailed, + ErrorMessage: errorMessage, + }, + } + packageStatuses := &protobufs.PackageStatuses{ + ServerProvidedAllPackagesHash: allHash, + Packages: statuses, + } + + mockStateProvider := new(mocks.MockPackagesStateProvider) + mockStateProvider.On("LastReportedStatuses").Return(packageStatuses, nil) + + c := &Client{ + logger: zap.NewNop(), + packagesStateProvider: mockStateProvider, + } + + c.onConnectHandler() + }, + }, + { + desc: "Installing with version mismatch", + testFunc: func(*testing.T) { + allHash := []byte("allHash") + hash := []byte("hash") + newHash := []byte("newHash") + newVersion := "99.99.99" + statuses := map[string]*protobufs.PackageStatus{ + packagestate.CollectorPackageName: { + Name: packagestate.CollectorPackageName, + AgentHasVersion: version.Version(), + AgentHasHash: hash, + ServerOfferedVersion: newVersion, + ServerOfferedHash: newHash, + Status: protobufs.PackageStatus_Installing, + }, + } + packageStatuses := &protobufs.PackageStatuses{ + ServerProvidedAllPackagesHash: allHash, + Packages: statuses, + } + + mockStateProvider := new(mocks.MockPackagesStateProvider) + mockStateProvider.On("LastReportedStatuses").Return(packageStatuses, nil) + mockStateProvider.On("SetLastReportedStatuses", mock.Anything).Return(nil).Run(func(args mock.Arguments) { + status := args.Get(0).(*protobufs.PackageStatuses) + + assert.NotNil(t, status) + assert.Equal(t, "", status.ErrorMessage) + assert.Equal(t, allHash, status.ServerProvidedAllPackagesHash) + assert.Equal(t, 1, len(status.Packages)) + assert.Equal(t, packagestate.CollectorPackageName, status.Packages[packagestate.CollectorPackageName].Name) + assert.Equal(t, version.Version(), status.Packages[packagestate.CollectorPackageName].AgentHasVersion) + assert.Equal(t, hash, status.Packages[packagestate.CollectorPackageName].AgentHasHash) + assert.Equal(t, newVersion, status.Packages[packagestate.CollectorPackageName].ServerOfferedVersion) + assert.Equal(t, newHash, status.Packages[packagestate.CollectorPackageName].ServerOfferedHash) + assert.Equal(t, "Failed because of collector version mismatch: expected 99.99.99, actual latest", status.Packages[packagestate.CollectorPackageName].ErrorMessage) + assert.Equal(t, protobufs.PackageStatus_InstallFailed, status.Packages[packagestate.CollectorPackageName].Status) + }) + + c := &Client{ + logger: zap.NewNop(), + packagesStateProvider: mockStateProvider, + } + + c.onConnectHandler() + }, + }, + { + desc: "Installing with new version match", + testFunc: func(*testing.T) { + allHash := []byte("allHash") + hash := []byte("hash") + newHash := []byte("newHash") + oldVersion := "99.99.99" + newVersion := version.Version() + statuses := map[string]*protobufs.PackageStatus{ + packagestate.CollectorPackageName: { + Name: packagestate.CollectorPackageName, + AgentHasVersion: oldVersion, + AgentHasHash: hash, + ServerOfferedVersion: newVersion, + ServerOfferedHash: newHash, + Status: protobufs.PackageStatus_Installing, + }, + } + packageStatuses := &protobufs.PackageStatuses{ + ServerProvidedAllPackagesHash: allHash, + Packages: statuses, + } + + mockStateProvider := new(mocks.MockPackagesStateProvider) + mockStateProvider.On("LastReportedStatuses").Return(packageStatuses, nil) + mockStateProvider.On("SetLastReportedStatuses", mock.Anything).Return(nil) + mockOpAmpClient := mocks.NewMockOpAMPClient(t) + mockOpAmpClient.On("SetPackageStatuses", mock.Anything).Return(nil).Run(func(args mock.Arguments) { + status := args.Get(0).(*protobufs.PackageStatuses) + + assert.NotNil(t, status) + assert.Equal(t, "", status.ErrorMessage) + assert.Equal(t, allHash, status.ServerProvidedAllPackagesHash) + assert.Equal(t, 1, len(status.Packages)) + assert.Equal(t, packagestate.CollectorPackageName, status.Packages[packagestate.CollectorPackageName].Name) + assert.Equal(t, newVersion, status.Packages[packagestate.CollectorPackageName].AgentHasVersion) + assert.Equal(t, newHash, status.Packages[packagestate.CollectorPackageName].AgentHasHash) + assert.Equal(t, newVersion, status.Packages[packagestate.CollectorPackageName].ServerOfferedVersion) + assert.Equal(t, newHash, status.Packages[packagestate.CollectorPackageName].ServerOfferedHash) + assert.Equal(t, "", status.Packages[packagestate.CollectorPackageName].ErrorMessage) + assert.Equal(t, protobufs.PackageStatus_Installed, status.Packages[packagestate.CollectorPackageName].Status) + }) + + c := &Client{ + logger: zap.NewNop(), + opampClient: mockOpAmpClient, + packagesStateProvider: mockStateProvider, + } + + c.onConnectHandler() + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, tc.testFunc) + } +} + +func TestClient_onConnectFailedHandler(t *testing.T) { + expectedErr := errors.New("oops") + + testCases := []struct { + desc string + testFunc func(*testing.T) + }{ + { + desc: "LastReportedStatus error", + testFunc: func(*testing.T) { + mockStateProvider := new(mocks.MockPackagesStateProvider) + mockStateProvider.On("LastReportedStatuses").Return(nil, expectedErr) + + c := &Client{ + logger: zap.NewNop(), + packagesStateProvider: mockStateProvider, + } + + c.onConnectFailedHandler(expectedErr) + }, + }, + { + desc: "LastReportedStatus no main package info", + testFunc: func(*testing.T) { + packageStatuses := &protobufs.PackageStatuses{ + ServerProvidedAllPackagesHash: []byte("allHash"), + Packages: make(map[string]*protobufs.PackageStatus), + } + + mockStateProvider := new(mocks.MockPackagesStateProvider) + mockStateProvider.On("LastReportedStatuses").Return(packageStatuses, nil) + + c := &Client{ + logger: zap.NewNop(), + packagesStateProvider: mockStateProvider, + } + + c.onConnectFailedHandler(expectedErr) + }, + }, + { + desc: "Disconnect do not change package status", + testFunc: func(*testing.T) { + mockStateProvider := new(mocks.MockPackagesStateProvider) + + c := &Client{ + logger: zap.NewNop(), + packagesStateProvider: mockStateProvider, + disconnecting: true, + } + + c.onConnectFailedHandler(expectedErr) + }, + }, + { + desc: "Good LastReportedStatus but not installing", + testFunc: func(*testing.T) { + allHash := []byte("allHash") + hash := []byte("hash") + newHash := []byte("newHash") + newVersion := "99.99.99" + errorMessage := "problem" + statuses := map[string]*protobufs.PackageStatus{ + packagestate.CollectorPackageName: { + Name: packagestate.CollectorPackageName, + AgentHasVersion: version.Version(), + AgentHasHash: hash, + ServerOfferedVersion: newVersion, + ServerOfferedHash: newHash, + Status: protobufs.PackageStatus_InstallFailed, + ErrorMessage: errorMessage, + }, + } + packageStatuses := &protobufs.PackageStatuses{ + ServerProvidedAllPackagesHash: allHash, + Packages: statuses, + } + + mockStateProvider := new(mocks.MockPackagesStateProvider) + mockStateProvider.On("LastReportedStatuses").Return(packageStatuses, nil) + + c := &Client{ + logger: zap.NewNop(), + packagesStateProvider: mockStateProvider, + } + + c.onConnectFailedHandler(expectedErr) + }, + }, + { + desc: "Good LastReportedStatus and installing", + testFunc: func(*testing.T) { + allHash := []byte("allHash") + hash := []byte("hash") + newHash := []byte("newHash") + newVersion := "99.99.99" + statuses := map[string]*protobufs.PackageStatus{ + packagestate.CollectorPackageName: { + Name: packagestate.CollectorPackageName, + AgentHasVersion: version.Version(), + AgentHasHash: hash, + ServerOfferedVersion: newVersion, + ServerOfferedHash: newHash, + Status: protobufs.PackageStatus_Installing, + }, + } + packageStatuses := &protobufs.PackageStatuses{ + ServerProvidedAllPackagesHash: allHash, + Packages: statuses, + } + + mockStateProvider := new(mocks.MockPackagesStateProvider) + mockStateProvider.On("LastReportedStatuses").Return(packageStatuses, nil) + mockStateProvider.On("SetLastReportedStatuses", mock.Anything).Return(nil).Run(func(args mock.Arguments) { + status := args.Get(0).(*protobufs.PackageStatuses) + + assert.NotNil(t, status) + assert.Equal(t, "", status.ErrorMessage) + assert.Equal(t, allHash, status.ServerProvidedAllPackagesHash) + assert.Equal(t, 1, len(status.Packages)) + assert.Equal(t, packagestate.CollectorPackageName, status.Packages[packagestate.CollectorPackageName].Name) + assert.Equal(t, version.Version(), status.Packages[packagestate.CollectorPackageName].AgentHasVersion) + assert.Equal(t, hash, status.Packages[packagestate.CollectorPackageName].AgentHasHash) + assert.Equal(t, newVersion, status.Packages[packagestate.CollectorPackageName].ServerOfferedVersion) + assert.Equal(t, newHash, status.Packages[packagestate.CollectorPackageName].ServerOfferedHash) + assert.Equal(t, fmt.Sprintf("Failed to connect to BindPlane: %s", expectedErr), status.Packages[packagestate.CollectorPackageName].ErrorMessage) + assert.Equal(t, protobufs.PackageStatus_InstallFailed, status.Packages[packagestate.CollectorPackageName].Status) + }) + + c := &Client{ + logger: zap.NewNop(), + packagesStateProvider: mockStateProvider, + } + + c.onConnectFailedHandler(expectedErr) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, tc.testFunc) + } +} + func TestClient_onGetEffectiveConfigHandler(t *testing.T) { mockManager := mocks.NewMockConfigManager(t) @@ -499,3 +947,525 @@ func TestClient_onRemoteConfigHandler(t *testing.T) { t.Run(tc.desc, tc.testFunc) } } + +func TestClient_onPackagesAvailableHandler(t *testing.T) { + collectorPackageName := packagestate.CollectorPackageName + allHash := []byte("totalhash0") + newAllHash := []byte("totalhash1") + packageHash := []byte("hash0") + newPackageHash := []byte("hash1") + newVersion := "999.999.999" + expectedErr := errors.New("oops") + + packages := map[string]*protobufs.PackageAvailable{ + collectorPackageName: { + Version: version.Version(), + Hash: packageHash, + File: &protobufs.DownloadableFile{}, + }, + } + packagesAvailable := &protobufs.PackagesAvailable{ + AllPackagesHash: newAllHash, + Packages: packages, + } + + statuses := map[string]*protobufs.PackageStatus{ + collectorPackageName: { + Name: collectorPackageName, + AgentHasVersion: version.Version(), + AgentHasHash: packageHash, + ServerOfferedVersion: version.Version(), + ServerOfferedHash: packageHash, + Status: protobufs.PackageStatus_Installed, + }, + } + packageStatuses := &protobufs.PackageStatuses{ + ServerProvidedAllPackagesHash: allHash, + Packages: statuses, + } + + testCases := []struct { + desc string + testFunc func(*testing.T) + }{ + { + desc: "Same PackagesAvailable version but bad Last PackagesStatuses", + testFunc: func(t *testing.T) { + mockProvider := mocks.NewMockPackagesStateProvider(t) + mockProvider.On("LastReportedStatuses").Return(nil, expectedErr) + mockProvider.On("SetLastReportedStatuses", mock.Anything).Return(nil) + mockOpAmpClient := mocks.NewMockOpAMPClient(t) + mockOpAmpClient.On("SetPackageStatuses", mock.Anything).Return(nil).Run(func(args mock.Arguments) { + status := args.Get(0).(*protobufs.PackageStatuses) + + assert.NotNil(t, status) + assert.Equal(t, "", status.ErrorMessage) + assert.Equal(t, packagesAvailable.AllPackagesHash, status.ServerProvidedAllPackagesHash) + assert.Equal(t, 1, len(status.Packages)) + assert.Equal(t, packagesAvailable.Packages[collectorPackageName].Version, status.Packages[collectorPackageName].ServerOfferedVersion) + assert.Equal(t, packagesAvailable.Packages[collectorPackageName].Hash, status.Packages[collectorPackageName].ServerOfferedHash) + assert.Equal(t, "", status.Packages[collectorPackageName].ErrorMessage) + assert.Equal(t, protobufs.PackageStatus_Installed, status.Packages[collectorPackageName].Status) + assert.Equal(t, collectorPackageName, status.Packages[collectorPackageName].Name) + assert.Equal(t, packagesAvailable.Packages[collectorPackageName].Hash, status.Packages[collectorPackageName].AgentHasHash) + assert.Equal(t, packagesAvailable.Packages[collectorPackageName].Version, status.Packages[collectorPackageName].AgentHasVersion) + }) + + c := &Client{ + packagesStateProvider: mockProvider, + opampClient: mockOpAmpClient, + logger: zap.NewNop(), + } + + err := c.onPackagesAvailableHandler(packagesAvailable) + assert.NoError(t, err) + }, + }, + { + desc: "Same PackagesAvailable version", + testFunc: func(t *testing.T) { + mockProvider := mocks.NewMockPackagesStateProvider(t) + mockProvider.On("LastReportedStatuses").Return(packageStatuses, nil) + mockProvider.On("SetLastReportedStatuses", mock.Anything).Return(nil) + mockOpAmpClient := mocks.NewMockOpAMPClient(t) + mockOpAmpClient.On("SetPackageStatuses", mock.Anything).Return(nil).Run(func(args mock.Arguments) { + status := args.Get(0).(*protobufs.PackageStatuses) + + assert.NotNil(t, status) + assert.Equal(t, "", status.ErrorMessage) + assert.Equal(t, packagesAvailable.AllPackagesHash, status.ServerProvidedAllPackagesHash) + assert.Equal(t, 1, len(status.Packages)) + assert.Equal(t, packagesAvailable.Packages[collectorPackageName].Version, status.Packages[collectorPackageName].ServerOfferedVersion) + assert.Equal(t, packagesAvailable.Packages[collectorPackageName].Hash, status.Packages[collectorPackageName].ServerOfferedHash) + assert.Equal(t, "", status.Packages[collectorPackageName].ErrorMessage) + assert.Equal(t, protobufs.PackageStatus_Installed, status.Packages[collectorPackageName].Status) + assert.Equal(t, collectorPackageName, status.Packages[collectorPackageName].Name) + assert.Equal(t, packageStatuses.Packages[collectorPackageName].AgentHasHash, status.Packages[collectorPackageName].AgentHasHash) + assert.Equal(t, packageStatuses.Packages[collectorPackageName].AgentHasVersion, status.Packages[collectorPackageName].AgentHasVersion) + }) + + c := &Client{ + packagesStateProvider: mockProvider, + opampClient: mockOpAmpClient, + logger: zap.NewNop(), + } + + err := c.onPackagesAvailableHandler(packagesAvailable) + assert.NoError(t, err) + }, + }, + { + desc: "Same PackagesAvailable version and non supported package", + testFunc: func(t *testing.T) { + badPackageName := "no-support-package" + packagesNotSupported := map[string]*protobufs.PackageAvailable{ + collectorPackageName: { + Version: version.Version(), + Hash: packageHash, + File: &protobufs.DownloadableFile{}, + }, + badPackageName: { + Version: newVersion, + Hash: packageHash, + }, + } + packagesAvailableNotSupported := &protobufs.PackagesAvailable{ + AllPackagesHash: newAllHash, + Packages: packagesNotSupported, + } + + mockProvider := mocks.NewMockPackagesStateProvider(t) + mockProvider.On("LastReportedStatuses").Return(packageStatuses, nil) + mockProvider.On("SetLastReportedStatuses", mock.Anything).Return(nil) + mockOpAmpClient := mocks.NewMockOpAMPClient(t) + mockOpAmpClient.On("SetPackageStatuses", mock.Anything).Return(nil).Run(func(args mock.Arguments) { + status := args.Get(0).(*protobufs.PackageStatuses) + + assert.NotNil(t, status) + assert.Equal(t, "", status.ErrorMessage) + assert.Equal(t, packagesAvailableNotSupported.AllPackagesHash, status.ServerProvidedAllPackagesHash) + assert.Equal(t, 2, len(status.Packages)) + assert.Equal(t, packagesAvailableNotSupported.Packages[collectorPackageName].Version, status.Packages[collectorPackageName].ServerOfferedVersion) + assert.Equal(t, packagesAvailableNotSupported.Packages[collectorPackageName].Hash, status.Packages[collectorPackageName].ServerOfferedHash) + assert.Equal(t, "", status.Packages[collectorPackageName].ErrorMessage) + assert.Equal(t, protobufs.PackageStatus_Installed, status.Packages[collectorPackageName].Status) + assert.Equal(t, collectorPackageName, status.Packages[collectorPackageName].Name) + assert.Equal(t, packagesAvailableNotSupported.Packages[collectorPackageName].Hash, status.Packages[collectorPackageName].AgentHasHash) + assert.Equal(t, packagesAvailableNotSupported.Packages[collectorPackageName].Version, status.Packages[collectorPackageName].AgentHasVersion) + assert.Equal(t, packagesAvailableNotSupported.Packages[badPackageName].Version, status.Packages[badPackageName].ServerOfferedVersion) + assert.Equal(t, packagesAvailableNotSupported.Packages[badPackageName].Hash, status.Packages[badPackageName].ServerOfferedHash) + assert.Equal(t, fmt.Sprintf("Package %s not supported", badPackageName), status.Packages[badPackageName].ErrorMessage) + assert.Equal(t, protobufs.PackageStatus_InstallFailed, status.Packages[badPackageName].Status) + assert.Equal(t, badPackageName, status.Packages[badPackageName].Name) + assert.Nil(t, status.Packages[badPackageName].AgentHasHash) + assert.Equal(t, "", status.Packages[badPackageName].AgentHasVersion) + }) + + c := &Client{ + packagesStateProvider: mockProvider, + opampClient: mockOpAmpClient, + logger: zap.NewNop(), + } + + err := c.onPackagesAvailableHandler(packagesAvailableNotSupported) + assert.NoError(t, err) + }, + }, + { + desc: "Same PackagesAvailable version but Last PackageStatuses version mismatch", + testFunc: func(t *testing.T) { + statusesDiffHash := map[string]*protobufs.PackageStatus{ + collectorPackageName: { + Name: collectorPackageName, + AgentHasVersion: newVersion, + AgentHasHash: newPackageHash, + ServerOfferedVersion: newVersion, + ServerOfferedHash: newPackageHash, + Status: protobufs.PackageStatus_Installed, + }, + } + packageStatusesDiffHash := &protobufs.PackageStatuses{ + ServerProvidedAllPackagesHash: newAllHash, + Packages: statusesDiffHash, + } + + mockProvider := mocks.NewMockPackagesStateProvider(t) + mockProvider.On("LastReportedStatuses").Return(packageStatusesDiffHash, nil) + mockProvider.On("SetLastReportedStatuses", mock.Anything).Return(nil) + mockOpAmpClient := mocks.NewMockOpAMPClient(t) + mockOpAmpClient.On("SetPackageStatuses", mock.Anything).Return(nil).Run(func(args mock.Arguments) { + status := args.Get(0).(*protobufs.PackageStatuses) + + assert.NotNil(t, status) + assert.Equal(t, "", status.ErrorMessage) + assert.Equal(t, packagesAvailable.AllPackagesHash, status.ServerProvidedAllPackagesHash) + assert.Equal(t, 1, len(status.Packages)) + assert.Equal(t, packagesAvailable.Packages[collectorPackageName].Version, status.Packages[collectorPackageName].ServerOfferedVersion) + assert.Equal(t, packagesAvailable.Packages[collectorPackageName].Hash, status.Packages[collectorPackageName].ServerOfferedHash) + assert.Equal(t, "", status.Packages[collectorPackageName].ErrorMessage) + assert.Equal(t, protobufs.PackageStatus_Installed, status.Packages[collectorPackageName].Status) + assert.Equal(t, collectorPackageName, status.Packages[collectorPackageName].Name) + assert.Equal(t, packagesAvailable.Packages[collectorPackageName].Hash, status.Packages[collectorPackageName].AgentHasHash) + assert.NotEqual(t, statusesDiffHash[collectorPackageName].AgentHasHash, status.Packages[collectorPackageName].AgentHasHash) + assert.Equal(t, packagesAvailable.Packages[collectorPackageName].Version, status.Packages[collectorPackageName].AgentHasVersion) + assert.NotEqual(t, statusesDiffHash[collectorPackageName].AgentHasVersion, status.Packages[collectorPackageName].AgentHasVersion) + }) + + c := &Client{ + packagesStateProvider: mockProvider, + opampClient: mockOpAmpClient, + logger: zap.NewNop(), + } + + err := c.onPackagesAvailableHandler(packagesAvailable) + assert.NoError(t, err) + }, + }, + // The version of this test where the update goes well can't exist because + // it would kill the collector. StartAndMonitorUpdater will always return an error + // if it does return. + { + desc: "New PackagesAvailable version with good file but bad update", + testFunc: func(t *testing.T) { + packagesNew := map[string]*protobufs.PackageAvailable{ + collectorPackageName: { + Version: newVersion, + Hash: newPackageHash, + File: &protobufs.DownloadableFile{}, + }, + } + packagesAvailableNew := &protobufs.PackagesAvailable{ + AllPackagesHash: newAllHash, + Packages: packagesNew, + } + savedStatuses := map[string]*protobufs.PackageStatus{ + collectorPackageName: { + Name: collectorPackageName, + AgentHasVersion: version.Version(), + AgentHasHash: packageHash, + ServerOfferedVersion: newVersion, + ServerOfferedHash: newPackageHash, + Status: protobufs.PackageStatus_Installing, + }, + } + savedPackageStatuses := &protobufs.PackageStatuses{ + ServerProvidedAllPackagesHash: newAllHash, + Packages: savedStatuses, + } + wg := sync.WaitGroup{} + wg.Add(2) + mockUpdaterManager := mocks.NewMockUpdaterManager(t) + mockUpdaterManager.On("StartAndMonitorUpdater").Return(expectedErr) + mockProvider := mocks.NewMockPackagesStateProvider(t) + mockProvider.On("LastReportedStatuses").Return(packageStatuses, nil).Once() + mockProvider.On("LastReportedStatuses").Return(savedPackageStatuses, nil) + mockProvider.On("SetLastReportedStatuses", mock.Anything).Return(nil) + mockFileManager := mocks.NewMockDownloadableFileManager(t) + mockFileManager.On("FetchAndExtractArchive", mock.Anything).Return(nil).Run(func(args mock.Arguments) { + wg.Done() + }) + mockFileManager.On("CleanupArtifacts").Return().Times(1) + + mockOpAmpClient := mocks.NewMockOpAMPClient(t) + mockOpAmpClient.On("SetPackageStatuses", mock.Anything).Return(nil).Once().Run(func(args mock.Arguments) { + status := args.Get(0).(*protobufs.PackageStatuses) + + assert.NotNil(t, status) + assert.Equal(t, "", status.ErrorMessage) + assert.Equal(t, packagesAvailableNew.AllPackagesHash, status.ServerProvidedAllPackagesHash) + assert.Equal(t, 1, len(status.Packages)) + assert.Equal(t, packagesAvailableNew.Packages[collectorPackageName].Version, status.Packages[collectorPackageName].ServerOfferedVersion) + assert.Equal(t, packagesAvailableNew.Packages[collectorPackageName].Hash, status.Packages[collectorPackageName].ServerOfferedHash) + assert.Equal(t, "", status.Packages[collectorPackageName].ErrorMessage) + assert.Equal(t, protobufs.PackageStatus_Installing, status.Packages[collectorPackageName].Status) + assert.Equal(t, collectorPackageName, status.Packages[collectorPackageName].Name) + assert.Equal(t, packageStatuses.Packages[collectorPackageName].AgentHasHash, status.Packages[collectorPackageName].AgentHasHash) + assert.Equal(t, packageStatuses.Packages[collectorPackageName].AgentHasVersion, status.Packages[collectorPackageName].AgentHasVersion) + }) + mockOpAmpClient.On("SetPackageStatuses", mock.Anything).Return(nil).Run(func(args mock.Arguments) { + status := args.Get(0).(*protobufs.PackageStatuses) + + assert.NotNil(t, status) + assert.Equal(t, "", status.ErrorMessage) + assert.Equal(t, packagesAvailableNew.AllPackagesHash, status.ServerProvidedAllPackagesHash) + assert.Equal(t, 1, len(status.Packages)) + assert.Equal(t, packagesAvailableNew.Packages[collectorPackageName].Version, status.Packages[collectorPackageName].ServerOfferedVersion) + assert.Equal(t, packagesAvailableNew.Packages[collectorPackageName].Hash, status.Packages[collectorPackageName].ServerOfferedHash) + assert.Equal(t, "Failed to run the latest Updater: oops", status.Packages[collectorPackageName].ErrorMessage) + assert.Equal(t, protobufs.PackageStatus_InstallFailed, status.Packages[collectorPackageName].Status) + assert.Equal(t, collectorPackageName, status.Packages[collectorPackageName].Name) + assert.Equal(t, packageStatuses.Packages[collectorPackageName].AgentHasHash, status.Packages[collectorPackageName].AgentHasHash) + assert.Equal(t, packageStatuses.Packages[collectorPackageName].AgentHasVersion, status.Packages[collectorPackageName].AgentHasVersion) + wg.Done() + }) + + c := &Client{ + packagesStateProvider: mockProvider, + downloadableFileManager: mockFileManager, + opampClient: mockOpAmpClient, + logger: zap.NewNop(), + updaterManager: mockUpdaterManager, + } + + err := c.onPackagesAvailableHandler(packagesAvailableNew) + assert.NoError(t, err) + wg.Wait() + assert.False(t, c.safeGetUpdatingPackage()) + }, + }, + { + desc: "New PackagesAvailable version while already installing", + testFunc: func(t *testing.T) { + packagesNew := map[string]*protobufs.PackageAvailable{ + collectorPackageName: { + Version: newVersion, + Hash: newPackageHash, + File: &protobufs.DownloadableFile{}, + }, + } + packagesAvailableNew := &protobufs.PackagesAvailable{ + AllPackagesHash: newAllHash, + Packages: packagesNew, + } + + mockOpAmpClient := mocks.NewMockOpAMPClient(t) + mockOpAmpClient.On("SetPackageStatuses", mock.Anything).Return(nil).Run(func(args mock.Arguments) { + status := args.Get(0).(*protobufs.PackageStatuses) + + assert.NotNil(t, status) + assert.Equal(t, "Already installing new packages", status.ErrorMessage) + assert.Equal(t, packagesAvailableNew.AllPackagesHash, status.ServerProvidedAllPackagesHash) + assert.Equal(t, 0, len(status.Packages)) + }) + + c := &Client{ + opampClient: mockOpAmpClient, + logger: zap.NewNop(), + updatingPackage: true, + } + + err := c.onPackagesAvailableHandler(packagesAvailableNew) + assert.ErrorContains(t, err, "failed because already installing packages") + }, + }, + { + desc: "New PackagesAvailable version with no DownloadableFile", + testFunc: func(t *testing.T) { + packagesNoFile := map[string]*protobufs.PackageAvailable{ + collectorPackageName: { + Version: newVersion, + Hash: newPackageHash, + }, + } + packagesAvailableNoFile := &protobufs.PackagesAvailable{ + AllPackagesHash: newAllHash, + Packages: packagesNoFile, + } + + mockProvider := mocks.NewMockPackagesStateProvider(t) + mockProvider.On("LastReportedStatuses").Return(packageStatuses, nil) + mockProvider.On("SetLastReportedStatuses", mock.Anything).Return(nil) + mockOpAmpClient := mocks.NewMockOpAMPClient(t) + mockOpAmpClient.On("SetPackageStatuses", mock.Anything).Return(nil).Run(func(args mock.Arguments) { + status := args.Get(0).(*protobufs.PackageStatuses) + + assert.NotNil(t, status) + assert.Equal(t, "", status.ErrorMessage) + assert.Equal(t, packagesAvailableNoFile.AllPackagesHash, status.ServerProvidedAllPackagesHash) + assert.Equal(t, 1, len(status.Packages)) + assert.Equal(t, packagesAvailableNoFile.Packages[collectorPackageName].Version, status.Packages[collectorPackageName].ServerOfferedVersion) + assert.Equal(t, packagesAvailableNoFile.Packages[collectorPackageName].Hash, status.Packages[collectorPackageName].ServerOfferedHash) + assert.Equal(t, "Package observiq-otel-collector does not have a valid downloadable file", status.Packages[collectorPackageName].ErrorMessage) + assert.Equal(t, protobufs.PackageStatus_InstallFailed, status.Packages[collectorPackageName].Status) + assert.Equal(t, collectorPackageName, status.Packages[collectorPackageName].Name) + assert.Equal(t, packageStatuses.Packages[collectorPackageName].AgentHasHash, status.Packages[collectorPackageName].AgentHasHash) + assert.Equal(t, packageStatuses.Packages[collectorPackageName].AgentHasVersion, status.Packages[collectorPackageName].AgentHasVersion) + }) + + c := &Client{ + packagesStateProvider: mockProvider, + opampClient: mockOpAmpClient, + logger: zap.NewNop(), + } + + err := c.onPackagesAvailableHandler(packagesAvailableNoFile) + assert.NoError(t, err) + }, + }, + { + desc: "New PackagesAvailable version with bad DownloadableFile", + testFunc: func(t *testing.T) { + packagesNew := map[string]*protobufs.PackageAvailable{ + collectorPackageName: { + Version: newVersion, + Hash: newPackageHash, + File: &protobufs.DownloadableFile{}, + }, + } + packagesAvailableNew := &protobufs.PackagesAvailable{ + AllPackagesHash: newAllHash, + Packages: packagesNew, + } + + mockFileManager := mocks.NewMockDownloadableFileManager(t) + mockFileManager.On("FetchAndExtractArchive", mock.Anything).Return(expectedErr) + mockFileManager.On("CleanupArtifacts").Return().Times(1) + mockProvider := mocks.NewMockPackagesStateProvider(t) + mockProvider.On("LastReportedStatuses").Return(packageStatuses, nil) + mockProvider.On("SetLastReportedStatuses", mock.Anything).Return(nil) + mockOpAmpClient := mocks.NewMockOpAMPClient(t) + // This is for the initial status that is sent in the main function. + mockOpAmpClient.On("SetPackageStatuses", mock.Anything).Return(nil).Once().Run(func(args mock.Arguments) { + status := args.Get(0).(*protobufs.PackageStatuses) + + assert.NotNil(t, status) + assert.Equal(t, "", status.ErrorMessage) + assert.Equal(t, packagesAvailableNew.AllPackagesHash, status.ServerProvidedAllPackagesHash) + assert.Equal(t, 1, len(status.Packages)) + assert.Equal(t, packagesAvailableNew.Packages[collectorPackageName].Version, status.Packages[collectorPackageName].ServerOfferedVersion) + assert.Equal(t, packagesAvailableNew.Packages[collectorPackageName].Hash, status.Packages[collectorPackageName].ServerOfferedHash) + assert.Equal(t, "", status.Packages[collectorPackageName].ErrorMessage) + assert.Equal(t, protobufs.PackageStatus_Installing, status.Packages[collectorPackageName].Status) + assert.Equal(t, collectorPackageName, status.Packages[collectorPackageName].Name) + assert.Equal(t, packageStatuses.Packages[collectorPackageName].AgentHasHash, status.Packages[collectorPackageName].AgentHasHash) + assert.Equal(t, packageStatuses.Packages[collectorPackageName].AgentHasVersion, status.Packages[collectorPackageName].AgentHasVersion) + }) + // This will be called within the goroutine that is spun up from the main function. + mockOpAmpClient.On("SetPackageStatuses", mock.Anything).Return(nil).Run(func(args mock.Arguments) { + status := args.Get(0).(*protobufs.PackageStatuses) + assert.NotNil(t, status) + assert.Equal(t, "", status.ErrorMessage) + assert.Equal(t, packagesAvailableNew.AllPackagesHash, status.ServerProvidedAllPackagesHash) + assert.Equal(t, 1, len(status.Packages)) + assert.Equal(t, packagesAvailableNew.Packages[collectorPackageName].Version, status.Packages[collectorPackageName].ServerOfferedVersion) + assert.Equal(t, packagesAvailableNew.Packages[collectorPackageName].Hash, status.Packages[collectorPackageName].ServerOfferedHash) + assert.Equal(t, "Failed to download and verify downloadable file: oops", status.Packages[collectorPackageName].ErrorMessage) + assert.Equal(t, protobufs.PackageStatus_InstallFailed, status.Packages[collectorPackageName].Status) + assert.Equal(t, collectorPackageName, status.Packages[collectorPackageName].Name) + assert.Equal(t, packageStatuses.Packages[collectorPackageName].AgentHasHash, status.Packages[collectorPackageName].AgentHasHash) + assert.Equal(t, packageStatuses.Packages[collectorPackageName].AgentHasVersion, status.Packages[collectorPackageName].AgentHasVersion) + }) + + c := &Client{ + packagesStateProvider: mockProvider, + downloadableFileManager: mockFileManager, + opampClient: mockOpAmpClient, + logger: zap.NewNop(), + } + + err := c.onPackagesAvailableHandler(packagesAvailableNew) + assert.NoError(t, err) + assert.Eventually(t, func() bool { return c.safeGetUpdatingPackage() == false }, 10*time.Second, 10*time.Millisecond) + }, + }, + { + desc: "Same PackagesAvailable version but bad set last PackageStatuses", + testFunc: func(t *testing.T) { + mockProvider := mocks.NewMockPackagesStateProvider(t) + mockProvider.On("LastReportedStatuses").Return(packageStatuses, nil) + mockProvider.On("SetLastReportedStatuses", mock.Anything).Return(expectedErr).Run(func(args mock.Arguments) { + status := args.Get(0).(*protobufs.PackageStatuses) + + assert.NotNil(t, status) + assert.Equal(t, "", status.ErrorMessage) + assert.Equal(t, packagesAvailable.AllPackagesHash, status.ServerProvidedAllPackagesHash) + assert.Equal(t, 1, len(status.Packages)) + assert.Equal(t, packagesAvailable.Packages[collectorPackageName].Version, status.Packages[collectorPackageName].ServerOfferedVersion) + assert.Equal(t, packagesAvailable.Packages[collectorPackageName].Hash, status.Packages[collectorPackageName].ServerOfferedHash) + assert.Equal(t, "", status.Packages[collectorPackageName].ErrorMessage) + assert.Equal(t, protobufs.PackageStatus_Installed, status.Packages[collectorPackageName].Status) + assert.Equal(t, collectorPackageName, status.Packages[collectorPackageName].Name) + assert.Equal(t, packagesAvailable.Packages[collectorPackageName].Hash, status.Packages[collectorPackageName].AgentHasHash) + assert.Equal(t, packagesAvailable.Packages[collectorPackageName].Version, status.Packages[collectorPackageName].AgentHasVersion) + }) + mockOpAmpClient := mocks.NewMockOpAMPClient(t) + + c := &Client{ + packagesStateProvider: mockProvider, + opampClient: mockOpAmpClient, + logger: zap.NewNop(), + } + + err := c.onPackagesAvailableHandler(packagesAvailable) + assert.ErrorIs(t, err, expectedErr) + }, + }, + { + desc: "Same PackagesAvailable version but bad SEND PackageStatuses", + testFunc: func(t *testing.T) { + mockProvider := mocks.NewMockPackagesStateProvider(t) + mockProvider.On("LastReportedStatuses").Return(packageStatuses, nil) + mockProvider.On("SetLastReportedStatuses", mock.Anything).Return(nil) + mockOpAmpClient := mocks.NewMockOpAMPClient(t) + mockOpAmpClient.On("SetPackageStatuses", mock.Anything).Return(expectedErr).Run(func(args mock.Arguments) { + status := args.Get(0).(*protobufs.PackageStatuses) + + assert.NotNil(t, status) + assert.Equal(t, "", status.ErrorMessage) + assert.Equal(t, packagesAvailable.AllPackagesHash, status.ServerProvidedAllPackagesHash) + assert.Equal(t, 1, len(status.Packages)) + assert.Equal(t, packagesAvailable.Packages[collectorPackageName].Version, status.Packages[collectorPackageName].ServerOfferedVersion) + assert.Equal(t, packagesAvailable.Packages[collectorPackageName].Hash, status.Packages[collectorPackageName].ServerOfferedHash) + assert.Equal(t, "", status.Packages[collectorPackageName].ErrorMessage) + assert.Equal(t, protobufs.PackageStatus_Installed, status.Packages[collectorPackageName].Status) + assert.Equal(t, collectorPackageName, status.Packages[collectorPackageName].Name) + assert.Equal(t, packagesAvailable.Packages[collectorPackageName].Hash, status.Packages[collectorPackageName].AgentHasHash) + assert.Equal(t, packagesAvailable.Packages[collectorPackageName].Version, status.Packages[collectorPackageName].AgentHasVersion) + }) + + c := &Client{ + packagesStateProvider: mockProvider, + opampClient: mockOpAmpClient, + logger: zap.NewNop(), + } + + err := c.onPackagesAvailableHandler(packagesAvailable) + assert.ErrorIs(t, err, expectedErr) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, tc.testFunc) + } +} diff --git a/opamp/observiq/observiq_downloadable_file_manager.go b/opamp/observiq/observiq_downloadable_file_manager.go new file mode 100644 index 000000000..1b700d42c --- /dev/null +++ b/opamp/observiq/observiq_downloadable_file_manager.go @@ -0,0 +1,169 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package observiq + +import ( + "crypto/sha256" + "crypto/subtle" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "os" + "path/filepath" + + archiver "github.com/mholt/archiver/v3" + "github.com/observiq/observiq-otel-collector/opamp" + "github.com/open-telemetry/opamp-go/protobufs" + "go.uber.org/zap" +) + +const extractFolder = "latest" + +// Ensure interface is satisfied +var _ opamp.DownloadableFileManager = (*DownloadableFileManager)(nil) + +// DownloadableFileManager handles DownloadableFile's from a PackagesAvailable message +type DownloadableFileManager struct { + tmpPath string + logger *zap.Logger +} + +// newDownloadableFileManager creates a new OpAmp DownloadableFileManager +func newDownloadableFileManager(logger *zap.Logger, tmpPath string) *DownloadableFileManager { + return &DownloadableFileManager{ + tmpPath: filepath.Clean(tmpPath), + logger: logger, + } +} + +// FetchAndExtractArchive fetches the archive at the specified URL, placing it into dir. +// It then checks to see if it matches the "expectedHash", a hex-encoded string representing the expected sha256 sum of the file. +// If it matches, the archive is extracted into the $dir/latest directory. +// If the archive cannot be extracted, downloaded, or verified, then an error is returned. +func (m DownloadableFileManager) FetchAndExtractArchive(file *protobufs.DownloadableFile) error { + archiveFilePath, err := getOutputFilePath(m.tmpPath, file.GetDownloadUrl()) + if err != nil { + return fmt.Errorf("failed to determine archive download path: %w", err) + } + + if err := m.downloadFile(file.GetDownloadUrl(), archiveFilePath); err != nil { + return fmt.Errorf("failed to download file: %w", err) + } + + extractPath := filepath.Join(m.tmpPath, extractFolder) + + if err := m.verifyContentHash(archiveFilePath, file.GetContentHash()); err != nil { + return fmt.Errorf("content hash could not be verified: %w", err) + } + + // Clean the "latest" dir before extraction + if err := os.RemoveAll(extractPath); err != nil { + return fmt.Errorf("error cleaning archive extraction target path: %w", err) + } + + if err := archiver.Unarchive(archiveFilePath, extractPath); err != nil { + return fmt.Errorf("failed to extract file: %w", err) + } + + return nil +} + +// Downloads the file into the outPath, truncating the file if it already exists +func (m DownloadableFileManager) downloadFile(downloadURL string, outPath string) error { + //#nosec G107 HTTP request must be dynamic based on input + resp, err := http.Get(downloadURL) + if err != nil { + return fmt.Errorf("could not GET url: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("got non-200 status code (%d)", resp.StatusCode) + } + + outPathClean := filepath.Clean(outPath) + f, err := os.OpenFile(outPathClean, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0600) + if err != nil { + return fmt.Errorf("failed to open file: %w", err) + } + defer func() { + err := f.Close() + if err != nil { + m.logger.Warn("Failed to close file", zap.Error(err)) + } + }() + + if _, err = io.Copy(f, resp.Body); err != nil { + return fmt.Errorf("failed to copy request body to file: %w", err) + } + + return nil +} + +// getOutputFilePath gets the output path relative to the base dir for the archive from the given URL. +func getOutputFilePath(basePath, downloadURL string) (string, error) { + err := os.MkdirAll(basePath, 0700) + if err != nil { + return "", fmt.Errorf("problem with base url: %w", err) + } + + url, err := url.Parse(downloadURL) + if err != nil { + return "", fmt.Errorf("cannot parse url: %w", err) + } + + if url.Path == "" { + return "", errors.New("input url must have path") + } + + return filepath.Join(basePath, filepath.Base(url.Path)), nil +} + +func (m DownloadableFileManager) verifyContentHash(contentPath string, expectedFileHash []byte) error { + // Hash file at contentPath using sha256 + fileHash := sha256.New() + contentPathClean := filepath.Clean(contentPath) + + f, err := os.Open(contentPathClean) + if err != nil { + return fmt.Errorf("failed to open file: %w", err) + } + defer func() { + err := f.Close() + if err != nil { + m.logger.Warn("Failed to close file", zap.Error(err)) + } + }() + + if _, err = io.Copy(fileHash, f); err != nil { + return fmt.Errorf("failed to calculate file hash: %w", err) + } + + actualContentHash := fileHash.Sum(nil) + if subtle.ConstantTimeCompare(expectedFileHash, actualContentHash) == 0 { + return errors.New("file hash did not match expected") + } + + return nil +} + +// CleanupArtifacts removes previous installation artifacts by removing the temporary directory. +func (m DownloadableFileManager) CleanupArtifacts() { + if err := os.RemoveAll(m.tmpPath); err != nil { + m.logger.Error("Failed to remove temporary directory", zap.Error(err)) + } +} diff --git a/opamp/observiq/observiq_downloadable_file_manager_test.go b/opamp/observiq/observiq_downloadable_file_manager_test.go new file mode 100644 index 000000000..013f1aae3 --- /dev/null +++ b/opamp/observiq/observiq_downloadable_file_manager_test.go @@ -0,0 +1,349 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package observiq + +import ( + "bytes" + "encoding/hex" + "fmt" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "runtime" + "testing" + + "github.com/open-telemetry/opamp-go/protobufs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +func TestDownloadFile(t *testing.T) { + tmpDir := t.TempDir() + downloadableFileManager := newDownloadableFileManager(zap.NewNop(), tmpDir) + t.Run("Downloads File Over HTTP", func(t *testing.T) { + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Errorf("Invalid request method: %s", r.Method) + return + } + + w.Write([]byte("Hello")) + })) + defer s.Close() + + outPath := filepath.Join(tmpDir, "out.txt") + + err := downloadableFileManager.downloadFile(s.URL, outPath) + require.NoError(t, err) + + b, err := os.ReadFile(outPath) + require.NoError(t, err) + assert.Equal(t, []byte("Hello"), b) + }) + + t.Run("Output file is existing directory", func(t *testing.T) { + tmpDir := t.TempDir() + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Errorf("Invalid request method: %s", r.Method) + return + } + + w.Write([]byte("Hello")) + })) + defer s.Close() + + err := downloadableFileManager.downloadFile(s.URL, tmpDir) + require.ErrorContains(t, err, "failed to open file:") + }) + + t.Run("Invalid URL", func(t *testing.T) { + tmpDir := t.TempDir() + outPath := filepath.Join(tmpDir, "out.txt") + + err := downloadableFileManager.downloadFile("http://localhost:9999999", outPath) + require.ErrorContains(t, err, "could not GET url") + }) + + t.Run("Server returns 404", func(t *testing.T) { + tmpDir := t.TempDir() + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer s.Close() + + outPath := filepath.Join(tmpDir, "out.txt") + + err := downloadableFileManager.downloadFile(s.URL, outPath) + require.ErrorContains(t, err, "got non-200 status code (404)") + }) +} + +func TestGetOutputFilePath(t *testing.T) { + testCases := []struct { + name string + basepath string + url string + out string + expectedErr string + }{ + { + name: "Input url is valid zip", + basepath: filepath.Join("/", "tmp", "observiq-otel-collector-update"), + url: "http://example.com/some-file.zip", + out: filepath.Join("/", "tmp", "observiq-otel-collector-update", "some-file.zip"), + }, + { + name: "Input url is valid tar", + basepath: filepath.Join("/", "tmp", "observiq-otel-collector-update"), + url: "http://example.com/some-file.tar.gz", + out: filepath.Join("/", "tmp", "observiq-otel-collector-update", "some-file.tar.gz"), + }, + { + name: "Input url is invalid", + basepath: filepath.Join("/", "tmp", "observiq-otel-collector-update"), + url: "http://local\thost/some-file.zip", + expectedErr: "cannot parse url", + }, + { + name: "Input url has no path", + basepath: filepath.Join("/", "tmp", "observiq-otel-collector-update"), + url: "http://example.com", + expectedErr: "input url must have path", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + out, err := getOutputFilePath(tc.basepath, tc.url) + if tc.expectedErr == "" { + require.NoError(t, err) + require.Equal(t, tc.out, out) + } else { + require.ErrorContains(t, err, tc.expectedErr) + } + }) + } +} + +func TestVerifyContentHash(t *testing.T) { + tmpDir := t.TempDir() + downloadableFileManager := newDownloadableFileManager(zap.NewNop(), tmpDir) + + hash1, _ := hex.DecodeString("c87e2ca771bab6024c269b933389d2a92d4941c848c52f155b9b84e1f109fe35") + hash2, _ := hex.DecodeString("7e4ead2053637d9fcb7f3316e748becb8af163c6f851446eeef878a994ae5c4b") + testCases := []struct { + name string + contentPath string + hash []byte + expectedErr string + }{ + { + name: "Content hash matches", + contentPath: filepath.Join("testdata", "test.txt"), + hash: hash1, + }, + { + name: "File does not exist", + contentPath: filepath.Join("testdata", "non-existant-file.txt"), + hash: hash1, + expectedErr: "failed to open file", + }, + { + name: "Content hash does not match", + contentPath: filepath.Join("testdata", "test.txt"), + hash: hash2, + expectedErr: "file hash did not match expected", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + _, statErr := os.Stat(tc.contentPath) + if runtime.GOOS == "windows" && statErr == nil { + // Cloning the repo on windows changes the line endings depending on git configuration. + // We need to thwart that mechanism. + // Make sure test.txt exists in the output dir + tmpDir := t.TempDir() + fileBytes, err := os.ReadFile(tc.contentPath) + require.NoError(t, err) + + // Replace \r\n with \n so tests pass on windows systems + newlinesOnly := bytes.ReplaceAll(fileBytes, []byte("\r\n"), []byte("\n")) + + // Change content path to new file, and write it. + tc.contentPath = filepath.Join(tmpDir, filepath.Base(tc.contentPath)) + err = os.WriteFile(tc.contentPath, newlinesOnly, 0666) + require.NoError(t, err) + + } + err := downloadableFileManager.verifyContentHash(tc.contentPath, tc.hash) + if tc.expectedErr == "" { + require.NoError(t, err) + } else { + require.ErrorContains(t, err, tc.expectedErr) + } + }) + } +} + +func TestDownloadAndVerifyExtraction(t *testing.T) { + hash1, _ := hex.DecodeString("d3bf2375be7372b34eae9bc16296ce9e40e53f5b79b329e23056c4aaf77eb47c") + hash2, _ := hex.DecodeString("5594349d022f7f374fa3ee777ded15f4f06a47aa08eec300bd06cdb0d2688fac") + hash3, _ := hex.DecodeString("e7045ebfc48a850a8ac2d342c172099f8c937a4265c55cd93cb39908278952b4") + testCases := []struct { + name string + archivePath string + expectedHash []byte + expectedErr string + }{ + { + name: "Download and extracts tar.gz files", + archivePath: filepath.Join("testdata", "test.tar.gz"), + expectedHash: hash1, + }, + { + name: "Download and extracts zip files", + archivePath: filepath.Join("testdata", "test.zip"), + expectedHash: hash2, + }, + { + name: "Fails to extract non-archive", + archivePath: filepath.Join("testdata", "not-actually-tar.tar.gz"), + expectedHash: hash3, + expectedErr: "failed to extract file", + }, + { + name: "Hash does not match downloaded hash", + archivePath: filepath.Join("testdata", "test.tar.gz"), + expectedHash: hash3, + expectedErr: "content hash could not be verified", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tmpDir := t.TempDir() + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + archiveBytes, err := os.ReadFile(tc.archivePath) + if err != nil { + t.Errorf("Failed to open archive for sending over http: %s", err) + } + + if filepath.Base(tc.archivePath) == "not-actually-tar.tar.gz" { + // This file is a text file, and git actually detects that and replaces line endings on windows + // Replace \r\n with \n so tests pass on windows systems + archiveBytes = bytes.ReplaceAll(archiveBytes, []byte("\r\n"), []byte("\n")) + } + + _, err = w.Write(archiveBytes) + if err != nil { + t.Errorf("Failed to copy archive for sending over http: %s", err) + } + })) + defer s.Close() + + file := &protobufs.DownloadableFile{ + DownloadUrl: fmt.Sprintf("%s/%s", s.URL, tc.archivePath), + ContentHash: []byte(tc.expectedHash), + } + + downloadableFileManager := newDownloadableFileManager(zap.NewNop(), tmpDir) + err := downloadableFileManager.FetchAndExtractArchive(file) + if tc.expectedErr == "" { + require.NoError(t, err) + + // Make sure test.txt exists in the output dir + expectedBytes, err := os.ReadFile(filepath.Join("testdata", "test.txt")) + require.NoError(t, err) + + // Replace \r\n with \n so tests pass on windows systems + expectedBytes = bytes.ReplaceAll(expectedBytes, []byte("\r\n"), []byte("\n")) + + actualBytes, err := os.ReadFile(filepath.Join(tmpDir, extractFolder, "test.txt")) + require.NoError(t, err) + + require.Equal(t, expectedBytes, actualBytes) + } else { + require.ErrorContains(t, err, tc.expectedErr) + } + }) + } +} + +func TestDownloadAndVerifyHTTPFailure(t *testing.T) { + tmpDir := t.TempDir() + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusForbidden) + })) + defer s.Close() + + file := &protobufs.DownloadableFile{ + DownloadUrl: fmt.Sprintf("%s/%s", s.URL, "some-archive.tar.gz"), + ContentHash: []byte{}, + } + + downloadableFileManager := newDownloadableFileManager(zap.NewNop(), tmpDir) + err := downloadableFileManager.FetchAndExtractArchive(file) + require.ErrorContains(t, err, "failed to download file:") +} + +func TestDownloadAndVerifyInvalidURL(t *testing.T) { + tmpDir := t.TempDir() + + file := &protobufs.DownloadableFile{ + DownloadUrl: "http://\t/some-archive.tar.gz", + ContentHash: []byte{}, + } + + downloadableFileManager := newDownloadableFileManager(zap.NewNop(), tmpDir) + err := downloadableFileManager.FetchAndExtractArchive(file) + require.ErrorContains(t, err, "failed to determine archive download path:") +} + +func TestCleanupArtifacts(t *testing.T) { + t.Run("Cleans up tmp dir if exists", func(t *testing.T) { + tmpDir := filepath.Join(t.TempDir(), "tmp") + + // Try to download -- this should create tmpDir, but fail to download + downloadableFileManager := newDownloadableFileManager(zap.NewNop(), tmpDir) + err := downloadableFileManager.FetchAndExtractArchive(&protobufs.DownloadableFile{ + DownloadUrl: "http://invalid-host:0/some-file.zip", + }) + + require.ErrorContains(t, err, "failed to download file") + require.DirExists(t, tmpDir) + + downloadableFileManager.CleanupArtifacts() + require.NoDirExists(t, tmpDir) + }) + + t.Run("Does nothing if tmp dir does not exist", func(t *testing.T) { + tmpDir := filepath.Join(t.TempDir(), "tmp") + downloadableFileManager := newDownloadableFileManager(zap.NewNop(), tmpDir) + + require.NoDirExists(t, tmpDir) + + downloadableFileManager.CleanupArtifacts() + require.NoDirExists(t, tmpDir) + }) +} diff --git a/opamp/observiq/observiq_packages_state_provider.go b/opamp/observiq/observiq_packages_state_provider.go new file mode 100644 index 000000000..e6a466052 --- /dev/null +++ b/opamp/observiq/observiq_packages_state_provider.go @@ -0,0 +1,152 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package observiq contains OpAmp structures compatible with the observiq client +package observiq + +import ( + "context" + "errors" + "fmt" + "io" + "os" + + "github.com/observiq/observiq-otel-collector/internal/version" + "github.com/observiq/observiq-otel-collector/packagestate" + "github.com/open-telemetry/opamp-go/client/types" + "github.com/open-telemetry/opamp-go/protobufs" + "go.uber.org/zap" +) + +// Ensure interface is satisfied +var _ types.PackagesStateProvider = (*packagesStateProvider)(nil) + +// packagesStateProvider represents a PackagesStateProvider which uses a PackageStateManager to persist PackageStatuses +type packagesStateProvider struct { + packageStateManager packagestate.StateManager + logger *zap.Logger +} + +// newPackagesStateProvider creates a new OpAmp PackagesStateProvider +func newPackagesStateProvider(logger *zap.Logger, jsonPath string) types.PackagesStateProvider { + return &packagesStateProvider{ + packageStateManager: packagestate.NewFileStateManager(logger, jsonPath), + logger: logger, + } +} + +// AllPackagesHash not implemented so returns an error with this info +func (p *packagesStateProvider) AllPackagesHash() ([]byte, error) { + p.logger.Debug("Retrieve all packages hash") + + return nil, errors.New("method not implemented: PackageStateProvider AllPackagesHash") +} + +// SetAllPackagesHash not implemented so returns an error with this info +func (p *packagesStateProvider) SetAllPackagesHash(_ []byte) error { + p.logger.Debug("Set all packages hash") + + return errors.New("method not implemented: PackageStateProvider SetAllPackagesHash") +} + +// Packages not implemented so returns an error with this info +func (p *packagesStateProvider) Packages() ([]string, error) { + p.logger.Debug("Retrieve package names") + + return nil, errors.New("method not implemented: PackageStateProvider Packages") +} + +// PackageState not implemented so returns an error with this info +func (p *packagesStateProvider) PackageState(_ string) (state types.PackageState, err error) { + p.logger.Debug("Retrieve package state") + + packageState := types.PackageState{} + + return packageState, errors.New("method not implemented: PackageStateProvider PackageState") +} + +// SetPackageState not implemented so returns an error with this info +func (p *packagesStateProvider) SetPackageState(_ string, _ types.PackageState) error { + p.logger.Debug("Set package state") + + return errors.New("method not implemented: PackageStateProvider SetPackageState") +} + +// CreatePackage not implemented so returns an error with this info +func (p *packagesStateProvider) CreatePackage(_ string, _ protobufs.PackageAvailable_PackageType) error { + p.logger.Debug("Create package") + + return errors.New("method not implemented: PackageStateProvider CreatePackage") +} + +// FileContentHash not implemented so returns an error with this info +func (p *packagesStateProvider) FileContentHash(_ string) ([]byte, error) { + p.logger.Debug("Retrieve package content hash") + + return nil, errors.New("method not implemented: PackageStateProvider FileContentHash") +} + +// UpdateContent not implemented so returns an error with this info +func (p *packagesStateProvider) UpdateContent(_ context.Context, _ string, _ io.Reader, _ []byte) error { + p.logger.Debug("Update package content") + + return errors.New("method not implemented: PackageStateProvider UpdateContent") +} + +// DeletePackage not implemented so returns an error with this info +func (p *packagesStateProvider) DeletePackage(_ string) error { + p.logger.Debug("Delete package") + + return errors.New("method not implemented: PackageStateProvider DeletePackage") +} + +// LastReportedStatuses retrieves the PackagesStatuses from a saved json file +func (p *packagesStateProvider) LastReportedStatuses() (*protobufs.PackageStatuses, error) { + p.logger.Debug("Retrieve last reported package statuses") + + packages := map[string]*protobufs.PackageStatus{ + packagestate.CollectorPackageName: { + Name: packagestate.CollectorPackageName, + AgentHasVersion: version.Version(), + Status: protobufs.PackageStatus_Installed, + }, + } + packageStatuses := &protobufs.PackageStatuses{ + Packages: packages, + } + + loadedStatues, err := p.packageStateManager.LoadStatuses() + + switch { + // No File exists so return the status we constructed + case errors.Is(err, os.ErrNotExist): + p.logger.Debug("Package statuses json doesn't exist") + return packageStatuses, nil + + // File existed but error while parsing it + case err != nil: + return packageStatuses, fmt.Errorf("failed loading package statuses: %w", err) + + // Successful load + default: + return loadedStatues, nil + } +} + +// SetLastReportedStatuses saves the given PackageStatuses into a json file +func (p *packagesStateProvider) SetLastReportedStatuses(statuses *protobufs.PackageStatuses) error { + p.logger.Debug("Set last reported package statuses") + + return p.packageStateManager.SaveStatuses(statuses) +} diff --git a/opamp/observiq/observiq_packages_state_provider_test.go b/opamp/observiq/observiq_packages_state_provider_test.go new file mode 100644 index 000000000..ad613a99b --- /dev/null +++ b/opamp/observiq/observiq_packages_state_provider_test.go @@ -0,0 +1,469 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package observiq + +import ( + "context" + "errors" + "io" + "os" + "testing" + + "github.com/observiq/observiq-otel-collector/internal/version" + "github.com/observiq/observiq-otel-collector/packagestate" + "github.com/observiq/observiq-otel-collector/packagestate/mocks" + "github.com/open-telemetry/opamp-go/client/types" + "github.com/open-telemetry/opamp-go/protobufs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +func TestNewPackagesStateProvider(t *testing.T) { + testCases := []struct { + desc string + testFunc func(*testing.T) + }{ + { + desc: "New PackagesStateProvider", + testFunc: func(t *testing.T) { + logger := zap.NewNop() + actual := newPackagesStateProvider(logger, "test.json") + + packagesStateProvider, ok := actual.(*packagesStateProvider) + require.True(t, ok) + + assert.Equal(t, logger, packagesStateProvider.logger) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, tc.testFunc) + } +} + +func TestAllPackagesHash(t *testing.T) { + testCases := []struct { + desc string + testFunc func(*testing.T) + }{ + { + desc: "Not Implemented", + testFunc: func(t *testing.T) { + logger := zap.NewNop() + p := &packagesStateProvider{ + logger: logger, + } + + actual, err := p.AllPackagesHash() + + assert.Nil(t, actual) + assert.ErrorContains(t, err, "method not implemented") + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, tc.testFunc) + } +} + +func TestSetAllPackagesHash(t *testing.T) { + testCases := []struct { + desc string + testFunc func(*testing.T) + }{ + { + desc: "Not Implemented", + testFunc: func(t *testing.T) { + logger := zap.NewNop() + p := &packagesStateProvider{ + logger: logger, + } + + err := p.SetAllPackagesHash([]byte("hash")) + + assert.ErrorContains(t, err, "method not implemented") + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, tc.testFunc) + } +} + +func TestPackages(t *testing.T) { + testCases := []struct { + desc string + testFunc func(*testing.T) + }{ + { + desc: "Not Implemented", + testFunc: func(t *testing.T) { + logger := zap.NewNop() + p := &packagesStateProvider{ + logger: logger, + } + + actual, err := p.Packages() + + assert.Nil(t, actual) + assert.ErrorContains(t, err, "method not implemented") + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, tc.testFunc) + } +} + +func TestPackageState(t *testing.T) { + testCases := []struct { + desc string + testFunc func(*testing.T) + }{ + { + desc: "Not Implemented", + testFunc: func(t *testing.T) { + logger := zap.NewNop() + p := &packagesStateProvider{ + logger: logger, + } + + actual, err := p.PackageState("name") + + assert.Equal(t, types.PackageState{}, actual) + assert.ErrorContains(t, err, "method not implemented") + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, tc.testFunc) + } +} + +func TestSetPackageState(t *testing.T) { + testCases := []struct { + desc string + testFunc func(*testing.T) + }{ + { + desc: "Not Implemented", + testFunc: func(t *testing.T) { + logger := zap.NewNop() + p := &packagesStateProvider{ + logger: logger, + } + + err := p.SetPackageState("name", types.PackageState{}) + + assert.ErrorContains(t, err, "method not implemented") + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, tc.testFunc) + } +} + +func TestCreatePackage(t *testing.T) { + testCases := []struct { + desc string + testFunc func(*testing.T) + }{ + { + desc: "Not Implemented", + testFunc: func(t *testing.T) { + logger := zap.NewNop() + p := &packagesStateProvider{ + logger: logger, + } + + err := p.CreatePackage("name", protobufs.PackageAvailable_TopLevelPackage) + + assert.ErrorContains(t, err, "method not implemented") + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, tc.testFunc) + } +} + +func TestFileContentHash(t *testing.T) { + testCases := []struct { + desc string + testFunc func(*testing.T) + }{ + { + desc: "Not Implemented", + testFunc: func(t *testing.T) { + logger := zap.NewNop() + p := &packagesStateProvider{ + logger: logger, + } + + actual, err := p.FileContentHash("name") + + assert.Nil(t, actual) + assert.ErrorContains(t, err, "method not implemented") + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, tc.testFunc) + } +} + +func TestUpdateContent(t *testing.T) { + testCases := []struct { + desc string + testFunc func(*testing.T) + }{ + { + desc: "Not Implemented", + testFunc: func(t *testing.T) { + logger := zap.NewNop() + p := &packagesStateProvider{ + logger: logger, + } + var r io.Reader + + err := p.UpdateContent(context.TODO(), "name", r, []byte("hash")) + + assert.ErrorContains(t, err, "method not implemented") + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, tc.testFunc) + } +} + +func TestDeletePackage(t *testing.T) { + testCases := []struct { + desc string + testFunc func(*testing.T) + }{ + { + desc: "Not Implemented", + testFunc: func(t *testing.T) { + logger := zap.NewNop() + p := &packagesStateProvider{ + logger: logger, + } + + err := p.DeletePackage("name") + + assert.ErrorContains(t, err, "method not implemented") + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, tc.testFunc) + } +} + +func TestLastReportedStatuses(t *testing.T) { + pkgName := packagestate.CollectorPackageName + testCases := []struct { + desc string + testFunc func(*testing.T) + }{ + { + desc: "PackageStateManager returns error for missing file", + testFunc: func(t *testing.T) { + mockManager := mocks.NewMockStateManager(t) + mockManager.On("LoadStatuses").Return(nil, os.ErrNotExist) + + p := &packagesStateProvider{ + packageStateManager: mockManager, + logger: zap.NewNop(), + } + + actual, err := p.LastReportedStatuses() + + assert.NoError(t, err) + assert.Nil(t, actual.ServerProvidedAllPackagesHash) + assert.Equal(t, "", actual.ErrorMessage) + assert.Equal(t, 1, len(actual.Packages)) + assert.Equal(t, pkgName, actual.Packages[pkgName].GetName()) + assert.Equal(t, version.Version(), actual.Packages[pkgName].GetAgentHasVersion()) + assert.Nil(t, actual.Packages[pkgName].GetAgentHasHash()) + assert.Equal(t, "", actual.Packages[pkgName].GetServerOfferedVersion()) + assert.Nil(t, actual.Packages[pkgName].GetServerOfferedHash()) + assert.Equal(t, protobufs.PackageStatus_Installed, actual.Packages[pkgName].GetStatus()) + assert.Equal(t, "", actual.Packages[pkgName].GetErrorMessage()) + }, + }, + { + desc: "Load Error", + testFunc: func(t *testing.T) { + expectedErr := errors.New("bad") + mockManager := mocks.NewMockStateManager(t) + mockManager.On("LoadStatuses").Return(nil, expectedErr) + + p := &packagesStateProvider{ + packageStateManager: mockManager, + logger: zap.NewNop(), + } + + actual, err := p.LastReportedStatuses() + + assert.ErrorIs(t, err, expectedErr) + assert.Nil(t, actual.ServerProvidedAllPackagesHash) + assert.Equal(t, "", actual.ErrorMessage) + assert.Equal(t, 1, len(actual.Packages)) + assert.Equal(t, pkgName, actual.Packages[pkgName].GetName()) + assert.Equal(t, version.Version(), actual.Packages[pkgName].GetAgentHasVersion()) + assert.Nil(t, actual.Packages[pkgName].GetAgentHasHash()) + assert.Equal(t, "", actual.Packages[pkgName].GetServerOfferedVersion()) + assert.Nil(t, actual.Packages[pkgName].GetServerOfferedHash()) + assert.Equal(t, protobufs.PackageStatus_Installed, actual.Packages[pkgName].GetStatus()) + assert.Equal(t, "", actual.Packages[pkgName].GetErrorMessage()) + }, + }, + { + desc: "Successful file read", + testFunc: func(t *testing.T) { + expected := &protobufs.PackageStatuses{ + Packages: map[string]*protobufs.PackageStatus{ + "package": { + Name: "package", + AgentHasVersion: "1.0", + AgentHasHash: []byte("hash1"), + ServerOfferedVersion: "2.0", + ServerOfferedHash: []byte("hash2"), + Status: protobufs.PackageStatus_InstallPending, + ErrorMessage: "bad", + }, + }, + ServerProvidedAllPackagesHash: []byte("hash"), + ErrorMessage: "whoops", + } + + mockManager := mocks.NewMockStateManager(t) + mockManager.On("LoadStatuses").Return(expected, nil) + + p := &packagesStateProvider{ + packageStateManager: mockManager, + logger: zap.NewNop(), + } + + actual, err := p.LastReportedStatuses() + + assert.NoError(t, err) + assert.Equal(t, expected, actual) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, tc.testFunc) + } +} + +func TestSetLastReportedStatuses(t *testing.T) { + pkgName := "package" + agentVersion := "1.0" + agentHash := []byte("hash1") + serverVersion := "2.0" + serverHash := []byte("hash2") + errMsg := "bad" + allHash := []byte("hash") + allErrMsg := "whoops" + testCases := []struct { + desc string + testFunc func(*testing.T) + }{ + { + desc: "PackageStateManager Returns error", + testFunc: func(t *testing.T) { + expectedErr := errors.New("bad") + + packageStatuses := &protobufs.PackageStatuses{ + ServerProvidedAllPackagesHash: allHash, + ErrorMessage: allErrMsg, + Packages: map[string]*protobufs.PackageStatus{ + pkgName: { + Name: pkgName, + AgentHasVersion: agentVersion, + AgentHasHash: agentHash, + ServerOfferedVersion: serverVersion, + ServerOfferedHash: serverHash, + Status: protobufs.PackageStatus_InstallPending, + ErrorMessage: errMsg, + }, + }, + } + + mockManager := mocks.NewMockStateManager(t) + mockManager.On("SaveStatuses", packageStatuses).Return(expectedErr) + + p := &packagesStateProvider{ + packageStateManager: mockManager, + logger: zap.NewNop(), + } + + err := p.SetLastReportedStatuses(packageStatuses) + assert.ErrorIs(t, err, expectedErr) + }, + }, + { + desc: "PackageStateManager No error", + testFunc: func(t *testing.T) { + packageStatuses := &protobufs.PackageStatuses{ + ServerProvidedAllPackagesHash: allHash, + ErrorMessage: allErrMsg, + Packages: map[string]*protobufs.PackageStatus{ + pkgName: { + Name: pkgName, + AgentHasVersion: agentVersion, + AgentHasHash: agentHash, + ServerOfferedVersion: serverVersion, + ServerOfferedHash: serverHash, + Status: protobufs.PackageStatus_InstallPending, + ErrorMessage: errMsg, + }, + }, + } + + mockManager := mocks.NewMockStateManager(t) + mockManager.On("SaveStatuses", packageStatuses).Return(nil) + + p := &packagesStateProvider{ + packageStateManager: mockManager, + logger: zap.NewNop(), + } + + err := p.SetLastReportedStatuses(packageStatuses) + assert.NoError(t, err) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, tc.testFunc) + } +} diff --git a/opamp/observiq/testdata/latest/badupdater b/opamp/observiq/testdata/latest/badupdater new file mode 100755 index 000000000..e69de29bb diff --git a/opamp/observiq/testdata/latest/quickupdater b/opamp/observiq/testdata/latest/quickupdater new file mode 100755 index 000000000..1fe3a3469 --- /dev/null +++ b/opamp/observiq/testdata/latest/quickupdater @@ -0,0 +1,3 @@ +#!/bin/bash + +sleep 1 diff --git a/opamp/observiq/testdata/latest/quickupdater.exe b/opamp/observiq/testdata/latest/quickupdater.exe new file mode 100755 index 000000000..c69a66e54 Binary files /dev/null and b/opamp/observiq/testdata/latest/quickupdater.exe differ diff --git a/opamp/observiq/testdata/latest/quickupdater.test b/opamp/observiq/testdata/latest/quickupdater.test new file mode 100644 index 000000000..74708a296 --- /dev/null +++ b/opamp/observiq/testdata/latest/quickupdater.test @@ -0,0 +1,23 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// For generation of windows quickupdater.exe +// Change extension to .go and run `env GOOS=windows GOARCH=amd64 go build quickupdater.go` to build new quickupdater.exe +package main + +import "time" + +func main() { + time.Sleep(1 * time.Second) +} diff --git a/opamp/observiq/testdata/latest/slowupdater b/opamp/observiq/testdata/latest/slowupdater new file mode 100755 index 000000000..8f22dbcde --- /dev/null +++ b/opamp/observiq/testdata/latest/slowupdater @@ -0,0 +1,3 @@ +#!/bin/bash + +sleep 10 diff --git a/opamp/observiq/testdata/latest/slowupdater.exe b/opamp/observiq/testdata/latest/slowupdater.exe new file mode 100755 index 000000000..2a56bddce Binary files /dev/null and b/opamp/observiq/testdata/latest/slowupdater.exe differ diff --git a/opamp/observiq/testdata/latest/slowupdater.test b/opamp/observiq/testdata/latest/slowupdater.test new file mode 100644 index 000000000..d30819016 --- /dev/null +++ b/opamp/observiq/testdata/latest/slowupdater.test @@ -0,0 +1,23 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// For generation of windows slowupdater.exe +// Change extension to .go and run `env GOOS=windows GOARCH=amd64 go build slowupdater.go` to build new slowupdater.exe +package main + +import "time" + +func main() { + time.Sleep(10 * time.Second) +} diff --git a/opamp/observiq/testdata/not-actually-tar.tar.gz b/opamp/observiq/testdata/not-actually-tar.tar.gz new file mode 100644 index 000000000..a73ee7294 --- /dev/null +++ b/opamp/observiq/testdata/not-actually-tar.tar.gz @@ -0,0 +1 @@ +This is a test file with a .tar.gz extension diff --git a/opamp/observiq/testdata/test.tar.gz b/opamp/observiq/testdata/test.tar.gz new file mode 100644 index 000000000..90484c341 Binary files /dev/null and b/opamp/observiq/testdata/test.tar.gz differ diff --git a/opamp/observiq/testdata/test.txt b/opamp/observiq/testdata/test.txt new file mode 100644 index 000000000..9f4b6d8bf --- /dev/null +++ b/opamp/observiq/testdata/test.txt @@ -0,0 +1 @@ +This is a test file diff --git a/opamp/observiq/testdata/test.zip b/opamp/observiq/testdata/test.zip new file mode 100644 index 000000000..2d794f018 Binary files /dev/null and b/opamp/observiq/testdata/test.zip differ diff --git a/opamp/observiq/updater_manager.go b/opamp/observiq/updater_manager.go new file mode 100644 index 000000000..52c799c10 --- /dev/null +++ b/opamp/observiq/updater_manager.go @@ -0,0 +1,82 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package observiq + +import ( + "fmt" + "io" + "os" + "path/filepath" + "time" + + "go.uber.org/zap" +) + +const updaterDir = "latest" +const defaultShutdownWaitTimeout = 30 * time.Second + +// updaterManager handles working with the Updater binary +type updaterManager interface { + // StartAndMonitorUpdater starts the Updater binary and monitors it for failure + StartAndMonitorUpdater() error +} + +// copyExecutable copies the executable at the input file path to the cwd. +// Returns the output path of the executable. +func copyExecutable(logger *zap.Logger, inputPath, cwd string) (string, error) { + inputPathClean := filepath.Clean(inputPath) + + inputFile, err := os.Open(inputPathClean) + if err != nil { + return "", fmt.Errorf("failed to open updater binary for reading: %w", err) + } + defer func() { + if err := inputFile.Close(); err != nil { + logger.Error("Failed to close input file", zap.Error(err)) + } + }() + + // Output path is just whatever the actual file name is (e.g. updater.exe), + // on top of the CWD. We take the absolute path, because it is needed to actually ensure you can + // exec a file not on your PATH. + outputPath, err := filepath.Abs(filepath.Join(cwd, filepath.Base(inputPath))) + if err != nil { + return "", fmt.Errorf("failed to get absolute path for output: %w", err) + } + + outputPathClean := filepath.Clean(outputPath) + + // Remove the file if it already exists, need this for macOS + if err := os.RemoveAll(outputPathClean); err != nil { + return "", fmt.Errorf("failed to remove any existing executable: %w", err) + } + + //#nosec G302 - 0700 instead of 0600 since the executable bit needs to be flipped + outputFile, err := os.OpenFile(outputPathClean, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, 0700) + if err != nil { + return "", fmt.Errorf("failed to open output file: %w", err) + } + defer func() { + if err := outputFile.Close(); err != nil { + logger.Error("Failed to close output file", zap.Error(err)) + } + }() + + if _, err := io.Copy(outputFile, inputFile); err != nil { + return "", fmt.Errorf("failed to copy executable to output: %w", err) + } + + return outputPathClean, nil +} diff --git a/opamp/observiq/updater_manager_others.go b/opamp/observiq/updater_manager_others.go new file mode 100644 index 000000000..7fdb36831 --- /dev/null +++ b/opamp/observiq/updater_manager_others.go @@ -0,0 +1,95 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !windows + +package observiq + +import ( + "errors" + "fmt" + "os" + "os/exec" + "path/filepath" + "syscall" + "time" + + "go.uber.org/zap" +) + +// Ensure interface is satisfied +var _ updaterManager = (*othersUpdaterManager)(nil) + +const defaultOthersUpdaterName = "updater" + +// othersUpdaterManager handles starting a Updater binary and watching it for failure with a timeout +type othersUpdaterManager struct { + tmpPath string + cwd string + updaterName string + logger *zap.Logger + shutdownWaitTimeout time.Duration +} + +// newUpdaterManager creates a new UpdaterManager +func newUpdaterManager(defaultLogger *zap.Logger, tmpPath string) (updaterManager, error) { + cwd, err := os.Getwd() + if err != nil { + return nil, fmt.Errorf("failed to get cwd: %w", err) + } + + return &othersUpdaterManager{ + tmpPath: filepath.Clean(tmpPath), + logger: defaultLogger.Named("updater manager"), + updaterName: defaultOthersUpdaterName, + cwd: cwd, + shutdownWaitTimeout: defaultShutdownWaitTimeout, + }, nil +} + +// StartAndMonitorUpdater will start the Updater binary and wait to see if it finishes unexpectedly. +// While waiting for Updater, it should kill the collector and we should never execute any code past running it +func (m othersUpdaterManager) StartAndMonitorUpdater() error { + initialUpdaterPath := filepath.Join(m.tmpPath, updaterDir, m.updaterName) + updaterPath, err := copyExecutable(m.logger.Named("copy-executable"), initialUpdaterPath, m.cwd) + if err != nil { + return fmt.Errorf("failed to copy updater to cwd: %w", err) + } + + //#nosec G204 -- paths are not determined via user input + cmd := exec.Command(updaterPath) + + // We need to set the processor group id to something different so that at least on mac, when the + // collector dies the updater won't die as well + cmd.SysProcAttr = &syscall.SysProcAttr{ + Setpgid: true, + Pgid: 0, + } + // Start does not block + if err := cmd.Start(); err != nil { + return fmt.Errorf("updater had an issue while starting: %w", err) + } + + // See if we're still alive after waiting for the timeout to pass + time.Sleep(m.shutdownWaitTimeout) + + // Updater might already be killed + if err := cmd.Process.Kill(); err != nil { + m.logger.Debug("Failed to kill failed Updater", zap.Error(err)) + } + + // Ideally we should not get here as we will be killed by the updater. + // Updater should either exit before us with error or we die before it does. + return errors.New("updater failed to update collector") +} diff --git a/opamp/observiq/updater_manager_others_test.go b/opamp/observiq/updater_manager_others_test.go new file mode 100644 index 000000000..7276c8007 --- /dev/null +++ b/opamp/observiq/updater_manager_others_test.go @@ -0,0 +1,148 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !windows + +package observiq + +import ( + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +func TestNewOthersUpdaterManager(t *testing.T) { + testCases := []struct { + desc string + testFunc func(*testing.T) + }{ + { + desc: "New LinuxUpdaterManager", + testFunc: func(t *testing.T) { + tmpPath := "/tmp" + logger := zap.NewNop() + cwd, err := os.Getwd() + require.NoError(t, err) + + expected := &othersUpdaterManager{ + tmpPath: tmpPath, + logger: logger.Named("updater manager"), + updaterName: "updater", + cwd: cwd, + shutdownWaitTimeout: 30 * time.Second, + } + + actual, err := newUpdaterManager(logger, tmpPath) + require.NoError(t, err) + require.Equal(t, expected, actual) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, tc.testFunc) + } +} + +// We don't have a good way to unit test the happy path, +// which involves the entire collector being killed in the middle of this function +func TestStartAndMonitorUpdater(t *testing.T) { + testCases := []struct { + desc string + testFunc func(*testing.T) + }{ + { + desc: "Updater does not exist at path", + testFunc: func(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + updateManager, err := newUpdaterManager(zap.NewNop(), tmpDir) + require.NoError(t, err) + + updateManager.(*othersUpdaterManager).cwd = tmpDir + updateManager.(*othersUpdaterManager).shutdownWaitTimeout = 5 * time.Second + + err = updateManager.StartAndMonitorUpdater() + + assert.ErrorContains(t, err, "no such file or directory") + }, + }, + { + desc: "Updater is not executable", + testFunc: func(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + + updateManager, err := newUpdaterManager(zap.NewNop(), "./testdata") + require.NoError(t, err) + + updateManager.(*othersUpdaterManager).cwd = tmpDir + updateManager.(*othersUpdaterManager).updaterName = "badupdater" + updateManager.(*othersUpdaterManager).shutdownWaitTimeout = 5 * time.Second + + err = updateManager.StartAndMonitorUpdater() + + assert.ErrorContains(t, err, "updater had an issue while starting:") + }, + }, + { + desc: "Updater exits quickly", + testFunc: func(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + + updateManager, err := newUpdaterManager(zap.NewNop(), "./testdata") + require.NoError(t, err) + + updateManager.(*othersUpdaterManager).cwd = tmpDir + updateManager.(*othersUpdaterManager).updaterName = "quickupdater" + updateManager.(*othersUpdaterManager).shutdownWaitTimeout = 5 * time.Second + + err = updateManager.StartAndMonitorUpdater() + + assert.EqualError(t, err, "updater failed to update collector") + }, + }, + { + desc: "Updater times out", + testFunc: func(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + + updateManager, err := newUpdaterManager(zap.NewNop(), "./testdata") + require.NoError(t, err) + + updateManager.(*othersUpdaterManager).cwd = tmpDir + updateManager.(*othersUpdaterManager).updaterName = "slowupdater" + updateManager.(*othersUpdaterManager).shutdownWaitTimeout = 5 * time.Second + + err = updateManager.StartAndMonitorUpdater() + + assert.ErrorContains(t, err, "updater failed to update collector") + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, tc.testFunc) + } +} diff --git a/opamp/observiq/updater_manager_windows.go b/opamp/observiq/updater_manager_windows.go new file mode 100644 index 000000000..87209c3f6 --- /dev/null +++ b/opamp/observiq/updater_manager_windows.go @@ -0,0 +1,88 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build windows + +package observiq + +import ( + "errors" + "fmt" + "os" + "os/exec" + "path/filepath" + "time" + + "go.uber.org/zap" +) + +const defaultWindowsUpdaterName = "updater.exe" + +// Ensure interface is satisfied +var _ updaterManager = (*windowsUpdaterManager)(nil) + +// windowsUpdaterManager handles starting a Updater binary and watching it for failure with a timeout +type windowsUpdaterManager struct { + tmpPath string + updaterName string + cwd string + logger *zap.Logger + shutdownWaitTimeout time.Duration +} + +// newUpdaterManager creates a new updaterManager +func newUpdaterManager(defaultLogger *zap.Logger, tmpPath string) (updaterManager, error) { + cwd, err := os.Getwd() + if err != nil { + return nil, fmt.Errorf("failed to get cwd: %w", err) + } + + return &windowsUpdaterManager{ + tmpPath: filepath.Clean(tmpPath), + logger: defaultLogger.Named("updater manager"), + updaterName: defaultWindowsUpdaterName, + cwd: cwd, + shutdownWaitTimeout: defaultShutdownWaitTimeout, + }, nil +} + +// StartAndMonitorUpdater will start the Updater binary and wait to see if it finishes unexpectedly. +// While waiting for Updater, it should kill the collector and we should never execute any code past running it +func (m windowsUpdaterManager) StartAndMonitorUpdater() error { + initialUpdaterPath := filepath.Join(m.tmpPath, updaterDir, m.updaterName) + updaterPath, err := copyExecutable(m.logger.Named("copy-executable"), initialUpdaterPath, m.cwd) + if err != nil { + return fmt.Errorf("failed to copy updater to cwd: %w", err) + } + + //#nosec G204 -- paths are not determined via user input + cmd := exec.Command(updaterPath) + + // Start does not block + if err := cmd.Start(); err != nil { + return fmt.Errorf("updater had an issue while starting: %w", err) + } + + // See if we're still alive after waiting for the timeout to pass + time.Sleep(m.shutdownWaitTimeout) + + // Updater might already be killed + if err := cmd.Process.Kill(); err != nil { + m.logger.Error("Failed to kill failed Updater", zap.Error(err)) + } + + // Ideally we should not get here as we will be killed by the updater. + // Updater should either exit before us with error or we die before it does. + return errors.New("updater failed to update collector") +} diff --git a/opamp/observiq/updater_manager_windows_test.go b/opamp/observiq/updater_manager_windows_test.go new file mode 100644 index 000000000..85779f563 --- /dev/null +++ b/opamp/observiq/updater_manager_windows_test.go @@ -0,0 +1,149 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build windows + +package observiq + +import ( + "os" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +func TestNewWindowsUpdaterManager(t *testing.T) { + testCases := []struct { + desc string + testFunc func(*testing.T) + }{ + { + desc: "New WindowsUpdaterManager", + testFunc: func(t *testing.T) { + tmpPath := "\\tmp" + logger := zap.NewNop() + cwd, err := os.Getwd() + require.NoError(t, err) + + expected := &windowsUpdaterManager{ + tmpPath: tmpPath, + logger: logger.Named("updater manager"), + updaterName: "updater.exe", + cwd: cwd, + shutdownWaitTimeout: 30 * time.Second, + } + + actual, err := newUpdaterManager(logger, tmpPath) + require.NoError(t, err) + require.Equal(t, expected, actual) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, tc.testFunc) + } +} + +// We don't have a good way to unit test the happy path, +// which involves the entire collector being killed in the middle of this function +func TestStartAndMonitorUpdater(t *testing.T) { + testCases := []struct { + desc string + testFunc func(*testing.T) + }{ + { + desc: "Updater does not exist at path", + testFunc: func(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + updateManager, err := newUpdaterManager(zap.NewNop(), tmpDir) + require.NoError(t, err) + + updateManager.(*windowsUpdaterManager).cwd = tmpDir + updateManager.(*windowsUpdaterManager).shutdownWaitTimeout = 5 * time.Second + + err = updateManager.StartAndMonitorUpdater() + + assert.ErrorContains(t, err, "failed to copy updater to cwd") + }, + }, + { + desc: "Updater is not executable", + testFunc: func(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + updateManager, err := newUpdaterManager(zap.NewNop(), "./testdata") + require.NoError(t, err) + + updateManager.(*windowsUpdaterManager).cwd = tmpDir + updateManager.(*windowsUpdaterManager).updaterName = "badupdater" + updateManager.(*windowsUpdaterManager).shutdownWaitTimeout = 5 * time.Second + + err = updateManager.StartAndMonitorUpdater() + + assert.ErrorContains(t, err, "updater had an issue while starting:") + }, + }, + { + desc: "Updater exits quickly", + testFunc: func(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + updateManager, err := newUpdaterManager(zap.NewNop(), "./testdata") + require.NoError(t, err) + + updateManager.(*windowsUpdaterManager).cwd = tmpDir + updateManager.(*windowsUpdaterManager).updaterName = "quickupdater.exe" + updateManager.(*windowsUpdaterManager).shutdownWaitTimeout = 5 * time.Second + + err = updateManager.StartAndMonitorUpdater() + + assert.EqualError(t, err, "updater failed to update collector") + }, + }, + { + desc: "Updater times out", + testFunc: func(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + updateManager, err := newUpdaterManager(zap.NewNop(), "./testdata") + require.NoError(t, err) + + updateManager.(*windowsUpdaterManager).cwd = tmpDir + updateManager.(*windowsUpdaterManager).updaterName = "slowupdater.exe" + updateManager.(*windowsUpdaterManager).shutdownWaitTimeout = 5 * time.Second + + err = updateManager.StartAndMonitorUpdater() + + assert.ErrorContains(t, err, "updater failed to update collector") + + // The slow updater needs time to shut down, so we wait an extra second. + // If the updater isn't killed, the tmpDir cannot be deleted and the test fails. + time.Sleep(1 * time.Second) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, tc.testFunc) + } +} diff --git a/packagestate/go.mod b/packagestate/go.mod new file mode 100644 index 000000000..706938a82 --- /dev/null +++ b/packagestate/go.mod @@ -0,0 +1,23 @@ +module github.com/observiq/observiq-otel-collector/packagestate + +go 1.17 + +require ( + github.com/open-telemetry/opamp-go v0.2.0 + github.com/stretchr/testify v1.8.0 + go.uber.org/zap v1.21.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/google/go-cmp v0.5.8 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/objx v0.4.0 // indirect + go.uber.org/atomic v1.9.0 // indirect + go.uber.org/goleak v1.1.12 // indirect + go.uber.org/multierr v1.8.0 // indirect + google.golang.org/protobuf v1.28.0 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/packagestate/go.sum b/packagestate/go.sum new file mode 100644 index 000000000..8cdef17e7 --- /dev/null +++ b/packagestate/go.sum @@ -0,0 +1,80 @@ +github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= +github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= +github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/open-telemetry/opamp-go v0.2.0 h1:dV7wTkG5XNiorU62N1CJPr3f5dM0PGEtUUBtvK+LEG0= +github.com/open-telemetry/opamp-go v0.2.0/go.mod h1:IMdeuHGVc5CjKSu5/oNV0o+UmiXuahoHvoZ4GOmAI9M= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0 h1:M2gUjqZET1qApGOWNSnZ49BAIMX4F/1plDv3+l31EJ4= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= +go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= +go.uber.org/goleak v1.1.12 h1:gZAh5/EyT/HQwlpkCy6wTpqfH9H8Lz8zbm3dZh+OyzA= +go.uber.org/goleak v1.1.12/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= +go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= +go.uber.org/multierr v1.8.0 h1:dg6GjLku4EH+249NNmoIciG9N/jURbDG+pFlTkhzIC8= +go.uber.org/multierr v1.8.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak= +go.uber.org/zap v1.21.0 h1:WefMeulhovoZ2sYXz7st6K0sLj7bBhpiFaud4r4zST8= +go.uber.org/zap v1.21.0/go.mod h1:wjWOCqI0f2ZZrJF/UufIOkiC8ii6tm1iqIsLo76RfJw= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= +google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/packagestate/mocks/mock_state_manager.go b/packagestate/mocks/mock_state_manager.go new file mode 100644 index 000000000..33259cadc --- /dev/null +++ b/packagestate/mocks/mock_state_manager.go @@ -0,0 +1,63 @@ +// Code generated by mockery v2.12.2. DO NOT EDIT. + +package mocks + +import ( + mock "github.com/stretchr/testify/mock" + + protobufs "github.com/open-telemetry/opamp-go/protobufs" + + testing "testing" +) + +// MockStateManager is an autogenerated mock type for the StateManager type +type MockStateManager struct { + mock.Mock +} + +// LoadStatuses provides a mock function with given fields: +func (_m *MockStateManager) LoadStatuses() (*protobufs.PackageStatuses, error) { + ret := _m.Called() + + var r0 *protobufs.PackageStatuses + if rf, ok := ret.Get(0).(func() *protobufs.PackageStatuses); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*protobufs.PackageStatuses) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// SaveStatuses provides a mock function with given fields: statuses +func (_m *MockStateManager) SaveStatuses(statuses *protobufs.PackageStatuses) error { + ret := _m.Called(statuses) + + var r0 error + if rf, ok := ret.Get(0).(func(*protobufs.PackageStatuses) error); ok { + r0 = rf(statuses) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewMockStateManager creates a new instance of MockStateManager. It also registers the testing.TB interface on the mock and a cleanup function to assert the mocks expectations. +func NewMockStateManager(t testing.TB) *MockStateManager { + mock := &MockStateManager{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/packagestate/packages_state_manager.go b/packagestate/packages_state_manager.go new file mode 100644 index 000000000..757d4d845 --- /dev/null +++ b/packagestate/packages_state_manager.go @@ -0,0 +1,161 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package packagestate contains structures for reading and writing the package status +package packagestate + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + + "github.com/open-telemetry/opamp-go/protobufs" + "go.uber.org/zap" +) + +// CollectorPackageName is the name for the top level packages for this collector +const CollectorPackageName = "observiq-otel-collector" + +// DefaultFileName is the default name of the file use to store state +const DefaultFileName = "package_statuses.json" + +// StateManager tracks Package states +type StateManager interface { + // LoadStatuses retrieves the previously saved PackagesStatuses. + // If none were saved returns error + LoadStatuses() (*protobufs.PackageStatuses, error) + + // SaveStatuses saves the given PackageStatuses + SaveStatuses(statuses *protobufs.PackageStatuses) error +} + +// FileStateManager manages state on disk via a JSON file +type FileStateManager struct { + jsonPath string + logger *zap.Logger +} + +type packageState struct { + Name string `json:"name"` + AgentVersion string `json:"agent_version"` + AgentHash []byte `json:"agent_hash"` + ServerVersion string `json:"server_version"` + ServerHash []byte `json:"server_hash"` + Status protobufs.PackageStatus_Status `json:"status"` + ErrorMessage string `json:"error_message"` +} +type packageStates struct { + AllPackagesHash []byte `json:"all_packages_hash"` + AllErrorMessage string `json:"all_error_message"` + PackageStates map[string]*packageState `json:"package_states"` +} + +// NewFileStateManager creates a new PackagesStateManager +func NewFileStateManager(logger *zap.Logger, jsonPath string) StateManager { + return &FileStateManager{ + jsonPath: filepath.Clean(jsonPath), + logger: logger, + } +} + +// LoadStatuses retrieves the PackagesStatuses from a saved json file +func (p *FileStateManager) LoadStatuses() (*protobufs.PackageStatuses, error) { + p.logger.Debug("Loading package statuses") + + statusesBytes, err := os.ReadFile(p.jsonPath) + if err != nil { + return nil, fmt.Errorf("failed to read package statuses json: %w", err) + } + + var packageStates packageStates + if err := json.Unmarshal(statusesBytes, &packageStates); err != nil { + return nil, fmt.Errorf("failed to unmarshal package statuses: %w", err) + } + + return packageStatesToStatuses(packageStates), nil +} + +// SaveStatuses saves the given PackageStatuses into a json file +func (p *FileStateManager) SaveStatuses(statuses *protobufs.PackageStatuses) error { + p.logger.Debug("Saving package statuses") + + // If there is any problem saving the new package statuses, make sure that we delete any existing file + // in order to not have outdated data as its better to start fresh + if err := os.Remove(p.jsonPath); err != nil { + p.logger.Debug("Failed to delete package statuses json", zap.Error(err)) + } + + states := packageStatusesToStates(statuses) + + data, err := json.Marshal(states) + if err != nil { + return fmt.Errorf("failed to marshal package statuses: %w", err) + } + + // Write data to a package_statuses.json file, with 0600 file permission + if err := os.WriteFile(p.jsonPath, data, 0600); err != nil { + return fmt.Errorf("failed to write package statuses json: %w", err) + } + + return nil +} + +func packageStatusesToStates(statuses *protobufs.PackageStatuses) *packageStates { + states := &packageStates{ + AllPackagesHash: statuses.GetServerProvidedAllPackagesHash(), + AllErrorMessage: statuses.GetErrorMessage(), + } + + packageStates := map[string]*packageState{} + for name, packageStatus := range statuses.Packages { + packageState := &packageState{ + Name: packageStatus.GetName(), + AgentVersion: packageStatus.GetAgentHasVersion(), + AgentHash: packageStatus.GetAgentHasHash(), + ServerVersion: packageStatus.GetServerOfferedVersion(), + ServerHash: packageStatus.GetServerOfferedHash(), + Status: packageStatus.GetStatus(), + ErrorMessage: packageStatus.GetErrorMessage(), + } + packageStates[name] = packageState + } + states.PackageStates = packageStates + + return states +} + +func packageStatesToStatuses(states packageStates) *protobufs.PackageStatuses { + statuses := &protobufs.PackageStatuses{ + ServerProvidedAllPackagesHash: states.AllPackagesHash, + ErrorMessage: states.AllErrorMessage, + } + + packages := map[string]*protobufs.PackageStatus{} + for name, packageState := range states.PackageStates { + packageStatus := &protobufs.PackageStatus{ + Name: packageState.Name, + AgentHasVersion: packageState.AgentVersion, + AgentHasHash: packageState.AgentHash, + ServerOfferedVersion: packageState.ServerVersion, + ServerOfferedHash: packageState.ServerHash, + Status: packageState.Status, + ErrorMessage: packageState.ErrorMessage, + } + packages[name] = packageStatus + } + statuses.Packages = packages + + return statuses +} diff --git a/packagestate/packages_state_manager_linux_test.go b/packagestate/packages_state_manager_linux_test.go new file mode 100644 index 000000000..ea4a4a73f --- /dev/null +++ b/packagestate/packages_state_manager_linux_test.go @@ -0,0 +1,98 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// go:build !windows + +package packagestate + +import ( + "os" + "path/filepath" + "testing" + + "github.com/open-telemetry/opamp-go/protobufs" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" +) + +func TestLoadStatusesLinux(t *testing.T) { + testCases := []struct { + desc string + testFunc func(*testing.T) + }{ + { + desc: "Problem reading file", + testFunc: func(t *testing.T) { + tmpDir := t.TempDir() + cantReadJSON := filepath.Join(tmpDir, "noread.json") + os.WriteFile(cantReadJSON, nil, 0000) + logger := zap.NewNop() + p := &FileStateManager{ + logger: logger, + jsonPath: cantReadJSON, + } + + actual, err := p.LoadStatuses() + + assert.ErrorContains(t, err, "failed to read package statuses json:") + assert.Nil(t, actual) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, tc.testFunc) + } +} + +func TestSaveStatusesLinux(t *testing.T) { + testCases := []struct { + desc string + testFunc func(*testing.T) + }{ + { + desc: "Can't write to file", + testFunc: func(t *testing.T) { + tmpDir := t.TempDir() + os.Chmod(tmpDir, 0400) + testJSON := filepath.Join(tmpDir, "test.json") + logger := zap.NewNop() + p := &FileStateManager{ + logger: logger, + jsonPath: testJSON, + } + + packageStatuses := &protobufs.PackageStatuses{ + ServerProvidedAllPackagesHash: []byte("hash"), + } + + err := p.SaveStatuses(packageStatuses) + + assert.ErrorContains(t, err, "failed to write package statuses json:") + + // Right now the following code won't work, because the file can't be deleted as we don't have write permissions. + // It would be nice to have a way to test a write failure, while still being able to delete the file. + // exists := true + // if _, err = os.Stat(testJSON); os.IsNotExist(err) { + // exists = false + // } + // assert.False(t, exists) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, tc.testFunc) + } +} diff --git a/packagestate/packages_state_manager_test.go b/packagestate/packages_state_manager_test.go new file mode 100644 index 000000000..678bcb726 --- /dev/null +++ b/packagestate/packages_state_manager_test.go @@ -0,0 +1,248 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package packagestate + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + + "github.com/open-telemetry/opamp-go/protobufs" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" +) + +func TestNewPackagesStateManager(t *testing.T) { + testCases := []struct { + desc string + testFunc func(*testing.T) + }{ + { + desc: "New PackagesStateManager", + testFunc: func(t *testing.T) { + jsonPath := "test.json" + logger := zap.NewNop() + actual := NewFileStateManager(logger, jsonPath) + + var expected StateManager = &FileStateManager{ + jsonPath: jsonPath, + logger: logger, + } + + assert.Equal(t, expected, actual) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, tc.testFunc) + } +} + +func TestLoadStatuses(t *testing.T) { + testCases := []struct { + desc string + testFunc func(*testing.T) + }{ + { + desc: "File doesn't exist", + testFunc: func(t *testing.T) { + noExistJSON := "garbage.json" + logger := zap.NewNop() + p := &FileStateManager{ + logger: logger, + jsonPath: noExistJSON, + } + + actual, err := p.LoadStatuses() + + assert.ErrorIs(t, err, os.ErrNotExist) + assert.Nil(t, actual) + }, + }, + { + desc: "Bad json file", + testFunc: func(t *testing.T) { + badJSON := "testdata/package_statuses_bad.json" + logger := zap.NewNop() + p := &FileStateManager{ + logger: logger, + jsonPath: badJSON, + } + + actual, err := p.LoadStatuses() + + assert.ErrorContains(t, err, "failed to unmarshal package statuses:") + assert.Nil(t, actual) + }, + }, + { + desc: "Good json file", + testFunc: func(t *testing.T) { + goodJSON := "testdata/package_statuses_good.json" + pkgName := "package" + agentVersion := "1.0" + agentHash := []byte("hash1") + serverVersion := "2.0" + serverHash := []byte("hash2") + errMsg := "bad" + allHash := []byte("hash") + allErrMsg := "whoops" + logger := zap.NewNop() + p := &FileStateManager{ + logger: logger, + jsonPath: goodJSON, + } + + actual, err := p.LoadStatuses() + + assert.NoError(t, err) + assert.Equal(t, allHash, actual.ServerProvidedAllPackagesHash) + assert.Equal(t, allErrMsg, actual.ErrorMessage) + assert.Equal(t, 1, len(actual.Packages)) + assert.Equal(t, pkgName, actual.Packages[pkgName].GetName()) + assert.Equal(t, agentVersion, actual.Packages[pkgName].GetAgentHasVersion()) + assert.Equal(t, agentHash, actual.Packages[pkgName].GetAgentHasHash()) + assert.Equal(t, serverVersion, actual.Packages[pkgName].GetServerOfferedVersion()) + assert.Equal(t, serverHash, actual.Packages[pkgName].GetServerOfferedHash()) + assert.Equal(t, protobufs.PackageStatus_InstallPending, actual.Packages[pkgName].GetStatus()) + assert.Equal(t, errMsg, actual.Packages[pkgName].GetErrorMessage()) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, tc.testFunc) + } +} + +func TestSaveStatuses(t *testing.T) { + pkgName := "package" + agentVersion := "1.0" + agentHash := []byte("hash1") + serverVersion := "2.0" + serverHash := []byte("hash2") + errMsg := "bad" + allHash := []byte("hash") + allErrMsg := "whoops" + testCases := []struct { + desc string + testFunc func(*testing.T) + }{ + { + desc: "New file", + testFunc: func(t *testing.T) { + tmpDir := t.TempDir() + testJSON := filepath.Join(tmpDir, "test.json") + logger := zap.NewNop() + p := &FileStateManager{ + logger: logger, + jsonPath: testJSON, + } + + packages := map[string]*protobufs.PackageStatus{ + pkgName: { + Name: pkgName, + AgentHasVersion: agentVersion, + AgentHasHash: agentHash, + ServerOfferedVersion: serverVersion, + ServerOfferedHash: serverHash, + Status: protobufs.PackageStatus_InstallPending, + ErrorMessage: errMsg, + }, + } + packageStatuses := &protobufs.PackageStatuses{ + ServerProvidedAllPackagesHash: allHash, + ErrorMessage: allErrMsg, + Packages: packages, + } + + err := p.SaveStatuses(packageStatuses) + assert.NoError(t, err) + + bytes, err := os.ReadFile(testJSON) + assert.NoError(t, err) + var fileStates packageStates + err = json.Unmarshal(bytes, &fileStates) + assert.NoError(t, err) + assert.Equal(t, allHash, fileStates.AllPackagesHash) + assert.Equal(t, allErrMsg, fileStates.AllErrorMessage) + assert.Equal(t, 1, len(fileStates.PackageStates)) + assert.Equal(t, pkgName, fileStates.PackageStates[pkgName].Name) + assert.Equal(t, agentVersion, fileStates.PackageStates[pkgName].AgentVersion) + assert.Equal(t, agentHash, fileStates.PackageStates[pkgName].AgentHash) + assert.Equal(t, serverVersion, fileStates.PackageStates[pkgName].ServerVersion) + assert.Equal(t, serverHash, fileStates.PackageStates[pkgName].ServerHash) + assert.Equal(t, protobufs.PackageStatus_InstallPending, fileStates.PackageStates[pkgName].Status) + assert.Equal(t, errMsg, fileStates.PackageStates[pkgName].ErrorMessage) + }, + }, + { + desc: "Existing file", + testFunc: func(t *testing.T) { + tmpDir := t.TempDir() + testJSON := filepath.Join(tmpDir, "test.json") + os.WriteFile(testJSON, nil, 0600) + + logger := zap.NewNop() + p := &FileStateManager{ + logger: logger, + jsonPath: testJSON, + } + + packages := map[string]*protobufs.PackageStatus{ + pkgName: { + Name: pkgName, + AgentHasVersion: agentVersion, + AgentHasHash: agentHash, + ServerOfferedVersion: serverVersion, + ServerOfferedHash: serverHash, + Status: protobufs.PackageStatus_InstallPending, + ErrorMessage: errMsg, + }, + } + packageStatuses := &protobufs.PackageStatuses{ + ServerProvidedAllPackagesHash: allHash, + ErrorMessage: allErrMsg, + Packages: packages, + } + + err := p.SaveStatuses(packageStatuses) + assert.NoError(t, err) + + bytes, err := os.ReadFile(testJSON) + assert.NoError(t, err) + var fileStates packageStates + err = json.Unmarshal(bytes, &fileStates) + assert.NoError(t, err) + assert.Equal(t, allHash, fileStates.AllPackagesHash) + assert.Equal(t, allErrMsg, fileStates.AllErrorMessage) + assert.Equal(t, 1, len(fileStates.PackageStates)) + assert.Equal(t, pkgName, fileStates.PackageStates[pkgName].Name) + assert.Equal(t, agentVersion, fileStates.PackageStates[pkgName].AgentVersion) + assert.Equal(t, agentHash, fileStates.PackageStates[pkgName].AgentHash) + assert.Equal(t, serverVersion, fileStates.PackageStates[pkgName].ServerVersion) + assert.Equal(t, serverHash, fileStates.PackageStates[pkgName].ServerHash) + assert.Equal(t, protobufs.PackageStatus_InstallPending, fileStates.PackageStates[pkgName].Status) + assert.Equal(t, errMsg, fileStates.PackageStates[pkgName].ErrorMessage) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, tc.testFunc) + } +} diff --git a/packagestate/testdata/package_statuses_bad.json b/packagestate/testdata/package_statuses_bad.json new file mode 100644 index 000000000..7c591551d --- /dev/null +++ b/packagestate/testdata/package_statuses_bad.json @@ -0,0 +1 @@ +nothing to see here diff --git a/packagestate/testdata/package_statuses_good.json b/packagestate/testdata/package_statuses_good.json new file mode 100644 index 000000000..7556d2211 --- /dev/null +++ b/packagestate/testdata/package_statuses_good.json @@ -0,0 +1,33 @@ +{ + "all_packages_hash": [ + 104, + 97, + 115, + 104 + ], + "all_error_message": "whoops", + "package_states": { + "package": { + "name": "package", + "updater_start": false, + "agent_version": "1.0", + "agent_hash": [ + 104, + 97, + 115, + 104, + 49 + ], + "server_version": "2.0", + "server_hash": [ + 104, + 97, + 115, + 104, + 50 + ], + "status": 1, + "error_message": "bad" + } + } +} diff --git a/service/com.observiq.collector.plist b/service/com.observiq.collector.plist index 6f987a28e..2cf3816e6 100644 --- a/service/com.observiq.collector.plist +++ b/service/com.observiq.collector.plist @@ -25,5 +25,7 @@ WorkingDirectory [INSTALLDIR] + ExitTimeOut + 20 diff --git a/service/observiq-otel-collector.service b/service/observiq-otel-collector.service index 90c60d914..66e0c4fa7 100644 --- a/service/observiq-otel-collector.service +++ b/service/observiq-otel-collector.service @@ -13,9 +13,10 @@ Environment=OIQ_OTEL_COLLECTOR_STORAGE=/opt/observiq-otel-collector/storage WorkingDirectory=/opt/observiq-otel-collector ExecStart=/opt/observiq-otel-collector/observiq-otel-collector --config config.yaml SuccessExitStatus=0 -TimeoutSec=120 +TimeoutSec=20 StandardOutput=journal Restart=on-failure RestartSec=5s +KillMode=process [Install] WantedBy=multi-user.target diff --git a/updater/.gitignore b/updater/.gitignore new file mode 100644 index 000000000..b89285960 --- /dev/null +++ b/updater/.gitignore @@ -0,0 +1,4 @@ +# logging.yaml and manager.yaml are ignored in the base module, but +# they are important in this module for test data. +!logging.yaml +!manager.yaml diff --git a/updater/README.md b/updater/README.md new file mode 100644 index 000000000..6477f661f --- /dev/null +++ b/updater/README.md @@ -0,0 +1,5 @@ +# observIQ Distro for OpenTelemetry Updater + +The updater is a separate binary that runs as a separate process to update collector artifacts (including the collector itself) when managed by [BindPlane OP](https://github.com/observIQ/bindplane-op). + +Because the updater edits service configurations, it needs elevated privileges to run (root on Linux + macOS, administrative privileges on Windows). diff --git a/updater/cmd/updater/main.go b/updater/cmd/updater/main.go new file mode 100644 index 000000000..48328cff1 --- /dev/null +++ b/updater/cmd/updater/main.go @@ -0,0 +1,63 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "fmt" + "log" + + "github.com/observiq/observiq-otel-collector/updater/internal/logging" + "github.com/observiq/observiq-otel-collector/updater/internal/path" + "github.com/observiq/observiq-otel-collector/updater/internal/updater" + "github.com/observiq/observiq-otel-collector/updater/internal/version" + "github.com/spf13/pflag" + "go.uber.org/zap" +) + +func main() { + var showVersion = pflag.BoolP("version", "v", false, "Prints the version of the updater and exits, if specified.") + pflag.Parse() + + if *showVersion { + fmt.Println("observiq-otel-collector updater version", version.Version()) + fmt.Println("commit:", version.GitHash()) + fmt.Println("built at:", version.Date()) + return + } + + // We can't create the zap logger yet, because we don't know the install dir, which is needed + // to create the logger. So we pass a Nop logger here. + installDir, err := path.InstallDir(zap.NewNop()) + if err != nil { + // Can't use "fail" here since we don't know the install directory + log.Fatalf("Failed to determine install directory: %s", err) + } + + logger, err := logging.NewLogger(installDir) + if err != nil { + log.Fatalf("Failed to create logger: %s\n", err) + } + + updater, err := updater.NewUpdater(logger, installDir) + if err != nil { + logger.Fatal("Failed to create updater", zap.Error(err)) + } + + if err := updater.Update(); err != nil { + logger.Fatal("Failed to update", zap.Error(err)) + } + + logger.Info("Updater finished successfully") +} diff --git a/updater/go.mod b/updater/go.mod new file mode 100644 index 000000000..5da7f2eae --- /dev/null +++ b/updater/go.mod @@ -0,0 +1,26 @@ +module github.com/observiq/observiq-otel-collector/updater + +go 1.17 + +require ( + github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 + github.com/observiq/observiq-otel-collector/packagestate v0.0.0-00010101000000-000000000000 + github.com/open-telemetry/opamp-go v0.2.0 + github.com/spf13/pflag v1.0.5 + github.com/stretchr/testify v1.8.0 + go.uber.org/zap v1.21.0 + golang.org/x/sys v0.0.0-20220408201424-a24fb2fb8a0f +) + +require ( + github.com/benbjohnson/clock v1.1.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/objx v0.4.0 // indirect + go.uber.org/atomic v1.9.0 // indirect + go.uber.org/multierr v1.8.0 // indirect + google.golang.org/protobuf v1.28.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) + +replace github.com/observiq/observiq-otel-collector/packagestate => ../packagestate diff --git a/updater/go.sum b/updater/go.sum new file mode 100644 index 000000000..6a2992d64 --- /dev/null +++ b/updater/go.sum @@ -0,0 +1,82 @@ +github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= +github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.8 h1:e6P7q2lk1O+qJJb4BtCQXlK8vWEO8V1ZeuEdJNOqZyg= +github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= +github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/open-telemetry/opamp-go v0.2.0 h1:dV7wTkG5XNiorU62N1CJPr3f5dM0PGEtUUBtvK+LEG0= +github.com/open-telemetry/opamp-go v0.2.0/go.mod h1:IMdeuHGVc5CjKSu5/oNV0o+UmiXuahoHvoZ4GOmAI9M= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0 h1:M2gUjqZET1qApGOWNSnZ49BAIMX4F/1plDv3+l31EJ4= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= +go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/goleak v1.1.11/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= +go.uber.org/goleak v1.1.12 h1:gZAh5/EyT/HQwlpkCy6wTpqfH9H8Lz8zbm3dZh+OyzA= +go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= +go.uber.org/multierr v1.8.0 h1:dg6GjLku4EH+249NNmoIciG9N/jURbDG+pFlTkhzIC8= +go.uber.org/multierr v1.8.0/go.mod h1:7EAYxJLBy9rStEaz58O2t4Uvip6FSURkq8/ppBp95ak= +go.uber.org/zap v1.21.0 h1:WefMeulhovoZ2sYXz7st6K0sLj7bBhpiFaud4r4zST8= +go.uber.org/zap v1.21.0/go.mod h1:wjWOCqI0f2ZZrJF/UufIOkiC8ii6tm1iqIsLo76RfJw= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220408201424-a24fb2fb8a0f h1:8w7RhxzTVgUzw/AH/9mUV5q0vMgy40SQRursCcfmkCw= +golang.org/x/sys v0.0.0-20220408201424-a24fb2fb8a0f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.28.0 h1:w43yiav+6bVFTBQFZX0r7ipe9JQ1QsbMgHwbBziscLw= +google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/updater/internal/action/action.go b/updater/internal/action/action.go new file mode 100644 index 000000000..6ddbd7845 --- /dev/null +++ b/updater/internal/action/action.go @@ -0,0 +1,21 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package action + +// RollbackableAction is an interface to represents an install action that may be rolled back. +//go:generate mockery --name RollbackableAction --filename rollbackable_action.go +type RollbackableAction interface { + Rollback() error +} diff --git a/updater/internal/action/file_action.go b/updater/internal/action/file_action.go new file mode 100644 index 000000000..9f65de006 --- /dev/null +++ b/updater/internal/action/file_action.go @@ -0,0 +1,87 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package action + +import ( + "errors" + "fmt" + "os" + "path/filepath" + + "github.com/observiq/observiq-otel-collector/updater/internal/file" + "go.uber.org/zap" +) + +// CopyFileAction is an action that records a file being copied from FromPath to ToPath +type CopyFileAction struct { + // FromPathRel is the path where the file originated, relative to the "latest" + // directory + FromPathRel string + // ToPath is the path where the file was written. + ToPath string + // FileCreated is a bool that records whether this action had to create a new file or not + FileCreated bool + backupDir string + logger *zap.Logger +} + +var _ RollbackableAction = (*CopyFileAction)(nil) +var _ fmt.Stringer = (*CopyFileAction)(nil) + +// NewCopyFileAction creates a new CopyFileAction that indicates a file was copied from +// fromPathRel into toPath. backupDir is specified for rollback purposes. +// NOTE: This action MUST be created BEFORE the action actually takes place; This allows +// for previous existence of the file to be recorded. +func NewCopyFileAction(logger *zap.Logger, fromPathRel, toPath, backupDir string) (*CopyFileAction, error) { + fileExists := true + _, err := os.Stat(toPath) + switch { + case errors.Is(err, os.ErrNotExist): + fileExists = false + case err != nil: + return nil, fmt.Errorf("unexpected error stat-ing file: %w", err) + } + + return &CopyFileAction{ + FromPathRel: fromPathRel, + ToPath: toPath, + // The file will be created if it doesn't already exist + FileCreated: !fileExists, + backupDir: backupDir, + logger: logger.Named("copy-file-action"), + }, nil +} + +// Rollback will undo the file copy, by either deleting the file if the file did not originally exist, +// or it will copy the old file in the rollback dir if it already exists. +func (c CopyFileAction) Rollback() error { + if c.FileCreated { + // File did not exist before this action. + // We just need to delete this file. + return os.RemoveAll(c.ToPath) + } + + // join the relative path to the backup directory to get the location of the backup path + backupFilePath := filepath.Join(c.backupDir, c.FromPathRel) + if err := file.CopyFileRollback(c.logger.Named("copy-file"), backupFilePath, c.ToPath); err != nil { + return fmt.Errorf("failed to copy file: %w", err) + } + + return nil +} + +func (c CopyFileAction) String() string { + return fmt.Sprintf("CopyFileAction{FromPathRel: '%s', ToPath: '%s', FileCreated: '%t'}", c.FromPathRel, c.ToPath, c.FileCreated) +} diff --git a/updater/internal/action/file_action_test.go b/updater/internal/action/file_action_test.go new file mode 100644 index 000000000..1d384f133 --- /dev/null +++ b/updater/internal/action/file_action_test.go @@ -0,0 +1,164 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package action + +import ( + "os" + "path/filepath" + "testing" + + "github.com/observiq/observiq-otel-collector/updater/internal/path" + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" +) + +func TestNewCopyFileAction(t *testing.T) { + t.Run("out file does not exist", func(t *testing.T) { + scratchDir := t.TempDir() + testInstallDir := filepath.Join("testdata", "copyfileaction") + backupDir := path.BackupDir(testInstallDir) + outFile := filepath.Join(scratchDir, "test.txt") + inFile := filepath.Join(testInstallDir, "latest", "test.txt") + + a, err := NewCopyFileAction(zaptest.NewLogger(t), inFile, outFile, backupDir) + require.NoError(t, err) + + require.Equal(t, &CopyFileAction{ + FromPathRel: inFile, + ToPath: outFile, + FileCreated: true, + backupDir: backupDir, + logger: a.logger, + }, a) + }) + + t.Run("out file exists", func(t *testing.T) { + scratchDir := t.TempDir() + testInstallDir := filepath.Join("testdata", "copyfileaction") + backupDir := path.BackupDir(testInstallDir) + outFile := filepath.Join(scratchDir, "test.txt") + inFile := filepath.Join(testInstallDir, "latest", "test.txt") + + f, err := os.Create(outFile) + require.NoError(t, err) + require.NoError(t, f.Close()) + + a, err := NewCopyFileAction(zaptest.NewLogger(t), inFile, outFile, backupDir) + require.NoError(t, err) + + require.Equal(t, &CopyFileAction{ + FromPathRel: inFile, + ToPath: outFile, + FileCreated: false, + backupDir: backupDir, + logger: a.logger, + }, a) + }) +} + +func TestCopyFileActionRollback(t *testing.T) { + t.Run("deletes out file if it does not exist", func(t *testing.T) { + scratchDir := t.TempDir() + testInstallDir := filepath.Join("testdata", "copyfileaction") + backupDir := path.BackupDir(testInstallDir) + outFile := filepath.Join(scratchDir, "test.txt") + inFile := filepath.Join(testInstallDir, "tmp", "latest", "test.txt") + + a, err := NewCopyFileAction(zaptest.NewLogger(t), inFile, outFile, backupDir) + require.NoError(t, err) + + inBytes, err := os.ReadFile(inFile) + require.NoError(t, err) + + err = os.WriteFile(outFile, inBytes, 0600) + require.NoError(t, err) + + err = a.Rollback() + require.NoError(t, err) + + require.NoFileExists(t, outFile) + }) + + t.Run("Rolls back out file when it exists", func(t *testing.T) { + scratchDir := t.TempDir() + testInstallDir := filepath.Join("testdata", "copyfileaction") + backupDir := path.BackupDir(testInstallDir) + outFile := filepath.Join(scratchDir, "test.txt") + inFileRel := "test.txt" + inFile := filepath.Join(testInstallDir, "tmp", "latest", inFileRel) + originalFile := filepath.Join(testInstallDir, "tmp", "rollback", "test.txt") + + originalBytes, err := os.ReadFile(originalFile) + require.NoError(t, err) + + err = os.WriteFile(outFile, originalBytes, 0600) + require.NoError(t, err) + + a, err := NewCopyFileAction(zaptest.NewLogger(t), inFileRel, outFile, backupDir) + require.NoError(t, err) + + // Overwrite original file with latest file + inBytes, err := os.ReadFile(inFile) + require.NoError(t, err) + + err = os.WriteFile(outFile, inBytes, 0600) + require.NoError(t, err) + + err = a.Rollback() + require.NoError(t, err) + + require.FileExists(t, outFile) + + rolledBackBytes, err := os.ReadFile(outFile) + require.NoError(t, err) + + require.Equal(t, originalBytes, rolledBackBytes) + }) + + t.Run("Fails if backup file doesn't exist", func(t *testing.T) { + scratchDir := t.TempDir() + testInstallDir := filepath.Join("testdata", "copyfileaction") + backupDir := path.BackupDir(testInstallDir) + outFile := filepath.Join(scratchDir, "test.txt") + inFile := filepath.Join(testInstallDir, "tmp", "latest", "not_in_backup.txt") + originalFile := filepath.Join(testInstallDir, "tmp", "rollback", "test.txt") + + // The latest file exists in the directory already, but for some reason is not copied to backup + originalBytes, err := os.ReadFile(originalFile) + require.NoError(t, err) + + err = os.WriteFile(outFile, originalBytes, 0600) + require.NoError(t, err) + + a, err := NewCopyFileAction(zaptest.NewLogger(t), inFile, outFile, backupDir) + require.NoError(t, err) + + // Overwrite original file with latest file + latestBytes, err := os.ReadFile(inFile) + require.NoError(t, err) + + err = os.WriteFile(outFile, latestBytes, 0600) + require.NoError(t, err) + + err = a.Rollback() + require.ErrorContains(t, err, "failed to copy file") + require.FileExists(t, outFile) + + finalBytes, err := os.ReadFile(outFile) + require.NoError(t, err) + require.Equal(t, latestBytes, finalBytes) + }) + +} diff --git a/updater/internal/action/mocks/rollbackable_action.go b/updater/internal/action/mocks/rollbackable_action.go new file mode 100644 index 000000000..43904bbfe --- /dev/null +++ b/updater/internal/action/mocks/rollbackable_action.go @@ -0,0 +1,39 @@ +// Code generated by mockery v2.14.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" + +// RollbackableAction is an autogenerated mock type for the RollbackableAction type +type RollbackableAction struct { + mock.Mock +} + +// Rollback provides a mock function with given fields: +func (_m *RollbackableAction) Rollback() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type mockConstructorTestingTNewRollbackableAction interface { + mock.TestingT + Cleanup(func()) +} + +// NewRollbackableAction creates a new instance of RollbackableAction. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewRollbackableAction(t mockConstructorTestingTNewRollbackableAction) *RollbackableAction { + mock := &RollbackableAction{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/updater/internal/action/service_action.go b/updater/internal/action/service_action.go new file mode 100644 index 000000000..b3cd2a606 --- /dev/null +++ b/updater/internal/action/service_action.go @@ -0,0 +1,83 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package action + +import ( + "github.com/observiq/observiq-otel-collector/updater/internal/path" + "github.com/observiq/observiq-otel-collector/updater/internal/service" + "go.uber.org/zap" +) + +// ServiceStopAction is an action that records that a service was stopped. +type ServiceStopAction struct { + svc service.Service +} + +var _ RollbackableAction = (*ServiceStopAction)(nil) + +// NewServiceStopAction creates a new ServiceStopAction +func NewServiceStopAction(svc service.Service) *ServiceStopAction { + return &ServiceStopAction{ + svc: svc, + } +} + +// Rollback rolls back the stop action (starts the service) +func (s ServiceStopAction) Rollback() error { + return s.svc.Start() +} + +// ServiceStartAction is an action that records that a service was started. +type ServiceStartAction struct { + svc service.Service +} + +var _ RollbackableAction = (*ServiceStartAction)(nil) + +// NewServiceStartAction creates a new ServiceStartAction +func NewServiceStartAction(svc service.Service) *ServiceStartAction { + return &ServiceStartAction{ + svc: svc, + } +} + +// Rollback rolls back the start action (stops the service) +func (s ServiceStartAction) Rollback() error { + return s.svc.Stop() +} + +// ServiceUpdateAction is an action that records that a service was updated. +type ServiceUpdateAction struct { + backupSvc service.Service +} + +var _ RollbackableAction = (*ServiceUpdateAction)(nil) + +// NewServiceUpdateAction creates a new ServiceUpdateAction +func NewServiceUpdateAction(logger *zap.Logger, installDir string) *ServiceUpdateAction { + namedLogger := logger.Named("service-update-action") + return &ServiceUpdateAction{ + backupSvc: service.NewService( + namedLogger, + installDir, + service.WithServiceFile(path.BackupServiceFile(installDir)), + ), + } +} + +// Rollback is an action that rolls back the service configuration to the one saved in the backup directory. +func (s ServiceUpdateAction) Rollback() error { + return s.backupSvc.Update() +} diff --git a/updater/internal/action/service_action_test.go b/updater/internal/action/service_action_test.go new file mode 100644 index 000000000..a67c6d5e5 --- /dev/null +++ b/updater/internal/action/service_action_test.go @@ -0,0 +1,54 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package action + +import ( + "testing" + + "github.com/observiq/observiq-otel-collector/updater/internal/service/mocks" + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" +) + +func TestServiceStartAction(t *testing.T) { + svc := mocks.NewService(t) + ssa := NewServiceStartAction(svc) + + svc.On("Stop").Once().Return(nil) + + err := ssa.Rollback() + require.NoError(t, err) +} + +func TestServiceStopAction(t *testing.T) { + svc := mocks.NewService(t) + ssa := NewServiceStopAction(svc) + + svc.On("Start").Once().Return(nil) + + err := ssa.Rollback() + require.NoError(t, err) +} + +func TestServiceUpdateAction(t *testing.T) { + svc := mocks.NewService(t) + sua := NewServiceUpdateAction(zaptest.NewLogger(t), "./testdata") + sua.backupSvc = svc + + svc.On("Update").Once().Return(nil) + + err := sua.Rollback() + require.NoError(t, err) +} diff --git a/updater/internal/action/testdata/copyfileaction/tmp/latest/not_in_backup.txt b/updater/internal/action/testdata/copyfileaction/tmp/latest/not_in_backup.txt new file mode 100644 index 000000000..20f76d643 --- /dev/null +++ b/updater/internal/action/testdata/copyfileaction/tmp/latest/not_in_backup.txt @@ -0,0 +1 @@ +This file doesn't exist in backup diff --git a/updater/internal/action/testdata/copyfileaction/tmp/latest/test.txt b/updater/internal/action/testdata/copyfileaction/tmp/latest/test.txt new file mode 100644 index 000000000..6dfa057f0 --- /dev/null +++ b/updater/internal/action/testdata/copyfileaction/tmp/latest/test.txt @@ -0,0 +1 @@ +This is a new file diff --git a/updater/internal/action/testdata/copyfileaction/tmp/rollback/test.txt b/updater/internal/action/testdata/copyfileaction/tmp/rollback/test.txt new file mode 100644 index 000000000..684d5588a --- /dev/null +++ b/updater/internal/action/testdata/copyfileaction/tmp/rollback/test.txt @@ -0,0 +1 @@ +This is the old file diff --git a/updater/internal/file/file.go b/updater/internal/file/file.go new file mode 100644 index 000000000..11bb78f21 --- /dev/null +++ b/updater/internal/file/file.go @@ -0,0 +1,133 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package file + +import ( + "errors" + "fmt" + "io" + "io/fs" + "os" + "path/filepath" + + "go.uber.org/zap" +) + +// CopyFileOverwrite copies the file from pathIn to pathOut. +// The output file is created if it does not exist. +// If the output file does exist, it is removed, then written from the input file, preserving the output file's mode. +func CopyFileOverwrite(logger *zap.Logger, pathIn, pathOut string) error { + fileMode := fs.FileMode(0600) + pathOutClean := filepath.Clean(pathOut) + + // Try to save existing file's permissions + outFileInfo, _ := os.Stat(pathOutClean) + if outFileInfo != nil { + fileMode = outFileInfo.Mode() + } + + pathInClean := filepath.Clean(pathIn) + // If the input file cannot be opened for some reason, do NOT delete the file + if _, err := os.Stat(pathInClean); err != nil { + return fmt.Errorf("failed to stat input file: %w", err) + } + + // Remove old file to prevent issues with mac + if err := os.Remove(pathOutClean); err != nil { + logger.Debug("Failed to remove output file", zap.Error(err)) + } + + return copyFileInternal(logger, pathIn, pathOut, os.O_CREATE|os.O_WRONLY, fileMode) +} + +// CopyFileNoOverwrite copies the file from pathIn to pathOut, preserving the input file's mode. +// If the output file already exists, this function returns an error. +func CopyFileNoOverwrite(logger *zap.Logger, pathIn, pathOut string) error { + pathInClean := filepath.Clean(pathIn) + + // Use the new file's permissions and fail if there's an issue (want to fail for backup) + inFileInfo, err := os.Stat(pathInClean) + if err != nil { + return fmt.Errorf("failed to retrieve fileinfo for input file: %w", err) + } + + // the os.O_EXCL flag will make OpenFile error if the file already exists + return copyFileInternal(logger, pathIn, pathOut, os.O_EXCL|os.O_CREATE|os.O_WRONLY, inFileInfo.Mode()) +} + +// CopyFileRollback copies the file to the file from pathIn to pathOut, preserving the input file's mode if possible +// Used to perform a rollback +func CopyFileRollback(logger *zap.Logger, pathIn, pathOut string) error { + // Default to 0600 if we can't determine the input file's mode + fileMode := fs.FileMode(0600) + pathInClean := filepath.Clean(pathIn) + // Use the backup file's permissions as a backup and don't fail on error (best chance for rollback) + inFileInfo, err := os.Stat(pathInClean) + switch { + case errors.Is(err, os.ErrNotExist): + return fmt.Errorf("input file does not exist: %w", err) + case err != nil: + // Even though we failed to stat, we'll continue in this case to give the best chance + // of rolling back successfully. + logger.Error("failed to retrieve fileinfo for input file", zap.Error(err)) + default: + fileMode = inFileInfo.Mode() + } + + pathOutClean := filepath.Clean(pathOut) + // Remove old file to prevent issues with mac + if err = os.Remove(pathOutClean); err != nil { + logger.Debug("Failed to remove output file", zap.Error(err)) + } + + return copyFileInternal(logger, pathIn, pathOut, os.O_CREATE|os.O_WRONLY, fileMode) +} + +// copyFileInternal copies the file at pathIn to pathOut, using the provided flags and file mode +func copyFileInternal(logger *zap.Logger, pathIn, pathOut string, outFlags int, outMode fs.FileMode) error { + pathInClean := filepath.Clean(pathIn) + + // Open the input file for reading. + inFile, err := os.Open(pathInClean) + if err != nil { + return fmt.Errorf("failed to open input file: %w", err) + } + defer func() { + err := inFile.Close() + if err != nil { + logger.Error("Failed to close input file", zap.Error(err)) + } + }() + + pathOutClean := filepath.Clean(pathOut) + // Open the output file, creating it if it does not exist and truncating it. + //#nosec G304 -- out file is cleaned; this is a general purpose copy function + outFile, err := os.OpenFile(pathOutClean, outFlags, outMode) + if err != nil { + return fmt.Errorf("failed to open output file: %w", err) + } + defer func() { + err := outFile.Close() + if err != nil { + logger.Error("Failed to close output file", zap.Error(err)) + } + }() + + // Copy the input file to the output file. + if _, err := io.Copy(outFile, inFile); err != nil { + return fmt.Errorf("failed to copy file: %w", err) + } + return nil +} diff --git a/updater/internal/file/file_test.go b/updater/internal/file/file_test.go new file mode 100644 index 000000000..33957e79a --- /dev/null +++ b/updater/internal/file/file_test.go @@ -0,0 +1,221 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package file + +import ( + "io/fs" + "os" + "path/filepath" + "runtime" + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" +) + +func TestCopyFileOverwrite(t *testing.T) { + t.Run("Copies file when output does not exist", func(t *testing.T) { + tmpDir := t.TempDir() + + inFile := filepath.Join("testdata", "test.txt") + outFile := filepath.Join(tmpDir, "test.txt") + + err := CopyFileOverwrite(zaptest.NewLogger(t), inFile, outFile) + require.NoError(t, err) + require.FileExists(t, outFile) + + contentsIn, err := os.ReadFile(inFile) + require.NoError(t, err) + + contentsOut, err := os.ReadFile(outFile) + require.NoError(t, err) + + require.Equal(t, contentsIn, contentsOut) + + fi, err := os.Stat(outFile) + require.NoError(t, err) + // file mode on windows acts unlike unix, we'll only check for this on linux/darwin + if runtime.GOOS != "windows" { + require.Equal(t, fs.FileMode(0600), fi.Mode()) + } + }) + + t.Run("Copies file when output already exists", func(t *testing.T) { + tmpDir := t.TempDir() + + inFile := filepath.Join("testdata", "test.txt") + outFile := filepath.Join(tmpDir, "test.txt") + + contentsIn, err := os.ReadFile(inFile) + require.NoError(t, err) + + err = os.WriteFile(outFile, []byte("This is a file that already exists"), 0640) + require.NoError(t, err) + + fioOrig, err := os.Stat(outFile) + require.NoError(t, err) + + err = CopyFileOverwrite(zaptest.NewLogger(t), inFile, outFile) + require.NoError(t, err) + require.FileExists(t, outFile) + + contentsOut, err := os.ReadFile(outFile) + require.NoError(t, err) + require.Equal(t, contentsIn, contentsOut) + + fio, err := os.Stat(outFile) + require.NoError(t, err) + // file mode on windows acts unlike unix, we'll only check for this on linux/darwin + if runtime.GOOS != "windows" { + require.Equal(t, fioOrig.Mode(), fio.Mode()) + } + }) + + t.Run("Fails when input file does not exist", func(t *testing.T) { + tmpDir := t.TempDir() + + inFile := filepath.Join("testdata", "does-not-exist.txt") + outFile := filepath.Join(tmpDir, "test.txt") + + err := CopyFileOverwrite(zaptest.NewLogger(t), inFile, outFile) + require.ErrorContains(t, err, "failed to stat input file") + require.NoFileExists(t, outFile) + }) + + t.Run("Does not truncate if input file does not exist", func(t *testing.T) { + tmpDir := t.TempDir() + + inFile := filepath.Join("testdata", "does-not-exist.txt") + outFile := filepath.Join(tmpDir, "test.txt") + + err := os.WriteFile(outFile, []byte("This is a file that already exists"), 0600) + require.NoError(t, err) + + err = CopyFileOverwrite(zaptest.NewLogger(t), inFile, outFile) + require.ErrorContains(t, err, "failed to stat input file") + require.FileExists(t, outFile) + + contentsOut, err := os.ReadFile(outFile) + require.NoError(t, err) + require.Equal(t, []byte("This is a file that already exists"), contentsOut) + }) +} + +func TestCopyFileRollback(t *testing.T) { + t.Run("Copies file when output does not exist", func(t *testing.T) { + tmpDir := t.TempDir() + + inFile := filepath.Join("testdata", "test.txt") + outFile := filepath.Join(tmpDir, "test.txt") + + err := CopyFileNoOverwrite(zaptest.NewLogger(t), inFile, outFile) + require.NoError(t, err) + require.FileExists(t, outFile) + + contentsIn, err := os.ReadFile(inFile) + require.NoError(t, err) + + contentsOut, err := os.ReadFile(outFile) + require.NoError(t, err) + + require.Equal(t, contentsIn, contentsOut) + + fio, err := os.Stat(outFile) + require.NoError(t, err) + fii, err := os.Stat(outFile) + require.NoError(t, err) + // file mode on windows acts unlike unix, we'll only check for this on linux/darwin + if runtime.GOOS != "windows" { + require.Equal(t, fii.Mode(), fio.Mode()) + } + }) + + t.Run("Fails to overwrite the output file", func(t *testing.T) { + tmpDir := t.TempDir() + + inFile := filepath.Join("testdata", "test.txt") + outFile := filepath.Join(tmpDir, "test.txt") + + err := os.WriteFile(outFile, []byte("This is a file that already exists"), 0640) + require.NoError(t, err) + + err = CopyFileNoOverwrite(zaptest.NewLogger(t), inFile, outFile) + require.ErrorContains(t, err, "failed to open output file") + require.FileExists(t, outFile) + + contentsOut, err := os.ReadFile(outFile) + require.NoError(t, err) + require.Equal(t, []byte("This is a file that already exists"), contentsOut) + + fi, err := os.Stat(outFile) + require.NoError(t, err) + // file mode on windows acts unlike unix, we'll only check for this on linux/darwin + if runtime.GOOS != "windows" { + require.Equal(t, fs.FileMode(0640), fi.Mode()) + } + }) + + t.Run("Fails when input file does not exist", func(t *testing.T) { + tmpDir := t.TempDir() + + inFile := filepath.Join("testdata", "does-not-exist.txt") + outFile := filepath.Join(tmpDir, "test.txt") + + err := CopyFileNoOverwrite(zaptest.NewLogger(t), inFile, outFile) + require.ErrorContains(t, err, "failed to retrieve fileinfo for input file") + require.NoFileExists(t, outFile) + }) +} + +func TestCopyFileNoOverwrite(t *testing.T) { + t.Run("Copies file when output does not exist and uses inFile's permissions", func(t *testing.T) { + tmpDir := t.TempDir() + + inFile := filepath.Join("testdata", "test.txt") + outFile := filepath.Join(tmpDir, "test.txt") + + err := CopyFileRollback(zaptest.NewLogger(t), inFile, outFile) + require.NoError(t, err) + require.FileExists(t, outFile) + + contentsIn, err := os.ReadFile(inFile) + require.NoError(t, err) + + contentsOut, err := os.ReadFile(outFile) + require.NoError(t, err) + + require.Equal(t, contentsIn, contentsOut) + + fio, err := os.Stat(outFile) + require.NoError(t, err) + fii, err := os.Stat(outFile) + require.NoError(t, err) + // file mode on windows acts unlike unix, we'll only check for this on linux/darwin + if runtime.GOOS != "windows" { + require.Equal(t, fii.Mode(), fio.Mode()) + } + }) + + t.Run("Fails when input file does not exist", func(t *testing.T) { + tmpDir := t.TempDir() + + inFile := filepath.Join("testdata", "does-not-exist.txt") + outFile := filepath.Join(tmpDir, "test.txt") + + err := CopyFileRollback(zaptest.NewLogger(t), inFile, outFile) + require.ErrorContains(t, err, "input file does not exist") + require.NoFileExists(t, outFile) + }) +} diff --git a/updater/internal/file/testdata/test.txt b/updater/internal/file/testdata/test.txt new file mode 100644 index 000000000..9f4b6d8bf --- /dev/null +++ b/updater/internal/file/testdata/test.txt @@ -0,0 +1 @@ +This is a test file diff --git a/updater/internal/install/install.go b/updater/internal/install/install.go new file mode 100644 index 000000000..c016a2a55 --- /dev/null +++ b/updater/internal/install/install.go @@ -0,0 +1,223 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package install + +import ( + "errors" + "fmt" + "io/fs" + "os" + "path/filepath" + + "github.com/observiq/observiq-otel-collector/updater/internal/action" + "github.com/observiq/observiq-otel-collector/updater/internal/file" + "github.com/observiq/observiq-otel-collector/updater/internal/path" + "github.com/observiq/observiq-otel-collector/updater/internal/rollback" + "github.com/observiq/observiq-otel-collector/updater/internal/service" + "go.uber.org/zap" +) + +//Installer is an interface that performs an Install of a new collector. +//go:generate mockery --name Installer --filename installer.go +type Installer interface { + // Install installs new artifacts over the old ones. + Install(rollback.Rollbacker) error +} + +// archiveInstaller allows you to install files from latestDir into installDir, +// as well as update the service configuration using the "Install" method. +type archiveInstaller struct { + latestDir string + installDir string + backupDir string + svc service.Service + logger *zap.Logger +} + +// NewInstaller returns a new instance of an Installer. +func NewInstaller(logger *zap.Logger, installDir string, service service.Service) Installer { + return &archiveInstaller{ + latestDir: path.LatestDir(installDir), + svc: service, + installDir: installDir, + backupDir: path.BackupDir(installDir), + logger: logger.Named("installer"), + } +} + +// Install installs the unpacked artifacts in latestDir to installDir, +// as well as installing the new service file using the installer's Service interface. +// It then starts the service. +func (i archiveInstaller) Install(rb rollback.Rollbacker) error { + // If JMX jar exists outside of install directory, make sure that gets backed up + if err := i.attemptSpecialJMXJarInstall(rb); err != nil { + return fmt.Errorf("failed to process special JMX jar: %w", err) + } + + // install files that go to installDirPath to their correct location, + // excluding any config files (logging.yaml, config.yaml, manager.yaml) + if err := installFiles(i.logger, i.latestDir, i.installDir, i.backupDir, rb); err != nil { + return fmt.Errorf("failed to install new files: %w", err) + } + i.logger.Debug("Install artifacts copied") + + // Update old service config to new service config + if err := i.svc.Update(); err != nil { + return fmt.Errorf("failed to update service: %w", err) + } + rb.AppendAction(action.NewServiceUpdateAction(i.logger, i.installDir)) + i.logger.Debug("Updated service configuration") + + // Start service + if err := i.svc.Start(); err != nil { + return fmt.Errorf("failed to start service: %w", err) + } + rb.AppendAction(action.NewServiceStartAction(i.svc)) + i.logger.Debug("Service started") + + return nil +} + +// installFiles moves the file tree rooted at inputPath to installDir, +// skipping configuration files. Appends CopyFileAction-s to the Rollbacker as it copies file. +func installFiles(logger *zap.Logger, inputPath, installDir, backupDir string, rb rollback.Rollbacker) error { + err := filepath.WalkDir(inputPath, func(inPath string, d fs.DirEntry, err error) error { + switch { + case err != nil: + // if there was an error walking the directory, we want to bail out. + return err + case d.IsDir(): + // Skip directories, we'll create them when we get a file in the directory. + return nil + case skipConfigFiles(inPath): + // Found a config file that we should skip copying. + return nil + } + + // We want the path relative to the directory we are walking in order to calculate where the file should be + // mirrored in the destination directory. + relPath, err := filepath.Rel(inputPath, inPath) + if err != nil { + return err + } + + // use the relative path to get the outPath (where we should write the file), and + // to get the out directory (which we will create if it does not exist). + outPath := filepath.Join(installDir, relPath) + outDir := filepath.Dir(outPath) + + if err := os.MkdirAll(outDir, 0750); err != nil { + return fmt.Errorf("failed to create dir: %w", err) + } + + // We create the action record here, because we want to record whether the file exists or not before + // we open the file (which will end up creating the file). + cfa, err := action.NewCopyFileAction(logger, relPath, outPath, backupDir) + if err != nil { + return fmt.Errorf("failed to create copy file action: %w", err) + } + + // Record that we are performing copying the file. + // We record before we actually do the action here because the file may be partially written, + // and we will want to roll that back if that is the case. + rb.AppendAction(cfa) + + if err := file.CopyFileOverwrite(logger.Named("copy-file"), inPath, outPath); err != nil { + return fmt.Errorf("failed to copy file: %w", err) + } + + return nil + }) + + if err != nil { + return fmt.Errorf("failed to walk latest dir: %w", err) + } + + return nil +} + +func (i archiveInstaller) attemptSpecialJMXJarInstall(rb rollback.Rollbacker) error { + jarPath := path.SpecialJMXJarFile(i.installDir) + jarDirPath := path.SpecialJarDir(i.installDir) + latestJarPath := path.LatestJMXJarFile(i.latestDir) + _, err := os.Stat(jarPath) + switch { + case err == nil: + if err := installFile(i.logger, latestJarPath, jarDirPath, i.backupDir, rb); err != nil { + return fmt.Errorf("failed to install JMX jar from latest directory: %w", err) + } + // Just log this error as the worst case is that there will be two jars copied over + if err = os.Remove(latestJarPath); err != nil { + i.logger.Warn("Failed to remove JMX jar from latest directory", zap.Error(err)) + } + case !errors.Is(err, os.ErrNotExist): + return fmt.Errorf("failed determine where currently installed JMX jar is: %w", err) + } + + return nil +} + +// installFile moves new file to output path. +// Appends CopyFileAction-s to the Rollbacker as it copies file. +func installFile(logger *zap.Logger, inPath, installDirPath, backupDirPath string, rb rollback.Rollbacker) error { + baseInPath := filepath.Base(inPath) + + // use the relative path to get the outPath (where we should write the file), and + // to get the out directory (which we will create if it does not exist). + outPath := filepath.Join(installDirPath, baseInPath) + outDir := filepath.Dir(outPath) + + if err := os.MkdirAll(outDir, 0750); err != nil { + return fmt.Errorf("failed to create dir: %w", err) + } + + // We create the action record here, because we want to record whether the file exists or not before + // we open the file (which will end up creating the file). + cfa, err := action.NewCopyFileAction(logger, baseInPath, outPath, backupDirPath) + if err != nil { + return fmt.Errorf("failed to create copy file action: %w", err) + } + + // Record that we are performing copying the file. + // We record before we actually do the action here because the file may be partially written, + // and we will want to roll that back if that is the case. + rb.AppendAction(cfa) + + if err := file.CopyFileOverwrite(logger.Named("copy-file"), inPath, outPath); err != nil { + return fmt.Errorf("failed to copy file: %w", err) + } + + return nil +} + +// skipConfigFiles returns true if the given path is a special config file. +// These files should not be overwritten. +func skipConfigFiles(path string) bool { + var configFiles = []string{ + "config.yaml", + "logging.yaml", + "manager.yaml", + } + + fileName := filepath.Base(path) + + for _, f := range configFiles { + if fileName == f { + return true + } + } + + return false +} diff --git a/updater/internal/install/install_test.go b/updater/internal/install/install_test.go new file mode 100644 index 000000000..30cada4dd --- /dev/null +++ b/updater/internal/install/install_test.go @@ -0,0 +1,376 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package install + +import ( + "bytes" + "errors" + "os" + "path/filepath" + "runtime" + "testing" + + "github.com/observiq/observiq-otel-collector/updater/internal/action" + rb_mocks "github.com/observiq/observiq-otel-collector/updater/internal/rollback/mocks" + "github.com/observiq/observiq-otel-collector/updater/internal/service/mocks" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" +) + +func TestInstallArtifacts(t *testing.T) { + t.Run("Installs artifacts correctly", func(t *testing.T) { + outDir := t.TempDir() + svc := mocks.NewService(t) + rb := rb_mocks.NewRollbacker(t) + + installer := &archiveInstaller{ + latestDir: filepath.Join("testdata", "example-install"), + installDir: outDir, + backupDir: filepath.Join("testdata", "rollback"), + svc: svc, + logger: zaptest.NewLogger(t), + } + + latestJarPath := filepath.Join(installer.latestDir, "opentelemetry-java-contrib-jmx-metrics.jar") + _, err := os.Create(latestJarPath) + require.NoError(t, err) + err = os.WriteFile(latestJarPath, []byte("# The new jar file"), 0660) + require.NoError(t, err) + + outDirConfig := filepath.Join(outDir, "config.yaml") + outDirLogging := filepath.Join(outDir, "logging.yaml") + outDirManager := filepath.Join(outDir, "manager.yaml") + + err = os.WriteFile(outDirConfig, []byte("# The original config file"), 0600) + require.NoError(t, err) + err = os.WriteFile(outDirLogging, []byte("# The original logging file"), 0600) + require.NoError(t, err) + err = os.WriteFile(outDirManager, []byte("# The original manager file"), 0600) + require.NoError(t, err) + + svc.On("Update").Once().Return(nil) + svc.On("Start").Once().Return(nil) + + actions := []action.RollbackableAction{} + rb.On("AppendAction", mock.Anything).Run(func(args mock.Arguments) { + action := args.Get(0).(action.RollbackableAction) + actions = append(actions, action) + }) + + err = installer.Install(rb) + require.NoError(t, err) + + contentsEqual(t, outDirConfig, "# The original config file") + contentsEqual(t, outDirManager, "# The original manager file") + contentsEqual(t, outDirLogging, "# The original logging file") + + require.FileExists(t, filepath.Join(outDir, "opentelemetry-java-contrib-jmx-metrics.jar")) + require.FileExists(t, filepath.Join(outDir, "test.txt")) + require.DirExists(t, filepath.Join(outDir, "test-folder")) + require.FileExists(t, filepath.Join(outDir, "test-folder", "another-test.txt")) + + contentsEqual(t, filepath.Join(outDir, "opentelemetry-java-contrib-jmx-metrics.jar"), "# The new jar file") + contentsEqual(t, filepath.Join(outDir, "test.txt"), "This is a test file\n") + contentsEqual(t, filepath.Join(outDir, "test-folder", "another-test.txt"), "This is a nested text file\n") + + copyTestTxtAction, err := action.NewCopyFileAction( + installer.logger, + filepath.Join("test.txt"), + filepath.Join(installer.installDir, "test.txt"), + installer.backupDir, + ) + require.NoError(t, err) + copyTestTxtAction.FileCreated = true + + copyJarAction, err := action.NewCopyFileAction( + installer.logger, + filepath.Join("opentelemetry-java-contrib-jmx-metrics.jar"), + filepath.Join(installer.installDir, "opentelemetry-java-contrib-jmx-metrics.jar"), + installer.backupDir, + ) + require.NoError(t, err) + copyJarAction.FileCreated = true + + copyNestedTestTxtAction, err := action.NewCopyFileAction( + installer.logger, + filepath.Join("test-folder", "another-test.txt"), + filepath.Join(installer.installDir, "test-folder", "another-test.txt"), + installer.backupDir, + ) + require.NoError(t, err) + copyNestedTestTxtAction.FileCreated = true + + require.Equal(t, len(actions), 5) + require.Contains(t, actions, copyJarAction) + require.Contains(t, actions, copyNestedTestTxtAction) + require.Contains(t, actions, copyTestTxtAction) + require.Contains(t, actions, action.NewServiceUpdateAction(installer.logger, installer.installDir)) + require.Contains(t, actions, action.NewServiceStartAction(svc)) + }) + + if runtime.GOOS != "windows" { + t.Run("Installs artifacts correctly when linux jmx jar", func(t *testing.T) { + jarDir := t.TempDir() + specialJarPath := filepath.Join(jarDir, "opentelemetry-java-contrib-jmx-metrics.jar") + _, err := os.Create(specialJarPath) + require.NoError(t, err) + err = os.WriteFile(specialJarPath, []byte("# The original jar file"), 0600) + require.NoError(t, err) + outDir := filepath.Join(jarDir, "installdir") + os.MkdirAll(outDir, 0700) + + svc := mocks.NewService(t) + rb := rb_mocks.NewRollbacker(t) + + installer := &archiveInstaller{ + latestDir: filepath.Join("testdata", "example-install"), + installDir: outDir, + backupDir: filepath.Join("testdata", "rollback"), + svc: svc, + logger: zaptest.NewLogger(t), + } + + latestJarPath := filepath.Join(installer.latestDir, "opentelemetry-java-contrib-jmx-metrics.jar") + _, err = os.Create(latestJarPath) + require.NoError(t, err) + err = os.WriteFile(latestJarPath, []byte("# The new jar file"), 0660) + require.NoError(t, err) + + outDirConfig := filepath.Join(outDir, "config.yaml") + outDirLogging := filepath.Join(outDir, "logging.yaml") + outDirManager := filepath.Join(outDir, "manager.yaml") + + err = os.WriteFile(outDirConfig, []byte("# The original config file"), 0600) + require.NoError(t, err) + err = os.WriteFile(outDirLogging, []byte("# The original logging file"), 0600) + require.NoError(t, err) + err = os.WriteFile(outDirManager, []byte("# The original manager file"), 0600) + require.NoError(t, err) + + svc.On("Update").Once().Return(nil) + svc.On("Start").Once().Return(nil) + + actions := []action.RollbackableAction{} + rb.On("AppendAction", mock.Anything).Run(func(args mock.Arguments) { + action := args.Get(0).(action.RollbackableAction) + actions = append(actions, action) + }) + + err = installer.Install(rb) + require.NoError(t, err) + + contentsEqual(t, outDirConfig, "# The original config file") + contentsEqual(t, outDirManager, "# The original manager file") + contentsEqual(t, outDirLogging, "# The original logging file") + + require.FileExists(t, filepath.Join(jarDir, "opentelemetry-java-contrib-jmx-metrics.jar")) + require.FileExists(t, filepath.Join(outDir, "test.txt")) + require.DirExists(t, filepath.Join(outDir, "test-folder")) + require.FileExists(t, filepath.Join(outDir, "test-folder", "another-test.txt")) + + contentsEqual(t, filepath.Join(jarDir, "opentelemetry-java-contrib-jmx-metrics.jar"), "# The new jar file") + contentsEqual(t, filepath.Join(outDir, "test.txt"), "This is a test file\n") + contentsEqual(t, filepath.Join(outDir, "test-folder", "another-test.txt"), "This is a nested text file\n") + + copyTestTxtAction, err := action.NewCopyFileAction( + installer.logger, + filepath.Join("test.txt"), + filepath.Join(installer.installDir, "test.txt"), + installer.backupDir, + ) + require.NoError(t, err) + copyTestTxtAction.FileCreated = true + + copyJarAction, err := action.NewCopyFileAction( + installer.logger, + filepath.Join("opentelemetry-java-contrib-jmx-metrics.jar"), + filepath.Join(jarDir, "opentelemetry-java-contrib-jmx-metrics.jar"), + installer.backupDir, + ) + require.NoError(t, err) + copyJarAction.FileCreated = false + + copyNestedTestTxtAction, err := action.NewCopyFileAction( + installer.logger, + filepath.Join("test-folder", "another-test.txt"), + filepath.Join(installer.installDir, "test-folder", "another-test.txt"), + installer.backupDir, + ) + require.NoError(t, err) + copyNestedTestTxtAction.FileCreated = true + + require.Equal(t, len(actions), 5) + require.Contains(t, actions, copyJarAction) + require.Contains(t, actions, copyNestedTestTxtAction) + require.Contains(t, actions, copyTestTxtAction) + require.Contains(t, actions, action.NewServiceUpdateAction(installer.logger, installer.installDir)) + require.Contains(t, actions, action.NewServiceStartAction(svc)) + }) + } else { + t.Skip() + } + + t.Run("Update fails", func(t *testing.T) { + outDir := t.TempDir() + svc := mocks.NewService(t) + rb := rb_mocks.NewRollbacker(t) + installer := &archiveInstaller{ + latestDir: filepath.Join("testdata", "example-install"), + installDir: outDir, + backupDir: filepath.Join("testdata", "rollback"), + svc: svc, + logger: zaptest.NewLogger(t), + } + + latestJarPath := filepath.Join(installer.latestDir, "opentelemetry-java-contrib-jmx-metrics.jar") + _, err := os.Create(latestJarPath) + require.NoError(t, err) + err = os.WriteFile(latestJarPath, []byte("# The new jar file"), 0660) + require.NoError(t, err) + + svc.On("Update").Once().Return(errors.New("uninstall failed")) + + actions := []action.RollbackableAction{} + rb.On("AppendAction", mock.Anything).Run(func(args mock.Arguments) { + action := args.Get(0).(action.RollbackableAction) + actions = append(actions, action) + }) + + err = installer.Install(rb) + require.ErrorContains(t, err, "failed to update service") + copyTestTxtAction, err := action.NewCopyFileAction( + installer.logger, + filepath.Join("test.txt"), + filepath.Join(installer.installDir, "test.txt"), + installer.backupDir, + ) + require.NoError(t, err) + copyTestTxtAction.FileCreated = true + + copyNestedTestTxtAction, err := action.NewCopyFileAction( + installer.logger, + filepath.Join("test-folder", "another-test.txt"), + filepath.Join(installer.installDir, "test-folder", "another-test.txt"), + installer.backupDir, + ) + require.NoError(t, err) + copyNestedTestTxtAction.FileCreated = true + + copyJarAction, err := action.NewCopyFileAction( + installer.logger, + filepath.Join("opentelemetry-java-contrib-jmx-metrics.jar"), + filepath.Join(installer.installDir, "opentelemetry-java-contrib-jmx-metrics.jar"), + installer.backupDir, + ) + require.NoError(t, err) + copyJarAction.FileCreated = true + + require.Equal(t, len(actions), 3) + require.Contains(t, actions, copyJarAction) + require.Contains(t, actions, copyNestedTestTxtAction) + require.Contains(t, actions, copyTestTxtAction) + }) + + t.Run("Start fails", func(t *testing.T) { + outDir := t.TempDir() + svc := mocks.NewService(t) + rb := rb_mocks.NewRollbacker(t) + installer := &archiveInstaller{ + latestDir: filepath.Join("testdata", "example-install"), + installDir: outDir, + backupDir: filepath.Join("testdata", "rollback"), + svc: svc, + logger: zaptest.NewLogger(t), + } + + latestJarPath := filepath.Join(installer.latestDir, "opentelemetry-java-contrib-jmx-metrics.jar") + _, err := os.Create(latestJarPath) + require.NoError(t, err) + err = os.WriteFile(latestJarPath, []byte("# The new jar file"), 0660) + require.NoError(t, err) + + svc.On("Update").Once().Return(nil) + svc.On("Start").Once().Return(errors.New("start failed")) + + actions := []action.RollbackableAction{} + rb.On("AppendAction", mock.Anything).Run(func(args mock.Arguments) { + action := args.Get(0).(action.RollbackableAction) + actions = append(actions, action) + }) + + err = installer.Install(rb) + require.ErrorContains(t, err, "failed to start service") + + copyTestTxtAction, err := action.NewCopyFileAction( + installer.logger, + filepath.Join("test.txt"), + filepath.Join(installer.installDir, "test.txt"), + installer.backupDir, + ) + require.NoError(t, err) + copyTestTxtAction.FileCreated = true + + copyNestedTestTxtAction, err := action.NewCopyFileAction( + installer.logger, + filepath.Join("test-folder", "another-test.txt"), + filepath.Join(installer.installDir, "test-folder", "another-test.txt"), + installer.backupDir, + ) + require.NoError(t, err) + copyNestedTestTxtAction.FileCreated = true + + copyJarAction, err := action.NewCopyFileAction( + installer.logger, + filepath.Join("opentelemetry-java-contrib-jmx-metrics.jar"), + filepath.Join(installer.installDir, "opentelemetry-java-contrib-jmx-metrics.jar"), + installer.backupDir, + ) + require.NoError(t, err) + copyJarAction.FileCreated = true + + require.Equal(t, len(actions), 4) + require.Contains(t, actions, copyJarAction) + require.Contains(t, actions, copyNestedTestTxtAction) + require.Contains(t, actions, copyTestTxtAction) + require.Contains(t, actions, action.NewServiceUpdateAction(installer.logger, installer.installDir)) + }) + + t.Run("Latest dir does not exist", func(t *testing.T) { + outDir := t.TempDir() + svc := mocks.NewService(t) + rb := rb_mocks.NewRollbacker(t) + installer := &archiveInstaller{ + latestDir: filepath.Join("testdata", "non-existent-dir"), + installDir: outDir, + svc: svc, + logger: zaptest.NewLogger(t), + } + + err := installer.Install(rb) + require.ErrorContains(t, err, "failed to install new files") + }) +} + +func contentsEqual(t *testing.T, path, expectedContents string) { + t.Helper() + + contents, err := os.ReadFile(path) + require.NoError(t, err) + + // Replace \r\n with \n to normalize for windows tests. + contents = bytes.ReplaceAll(contents, []byte("\r\n"), []byte("\n")) + require.Equal(t, []byte(expectedContents), contents) +} diff --git a/updater/internal/install/mocks/installer.go b/updater/internal/install/mocks/installer.go new file mode 100644 index 000000000..398826afe --- /dev/null +++ b/updater/internal/install/mocks/installer.go @@ -0,0 +1,42 @@ +// Code generated by mockery v2.14.0. DO NOT EDIT. + +package mocks + +import ( + rollback "github.com/observiq/observiq-otel-collector/updater/internal/rollback" + mock "github.com/stretchr/testify/mock" +) + +// Installer is an autogenerated mock type for the Installer type +type Installer struct { + mock.Mock +} + +// Install provides a mock function with given fields: _a0 +func (_m *Installer) Install(_a0 rollback.Rollbacker) error { + ret := _m.Called(_a0) + + var r0 error + if rf, ok := ret.Get(0).(func(rollback.Rollbacker) error); ok { + r0 = rf(_a0) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type mockConstructorTestingTNewInstaller interface { + mock.TestingT + Cleanup(func()) +} + +// NewInstaller creates a new instance of Installer. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewInstaller(t mockConstructorTestingTNewInstaller) *Installer { + mock := &Installer{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/updater/internal/install/testdata/example-install/config.yaml b/updater/internal/install/testdata/example-install/config.yaml new file mode 100644 index 000000000..ffbf81d31 --- /dev/null +++ b/updater/internal/install/testdata/example-install/config.yaml @@ -0,0 +1 @@ +# This is a placeholder config file diff --git a/updater/internal/install/testdata/example-install/logging.yaml b/updater/internal/install/testdata/example-install/logging.yaml new file mode 100644 index 000000000..cf76a2844 --- /dev/null +++ b/updater/internal/install/testdata/example-install/logging.yaml @@ -0,0 +1 @@ +# This is a placeholder logging.yaml diff --git a/updater/internal/install/testdata/example-install/manager.yaml b/updater/internal/install/testdata/example-install/manager.yaml new file mode 100644 index 000000000..cef5a425c --- /dev/null +++ b/updater/internal/install/testdata/example-install/manager.yaml @@ -0,0 +1 @@ +# manager.yaml should not exist in the archive, but we check for it anyways, just in case. diff --git a/updater/internal/install/testdata/example-install/test-folder/another-test.txt b/updater/internal/install/testdata/example-install/test-folder/another-test.txt new file mode 100644 index 000000000..45b861001 --- /dev/null +++ b/updater/internal/install/testdata/example-install/test-folder/another-test.txt @@ -0,0 +1 @@ +This is a nested text file diff --git a/updater/internal/install/testdata/example-install/test.txt b/updater/internal/install/testdata/example-install/test.txt new file mode 100644 index 000000000..9f4b6d8bf --- /dev/null +++ b/updater/internal/install/testdata/example-install/test.txt @@ -0,0 +1 @@ +This is a test file diff --git a/updater/internal/logging/logging_others.go b/updater/internal/logging/logging_others.go new file mode 100644 index 000000000..bf2be17b9 --- /dev/null +++ b/updater/internal/logging/logging_others.go @@ -0,0 +1,50 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build !windows + +package logging + +import ( + "fmt" + "os" + + "github.com/observiq/observiq-otel-collector/updater/internal/path" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +// NewLogger returns a new logger, that logs to the log directory relative to installDir. +// It deletes the previous log file, as well. +func NewLogger(installDir string) (*zap.Logger, error) { + logFile := path.LogFile(installDir) + + conf := zap.NewProductionConfig() + conf.OutputPaths = []string{ + logFile, + } + conf.Level.SetLevel(zapcore.DebugLevel) + + err := os.RemoveAll(logFile) + if err != nil { + return nil, fmt.Errorf("failed to remove previous log file: %w", err) + } + + prodLogger, err := conf.Build() + if err != nil { + return nil, fmt.Errorf("failed to build logger: %w", err) + } + + return prodLogger, nil +} diff --git a/updater/internal/logging/logging_test.go b/updater/internal/logging/logging_test.go new file mode 100644 index 000000000..119488f8e --- /dev/null +++ b/updater/internal/logging/logging_test.go @@ -0,0 +1,76 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logging + +import ( + "bytes" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewLogger(t *testing.T) { + t.Run("Existing file is removed", func(t *testing.T) { + // We don't use t.TempDir here, because we can't clean up the out directory on windows. + // We also don't clean up the out directory; It's in the temporary directory and may be cleaned up manually at any time. + tmpDir, err := os.MkdirTemp("", "test-logger-existing-file") + require.NoError(t, err) + // Remove previous log directory if it exists + require.NoError(t, os.RemoveAll(filepath.Join(tmpDir, "log"))) + require.NoError(t, os.MkdirAll(filepath.Join(tmpDir, "log"), 0775)) + + updaterLogPath, err := filepath.Abs(filepath.Join(tmpDir, "log", "updater.log")) + require.NoError(t, err) + + initialBytes := []byte("Some existing bytes") + require.NoError(t, os.WriteFile(updaterLogPath, initialBytes, 0660)) + + logger, err := NewLogger(tmpDir) + require.NoError(t, err) + + currentBytes, err := os.ReadFile(updaterLogPath) + require.NoError(t, err) + + if bytes.HasPrefix(currentBytes, initialBytes) { + t.Fatalf("The log file was not deleted (current bytes: '%s')", currentBytes) + } + + logger.Info("This is a log message") + require.NoError(t, logger.Sync()) + + require.FileExists(t, updaterLogPath) + }) + + t.Run("Logger creates file if existing file does not exist", func(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "test-logger-no-existing-file") + require.NoError(t, err) + // Remove previous log directory if it exists + require.NoError(t, os.RemoveAll(filepath.Join(tmpDir, "log"))) + require.NoError(t, os.MkdirAll(filepath.Join(tmpDir, "log"), 0775)) + + updaterLogPath, err := filepath.Abs(filepath.Join(tmpDir, "log", "updater.log")) + require.NoError(t, err) + + logger, err := NewLogger(tmpDir) + require.NoError(t, err) + + logger.Info("This is a log message") + require.NoError(t, logger.Sync()) + + require.FileExists(t, updaterLogPath) + }) +} diff --git a/updater/internal/logging/logging_windows.go b/updater/internal/logging/logging_windows.go new file mode 100644 index 000000000..a1f3b106b --- /dev/null +++ b/updater/internal/logging/logging_windows.go @@ -0,0 +1,68 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package logging + +import ( + "fmt" + "net/url" + "os" + "sync" + + "github.com/observiq/observiq-otel-collector/updater/internal/path" + "go.uber.org/zap" +) + +var registerSinkOnce = &sync.Once{} + +// NewLogger returns a new logger, that logs to the log directory relative to installDir. +// It deletes the previous log file, as well. +// NewLogger must only be called once, at the start of the program. +func NewLogger(installDir string) (*zap.Logger, error) { + // On windows, absolute paths do not work for zap's default sink, so we must register it. + // see: https://github.com/uber-go/zap/issues/621 + var err error + registerSinkOnce.Do(func() { + err = zap.RegisterSink("winfile", newWinFileSink) + }) + if err != nil { + return nil, fmt.Errorf("failed to registed windows file sink: %w", err) + } + + logFile := path.LogFile(installDir) + + err = os.RemoveAll(logFile) + if err != nil { + return nil, fmt.Errorf("failed to remove previous log file: %w", err) + } + + conf := zap.NewProductionConfig() + conf.OutputPaths = []string{ + "winfile:///" + logFile, + } + + prodLogger, err := conf.Build() + if err != nil { + return nil, fmt.Errorf("failed to build logger: %w", err) + } + + return prodLogger, nil +} + +// Windows requires a special sink, so that we may properly parse the file path +// See: https://github.com/uber-go/zap/issues/621 +func newWinFileSink(u *url.URL) (zap.Sink, error) { + // Remove leading slash left by url.Parse() + return os.OpenFile(u.Path[1:], os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0600) +} diff --git a/updater/internal/path/path.go b/updater/internal/path/path.go new file mode 100644 index 000000000..bd2d91dc0 --- /dev/null +++ b/updater/internal/path/path.go @@ -0,0 +1,63 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package path + +import "path/filepath" + +// TempDir gets the path to the "tmp" dir, used for staging updates & backups +func TempDir(installDir string) string { + return filepath.Join(installDir, "tmp") +} + +// LatestDir gets the path to the "latest" dir, where the new artifacts are unpacked. +func LatestDir(installDir string) string { + return filepath.Join(TempDir(installDir), "latest") +} + +// BackupDir gets the path to the "rollback" dir, where current artifacts are backed up. +func BackupDir(installDir string) string { + return filepath.Join(TempDir(installDir), "rollback") +} + +// ServiceFileDir gets the directory of the service file definitions +func ServiceFileDir(installDir string) string { + return filepath.Join(installDir, "install") +} + +// SpecialJarDir gets the directory where linux and darwin installs put the JMX jar +// Keeping this relative for now so we don't have to deal with /opt in tests +func SpecialJarDir(installDir string) string { + return filepath.Join(installDir, "..") +} + +// BackupServiceFile returns the full path to the backup service file +func BackupServiceFile(installDir string) string { + return filepath.Join(BackupDir(installDir), "backup.service") +} + +// LogFile returns the full path to the log file for the updater +func LogFile(installDir string) string { + return filepath.Join(installDir, "log", "updater.log") +} + +// LatestJMXJarFile returns the full path to the latest JMX jar to be installed +func LatestJMXJarFile(latestDir string) string { + return filepath.Join(latestDir, "opentelemetry-java-contrib-jmx-metrics.jar") +} + +// SpecialJMXJarFile returns the full path to the JMX Jar on linux and darwin installs +func SpecialJMXJarFile(installDir string) string { + return filepath.Join(SpecialJarDir(installDir), "opentelemetry-java-contrib-jmx-metrics.jar") +} diff --git a/updater/internal/path/path_darwin.go b/updater/internal/path/path_darwin.go new file mode 100644 index 000000000..75d5631e6 --- /dev/null +++ b/updater/internal/path/path_darwin.go @@ -0,0 +1,25 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package path + +import "go.uber.org/zap" + +// DarwinInstallDir is the path to the install directory on Darwin. +const DarwinInstallDir = "/opt/observiq-otel-collector" + +// InstallDir returns the filepath to the install directory +func InstallDir(_ *zap.Logger) (string, error) { + return DarwinInstallDir, nil +} diff --git a/updater/internal/path/path_linux.go b/updater/internal/path/path_linux.go new file mode 100644 index 000000000..41377c621 --- /dev/null +++ b/updater/internal/path/path_linux.go @@ -0,0 +1,25 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package path + +import "go.uber.org/zap" + +// LinuxInstallDir is the install directory of the collector on linux. +const LinuxInstallDir = "/opt/observiq-otel-collector" + +// InstallDir returns the filepath to the install directory +func InstallDir(_ *zap.Logger) (string, error) { + return LinuxInstallDir, nil +} diff --git a/updater/internal/path/path_test.go b/updater/internal/path/path_test.go new file mode 100644 index 000000000..7992c2a64 --- /dev/null +++ b/updater/internal/path/path_test.go @@ -0,0 +1,58 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package path + +import ( + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestTempDir(t *testing.T) { + require.Equal(t, filepath.Join("install", "tmp"), TempDir("install")) +} + +func TestLatestDir(t *testing.T) { + require.Equal(t, filepath.Join("install", "tmp", "latest"), LatestDir("install")) +} + +func TestBackupDir(t *testing.T) { + require.Equal(t, filepath.Join("install", "tmp", "rollback"), BackupDir("install")) +} + +func TestServiceFileDir(t *testing.T) { + require.Equal(t, filepath.Join("install", "install"), ServiceFileDir("install")) +} + +func TestSpecialJarDir(t *testing.T) { + require.Equal(t, filepath.Join("install", ".."), SpecialJarDir("install")) +} + +func TestBackupServiceFile(t *testing.T) { + require.Equal(t, filepath.Join("install", "tmp", "rollback", "backup.service"), BackupServiceFile("install")) +} + +func TestLogFile(t *testing.T) { + require.Equal(t, filepath.Join("install", "log", "updater.log"), LogFile("install")) +} + +func TestLatestJMXJarFile(t *testing.T) { + require.Equal(t, filepath.Join("latest", "opentelemetry-java-contrib-jmx-metrics.jar"), LatestJMXJarFile("latest")) +} + +func TestSpecialJMXJarFile(t *testing.T) { + require.Equal(t, filepath.Join("install", "..", "opentelemetry-java-contrib-jmx-metrics.jar"), SpecialJMXJarFile("install")) +} diff --git a/updater/internal/path/path_windows.go b/updater/internal/path/path_windows.go new file mode 100644 index 000000000..f0787def1 --- /dev/null +++ b/updater/internal/path/path_windows.go @@ -0,0 +1,53 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package path + +import ( + "fmt" + + "go.uber.org/zap" + "golang.org/x/sys/windows/registry" +) + +const defaultProductName = "observIQ Distro for OpenTelemetry Collector" + +// installDirFromRegistry gets the installation dir of the given product from the Windows Registry +func installDirFromRegistry(logger *zap.Logger, productName string) (string, error) { + // this key is created when installing using the MSI installer + keyPath := fmt.Sprintf(`Software\Microsoft\Windows\CurrentVersion\Uninstall\%s`, productName) + key, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.READ) + if err != nil { + return "", fmt.Errorf("failed to open registry key: %w", err) + } + defer func() { + err := key.Close() + if err != nil { + logger.Error("InstallDirFromRegistry: failed to close registry key", zap.Error(err)) + } + }() + + // This value ("InstallLocation") contains the path to the install folder. + val, _, err := key.GetStringValue("InstallLocation") + if err != nil { + return "", fmt.Errorf("failed to read install dir: %w", err) + } + + return val, nil +} + +// InstallDir returns the filepath to the install directory +func InstallDir(logger *zap.Logger) (string, error) { + return installDirFromRegistry(logger, defaultProductName) +} diff --git a/updater/internal/path/path_windows_test.go b/updater/internal/path/path_windows_test.go new file mode 100644 index 000000000..330fab3dd --- /dev/null +++ b/updater/internal/path/path_windows_test.go @@ -0,0 +1,80 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build windows && integration + +package path + +import ( + "fmt" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" + "golang.org/x/sys/windows/registry" +) + +func TestInstallDirFromRegistry(t *testing.T) { + t.Run("Successfully grabs install dir from registry", func(t *testing.T) { + productName := "default-product-name" + installDir, err := filepath.Abs("C:/temp") + require.NoError(t, err) + + defer deleteInstallDirRegistryKey(t, productName) + createInstallDirRegistryKey(t, productName, installDir) + + dir, err := installDirFromRegistry(zaptest.NewLogger(t), productName) + require.NoError(t, err) + require.Equal(t, installDir+`\`, dir) + }) + + t.Run("Registry key does not exist", func(t *testing.T) { + productName := "default-product-name" + + _, err := installDirFromRegistry(zaptest.NewLogger(t), productName) + require.ErrorContains(t, err, "failed to open registry key") + }) +} + +func deleteInstallDirRegistryKey(t *testing.T, productName string) { + t.Helper() + + keyPath := fmt.Sprintf(`Software\Microsoft\Windows\CurrentVersion\Uninstall\%s`, productName) + key, err := registry.OpenKey(registry.LOCAL_MACHINE, keyPath, registry.WRITE) + if err != nil { + // Key may not exist, assume that's why we couldn't open it + return + } + defer key.Close() + + err = registry.DeleteKey(key, "") + require.NoError(t, err) +} + +func createInstallDirRegistryKey(t *testing.T, productName, installDir string) { + t.Helper() + + installDir, err := filepath.Abs(installDir) + require.NoError(t, err) + installDir += `\` + + keyPath := fmt.Sprintf(`Software\Microsoft\Windows\CurrentVersion\Uninstall\%s`, productName) + key, _, err := registry.CreateKey(registry.LOCAL_MACHINE, keyPath, registry.WRITE) + require.NoError(t, err) + defer key.Close() + + err = key.SetStringValue("InstallLocation", installDir) + require.NoError(t, err) +} diff --git a/updater/internal/rollback/mocks/rollbacker.go b/updater/internal/rollback/mocks/rollbacker.go new file mode 100644 index 000000000..dce4165b6 --- /dev/null +++ b/updater/internal/rollback/mocks/rollbacker.go @@ -0,0 +1,52 @@ +// Code generated by mockery v2.14.0. DO NOT EDIT. + +package mocks + +import ( + action "github.com/observiq/observiq-otel-collector/updater/internal/action" + mock "github.com/stretchr/testify/mock" +) + +// Rollbacker is an autogenerated mock type for the Rollbacker type +type Rollbacker struct { + mock.Mock +} + +// AppendAction provides a mock function with given fields: _a0 +func (_m *Rollbacker) AppendAction(_a0 action.RollbackableAction) { + _m.Called(_a0) +} + +// Backup provides a mock function with given fields: +func (_m *Rollbacker) Backup() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Rollback provides a mock function with given fields: +func (_m *Rollbacker) Rollback() { + _m.Called() +} + +type mockConstructorTestingTNewRollbacker interface { + mock.TestingT + Cleanup(func()) +} + +// NewRollbacker creates a new instance of Rollbacker. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewRollbacker(t mockConstructorTestingTNewRollbacker) *Rollbacker { + mock := &Rollbacker{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/updater/internal/rollback/rollback.go b/updater/internal/rollback/rollback.go new file mode 100644 index 000000000..9fe544c9f --- /dev/null +++ b/updater/internal/rollback/rollback.go @@ -0,0 +1,195 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rollback + +import ( + "errors" + "fmt" + "io/fs" + "os" + "path/filepath" + "strings" + + "github.com/observiq/observiq-otel-collector/updater/internal/action" + "github.com/observiq/observiq-otel-collector/updater/internal/file" + "github.com/observiq/observiq-otel-collector/updater/internal/path" + "github.com/observiq/observiq-otel-collector/updater/internal/service" + "go.uber.org/zap" +) + +//Rollbacker is an interface that performs rollback/backup actions. +//go:generate mockery --name Rollbacker --filename rollbacker.go +type Rollbacker interface { + // AppendAction saves the action so that it can be rolled back later. + AppendAction(action action.RollbackableAction) + // Backup backs up the current installation + Backup() error + // Rollback undoes the actions recorded by AppendAction. + Rollback() +} + +// filesystemRollbacker is a struct that records rollback information, +// and can use that information to perform a rollback using files backed up +// on the filesystem. +type filesystemRollbacker struct { + originalSvc service.Service + backupDir string + installDir string + actions []action.RollbackableAction + logger *zap.Logger +} + +// NewRollbacker returns a new Rollbacker +func NewRollbacker(logger *zap.Logger, installDir string) Rollbacker { + namedLogger := logger.Named("rollbacker") + + return &filesystemRollbacker{ + backupDir: path.BackupDir(installDir), + installDir: installDir, + logger: namedLogger, + originalSvc: service.NewService(namedLogger, installDir), + } +} + +// AppendAction records the action that was performed, so that it may be undone later. +func (r *filesystemRollbacker) AppendAction(action action.RollbackableAction) { + r.actions = append(r.actions, action) +} + +// Backup backs up the installDir to the rollbackDir +func (r filesystemRollbacker) Backup() error { + r.logger.Debug("Backing up current installation") + // Remove any pre-existing backup + if err := os.RemoveAll(r.backupDir); err != nil { + return fmt.Errorf("failed to remove previous backup: %w", err) + } + + // Copy all the files in the install directory to the backup directory + if err := backupFiles(r.logger, r.installDir, r.backupDir); err != nil { + return fmt.Errorf("failed to copy files to backup dir: %w", err) + } + + // If JMX jar exists outside of install directory, make sure that gets backed up + jarPath := path.SpecialJMXJarFile(r.installDir) + _, err := os.Stat(jarPath) + switch { + case err == nil: + if err := backupFile(r.logger, jarPath, r.backupDir); err != nil { + return fmt.Errorf("failed to copy JMX jar to jar backup dir: %w", err) + } + case !errors.Is(err, os.ErrNotExist): + return fmt.Errorf("failed determine where currently installed JMX jar is: %w", err) + } + + // Backup the service configuration so we can reload it in case of rollback + if err := r.originalSvc.Backup(); err != nil { + return fmt.Errorf("failed to backup service configuration: %w", err) + } + + return nil +} + +// Rollback performs a rollback by undoing all recorded actions. +func (r filesystemRollbacker) Rollback() { + r.logger.Debug("Performing rollback") + // We need to loop through the actions slice backwards, to roll back the actions in the correct order. + // e.g. if StartService was called last, we need to stop the service first, then rollback previous actions. + for i := len(r.actions) - 1; i >= 0; i-- { + action := r.actions[i] + r.logger.Debug("Rolling back action", zap.Any("action", action)) + if err := action.Rollback(); err != nil { + r.logger.Error("Failed to run rollback action", zap.Error(err)) + } + } +} + +// backupFiles copies files from installDir to output path, skipping tmpDir. +func backupFiles(logger *zap.Logger, installDir, outputPath string) error { + absTmpDir, err := filepath.Abs(path.TempDir(installDir)) + if err != nil { + return fmt.Errorf("failed to get absolute path for temporary directory: %w", err) + } + + err = filepath.WalkDir(installDir, func(inPath string, d fs.DirEntry, err error) error { + + fullPath, absErr := filepath.Abs(inPath) + if absErr != nil { + return fmt.Errorf("failed to determine absolute path of file: %w", absErr) + } + + switch { + case err != nil: + // if there was an error walking the directory, we want to bail out. + return err + case d.IsDir() && strings.HasPrefix(fullPath, absTmpDir): + // If this is the "tmp" directory, we want to skip copying this directory, + // since this folder is only for temporary files (and is where this binary is running right now) + return filepath.SkipDir + case d.IsDir(): + // Skip directories, we'll create them when we get a file in the directory. + return nil + } + + // We want the path relative to the directory we are walking in order to calculate where the file should be + // mirrored in the output directory. + relPath, err := filepath.Rel(installDir, inPath) + if err != nil { + return err + } + + // use the relative path to get the outPath (where we should write the file), and + // to get the out directory (which we will create if it does not exist). + outPath := filepath.Join(outputPath, relPath) + outDir := filepath.Dir(outPath) + + if err := os.MkdirAll(outDir, 0750); err != nil { + return fmt.Errorf("failed to create dir: %w", err) + } + + // Fail if copying the input file to the output file would fail + if err := file.CopyFileNoOverwrite(logger.Named("copy-file"), inPath, outPath); err != nil { + return fmt.Errorf("failed to copy file: %w", err) + } + + return nil + }) + + if err != nil { + return fmt.Errorf("failed to walk latest dir: %w", err) + } + + return nil +} + +// backupFile copies original file to output path +func backupFile(logger *zap.Logger, inPath, outputDirPath string) error { + baseInPath := filepath.Base(inPath) + + // use the relative path to get the outPath (where we should write the file), and + // to get the out directory (which we will create if it does not exist). + outPath := filepath.Join(outputDirPath, baseInPath) + outDir := filepath.Dir(outPath) + + if err := os.MkdirAll(outDir, 0750); err != nil { + return fmt.Errorf("failed to create dir: %w", err) + } + + // Fail if copying the input file to the output file would fail + if err := file.CopyFileNoOverwrite(logger.Named("copy-file"), inPath, outPath); err != nil { + return fmt.Errorf("failed to copy file: %w", err) + } + + return nil +} diff --git a/updater/internal/rollback/rollback_test.go b/updater/internal/rollback/rollback_test.go new file mode 100644 index 000000000..0f645ea6b --- /dev/null +++ b/updater/internal/rollback/rollback_test.go @@ -0,0 +1,168 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rollback + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "testing" + + action_mocks "github.com/observiq/observiq-otel-collector/updater/internal/action/mocks" + service_mocks "github.com/observiq/observiq-otel-collector/updater/internal/service/mocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" +) + +func TestRollbackerBackup(t *testing.T) { + t.Run("Successfully backs up everything", func(t *testing.T) { + outDir := t.TempDir() + installDir := filepath.Join("testdata", "rollbacker") + + svc := service_mocks.NewService(t) + svc.On("Backup").Return(nil) + + rb := &filesystemRollbacker{ + originalSvc: svc, + backupDir: outDir, + installDir: installDir, + logger: zaptest.NewLogger(t), + } + + installJarPath := filepath.Join(rb.installDir, "..", "opentelemetry-java-contrib-jmx-metrics.jar") + _, err := os.Create(installJarPath) + require.NoError(t, err) + err = os.WriteFile(installJarPath, []byte("# The old jar file"), 0660) + require.NoError(t, err) + + err = rb.Backup() + require.NoError(t, err) + + require.FileExists(t, filepath.Join(outDir, "opentelemetry-java-contrib-jmx-metrics.jar")) + require.FileExists(t, filepath.Join(outDir, "some-file.txt")) + require.FileExists(t, filepath.Join(outDir, "plugins-dir", "plugin.txt")) + require.NoDirExists(t, filepath.Join(outDir, "tmp-dir")) + }) + + t.Run("Service backup fails", func(t *testing.T) { + outDir := t.TempDir() + installDir := filepath.Join("testdata", "rollbacker") + + svc := service_mocks.NewService(t) + svc.On("Backup").Return(fmt.Errorf("invalid permissions")) + + rb := &filesystemRollbacker{ + originalSvc: svc, + backupDir: outDir, + installDir: installDir, + logger: zaptest.NewLogger(t), + } + + err := rb.Backup() + require.ErrorContains(t, err, "failed to backup service configuration") + }) + + t.Run("Removes pre-existing backup", func(t *testing.T) { + outDir := t.TempDir() + installDir := filepath.Join("testdata", "rollbacker") + leftoverFile := filepath.Join(outDir, "leftover-file.txt") + + svc := service_mocks.NewService(t) + svc.On("Backup").Return(nil) + + err := os.MkdirAll(outDir, 0750) + require.NoError(t, err) + err = os.WriteFile(leftoverFile, []byte("leftover file"), 0600) + require.NoError(t, err) + + rb := &filesystemRollbacker{ + originalSvc: svc, + backupDir: outDir, + installDir: installDir, + logger: zaptest.NewLogger(t), + } + + err = rb.Backup() + require.NoError(t, err) + + require.FileExists(t, filepath.Join(outDir, "opentelemetry-java-contrib-jmx-metrics.jar")) + require.FileExists(t, filepath.Join(outDir, "some-file.txt")) + require.FileExists(t, filepath.Join(outDir, "plugins-dir", "plugin.txt")) + require.NoDirExists(t, filepath.Join(outDir, "tmp-dir")) + require.NoFileExists(t, leftoverFile) + }) +} + +func TestRollbackerRollback(t *testing.T) { + t.Run("Runs rollback actions in the correct order", func(t *testing.T) { + seq := 0 + + rb := &filesystemRollbacker{ + logger: zaptest.NewLogger(t), + } + + for i := 0; i < 10; i++ { + actionNum := i + action := action_mocks.NewRollbackableAction(t) + action.On("Rollback").Run(func(args mock.Arguments) { + // Rollback should be done in reverse order; So action 0 + // should be done last (10th action, seq == 9), while + // the last action (action 9) should be done first (seq == 0) + expectedSeq := 10 - actionNum - 1 + assert.Equal(t, expectedSeq, seq, "Expected action %d to occur at sequence %d", seq, expectedSeq) + seq++ + }).Return(nil) + + rb.AppendAction(action) + } + + rb.Rollback() + }) + + t.Run("Continues despite rollback errors", func(t *testing.T) { + seq := 0 + + rb := &filesystemRollbacker{ + logger: zaptest.NewLogger(t), + } + + for i := 0; i < 10; i++ { + actionNum := i + action := action_mocks.NewRollbackableAction(t) + + call := action.On("Rollback").Run(func(args mock.Arguments) { + // Rollback should be done in reverse order; So action 0 + // should be done last (10th action, seq == 9), while + // the last action (action 9) should be done first (seq == 0) + expectedSeq := 10 - actionNum - 1 + assert.Equal(t, expectedSeq, seq, "Expected action %d to occur at sequence %d", seq, expectedSeq) + seq++ + }) + + if actionNum == 5 { + call.Return(errors.New("failed to rollback")) + } else { + call.Return(nil) + } + + rb.AppendAction(action) + } + + rb.Rollback() + }) +} diff --git a/updater/internal/rollback/testdata/rollbacker/plugins-dir/plugin.txt b/updater/internal/rollback/testdata/rollbacker/plugins-dir/plugin.txt new file mode 100644 index 000000000..c47f0348d --- /dev/null +++ b/updater/internal/rollback/testdata/rollbacker/plugins-dir/plugin.txt @@ -0,0 +1 @@ +This is a test file for copying diff --git a/updater/internal/rollback/testdata/rollbacker/some-file.txt b/updater/internal/rollback/testdata/rollbacker/some-file.txt new file mode 100644 index 000000000..9f4b6d8bf --- /dev/null +++ b/updater/internal/rollback/testdata/rollbacker/some-file.txt @@ -0,0 +1 @@ +This is a test file diff --git a/updater/internal/rollback/testdata/rollbacker/tmp-dir/tmp-file.txt b/updater/internal/rollback/testdata/rollbacker/tmp-dir/tmp-file.txt new file mode 100644 index 000000000..f594928cb --- /dev/null +++ b/updater/internal/rollback/testdata/rollbacker/tmp-dir/tmp-file.txt @@ -0,0 +1 @@ +This file should not be copied, because it is in the tmp-dir diff --git a/updater/internal/service/mocks/service.go b/updater/internal/service/mocks/service.go new file mode 100644 index 000000000..c2810171a --- /dev/null +++ b/updater/internal/service/mocks/service.go @@ -0,0 +1,81 @@ +// Code generated by mockery v2.14.0. DO NOT EDIT. + +package mocks + +import mock "github.com/stretchr/testify/mock" + +// Service is an autogenerated mock type for the Service type +type Service struct { + mock.Mock +} + +// Backup provides a mock function with given fields: +func (_m *Service) Backup() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Start provides a mock function with given fields: +func (_m *Service) Start() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Stop provides a mock function with given fields: +func (_m *Service) Stop() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Update provides a mock function with given fields: +func (_m *Service) Update() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +type mockConstructorTestingTNewService interface { + mock.TestingT + Cleanup(func()) +} + +// NewService creates a new instance of Service. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewService(t mockConstructorTestingTNewService) *Service { + mock := &Service{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/updater/internal/service/service.go b/updater/internal/service/service.go new file mode 100644 index 000000000..fef849be6 --- /dev/null +++ b/updater/internal/service/service.go @@ -0,0 +1,44 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "bytes" + "os" + "path/filepath" +) + +//go:generate mockery --name Service --filename service.go +// Service represents a controllable service +type Service interface { + // Start the service + Start() error + + // Stop the service + Stop() error + + // Updates the old service configuration to the new one + Update() error + + // Backup backs the current service configuration + Backup() error +} + +// replaceInstallDir replaces "[INSTALLDIR]" with the given installDir string. +// This is meant to mimic windows "formatted" string syntax. +func replaceInstallDir(unformattedBytes []byte, installDir string) []byte { + installDirClean := filepath.Clean(installDir) + string(os.PathSeparator) + return bytes.ReplaceAll(unformattedBytes, []byte("[INSTALLDIR]"), []byte(installDirClean)) +} diff --git a/updater/internal/service/service_darwin.go b/updater/internal/service/service_darwin.go new file mode 100644 index 000000000..3b07bbb6b --- /dev/null +++ b/updater/internal/service/service_darwin.go @@ -0,0 +1,148 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build darwin + +package service + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + + "github.com/observiq/observiq-otel-collector/updater/internal/file" + "github.com/observiq/observiq-otel-collector/updater/internal/path" + "go.uber.org/zap" +) + +const ( + darwinServiceFilePath = "/Library/LaunchDaemons/com.observiq.collector.plist" +) + +// Option is an extra option for creating a Service +type Option func(darwinSvc *darwinService) + +// WithServiceFile returns an option setting the service file to use when updating using the service +func WithServiceFile(svcFilePath string) Option { + return func(darwinSvc *darwinService) { + darwinSvc.newServiceFilePath = svcFilePath + } +} + +// NewService returns an instance of the Service interface for managing the observiq-otel-collector service on the current OS. +func NewService(logger *zap.Logger, installDir string, opts ...Option) Service { + darwinSvc := &darwinService{ + newServiceFilePath: filepath.Join(path.ServiceFileDir(installDir), "com.observiq.collector.plist"), + installedServiceFilePath: darwinServiceFilePath, + installDir: path.DarwinInstallDir, + logger: logger.Named("darwin-service"), + } + + for _, opt := range opts { + opt(darwinSvc) + } + + return darwinSvc +} + +type darwinService struct { + // newServiceFilePath is the file path to the new plist file + newServiceFilePath string + // installedServiceFilePath is the file path to the installed plist file + installedServiceFilePath string + // installDir is the root directory of the main installation + installDir string + logger *zap.Logger +} + +// Start the service +func (d darwinService) Start() error { + // Launchctl exits with error code 0 if the file does not exist. + // We want to ensure that we error in this scenario. + if _, err := os.Stat(d.installedServiceFilePath); err != nil { + return fmt.Errorf("failed to stat installed service file: %w", err) + } + + //#nosec G204 -- installedServiceFilePath is not determined by user input + cmd := exec.Command("launchctl", "load", d.installedServiceFilePath) + if err := cmd.Run(); err != nil { + return fmt.Errorf("running launchctl failed: %w", err) + } + return nil +} + +// Stop the service +func (d darwinService) Stop() error { + // Launchctl exits with error code 0 if the file does not exist. + // We want to ensure that we error in this scenario. + if _, err := os.Stat(d.installedServiceFilePath); err != nil { + return fmt.Errorf("failed to stat installed service file: %w", err) + } + + //#nosec G204 -- installedServiceFilePath is not determined by user input + cmd := exec.Command("launchctl", "unload", d.installedServiceFilePath) + if err := cmd.Run(); err != nil { + return fmt.Errorf("running launchctl failed: %w", err) + } + return nil +} + +// Installs the service +func (d darwinService) install() error { + serviceFileBytes, err := os.ReadFile(d.newServiceFilePath) + if err != nil { + return fmt.Errorf("failed to open input file: %w", err) + } + + expandedServiceFileBytes := replaceInstallDir(serviceFileBytes, d.installDir) + if err := os.WriteFile(d.installedServiceFilePath, expandedServiceFileBytes, 0600); err != nil { + return fmt.Errorf("failed to write service file: %w", err) + } + + return d.Start() +} + +// Uninstalls the service +func (d darwinService) uninstall() error { + if err := d.Stop(); err != nil { + return err + } + + if err := os.Remove(d.installedServiceFilePath); err != nil { + return fmt.Errorf("failed to remove service file: %w", err) + } + + return nil +} + +func (d darwinService) Update() error { + if err := d.uninstall(); err != nil { + return fmt.Errorf("failed to uninstall old service: %w", err) + } + + if err := d.install(); err != nil { + return fmt.Errorf("failed to install new service: %w", err) + } + + return nil +} + +func (d darwinService) Backup() error { + if err := file.CopyFileNoOverwrite(d.logger.Named("copy-file"), d.installedServiceFilePath, path.BackupServiceFile(d.installDir)); err != nil { + return fmt.Errorf("failed to copy service file: %w", err) + } + + return nil +} diff --git a/updater/internal/service/service_darwin_test.go b/updater/internal/service/service_darwin_test.go new file mode 100644 index 000000000..9778e3573 --- /dev/null +++ b/updater/internal/service/service_darwin_test.go @@ -0,0 +1,306 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build darwin && integration + +package service + +import ( + "os" + "os/exec" + "path/filepath" + "regexp" + "testing" + + "github.com/observiq/observiq-otel-collector/updater/internal/path" + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" +) + +func TestDarwinServiceInstall(t *testing.T) { + t.Run("Test install + uninstall", func(t *testing.T) { + installedServicePath := filepath.Join(os.Getenv("HOME"), "Library", "LaunchAgents", "darwin-service.plist") + + uninstallService(t, installedServicePath) + + d := &darwinService{ + newServiceFilePath: filepath.Join("testdata", "darwin-service.plist"), + installedServiceFilePath: installedServicePath, + logger: zaptest.NewLogger(t), + } + + err := d.install() + require.NoError(t, err) + require.FileExists(t, installedServicePath) + + // We want to check that the service was actually loaded + requireServiceLoadedStatus(t, true) + + err = d.uninstall() + require.NoError(t, err) + require.NoFileExists(t, installedServicePath) + + // Make sure the service is no longer listed + requireServiceLoadedStatus(t, false) + }) + + t.Run("Test stop + start", func(t *testing.T) { + installedServicePath := filepath.Join(os.Getenv("HOME"), "Library", "LaunchAgents", "darwin-service.plist") + + // TODO: Do this automagically + uninstallService(t, installedServicePath) + + d := &darwinService{ + newServiceFilePath: filepath.Join("testdata", "darwin-service.plist"), + installedServiceFilePath: installedServicePath, + logger: zaptest.NewLogger(t), + } + + err := d.install() + require.NoError(t, err) + require.FileExists(t, installedServicePath) + + // We want to check that the service was actually loaded + requireServiceLoadedStatus(t, true) + + err = d.Start() + require.NoError(t, err) + + requireServiceRunning(t) + + err = d.Stop() + require.NoError(t, err) + + requireServiceLoadedStatus(t, false) + + err = d.uninstall() + require.NoError(t, err) + require.NoFileExists(t, installedServicePath) + + // Make sure the service is no longer listed + requireServiceLoadedStatus(t, false) + }) + + t.Run("Test invalid path for input file", func(t *testing.T) { + installedServicePath := filepath.Join(os.Getenv("HOME"), "Library", "LaunchAgents", "darwin-service.plist") + + uninstallService(t, installedServicePath) + + d := &darwinService{ + newServiceFilePath: filepath.Join("testdata", "does-not-exist.plist"), + installedServiceFilePath: installedServicePath, + logger: zaptest.NewLogger(t), + } + + err := d.install() + require.ErrorContains(t, err, "failed to open input file") + requireServiceLoadedStatus(t, false) + }) + + t.Run("Test invalid path for output file for install", func(t *testing.T) { + installedServicePath := filepath.Join(os.Getenv("HOME"), "Library", "LaunchAgents", "directory-does-not-exist", "darwin-service.plist") + + uninstallService(t, installedServicePath) + + d := &darwinService{ + newServiceFilePath: filepath.Join("testdata", "darwin-service.plist"), + installedServiceFilePath: installedServicePath, + logger: zaptest.NewLogger(t), + } + + err := d.install() + require.ErrorContains(t, err, "failed to write service file") + requireServiceLoadedStatus(t, false) + }) + + t.Run("Uninstall fails if not installed", func(t *testing.T) { + installedServicePath := filepath.Join(os.Getenv("HOME"), "Library", "LaunchAgents", "darwin-service.plist") + + uninstallService(t, installedServicePath) + + d := &darwinService{ + newServiceFilePath: filepath.Join("testdata", "darwin-service.plist"), + installedServiceFilePath: installedServicePath, + logger: zaptest.NewLogger(t), + } + + err := d.uninstall() + require.ErrorContains(t, err, "failed to stat installed service file") + requireServiceLoadedStatus(t, false) + }) + + t.Run("Start fails if service not found", func(t *testing.T) { + installedServicePath := filepath.Join(os.Getenv("HOME"), "Library", "LaunchAgents", "darwin-service.plist") + + uninstallService(t, installedServicePath) + + d := &darwinService{ + newServiceFilePath: filepath.Join("testdata", "darwin-service.plist"), + installedServiceFilePath: installedServicePath, + logger: zaptest.NewLogger(t), + } + + err := d.Start() + require.ErrorContains(t, err, "failed to stat installed service file") + }) + + t.Run("Stop fails if service not found", func(t *testing.T) { + installedServicePath := filepath.Join(os.Getenv("HOME"), "Library", "LaunchAgents", "darwin-service.plist") + + uninstallService(t, installedServicePath) + + d := &darwinService{ + newServiceFilePath: filepath.Join("testdata", "darwin-service.plist"), + installedServiceFilePath: installedServicePath, + logger: zaptest.NewLogger(t), + } + + err := d.Stop() + require.ErrorContains(t, err, "failed to stat installed service file") + }) + + t.Run("Backup installed service succeeds", func(t *testing.T) { + installedServicePath := filepath.Join(os.Getenv("HOME"), "Library", "LaunchAgents", "darwin-service.plist") + + uninstallService(t, installedServicePath) + + newServiceFile := filepath.Join("testdata", "darwin-service.plist") + serviceFileContents, err := os.ReadFile(newServiceFile) + require.NoError(t, err) + + installDir := t.TempDir() + require.NoError(t, os.MkdirAll(path.BackupDir(installDir), 0775)) + + d := &darwinService{ + newServiceFilePath: newServiceFile, + installedServiceFilePath: installedServicePath, + installDir: installDir, + logger: zaptest.NewLogger(t), + } + + err = d.install() + require.NoError(t, err) + require.FileExists(t, installedServicePath) + + // We want to check that the service was actually loaded + requireServiceLoadedStatus(t, true) + + require.NoError(t, d.Stop()) + + err = d.Backup() + require.NoError(t, err) + require.FileExists(t, path.BackupServiceFile(installDir)) + + backupServiceContents, err := os.ReadFile(path.BackupServiceFile(installDir)) + + require.Equal(t, serviceFileContents, backupServiceContents) + require.NoError(t, d.uninstall()) + }) + + t.Run("Backup installed service fails if not installed", func(t *testing.T) { + installedServicePath := filepath.Join(os.Getenv("HOME"), "Library", "LaunchAgents", "darwin-service.plist") + + uninstallService(t, installedServicePath) + + newServiceFile := filepath.Join("testdata", "darwin-service.plist") + installDir := t.TempDir() + require.NoError(t, os.MkdirAll(path.BackupDir(installDir), 0775)) + + d := &darwinService{ + newServiceFilePath: newServiceFile, + installedServiceFilePath: installedServicePath, + installDir: installDir, + logger: zaptest.NewLogger(t), + } + + err := d.Backup() + require.ErrorContains(t, err, "failed to copy service file") + }) + + t.Run("Backup installed service fails if output file already exists", func(t *testing.T) { + installedServicePath := filepath.Join(os.Getenv("HOME"), "Library", "LaunchAgents", "darwin-service.plist") + + uninstallService(t, installedServicePath) + + newServiceFile := filepath.Join("testdata", "darwin-service.plist") + + installDir := t.TempDir() + require.NoError(t, os.MkdirAll(path.BackupDir(installDir), 0775)) + + d := &darwinService{ + newServiceFilePath: newServiceFile, + installedServiceFilePath: installedServicePath, + logger: zaptest.NewLogger(t), + installDir: installDir, + } + + err := d.install() + require.NoError(t, err) + require.FileExists(t, installedServicePath) + + // We want to check that the service was actually loaded + requireServiceLoadedStatus(t, true) + + require.NoError(t, d.Stop()) + + // Write the backup file before creating it; Backup should + // not ever overwrite an existing file + os.WriteFile(path.BackupServiceFile(installDir), []byte("file exists"), 0600) + + err = d.Backup() + require.ErrorContains(t, err, "failed to copy service file") + }) +} + +// uninstallService is a helper that uninstalls the service manually for test setup, in case it is somehow leftover. +func uninstallService(t *testing.T, installedPath string) { + t.Helper() + + cmd := exec.Command("launchctl", "unload", installedPath) + // May already be unloaded; We'll ignore the error. + _ = cmd.Run() + + err := os.RemoveAll(installedPath) + require.NoError(t, err) +} + +const exitCodeServiceNotFound = 113 + +func requireServiceLoadedStatus(t *testing.T, loaded bool) { + t.Helper() + + cmd := exec.Command("launchctl", "list", "darwin-service") + err := cmd.Run() + if loaded { + // If the service should be loaded, then we expect a 0 exit code, so no error is given + require.NoError(t, err) + return + } + + eErr, ok := err.(*exec.ExitError) + require.True(t, ok, "launchctl list exited with non-ExitError: %s", eErr) + require.Equal(t, exitCodeServiceNotFound, eErr.ExitCode(), "unexpected exit code when asserting service is unloaded: %d", eErr.ExitCode()) +} + +var descriptionPIDRegex = regexp.MustCompile(`\s*"PID" = \d+;`) + +func requireServiceRunning(t *testing.T) { + t.Helper() + + cmd := exec.Command("launchctl", "list", "darwin-service") + out, err := cmd.Output() + require.NoError(t, err) + matches := descriptionPIDRegex.Match(out) + require.True(t, matches, "Service should be running, but it was not found in launchctl list") +} diff --git a/updater/internal/service/service_linux.go b/updater/internal/service/service_linux.go new file mode 100644 index 000000000..c8804270f --- /dev/null +++ b/updater/internal/service/service_linux.go @@ -0,0 +1,173 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build linux + +package service + +import ( + "fmt" + "io" + "log" + "os" + "os/exec" + "path/filepath" + + "github.com/observiq/observiq-otel-collector/updater/internal/file" + "github.com/observiq/observiq-otel-collector/updater/internal/path" + "go.uber.org/zap" +) + +const linuxServiceName = "observiq-otel-collector" +const linuxServiceFilePath = "/usr/lib/systemd/system/observiq-otel-collector.service" + +// Option is an extra option for creating a Service +type Option func(linuxSvc *linuxService) + +// WithServiceFile returns an option setting the service file to use when updating using the service +func WithServiceFile(svcFilePath string) Option { + return func(linuxSvc *linuxService) { + linuxSvc.newServiceFilePath = svcFilePath + } +} + +// NewService returns an instance of the Service interface for managing the observiq-otel-collector service on the current OS. +func NewService(logger *zap.Logger, installDir string, opts ...Option) Service { + linuxSvc := &linuxService{ + newServiceFilePath: filepath.Join(path.ServiceFileDir(installDir), "observiq-otel-collector.service"), + serviceName: linuxServiceName, + installedServiceFilePath: linuxServiceFilePath, + installDir: installDir, + logger: logger.Named("linux-service"), + } + + for _, opt := range opts { + opt(linuxSvc) + } + + return linuxSvc +} + +type linuxService struct { + // newServiceFilePath is the file path to the new unit file + newServiceFilePath string + // serviceName is the name of the service + serviceName string + // installedServiceFilePath is the file path to the installed unit file + installedServiceFilePath string + installDir string + logger *zap.Logger +} + +// Start the service +func (l linuxService) Start() error { + //#nosec G204 -- serviceName is not determined by user input + cmd := exec.Command("systemctl", "start", l.serviceName) + if err := cmd.Run(); err != nil { + return fmt.Errorf("running systemctl failed: %w", err) + } + return nil +} + +// Stop the service +func (l linuxService) Stop() error { + //#nosec G204 -- serviceName is not determined by user input + cmd := exec.Command("systemctl", "stop", l.serviceName) + if err := cmd.Run(); err != nil { + return fmt.Errorf("running systemctl failed: %w", err) + } + return nil +} + +// installs the service +func (l linuxService) install() error { + inFile, err := os.Open(l.newServiceFilePath) + if err != nil { + return fmt.Errorf("failed to open input file: %w", err) + } + defer func() { + err := inFile.Close() + if err != nil { + log.Default().Printf("Service Install: Failed to close input file: %s", err) + } + }() + + outFile, err := os.OpenFile(l.installedServiceFilePath, os.O_CREATE|os.O_WRONLY, 0600) + if err != nil { + return fmt.Errorf("failed to open output file: %w", err) + } + defer func() { + err := outFile.Close() + if err != nil { + log.Default().Printf("Service Install: Failed to close output file: %s", err) + } + }() + + if _, err := io.Copy(outFile, inFile); err != nil { + return fmt.Errorf("failed to copy service file: %w", err) + } + + cmd := exec.Command("systemctl", "daemon-reload") + if err := cmd.Run(); err != nil { + return fmt.Errorf("reloading systemctl failed: %w", err) + } + + //#nosec G204 -- serviceName is not determined by user input + cmd = exec.Command("systemctl", "enable", l.serviceName) + if err := cmd.Run(); err != nil { + return fmt.Errorf("enabling unit file failed: %w", err) + } + + return nil +} + +// uninstalls the service +func (l linuxService) uninstall() error { + //#nosec G204 -- serviceName is not determined by user input + cmd := exec.Command("systemctl", "disable", l.serviceName) + if err := cmd.Run(); err != nil { + return fmt.Errorf("failed to disable unit: %w", err) + } + + if err := os.Remove(l.installedServiceFilePath); err != nil { + return fmt.Errorf("failed to remove service file: %w", err) + } + + cmd = exec.Command("systemctl", "daemon-reload") + if err := cmd.Run(); err != nil { + return fmt.Errorf("reloading systemctl failed: %w", err) + } + + return nil +} + +func (l linuxService) Update() error { + if err := l.uninstall(); err != nil { + return fmt.Errorf("failed to uninstall old service: %w", err) + } + + if err := l.install(); err != nil { + return fmt.Errorf("failed to install new service: %w", err) + } + + return nil +} + +func (l linuxService) Backup() error { + if err := file.CopyFileNoOverwrite(l.logger.Named("copy-file"), l.installedServiceFilePath, path.BackupServiceFile(l.installDir)); err != nil { + return fmt.Errorf("failed to copy service file: %w", err) + } + + return nil +} diff --git a/updater/internal/service/service_linux_test.go b/updater/internal/service/service_linux_test.go new file mode 100644 index 000000000..7bc8f6922 --- /dev/null +++ b/updater/internal/service/service_linux_test.go @@ -0,0 +1,312 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// an elevated user is needed to run the service tests +//go:build linux && integration + +package service + +import ( + "os" + "os/exec" + "path/filepath" + "testing" + + "github.com/observiq/observiq-otel-collector/updater/internal/path" + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" +) + +// NOTE: These tests must run as root in order to pass +func TestLinuxServiceInstall(t *testing.T) { + t.Run("Test install + uninstall", func(t *testing.T) { + installedServicePath := "/usr/lib/systemd/system/linux-service.service" + uninstallService(t, installedServicePath, "linux-service") + + l := &linuxService{ + newServiceFilePath: filepath.Join("testdata", "linux-service.service"), + serviceName: "linux-service", + installedServiceFilePath: installedServicePath, + logger: zaptest.NewLogger(t), + } + + err := l.install() + require.NoError(t, err) + require.FileExists(t, installedServicePath) + + //We want to check that the service was actually loaded + requireServiceLoadedStatus(t, true) + + err = l.uninstall() + require.NoError(t, err) + require.NoFileExists(t, installedServicePath) + + //Make sure the service is no longer listed + requireServiceLoadedStatus(t, false) + }) + + t.Run("Test stop + start", func(t *testing.T) { + installedServicePath := "/usr/lib/systemd/system/linux-service.service" + uninstallService(t, installedServicePath, "linux-service") + + l := &linuxService{ + newServiceFilePath: filepath.Join("testdata", "linux-service.service"), + serviceName: "linux-service", + installedServiceFilePath: installedServicePath, + logger: zaptest.NewLogger(t), + } + + err := l.install() + require.NoError(t, err) + require.FileExists(t, installedServicePath) + + // We want to check that the service was actually loaded + requireServiceLoadedStatus(t, true) + + err = l.Start() + require.NoError(t, err) + + requireServiceRunningStatus(t, true) + + err = l.Stop() + require.NoError(t, err) + + requireServiceRunningStatus(t, false) + + err = l.uninstall() + require.NoError(t, err) + require.NoFileExists(t, installedServicePath) + + // Make sure the service is no longer listed + requireServiceLoadedStatus(t, false) + }) + + t.Run("Test invalid path for input file", func(t *testing.T) { + installedServicePath := "/usr/lib/systemd/system/linux-service.service" + uninstallService(t, installedServicePath, "linux-service") + + l := &linuxService{ + newServiceFilePath: filepath.Join("testdata", "does-not-exist.service"), + serviceName: "linux-service", + installedServiceFilePath: installedServicePath, + logger: zaptest.NewLogger(t), + } + + err := l.install() + require.ErrorContains(t, err, "failed to open input file") + requireServiceLoadedStatus(t, false) + }) + + t.Run("Test invalid path for output file for install", func(t *testing.T) { + installedServicePath := "/usr/lib/systemd/system/dir-does-not-exist/linux-service.service" + uninstallService(t, installedServicePath, "linux-service") + + l := &linuxService{ + newServiceFilePath: filepath.Join("testdata", "linux-service.service"), + serviceName: "linux-service", + installedServiceFilePath: installedServicePath, + logger: zaptest.NewLogger(t), + } + + err := l.install() + require.ErrorContains(t, err, "failed to open output file") + requireServiceLoadedStatus(t, false) + }) + + t.Run("Uninstall fails if not installed", func(t *testing.T) { + installedServicePath := "/usr/lib/systemd/system/linux-service.service" + uninstallService(t, installedServicePath, "linux-service") + + l := &linuxService{ + newServiceFilePath: filepath.Join("testdata", "linux-service.service"), + serviceName: "linux-service", + installedServiceFilePath: installedServicePath, + logger: zaptest.NewLogger(t), + } + + err := l.uninstall() + require.ErrorContains(t, err, "failed to disable unit") + requireServiceLoadedStatus(t, false) + }) + + t.Run("Start fails if service not found", func(t *testing.T) { + installedServicePath := "/usr/lib/systemd/system/linux-service.service" + uninstallService(t, installedServicePath, "linux-service") + + l := &linuxService{ + newServiceFilePath: filepath.Join("testdata", "linux-service.service"), + serviceName: "linux-service", + installedServiceFilePath: installedServicePath, + logger: zaptest.NewLogger(t), + } + + err := l.Start() + require.ErrorContains(t, err, "running systemctl failed") + }) + + t.Run("Stop fails if service not found", func(t *testing.T) { + installedServicePath := "/usr/lib/systemd/system/linux-service.service" + uninstallService(t, installedServicePath, "linux-service") + + l := &linuxService{ + newServiceFilePath: filepath.Join("testdata", "linux-service.service"), + serviceName: "linux-service", + installedServiceFilePath: installedServicePath, + logger: zaptest.NewLogger(t), + } + + err := l.Stop() + require.ErrorContains(t, err, "running systemctl failed") + }) + + t.Run("Backup installed service succeeds", func(t *testing.T) { + installedServicePath := "/usr/lib/systemd/system/linux-service.service" + uninstallService(t, installedServicePath, "linux-service") + + newServiceFile := filepath.Join("testdata", "linux-service.service") + serviceFileContents, err := os.ReadFile(newServiceFile) + require.NoError(t, err) + + installDir := t.TempDir() + require.NoError(t, os.MkdirAll(path.BackupDir(installDir), 0775)) + + d := &linuxService{ + newServiceFilePath: newServiceFile, + installedServiceFilePath: installedServicePath, + serviceName: "linux-service", + installDir: installDir, + logger: zaptest.NewLogger(t), + } + + err = d.install() + require.NoError(t, err) + require.FileExists(t, installedServicePath) + + // We want to check that the service was actually loaded + requireServiceLoadedStatus(t, true) + + err = d.Backup() + require.NoError(t, err) + require.FileExists(t, path.BackupServiceFile(installDir)) + + backupServiceContents, err := os.ReadFile(path.BackupServiceFile(installDir)) + + require.Equal(t, serviceFileContents, backupServiceContents) + require.NoError(t, d.uninstall()) + }) + + t.Run("Backup installed service fails if not installed", func(t *testing.T) { + installedServicePath := "/usr/lib/systemd/system/linux-service.service" + uninstallService(t, installedServicePath, "linux-service") + + newServiceFile := filepath.Join("testdata", "linux-service.service") + + installDir := t.TempDir() + require.NoError(t, os.MkdirAll(path.BackupDir(installDir), 0775)) + + d := &linuxService{ + newServiceFilePath: newServiceFile, + installedServiceFilePath: installedServicePath, + serviceName: "linux-service", + installDir: installDir, + logger: zaptest.NewLogger(t), + } + + err := d.Backup() + require.ErrorContains(t, err, "failed to copy service file") + }) + + t.Run("Backup installed service fails if output file already exists", func(t *testing.T) { + installedServicePath := "/usr/lib/systemd/system/linux-service.service" + uninstallService(t, installedServicePath, "linux-service") + + newServiceFile := filepath.Join("testdata", "linux-service.service") + + installDir := t.TempDir() + require.NoError(t, os.MkdirAll(path.BackupDir(installDir), 0775)) + + d := &linuxService{ + newServiceFilePath: newServiceFile, + installedServiceFilePath: installedServicePath, + serviceName: "linux-service", + installDir: installDir, + logger: zaptest.NewLogger(t), + } + + err := d.install() + require.NoError(t, err) + require.FileExists(t, installedServicePath) + + // We want to check that the service was actually loaded + requireServiceLoadedStatus(t, true) + + // Write the backup file before creating it; Backup should + // not ever overwrite an existing file + os.WriteFile(path.BackupServiceFile(installDir), []byte("file exists"), 0600) + + err = d.Backup() + require.ErrorContains(t, err, "failed to copy service file") + }) +} + +// uninstallService is a helper that uninstalls the service manually for test setup, in case it is somehow leftover. +func uninstallService(t *testing.T, installedPath, serviceName string) { + cmd := exec.Command("systemctl", "stop", serviceName) + _ = cmd.Run() + + cmd = exec.Command("systemctl", "disable", serviceName) + _ = cmd.Run() + + err := os.RemoveAll(installedPath) + require.NoError(t, err) + + cmd = exec.Command("systemctl", "daemon-reload") + _ = cmd.Run() +} + +const exitCodeServiceNotFound = 4 +const exitCodeServiceInactive = 3 + +func requireServiceLoadedStatus(t *testing.T, loaded bool) { + t.Helper() + + cmd := exec.Command("systemctl", "status", "linux-service") + err := cmd.Run() + require.Error(t, err, "expected non-zero exit code from 'systemctl status linux-service'") + + eErr, ok := err.(*exec.ExitError) + if loaded { + // If the service should be loaded, then we expect a 0 exit code, so no error is given + require.Equal(t, exitCodeServiceInactive, eErr.ExitCode(), "unexpected exit code when asserting service is unloaded: %d", eErr.ExitCode()) + return + } + + require.True(t, ok, "systemctl status exited with non-ExitError: %s", eErr) + require.Equal(t, exitCodeServiceNotFound, eErr.ExitCode(), "unexpected exit code when asserting service is unloaded: %d", eErr.ExitCode()) +} + +func requireServiceRunningStatus(t *testing.T, running bool) { + cmd := exec.Command("systemctl", "status", "linux-service") + err := cmd.Run() + + if running { + // exit code 0 indicates service is loaded & running + require.NoError(t, err) + return + } + + eErr, ok := err.(*exec.ExitError) + require.True(t, ok, "systemctl status exited with non-ExitError: %s", eErr) + require.Equal(t, exitCodeServiceInactive, eErr.ExitCode(), "unexpected exit code when asserting service is not running: %d", eErr.ExitCode()) +} diff --git a/updater/internal/service/service_test.go b/updater/internal/service/service_test.go new file mode 100644 index 000000000..3d5be7a4e --- /dev/null +++ b/updater/internal/service/service_test.go @@ -0,0 +1,54 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package service + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestReplaceInstallDir(t *testing.T) { + testCases := []struct { + input []byte + installDir string + output []byte + }{ + { + input: []byte("[INSTALLDIR]"), + installDir: "some/install/directory", + output: []byte(filepath.Join("some", "install", "directory") + string(os.PathSeparator)), + }, + { + input: []byte("no install dir"), + installDir: "some/install/directory", + output: []byte("no install dir"), + }, + { + input: []byte("[INSTALLDIR]observiq-otel-collector"), + installDir: "some/install/directory", + output: []byte(filepath.Join("some", "install", "directory", "observiq-otel-collector")), + }, + } + + for _, tc := range testCases { + t.Run(string(tc.input), func(t *testing.T) { + out := replaceInstallDir(tc.input, tc.installDir) + require.Equal(t, tc.output, out) + }) + } +} diff --git a/updater/internal/service/service_windows.go b/updater/internal/service/service_windows.go new file mode 100644 index 000000000..342e27432 --- /dev/null +++ b/updater/internal/service/service_windows.go @@ -0,0 +1,370 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//go:build windows + +package service + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + + "go.uber.org/zap" + "golang.org/x/sys/windows/svc" + "golang.org/x/sys/windows/svc/mgr" + + "github.com/kballard/go-shellquote" + "github.com/observiq/observiq-otel-collector/updater/internal/path" +) + +const ( + defaultProductName = "observIQ Distro for OpenTelemetry Collector" + defaultServiceName = "observiq-otel-collector" +) + +// Option is an extra option for creating a Service +type Option func(winSvc *windowsService) + +// WithServiceFile returns an option setting the service file to use when updating using the service +func WithServiceFile(svcFilePath string) Option { + return func(winSvc *windowsService) { + winSvc.newServiceFilePath = svcFilePath + } +} + +// NewService returns an instance of the Service interface for managing the observiq-otel-collector service on the current OS. +func NewService(logger *zap.Logger, installDir string, opts ...Option) Service { + winSvc := &windowsService{ + newServiceFilePath: filepath.Join(path.ServiceFileDir(installDir), "windows_service.json"), + serviceName: defaultServiceName, + productName: defaultProductName, + installDir: installDir, + logger: logger.Named("windows-service"), + } + + for _, opt := range opts { + opt(winSvc) + } + + return winSvc +} + +type windowsService struct { + // newServiceFilePath is the file path to the new unit file + newServiceFilePath string + // serviceName is the name of the service + serviceName string + // productName is the name of the installed product + productName string + installDir string + logger *zap.Logger +} + +// Start the service +func (w windowsService) Start() error { + m, err := mgr.Connect() + if err != nil { + return fmt.Errorf("failed to connect to service manager: %w", err) + } + defer m.Disconnect() + + s, err := m.OpenService(w.serviceName) + if err != nil { + return fmt.Errorf("failed to open service: %w", err) + } + defer s.Close() + + if err := s.Start(); err != nil { + return fmt.Errorf("failed to start service: %w", err) + } + + return nil +} + +// Stop the service +func (w windowsService) Stop() error { + m, err := mgr.Connect() + if err != nil { + return fmt.Errorf("failed to connect to service manager: %w", err) + } + defer m.Disconnect() + + s, err := m.OpenService(w.serviceName) + if err != nil { + return fmt.Errorf("failed to open service: %w", err) + } + defer s.Close() + + if _, err := s.Control(svc.Stop); err != nil { + return fmt.Errorf("failed to start service: %w", err) + } + + return nil +} + +func (w windowsService) Update() error { + // parse the service definition from disk + wsc, err := readWindowsServiceConfig(w.newServiceFilePath) + if err != nil { + return fmt.Errorf("failed to read service config: %w", err) + } + + // expand the arguments to be properly formatted (expand [INSTALLDIR], clean '"' to be '"') + expandArguments(wsc, w.installDir) + + // Get the start type + startType, delayed, err := winapiStartType(wsc.Service.Start) + if err != nil { + return fmt.Errorf("failed to parse start type in service config: %w", err) + } + + m, err := mgr.Connect() + if err != nil { + return fmt.Errorf("failed to connect to service manager: %w", err) + } + defer m.Disconnect() + + // Get the installed service handle + s, err := m.OpenService(w.serviceName) + if err != nil { + return fmt.Errorf("failed to open service: %w", err) + } + defer func() { + if err := s.Close(); err != nil { + w.logger.Error("failed to close service after update", zap.Error(err)) + } + }() + + // Get the current config; We will use the current config as the basis + // for the new config. + newConf, err := s.Config() + if err != nil { + return fmt.Errorf("failed to get current service configuration: %w", err) + } + + // Get the full path the the collector + fullCollectorPath := filepath.Join(w.installDir, wsc.Path) + // binary path is the path to the EXE, then the space separated list of arguments. + // we quote the collector path, in case it contains spaces. + binaryPathName := fmt.Sprintf("\"%s\" %s", fullCollectorPath, wsc.Service.Arguments) + + // Fill in the new config values + newConf.BinaryPathName = binaryPathName + newConf.Description = wsc.Service.Description + newConf.DisplayName = wsc.Service.DisplayName + newConf.StartType = startType + newConf.DelayedAutoStart = delayed + + // Update the service in-place + err = s.UpdateConfig(newConf) + if err != nil { + return fmt.Errorf("failed to updater service: %w", err) + } + + return nil +} + +func (w windowsService) Backup() error { + + wsc, err := w.currentServiceConfig() + if err != nil { + return fmt.Errorf("failed to construct service config: %w", err) + } + + // Marshal config as json + wscBytes, err := json.Marshal(wsc) + if err != nil { + return fmt.Errorf("failed to marshal config: %w", err) + } + + // Open with O_EXCL to fail if the file already exists + f, err := os.OpenFile(path.BackupServiceFile(w.installDir), os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0600) + if err != nil { + return fmt.Errorf("failed to create backup service file: %w", err) + } + defer func() { + if err := f.Close(); err != nil { + w.logger.Error("Failed to close backup service file", zap.Error(err)) + } + }() + + // finally, write the config out so we can rollback. + if _, err := f.Write(wscBytes); err != nil { + return fmt.Errorf("failed to write backup service config: %w", err) + } + + return nil +} + +// windowsServiceConfig defines how the service should be configured, including the entrypoint for the service. +type windowsServiceConfig struct { + // Path is the file that will be executed for the service. It is relative to the install directory. + Path string `json:"path"` + // Configuration for the service (e.g. start type, display name, desc) + Service windowsServiceDefinitionConfig `json:"service"` +} + +// windowsServiceDefinitionConfig defines how the service should be configured. +// Name is a part of the on disk config, but we keep the service name hardcoded; We do not want to use a different service name. +type windowsServiceDefinitionConfig struct { + // Start gives the start type of the service. + // See: https://wixtoolset.org/documentation/manual/v3/xsd/wix/serviceinstall.html + Start string `json:"start"` + // DisplayName is the human-readable name of the service. + DisplayName string `json:"display-name"` + // Description is a human-readable description of the service. + Description string `json:"description"` + // Arguments is a list of space-separated + Arguments string `json:"arguments"` +} + +// readWindowsServiceConfig reads the service config from the file at the given path +func readWindowsServiceConfig(path string) (*windowsServiceConfig, error) { + cleanPath := filepath.Clean(path) + b, err := os.ReadFile(cleanPath) + if err != nil { + return nil, fmt.Errorf("failed to read file: %w", err) + } + + var wsc windowsServiceConfig + err = json.Unmarshal(b, &wsc) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal json: %w", err) + } + + return &wsc, nil +} + +// expandArguments expands [INSTALLDIR] to the actual install directory and +// expands '"' to the literal '"' +func expandArguments(wsc *windowsServiceConfig, installDir string) { + wsc.Service.Arguments = string(replaceInstallDir([]byte(wsc.Service.Arguments), installDir)) + wsc.Service.Arguments = strings.ReplaceAll(wsc.Service.Arguments, """, `"`) +} + +func (w windowsService) currentServiceConfig() (*windowsServiceConfig, error) { + m, err := mgr.Connect() + if err != nil { + return nil, fmt.Errorf("failed to connect to service manager: %w", err) + } + defer m.Disconnect() + + s, err := m.OpenService(w.serviceName) + if err != nil { + return nil, fmt.Errorf("failed to open service: %w", err) + } + defer s.Close() + + // Get the current config of the service + conf, err := s.Config() + if err != nil { + return nil, fmt.Errorf("failed to get service config: %w", err) + } + + fullBinaryPath, argString, err := splitServiceBinaryName(conf.BinaryPathName) + if err != nil { + return nil, fmt.Errorf("failed to split service BinaryPathName: %w", err) + } + + // In the original config, the Path is the main binary path, relative to the install directory. + binaryPath, err := filepath.Rel(w.installDir, fullBinaryPath) + if err != nil { + return nil, fmt.Errorf("could not find service exe relative to install dir: %w", err) + } + + // Convert windows api start type to the config file service type + confStartType, err := configStartType(conf.StartType, conf.DelayedAutoStart) + if err != nil { + return nil, fmt.Errorf("failed to get start type: %w", err) + } + + // Construct the config + return &windowsServiceConfig{ + Path: binaryPath, + Service: windowsServiceDefinitionConfig{ + Start: confStartType, + DisplayName: conf.DisplayName, + Description: conf.Description, + Arguments: argString, + }, + }, nil +} + +func splitServiceBinaryName(binaryPathName string) (binaryPath, argString string, err error) { + // Split the service arguments into an array of arguments + args, err := shellquote.Split(binaryPathName) + if err != nil { + return "", "", fmt.Errorf("failed to split service config args: %w", err) + } + + // The first argument is always the binary name; If the length of the array is 0, we know this is an invalid argument list. + if len(args) < 1 { + return "", "", fmt.Errorf("no binary specified in service config") + } + + // The absolute path to the binary is the first argument + binaryPath = args[0] + + // Stored argument string doesn't include the binary path (first arg) + args = args[1:] + + // Args should end up being a string, where literal quotes are """ + argString = shellquote.Join(args...) + // shellquote uses ' to quote, so we convert those to """ + argString = strings.ReplaceAll(argString, "'", """) + + return binaryPath, argString, nil +} + +// winapiStartType converts the start type from the windowsServiceConfig to a start type recognizable by the windows +// service API +func winapiStartType(cfgStartType string) (startType uint32, delayed bool, err error) { + switch cfgStartType { + case "auto": + // Automatically starts on system bootup. + startType = mgr.StartAutomatic + case "demand": + // Must be started manually + startType = mgr.StartManual + case "disabled": + // Does not start, must be enabled to run. + startType = mgr.StartDisabled + case "delayed": + // Boots automatically on start, but AFTER bootup has completed. + startType = mgr.StartAutomatic + delayed = true + default: + err = fmt.Errorf("invalid start type in service config: %s", cfgStartType) + } + return +} + +func configStartType(winapiStartType uint32, delayed bool) (string, error) { + switch winapiStartType { + case mgr.StartAutomatic: + if delayed { + return "delayed", nil + } + return "auto", nil + case mgr.StartDisabled: + return "disabled", nil + case mgr.StartManual: + return "manual", nil + default: + return "", fmt.Errorf("invalid winapi start type: %d", winapiStartType) + } +} diff --git a/updater/internal/service/service_windows_test.go b/updater/internal/service/service_windows_test.go new file mode 100644 index 000000000..71e9931b6 --- /dev/null +++ b/updater/internal/service/service_windows_test.go @@ -0,0 +1,598 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// an elevated user is needed to run the service tests +//go:build windows && integration + +package service + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + "time" + + "go.uber.org/zap/zaptest" + + "github.com/observiq/observiq-otel-collector/updater/internal/path" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sys/windows/svc" + "golang.org/x/sys/windows/svc/mgr" +) + +func TestWindowsServiceInstall(t *testing.T) { + t.Run("Test Update", func(t *testing.T) { + tempDir := t.TempDir() + testProductName := "Test Product" + testServiceName := "windows-service" + + serviceJSON := filepath.Join(tempDir, "windows-service.json") + testServiceProgram := filepath.Join(tempDir, "windows-service.exe") + serviceGoFile, err := filepath.Abs(filepath.Join("testdata", "test-windows-service.go")) + require.NoError(t, err) + + writeServiceFile(t, serviceJSON, filepath.Join("testdata", "windows-service.json"), serviceGoFile) + compileProgram(t, serviceGoFile, testServiceProgram) + + installService(t, + testServiceProgram, + testServiceName, + "Test Windows Service - Initial Display Name", + "This is the test windows service; initial desription", + mgr.StartAutomatic, + false) + + t.Cleanup(func() { + uninstallService(t, testServiceName) + time.Sleep(100 * time.Millisecond) + }) + + w := &windowsService{ + newServiceFilePath: serviceJSON, + serviceName: testServiceName, + productName: testProductName, + installDir: tempDir, + logger: zaptest.NewLogger(t), + } + + err = w.Update() + require.NoError(t, err) + + //We want to check that the service was actually loaded + requireServiceLoadedStatus(t, true) + + requireServiceConfigMatches(t, + fmt.Sprintf("\"%s\"", testServiceProgram), + "windows-service", + mgr.StartAutomatic, + "Test Windows Service", + "This is a windows service to test", + true, + []string{ + "--config", + fmt.Sprintf("\"%s\"", filepath.Join(tempDir, "test.yaml")), + }, + ) + }) + + t.Run("Test update (space in install folder)", func(t *testing.T) { + tempDir := filepath.Join(t.TempDir(), "temp dir with spaces") + require.NoError(t, os.MkdirAll(tempDir, 0777)) + testProductName := "Test Product" + testServiceName := "windows-service" + + serviceJSON := filepath.Join(tempDir, "windows-service.json") + testServiceProgram := filepath.Join(tempDir, "windows-service.exe") + serviceGoFile, err := filepath.Abs(filepath.Join("testdata", "test-windows-service.go")) + require.NoError(t, err) + + writeServiceFile(t, serviceJSON, filepath.Join("testdata", "windows-service.json"), serviceGoFile) + compileProgram(t, serviceGoFile, testServiceProgram) + + installService(t, + testServiceProgram, + testServiceName, + "Test Windows Service - Initial Display Name", + "This is the test windows service; initial desription", + mgr.StartAutomatic, + false) + + t.Cleanup(func() { + uninstallService(t, testServiceName) + time.Sleep(100 * time.Millisecond) + }) + + w := &windowsService{ + newServiceFilePath: serviceJSON, + serviceName: testServiceName, + productName: testProductName, + installDir: tempDir, + logger: zaptest.NewLogger(t), + } + + err = w.Update() + require.NoError(t, err) + + //We want to check that the service was actually loaded + requireServiceLoadedStatus(t, true) + + requireServiceConfigMatches(t, + testServiceProgram, + "windows-service", + mgr.StartAutomatic, + "Test Windows Service", + "This is a windows service to test", + true, + []string{ + "--config", + filepath.Join(tempDir, "test.yaml"), + }, + ) + }) + + t.Run("Test stop + start", func(t *testing.T) { + tempDir := t.TempDir() + testProductName := "Test Product" + testServiceName := "windows-service" + + serviceJSON := filepath.Join(tempDir, "windows-service.json") + testServiceProgram := filepath.Join(tempDir, "windows-service.exe") + serviceGoFile, err := filepath.Abs(filepath.Join("testdata", "test-windows-service.go")) + require.NoError(t, err) + + writeServiceFile(t, serviceJSON, filepath.Join("testdata", "windows-service.json"), serviceGoFile) + compileProgram(t, serviceGoFile, testServiceProgram) + + installService(t, + testServiceProgram, + testServiceName, + "Test Windows Service - Initial Display Name", + "This is the test windows service; initial desription", + mgr.StartManual, + false) + + t.Cleanup(func() { + uninstallService(t, testServiceName) + time.Sleep(100 * time.Millisecond) + }) + + w := &windowsService{ + newServiceFilePath: serviceJSON, + serviceName: "windows-service", + productName: testProductName, + installDir: tempDir, + logger: zaptest.NewLogger(t), + } + + // We want to check that the service was actually loaded + requireServiceLoadedStatus(t, true) + + err = w.Start() + require.NoError(t, err) + + requireServiceRunningStatus(t, true) + + err = w.Stop() + require.NoError(t, err) + + requireServiceRunningStatus(t, false) + }) + + t.Run("Test invalid path for input file on update", func(t *testing.T) { + tempDir := t.TempDir() + testProductName := "Test Product" + testServiceName := "windows-service" + + serviceJSON := filepath.Join(tempDir, "windows-service.json") + testServiceProgram := filepath.Join(tempDir, "windows-service.exe") + serviceGoFile, err := filepath.Abs(filepath.Join("testdata", "test-windows-service.go")) + require.NoError(t, err) + + writeServiceFile(t, serviceJSON, filepath.Join("testdata", "windows-service.json"), serviceGoFile) + compileProgram(t, serviceGoFile, testServiceProgram) + + installService(t, + testServiceProgram, + testServiceName, + "Test Windows Service - Initial Display Name", + "This is the test windows service; initial desription", + mgr.StartManual, + false) + + t.Cleanup(func() { + uninstallService(t, testServiceName) + time.Sleep(100 * time.Millisecond) + }) + + w := &windowsService{ + newServiceFilePath: filepath.Join(tempDir, "not-a-valid-service.json"), + serviceName: testServiceName, + productName: testProductName, + installDir: tempDir, + logger: zaptest.NewLogger(t), + } + + err = w.Update() + require.ErrorContains(t, err, "The system cannot find the file specified.") + requireServiceLoadedStatus(t, true) + requireServiceRunningStatus(t, false) + }) + + t.Run("Update fails if not installed", func(t *testing.T) { + tempDir := t.TempDir() + testProductName := "Test Product" + testServiceName := "windows-service" + + serviceJSON := filepath.Join(tempDir, "windows-service.json") + testServiceProgram := filepath.Join(tempDir, "windows-service.exe") + serviceGoFile, err := filepath.Abs(filepath.Join("testdata", "test-windows-service.go")) + require.NoError(t, err) + + writeServiceFile(t, serviceJSON, filepath.Join("testdata", "windows-service.json"), serviceGoFile) + compileProgram(t, serviceGoFile, testServiceProgram) + + w := &windowsService{ + newServiceFilePath: serviceJSON, + serviceName: testServiceName, + installDir: tempDir, + productName: testProductName, + } + + err = w.Update() + require.ErrorContains(t, err, "failed to open service") + requireServiceLoadedStatus(t, false) + }) + + t.Run("Start fails if service not found", func(t *testing.T) { + tempDir := t.TempDir() + testProductName := "Test Product" + + serviceJSON := filepath.Join(tempDir, "windows-service.json") + + w := &windowsService{ + newServiceFilePath: serviceJSON, + serviceName: "windows-service", + productName: testProductName, + installDir: tempDir, + logger: zaptest.NewLogger(t), + } + + err := w.Start() + require.ErrorContains(t, err, "failed to open service") + }) + + t.Run("Stop fails if service not found", func(t *testing.T) { + tempDir := t.TempDir() + testProductName := "Test Product" + + serviceJSON := filepath.Join(tempDir, "windows-service.json") + + w := &windowsService{ + newServiceFilePath: serviceJSON, + serviceName: "windows-service", + productName: testProductName, + installDir: tempDir, + logger: zaptest.NewLogger(t), + } + + err := w.Stop() + require.ErrorContains(t, err, "failed to open service") + }) + + t.Run("Test backup works", func(t *testing.T) { + tempDir := t.TempDir() + installDir, err := filepath.Abs(filepath.Join(tempDir, "install directory")) + require.NoError(t, err) + require.NoError(t, os.MkdirAll(path.BackupDir(installDir), 0775)) + + testProductName := "Test Product" + testServiceName := "windows-service" + + serviceJSON := filepath.Join(installDir, "windows-service.json") + testServiceProgram := filepath.Join(installDir, "windows-service.exe") + serviceGoFile, err := filepath.Abs(filepath.Join("testdata", "test-windows-service.go")) + require.NoError(t, err) + + writeServiceFile(t, serviceJSON, filepath.Join("testdata", "windows-service.json"), serviceGoFile) + compileProgram(t, serviceGoFile, testServiceProgram) + + installService(t, + testServiceProgram, + testServiceName, + "Test Windows Service - Initial Display Name", + "This is the test windows service; initial desription", + mgr.StartManual, + false) + + t.Cleanup(func() { + uninstallService(t, testServiceName) + time.Sleep(100 * time.Millisecond) + }) + + w := &windowsService{ + newServiceFilePath: serviceJSON, + serviceName: "windows-service", + productName: testProductName, + installDir: installDir, + logger: zaptest.NewLogger(t), + } + + require.NoError(t, w.Update()) + + //We want to check that the service was actually loaded + requireServiceLoadedStatus(t, true) + + requireServiceConfigMatches(t, + testServiceProgram, + "windows-service", + mgr.StartAutomatic, + "Test Windows Service", + "This is a windows service to test", + true, + []string{ + "--config", + filepath.Join(installDir, "test.yaml"), + }, + ) + + // Take a backup; Assert the backup makes sense. + // It will not be the same as the original service file due to expansion of INSTALLDIR + // which is OK and expected. + err = w.Backup() + require.NoError(t, err) + + backupSvcFile := path.BackupServiceFile(installDir) + + svcCfg, err := readWindowsServiceConfig(backupSvcFile) + require.NoError(t, err) + + assert.Equal(t, &windowsServiceConfig{ + Path: "windows-service.exe", + Service: windowsServiceDefinitionConfig{ + Start: "delayed", + DisplayName: "Test Windows Service", + Description: "This is a windows service to test", + Arguments: fmt.Sprintf("--config "%s"", filepath.Join(installDir, "test.yaml")), + }, + }, svcCfg) + + }) +} + +func TestStartType(t *testing.T) { + testCases := []struct { + cfgStartType string + startType uint32 + delayed bool + expectedErr string + }{ + { + cfgStartType: "auto", + startType: mgr.StartAutomatic, + delayed: false, + }, + { + cfgStartType: "demand", + startType: mgr.StartManual, + delayed: false, + }, + { + cfgStartType: "disabled", + startType: mgr.StartDisabled, + delayed: false, + }, + { + cfgStartType: "delayed", + startType: mgr.StartAutomatic, + delayed: true, + }, + { + cfgStartType: "not-a-real-start-type", + expectedErr: "invalid start type in service config", + }, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("cfgStartType: %s", tc.cfgStartType), func(t *testing.T) { + st, d, err := winapiStartType(tc.cfgStartType) + if tc.expectedErr != "" { + require.ErrorContains(t, err, tc.expectedErr) + } else { + assert.Equal(t, tc.startType, st) + assert.Equal(t, tc.delayed, d) + } + }) + } +} + +func installService(t *testing.T, binPath, serviceName, displayName, description string, startType uint32, delayed bool) { + t.Helper() + + m, err := mgr.Connect() + if err != nil { + require.Fail(t, "failed to connect to service manager", "failed to connect to service manager: %s", err) + } + defer m.Disconnect() + + s, err := m.CreateService(serviceName, binPath, mgr.Config{ + DisplayName: displayName, + Description: description, + StartType: startType, + DelayedAutoStart: delayed, + }) + require.NoError(t, err) + require.NoError(t, s.Close()) +} + +// uninstallService is a helper that uninstalls the service manually for test setup, in case it is somehow leftover. +func uninstallService(t *testing.T, serviceName string) { + m, err := mgr.Connect() + require.NoError(t, err) + defer m.Disconnect() + + s, err := m.OpenService(serviceName) + if err != nil { + // Failed to open the service, we assume it doesn't exist + t.Logf("failed to open service: %s", err) + return + } + defer s.Close() + + status, err := s.Control(svc.Stop) + // If we get an error, the service is likely already in a stopped state. + if err == nil { + for status.State != svc.Stopped { + time.Sleep(100 * time.Millisecond) + status, err = s.Query() + require.NoError(t, err) + } + } else { + t.Logf("failed to stop service: %s", err) + } + + err = s.Delete() + require.NoError(t, err) + + s.Close() + + const serviceNotExistErrStr = "The specified service does not exist as an installed service." + for { + s, err := m.OpenService(serviceName) + if err != nil { + if err.Error() == serviceNotExistErrStr { + // This is expected when the service is uninstalled. + t.Logf("Service no longer exists: %s", err) + break + } + require.FailNow(t, "failed to uninstall service", "got unexpected error when waiting for service deletion: %s", err) + } + + if err := s.Close(); err != nil { + require.FailNow(t, "failed to uninstall service", "got unexpected error when closing service handle: %s", err) + } + // rest with the handle closed to let the service manager remove the service + time.Sleep(50 * time.Millisecond) + } +} + +func requireServiceLoadedStatus(t *testing.T, loaded bool) { + t.Helper() + + m, err := mgr.Connect() + require.NoError(t, err, "failed to connect to service manager") + defer m.Disconnect() + + s, err := m.OpenService("windows-service") + if err != nil { + require.False(t, loaded, "Could not connect open service, but service should be loaded") + return + } + defer s.Close() + + require.True(t, loaded, "Connected to open service, but it should not be loaded") + +} + +func requireServiceConfigMatches(t *testing.T, binaryPath, name string, startType uint32, displayName, description string, delayed bool, args []string) { + t.Helper() + + m, err := mgr.Connect() + require.NoError(t, err, "failed to connect to service manager") + defer m.Disconnect() + + s, err := m.OpenService(name) + require.NoError(t, err, "failed to open service") + defer s.Close() + + cfg, err := s.Config() + require.NoError(t, err) + + expectedBinaryPathName := joinArgs(append([]string{binaryPath}, args...)...) + assert.Equal(t, displayName, cfg.DisplayName) + assert.Equal(t, description, cfg.Description) + assert.Equal(t, delayed, cfg.DelayedAutoStart) + assert.Equal(t, startType, cfg.StartType) + assert.Equal(t, expectedBinaryPathName, cfg.BinaryPathName) + // We always install as LocalSystem, which is the "super user" of the system + assert.Equal(t, "LocalSystem", cfg.ServiceStartName) +} + +func requireServiceRunningStatus(t *testing.T, running bool) { + t.Helper() + + m, err := mgr.Connect() + require.NoError(t, err, "failed to connect to service manager") + defer m.Disconnect() + + s, err := m.OpenService("windows-service") + require.NoError(t, err, "Failed to open service") + defer s.Close() + + status, err := s.Query() + require.NoError(t, err, "Failed to query service state") + + if running { + require.Contains(t, []svc.State{svc.StartPending, svc.Running}, status.State) + } else { + require.Contains(t, []svc.State{svc.StopPending, svc.Stopped}, status.State) + } +} + +func writeServiceFile(t *testing.T, outPath, inPath, serviceGoPath string) { + t.Helper() + + b, err := os.ReadFile(inPath) + require.NoError(t, err) + + fileStr := string(b) + fileStr = os.Expand(fileStr, func(s string) string { + switch s { + case "SERVICE_PATH": + return strings.ReplaceAll(serviceGoPath, `\`, `\\`) + } + return "" + }) + + err = os.WriteFile(outPath, []byte(fileStr), 0666) + require.NoError(t, err) +} + +func compileProgram(t *testing.T, inPath, outPath string) { + t.Helper() + + cmd := exec.Command("go.exe", "build", "-o", outPath, inPath) + err := cmd.Run() + require.NoError(t, err) +} + +func joinArgs(args ...string) string { + sb := strings.Builder{} + for _, arg := range args { + if strings.Contains(arg, " ") { + sb.WriteString(`"`) + sb.WriteString(arg) + sb.WriteString(`"`) + } else { + sb.WriteString(arg) + } + sb.WriteString(" ") + } + + str := sb.String() + return str[:len(str)-1] +} diff --git a/updater/internal/service/testdata/darwin-service.plist b/updater/internal/service/testdata/darwin-service.plist new file mode 100644 index 000000000..f50cabaf7 --- /dev/null +++ b/updater/internal/service/testdata/darwin-service.plist @@ -0,0 +1,15 @@ + + + + + Label + darwin-service + ProgramArguments + + sleep + 1000 + + KeepAlive + + + diff --git a/updater/internal/service/testdata/linux-service.service b/updater/internal/service/testdata/linux-service.service new file mode 100644 index 000000000..5d72584a2 --- /dev/null +++ b/updater/internal/service/testdata/linux-service.service @@ -0,0 +1,10 @@ +[Unit] +Description=Test service +After=network.target +[Service] +Type=simple +User=root +ExecStart=sleep 1000 +SuccessExitStatus=0 +[Install] +WantedBy=multi-user.target diff --git a/updater/internal/service/testdata/test-windows-service.go b/updater/internal/service/testdata/test-windows-service.go new file mode 100644 index 000000000..c7acecef3 --- /dev/null +++ b/updater/internal/service/testdata/test-windows-service.go @@ -0,0 +1,55 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import ( + "log" + + "golang.org/x/sys/windows/svc" +) + +func main() { + winSvc, err := svc.IsWindowsService() + if err != nil { + log.Fatalf("Failed to determine if we were a windows service") + } + + if !winSvc { + log.Fatalf("This program must be run as a windows service") + } + + err = svc.Run("", &windowsService{}) + if err != nil { + log.Fatalf("Failed to run service: %s", err) + } + +} + +type windowsService struct{} + +func (sh *windowsService) Execute(args []string, r <-chan svc.ChangeRequest, s chan<- svc.Status) (bool, uint32) { + s <- svc.Status{State: svc.Running, Accepts: svc.AcceptStop | svc.AcceptShutdown} + for { + req := <-r + switch req.Cmd { + case svc.Interrogate: + s <- req.CurrentStatus + case svc.Stop, svc.Shutdown: + return false, 0 + default: + return false, 1052 + } + } +} diff --git a/updater/internal/service/testdata/windows-service.json b/updater/internal/service/testdata/windows-service.json new file mode 100644 index 000000000..2805e768f --- /dev/null +++ b/updater/internal/service/testdata/windows-service.json @@ -0,0 +1,10 @@ +{ + "path": "windows-service.exe", + "service": { + "name": "windows-service", + "start": "delayed", + "display-name": "Test Windows Service", + "description": "This is a windows service to test", + "arguments": "--config "[INSTALLDIR]test.yaml"" + } +} diff --git a/updater/internal/state/mocks/mock_monitor.go b/updater/internal/state/mocks/mock_monitor.go new file mode 100644 index 000000000..5d772548e --- /dev/null +++ b/updater/internal/state/mocks/mock_monitor.go @@ -0,0 +1,55 @@ +// Code generated by mockery v2.12.2. DO NOT EDIT. + +package mocks + +import ( + context "context" + + protobufs "github.com/open-telemetry/opamp-go/protobufs" + mock "github.com/stretchr/testify/mock" + + testing "testing" +) + +// MockMonitor is an autogenerated mock type for the Monitor type +type MockMonitor struct { + mock.Mock +} + +// MonitorForSuccess provides a mock function with given fields: ctx, packageName +func (_m *MockMonitor) MonitorForSuccess(ctx context.Context, packageName string) error { + ret := _m.Called(ctx, packageName) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string) error); ok { + r0 = rf(ctx, packageName) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// SetState provides a mock function with given fields: packageName, status, statusErr +func (_m *MockMonitor) SetState(packageName string, status protobufs.PackageStatus_Status, statusErr error) error { + ret := _m.Called(packageName, status, statusErr) + + var r0 error + if rf, ok := ret.Get(0).(func(string, protobufs.PackageStatus_Status, error) error); ok { + r0 = rf(packageName, status, statusErr) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// NewMockMonitor creates a new instance of MockMonitor. It also registers the testing.TB interface on the mock and a cleanup function to assert the mocks expectations. +func NewMockMonitor(t testing.TB) *MockMonitor { + mock := &MockMonitor{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/updater/internal/state/monitor.go b/updater/internal/state/monitor.go new file mode 100644 index 000000000..834535b81 --- /dev/null +++ b/updater/internal/state/monitor.go @@ -0,0 +1,134 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package state contains structures to monitor and update the state of the collector in the package status +package state + +import ( + "context" + "errors" + "fmt" + "path/filepath" + "time" + + "github.com/observiq/observiq-otel-collector/packagestate" + "github.com/open-telemetry/opamp-go/protobufs" + "go.uber.org/zap" +) + +var ( + // ErrFailedStatus is the error when the Package status indicates a failure + ErrFailedStatus = errors.New("package status indicates failure") +) + +// Monitor allows checking and setting state of active install +type Monitor interface { + // SetState sets the state for the package. + // If passed in statusErr is not nil it will record the error as the message + SetState(packageName string, status protobufs.PackageStatus_Status, statusErr error) error + + // MonitorForSuccess will periodically check the state of the package. It will keep checking until the context is canceled or a failed/success state is detected. + // It will return an error if status is Failed or if the context times out. + MonitorForSuccess(ctx context.Context, packageName string) error +} + +// CollectorMonitor implements Monitor interface for monitoring the Collector Package Status file +type CollectorMonitor struct { + stateManager packagestate.StateManager + currentStatus *protobufs.PackageStatuses +} + +// NewCollectorMonitor create a new Monitor specifically for the collector +func NewCollectorMonitor(logger *zap.Logger, installDir string) (Monitor, error) { + namedLogger := logger.Named("collector-monitor") + + // Create a collector monitor + packageStatusPath := filepath.Join(installDir, packagestate.DefaultFileName) + collectorMonitor := &CollectorMonitor{ + stateManager: packagestate.NewFileStateManager(namedLogger, packageStatusPath), + } + + // Load the current status to ensure the package status file exists + var err error + collectorMonitor.currentStatus, err = collectorMonitor.stateManager.LoadStatuses() + if err != nil { + return nil, fmt.Errorf("failed to load package statues: %w", err) + } + + return collectorMonitor, nil + +} + +// SetState sets the status on the specified package and saves it to the package status file +func (c *CollectorMonitor) SetState(packageName string, status protobufs.PackageStatus_Status, statusErr error) error { + // Verify we have package by that name + targetPackage, ok := c.currentStatus.GetPackages()[packageName] + if !ok { + return fmt.Errorf("no package for name %s", packageName) + } + + // Update the status + targetPackage.Status = status + + // If that passed in error is not nil set it as the error message + if statusErr != nil { + targetPackage.ErrorMessage = statusErr.Error() + } + + c.currentStatus.GetPackages()[packageName] = targetPackage + + // Save to updated status to disk + return c.stateManager.SaveStatuses(c.currentStatus) +} + +// MonitorForSuccess intermittently checks the package status file for either an install failed or success status. +// If an InstallFailed status is read this returns ErrFailedStatus error. +// If the context is canceled the context error will be returned. +func (c *CollectorMonitor) MonitorForSuccess(ctx context.Context, packageName string) error { + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + packageStatus, err := c.stateManager.LoadStatuses() + switch { + // If there is any error we'll just continue. Some valid reasons we could error and should retry: + // - File was deleted by new collector before it's rewritten + // - File is being written to while we're reading it so we'll get invalid JSON + case err != nil: + continue + default: + targetPackage, ok := packageStatus.GetPackages()[packageName] + // Target package might not exist yet so continue + if !ok { + continue + } + + switch targetPackage.GetStatus() { + case protobufs.PackageStatus_InstallFailed: + return ErrFailedStatus + case protobufs.PackageStatus_Installed: + // Install successful + return nil + default: + // Collector may still be starting up or we may have read the file while it's being written + continue + } + } + } + } +} diff --git a/updater/internal/state/monitor_test.go b/updater/internal/state/monitor_test.go new file mode 100644 index 000000000..82d5c0dbf --- /dev/null +++ b/updater/internal/state/monitor_test.go @@ -0,0 +1,367 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package state + +import ( + "context" + "errors" + "os" + "testing" + + "github.com/observiq/observiq-otel-collector/packagestate/mocks" + "github.com/open-telemetry/opamp-go/protobufs" + "github.com/stretchr/testify/assert" +) + +func TestCollectorMonitorSetState(t *testing.T) { + testCases := []struct { + desc string + testFunc func(*testing.T) + }{ + { + desc: "Package not in current status", + testFunc: func(*testing.T) { + mockStateManger := mocks.NewMockStateManager(t) + + collectorMonitor := &CollectorMonitor{ + stateManager: mockStateManger, + currentStatus: &protobufs.PackageStatuses{ + Packages: make(map[string]*protobufs.PackageStatus), + }, + } + + err := collectorMonitor.SetState("my_package", protobufs.PackageStatus_Installed, nil) + assert.Error(t, err) + }, + }, + { + desc: "Sets Status no error", + testFunc: func(*testing.T) { + pgkName := "my_package" + expectedStatus := &protobufs.PackageStatuses{ + Packages: map[string]*protobufs.PackageStatus{ + pgkName: { + Name: pgkName, + AgentHasVersion: "1.0", + AgentHasHash: []byte("hash1"), + ServerOfferedVersion: "1.2", + ServerOfferedHash: []byte("hash2"), + Status: protobufs.PackageStatus_Installed, + }, + }, + } + + mockStateManger := mocks.NewMockStateManager(t) + mockStateManger.On("SaveStatuses", expectedStatus).Return(nil) + + collectorMonitor := &CollectorMonitor{ + stateManager: mockStateManger, + currentStatus: &protobufs.PackageStatuses{ + Packages: map[string]*protobufs.PackageStatus{ + pgkName: { + Name: pgkName, + AgentHasVersion: "1.0", + AgentHasHash: []byte("hash1"), + ServerOfferedVersion: "1.2", + ServerOfferedHash: []byte("hash2"), + Status: protobufs.PackageStatus_InstallPending, + }, + }, + }, + } + + err := collectorMonitor.SetState("my_package", protobufs.PackageStatus_Installed, nil) + assert.NoError(t, err) + assert.Equal(t, expectedStatus, collectorMonitor.currentStatus) + }, + }, + { + desc: "Sets Status w/error", + testFunc: func(*testing.T) { + pgkName := "my_package" + statusErr := errors.New("some error") + + expectedStatus := &protobufs.PackageStatuses{ + Packages: map[string]*protobufs.PackageStatus{ + pgkName: { + Name: pgkName, + AgentHasVersion: "1.0", + AgentHasHash: []byte("hash1"), + ServerOfferedVersion: "1.2", + ServerOfferedHash: []byte("hash2"), + Status: protobufs.PackageStatus_InstallFailed, + ErrorMessage: statusErr.Error(), + }, + }, + } + + mockStateManger := mocks.NewMockStateManager(t) + mockStateManger.On("SaveStatuses", expectedStatus).Return(nil) + + collectorMonitor := &CollectorMonitor{ + stateManager: mockStateManger, + currentStatus: &protobufs.PackageStatuses{ + Packages: map[string]*protobufs.PackageStatus{ + pgkName: { + Name: pgkName, + AgentHasVersion: "1.0", + AgentHasHash: []byte("hash1"), + ServerOfferedVersion: "1.2", + ServerOfferedHash: []byte("hash2"), + Status: protobufs.PackageStatus_InstallPending, + }, + }, + }, + } + + err := collectorMonitor.SetState("my_package", protobufs.PackageStatus_InstallFailed, statusErr) + assert.NoError(t, err) + assert.Equal(t, expectedStatus, collectorMonitor.currentStatus) + }, + }, + { + desc: "StateManager fails to save", + testFunc: func(*testing.T) { + pgkName := "my_package" + expectedErr := errors.New("bad") + expectedStatus := &protobufs.PackageStatuses{ + Packages: map[string]*protobufs.PackageStatus{ + pgkName: { + Name: pgkName, + AgentHasVersion: "1.0", + AgentHasHash: []byte("hash1"), + ServerOfferedVersion: "1.2", + ServerOfferedHash: []byte("hash2"), + Status: protobufs.PackageStatus_Installed, + }, + }, + } + + mockStateManger := mocks.NewMockStateManager(t) + mockStateManger.On("SaveStatuses", expectedStatus).Return(expectedErr) + + collectorMonitor := &CollectorMonitor{ + stateManager: mockStateManger, + currentStatus: &protobufs.PackageStatuses{ + Packages: map[string]*protobufs.PackageStatus{ + pgkName: { + Name: pgkName, + AgentHasVersion: "1.0", + AgentHasHash: []byte("hash1"), + ServerOfferedVersion: "1.2", + ServerOfferedHash: []byte("hash2"), + Status: protobufs.PackageStatus_InstallPending, + }, + }, + }, + } + + err := collectorMonitor.SetState("my_package", protobufs.PackageStatus_Installed, nil) + assert.ErrorIs(t, err, expectedErr) + assert.Equal(t, expectedStatus, collectorMonitor.currentStatus) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, tc.testFunc) + } +} + +func TestCollectorMonitorMonitorForSuccess(t *testing.T) { + testCases := []struct { + desc string + testFunc func(*testing.T) + }{ + { + desc: "Context is canceled", + testFunc: func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + mockStateManger := mocks.NewMockStateManager(t) + collectorMonitor := &CollectorMonitor{ + stateManager: mockStateManger, + } + + err := collectorMonitor.MonitorForSuccess(ctx, "my_package") + assert.ErrorIs(t, err, context.Canceled) + }, + }, + { + desc: "Package Status Indicates Failed Install", + testFunc: func(t *testing.T) { + pgkName := "my_package" + returnedStatus := &protobufs.PackageStatuses{ + Packages: map[string]*protobufs.PackageStatus{ + pgkName: { + Name: pgkName, + Status: protobufs.PackageStatus_InstallFailed, + }, + }, + } + + mockStateManger := mocks.NewMockStateManager(t) + mockStateManger.On("LoadStatuses").Return(returnedStatus, nil) + + collectorMonitor := &CollectorMonitor{ + stateManager: mockStateManger, + } + + err := collectorMonitor.MonitorForSuccess(context.Background(), pgkName) + assert.ErrorIs(t, err, ErrFailedStatus) + }, + }, + { + desc: "Package Status Indicates Successful install", + testFunc: func(t *testing.T) { + pgkName := "my_package" + returnedStatus := &protobufs.PackageStatuses{ + Packages: map[string]*protobufs.PackageStatus{ + pgkName: { + Name: pgkName, + Status: protobufs.PackageStatus_Installed, + }, + }, + } + + mockStateManger := mocks.NewMockStateManager(t) + mockStateManger.On("LoadStatuses").Return(returnedStatus, nil) + + collectorMonitor := &CollectorMonitor{ + stateManager: mockStateManger, + } + + err := collectorMonitor.MonitorForSuccess(context.Background(), pgkName) + assert.NoError(t, err) + }, + }, + { + desc: "File does not exist at first then is successful", + testFunc: func(t *testing.T) { + pgkName := "my_package" + returnedStatus := &protobufs.PackageStatuses{ + Packages: map[string]*protobufs.PackageStatus{ + pgkName: { + Name: pgkName, + Status: protobufs.PackageStatus_Installed, + }, + }, + } + + mockStateManger := mocks.NewMockStateManager(t) + mockStateManger.On("LoadStatuses").Once().Return(nil, os.ErrNotExist) + mockStateManger.On("LoadStatuses").Return(returnedStatus, nil) + + collectorMonitor := &CollectorMonitor{ + stateManager: mockStateManger, + } + + err := collectorMonitor.MonitorForSuccess(context.Background(), pgkName) + assert.NoError(t, err) + }, + }, + { + desc: "Error reading file at first first then is successful", + testFunc: func(t *testing.T) { + pgkName := "my_package" + returnedStatus := &protobufs.PackageStatuses{ + Packages: map[string]*protobufs.PackageStatus{ + pgkName: { + Name: pgkName, + Status: protobufs.PackageStatus_Installed, + }, + }, + } + + mockStateManger := mocks.NewMockStateManager(t) + mockStateManger.On("LoadStatuses").Once().Return(nil, errors.New("bad")) + mockStateManger.On("LoadStatuses").Return(returnedStatus, nil) + + collectorMonitor := &CollectorMonitor{ + stateManager: mockStateManger, + } + + err := collectorMonitor.MonitorForSuccess(context.Background(), pgkName) + assert.NoError(t, err) + }, + }, + { + desc: "Package is not present at first then is successful", + testFunc: func(t *testing.T) { + pgkName := "my_package" + firstStatus := &protobufs.PackageStatuses{ + Packages: map[string]*protobufs.PackageStatus{}, + } + secondStatus := &protobufs.PackageStatuses{ + Packages: map[string]*protobufs.PackageStatus{ + pgkName: { + Name: pgkName, + Status: protobufs.PackageStatus_Installed, + }, + }, + } + + mockStateManger := mocks.NewMockStateManager(t) + mockStateManger.On("LoadStatuses").Once().Return(firstStatus, nil) + mockStateManger.On("LoadStatuses").Return(secondStatus, nil) + + collectorMonitor := &CollectorMonitor{ + stateManager: mockStateManger, + } + + err := collectorMonitor.MonitorForSuccess(context.Background(), pgkName) + assert.NoError(t, err) + }, + }, + { + desc: "Package is still marked as Installing at first then is successful", + testFunc: func(t *testing.T) { + pgkName := "my_package" + firstStatus := &protobufs.PackageStatuses{ + Packages: map[string]*protobufs.PackageStatus{ + pgkName: { + Name: pgkName, + Status: protobufs.PackageStatus_InstallPending, + }, + }, + } + secondStatus := &protobufs.PackageStatuses{ + Packages: map[string]*protobufs.PackageStatus{ + pgkName: { + Name: pgkName, + Status: protobufs.PackageStatus_Installed, + }, + }, + } + + mockStateManger := mocks.NewMockStateManager(t) + mockStateManger.On("LoadStatuses").Once().Return(firstStatus, nil) + mockStateManger.On("LoadStatuses").Return(secondStatus, nil) + + collectorMonitor := &CollectorMonitor{ + stateManager: mockStateManger, + } + + err := collectorMonitor.MonitorForSuccess(context.Background(), pgkName) + assert.NoError(t, err) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, tc.testFunc) + } +} diff --git a/updater/internal/updater/testdata/package_statuses.json b/updater/internal/updater/testdata/package_statuses.json new file mode 100644 index 000000000..0967ef424 --- /dev/null +++ b/updater/internal/updater/testdata/package_statuses.json @@ -0,0 +1 @@ +{} diff --git a/updater/internal/updater/updater.go b/updater/internal/updater/updater.go new file mode 100644 index 000000000..fe408c18f --- /dev/null +++ b/updater/internal/updater/updater.go @@ -0,0 +1,146 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package updater + +import ( + "context" + "errors" + "fmt" + "os" + "time" + + "github.com/observiq/observiq-otel-collector/packagestate" + "github.com/observiq/observiq-otel-collector/updater/internal/action" + "github.com/observiq/observiq-otel-collector/updater/internal/install" + "github.com/observiq/observiq-otel-collector/updater/internal/path" + "github.com/observiq/observiq-otel-collector/updater/internal/rollback" + "github.com/observiq/observiq-otel-collector/updater/internal/service" + "github.com/observiq/observiq-otel-collector/updater/internal/state" + "github.com/open-telemetry/opamp-go/protobufs" + "go.uber.org/zap" +) + +// Updater is a struct that can be used to perform a collector update +type Updater struct { + installDir string + installer install.Installer + svc service.Service + rollbacker rollback.Rollbacker + monitor state.Monitor + logger *zap.Logger +} + +// NewUpdater creates a new updater which can be used to update the installation based at +// installDir +func NewUpdater(logger *zap.Logger, installDir string) (*Updater, error) { + monitor, err := state.NewCollectorMonitor(logger, installDir) + if err != nil { + return nil, fmt.Errorf("failed to create monitor: %w", err) + } + + svc := service.NewService(logger, installDir) + return &Updater{ + installDir: installDir, + installer: install.NewInstaller(logger, installDir, svc), + svc: svc, + rollbacker: rollback.NewRollbacker(logger, installDir), + monitor: monitor, + logger: logger, + }, nil +} + +// Update performs the update of the collector binary +func (u *Updater) Update() error { + // Stop the service before backing up the install directory; + // We want to stop as early as possible so that we don't hit the collector's timeout + // while it waits to be shutdown. + if err := u.svc.Stop(); err != nil { + return fmt.Errorf("failed to stop service: %w", err) + } + // Record that we stopped the service + u.rollbacker.AppendAction(action.NewServiceStopAction(u.svc)) + + // Now that we stopped the service, it will be our responsibility to cleanup the tmp dir. + // We will do this regardless of whether we succeed or fail after this point. + defer u.removeTmpDir() + + u.logger.Debug("Stopped the service") + + // Create the backup + if err := u.rollbacker.Backup(); err != nil { + u.logger.Error("Failed to backup", zap.Error(err)) + + // Set the state to failed before rollback so collector knows it failed + if setErr := u.monitor.SetState(packagestate.CollectorPackageName, protobufs.PackageStatus_InstallFailed, err); setErr != nil { + u.logger.Error("Failed to set state on backup failure", zap.Error(setErr)) + } + + u.rollbacker.Rollback() + + u.logger.Error("Rollback complete") + return fmt.Errorf("failed to backup: %w", err) + } + + // Install artifacts + if err := u.installer.Install(u.rollbacker); err != nil { + u.logger.Error("Failed to install", zap.Error(err)) + + // Set the state to failed before rollback so collector knows it failed + if setErr := u.monitor.SetState(packagestate.CollectorPackageName, protobufs.PackageStatus_InstallFailed, err); setErr != nil { + u.logger.Error("Failed to set state on install failure", zap.Error(setErr)) + } + + u.rollbacker.Rollback() + + u.logger.Error("Rollback complete") + return fmt.Errorf("failed to install: %w", err) + } + + // Create a context with timeout to wait for a success or failed status + checkCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + u.logger.Debug("Installation successful, begin monitor for success") + + // Monitor the install state + if err := u.monitor.MonitorForSuccess(checkCtx, packagestate.CollectorPackageName); err != nil { + u.logger.Error("Failed to install", zap.Error(err)) + + // If this is not an error due to the collector setting a failed status we need to set a failed status + if !errors.Is(err, state.ErrFailedStatus) { + // Set the state to failed before rollback so collector knows it failed + if setErr := u.monitor.SetState(packagestate.CollectorPackageName, protobufs.PackageStatus_InstallFailed, err); setErr != nil { + u.logger.Error("Failed to set state on install failure", zap.Error(setErr)) + } + } + + u.rollbacker.Rollback() + + u.logger.Error("Rollback complete") + return fmt.Errorf("failed while monitoring for success: %w", err) + } + + // Successful update + u.logger.Info("Update Complete") + return nil +} + +// removeTmpDir removes the temporary directory that holds the update artifacts. +func (u *Updater) removeTmpDir() { + err := os.RemoveAll(path.TempDir(u.installDir)) + if err != nil { + u.logger.Error("failed to remove temporary directory", zap.Error(err)) + } +} diff --git a/updater/internal/updater/updater_test.go b/updater/internal/updater/updater_test.go new file mode 100644 index 000000000..692818448 --- /dev/null +++ b/updater/internal/updater/updater_test.go @@ -0,0 +1,317 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package updater + +import ( + "errors" + "testing" + + "github.com/observiq/observiq-otel-collector/packagestate" + "github.com/observiq/observiq-otel-collector/updater/internal/action" + install_mocks "github.com/observiq/observiq-otel-collector/updater/internal/install/mocks" + rollback_mocks "github.com/observiq/observiq-otel-collector/updater/internal/rollback/mocks" + service_mocks "github.com/observiq/observiq-otel-collector/updater/internal/service/mocks" + "github.com/observiq/observiq-otel-collector/updater/internal/state" + state_mocks "github.com/observiq/observiq-otel-collector/updater/internal/state/mocks" + "github.com/open-telemetry/opamp-go/protobufs" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" +) + +func TestNewUpdater(t *testing.T) { + t.Run("New updater is created successfully", func(t *testing.T) { + installDir := "testdata" + logger := zaptest.NewLogger(t) + updater, err := NewUpdater(logger, installDir) + require.NoError(t, err) + require.NotNil(t, updater) + assert.NotNil(t, updater.installer) + assert.NotNil(t, updater.svc) + assert.NotNil(t, updater.rollbacker) + assert.NotNil(t, updater.monitor) + assert.NotNil(t, updater.logger) + assert.Equal(t, installDir, updater.installDir) + }) + + t.Run("New updater fails due to missing package statuses", func(t *testing.T) { + installDir := t.TempDir() + logger := zaptest.NewLogger(t) + updater, err := NewUpdater(logger, installDir) + require.ErrorContains(t, err, "failed to create monitor") + require.Nil(t, updater) + }) +} + +func TestUpdaterUpdate(t *testing.T) { + t.Run("Update is successful", func(t *testing.T) { + installDir := t.TempDir() + + installer := install_mocks.NewInstaller(t) + svc := service_mocks.NewService(t) + rollbacker := rollback_mocks.NewRollbacker(t) + monitor := state_mocks.NewMockMonitor(t) + + updater := &Updater{ + installDir: installDir, + installer: installer, + svc: svc, + rollbacker: rollbacker, + monitor: monitor, + logger: zaptest.NewLogger(t), + } + + svc.On("Stop").Times(1).Return(nil) + rollbacker.On("AppendAction", action.NewServiceStopAction(svc)).Times(1).Return() + rollbacker.On("Backup").Times(1).Return(nil) + installer.On("Install", rollbacker).Times(1).Return(nil) + monitor.On("MonitorForSuccess", mock.Anything, packagestate.CollectorPackageName).Times(1).Return(nil) + + err := updater.Update() + require.NoError(t, err) + }) + + t.Run("Service stop fails", func(t *testing.T) { + installDir := t.TempDir() + + installer := install_mocks.NewInstaller(t) + svc := service_mocks.NewService(t) + rollbacker := rollback_mocks.NewRollbacker(t) + monitor := state_mocks.NewMockMonitor(t) + + updater := &Updater{ + installDir: installDir, + installer: installer, + svc: svc, + rollbacker: rollbacker, + monitor: monitor, + logger: zaptest.NewLogger(t), + } + + svc.On("Stop").Times(1).Return(errors.New("insufficient permissions")) + + err := updater.Update() + require.ErrorContains(t, err, "failed to stop service") + }) + + t.Run("Backup fails", func(t *testing.T) { + installDir := t.TempDir() + + installer := install_mocks.NewInstaller(t) + svc := service_mocks.NewService(t) + rollbacker := rollback_mocks.NewRollbacker(t) + monitor := state_mocks.NewMockMonitor(t) + + updater := &Updater{ + installDir: installDir, + installer: installer, + svc: svc, + rollbacker: rollbacker, + monitor: monitor, + logger: zaptest.NewLogger(t), + } + + err := errors.New("insufficient permissions") + + svc.On("Stop").Times(1).Return(nil) + rollbacker.On("AppendAction", action.NewServiceStopAction(svc)).Times(1).Return() + rollbacker.On("Backup").Times(1).Return(err) + monitor.On("SetState", packagestate.CollectorPackageName, protobufs.PackageStatus_InstallFailed, err).Times(1).Return(nil) + rollbacker.On("Rollback").Times(1).Return() + + err = updater.Update() + require.ErrorContains(t, err, "failed to backup") + }) + + t.Run("Backup fails, set state fails", func(t *testing.T) { + installDir := t.TempDir() + + installer := install_mocks.NewInstaller(t) + svc := service_mocks.NewService(t) + rollbacker := rollback_mocks.NewRollbacker(t) + monitor := state_mocks.NewMockMonitor(t) + + updater := &Updater{ + installDir: installDir, + installer: installer, + svc: svc, + rollbacker: rollbacker, + monitor: monitor, + logger: zaptest.NewLogger(t), + } + + err := errors.New("insufficient permissions") + + svc.On("Stop").Times(1).Return(nil) + rollbacker.On("AppendAction", action.NewServiceStopAction(svc)).Times(1).Return() + rollbacker.On("Backup").Times(1).Return(err) + monitor.On("SetState", packagestate.CollectorPackageName, protobufs.PackageStatus_InstallFailed, err).Times(1).Return(errors.New("insufficient permissions")) + rollbacker.On("Rollback").Times(1).Return() + + err = updater.Update() + require.ErrorContains(t, err, "failed to backup") + }) + + t.Run("Install fails", func(t *testing.T) { + installDir := t.TempDir() + + installer := install_mocks.NewInstaller(t) + svc := service_mocks.NewService(t) + rollbacker := rollback_mocks.NewRollbacker(t) + monitor := state_mocks.NewMockMonitor(t) + + updater := &Updater{ + installDir: installDir, + installer: installer, + svc: svc, + rollbacker: rollbacker, + monitor: monitor, + logger: zaptest.NewLogger(t), + } + + err := errors.New("insufficient permissions") + + svc.On("Stop").Times(1).Return(nil) + rollbacker.On("AppendAction", action.NewServiceStopAction(svc)).Times(1).Return() + rollbacker.On("Backup").Times(1).Return(nil) + installer.On("Install", rollbacker).Times(1).Return(err) + monitor.On("SetState", packagestate.CollectorPackageName, protobufs.PackageStatus_InstallFailed, err).Times(1).Return(nil) + rollbacker.On("Rollback").Times(1).Return() + + err = updater.Update() + require.ErrorContains(t, err, "failed to install") + }) + + t.Run("Install fails, set state fails", func(t *testing.T) { + installDir := t.TempDir() + + installer := install_mocks.NewInstaller(t) + svc := service_mocks.NewService(t) + rollbacker := rollback_mocks.NewRollbacker(t) + monitor := state_mocks.NewMockMonitor(t) + + updater := &Updater{ + installDir: installDir, + installer: installer, + svc: svc, + rollbacker: rollbacker, + monitor: monitor, + logger: zaptest.NewLogger(t), + } + + err := errors.New("insufficient permissions") + + svc.On("Stop").Times(1).Return(nil) + rollbacker.On("AppendAction", action.NewServiceStopAction(svc)).Times(1).Return() + rollbacker.On("Backup").Times(1).Return(nil) + installer.On("Install", rollbacker).Times(1).Return(err) + monitor.On("SetState", packagestate.CollectorPackageName, protobufs.PackageStatus_InstallFailed, err).Times(1).Return(errors.New("insufficient permissions")) + rollbacker.On("Rollback").Times(1).Return() + + err = updater.Update() + require.ErrorContains(t, err, "failed to install") + }) + + t.Run("Monitor for success fails to monitor", func(t *testing.T) { + installDir := t.TempDir() + + installer := install_mocks.NewInstaller(t) + svc := service_mocks.NewService(t) + rollbacker := rollback_mocks.NewRollbacker(t) + monitor := state_mocks.NewMockMonitor(t) + + updater := &Updater{ + installDir: installDir, + installer: installer, + svc: svc, + rollbacker: rollbacker, + monitor: monitor, + logger: zaptest.NewLogger(t), + } + + err := errors.New("insufficient permissions") + + svc.On("Stop").Times(1).Return(nil) + rollbacker.On("AppendAction", action.NewServiceStopAction(svc)).Times(1).Return() + rollbacker.On("Backup").Times(1).Return(nil) + installer.On("Install", rollbacker).Times(1).Return(nil) + monitor.On("MonitorForSuccess", mock.Anything, packagestate.CollectorPackageName).Times(1).Return(err) + monitor.On("SetState", packagestate.CollectorPackageName, protobufs.PackageStatus_InstallFailed, err).Times(1).Return(nil) + rollbacker.On("Rollback").Times(1).Return() + + err = updater.Update() + require.ErrorContains(t, err, "failed while monitoring for success") + }) + + t.Run("Monitor for success fails to monitor, set state fails", func(t *testing.T) { + installDir := t.TempDir() + + installer := install_mocks.NewInstaller(t) + svc := service_mocks.NewService(t) + rollbacker := rollback_mocks.NewRollbacker(t) + monitor := state_mocks.NewMockMonitor(t) + + updater := &Updater{ + installDir: installDir, + installer: installer, + svc: svc, + rollbacker: rollbacker, + monitor: monitor, + logger: zaptest.NewLogger(t), + } + + err := errors.New("insufficient permissions") + + svc.On("Stop").Times(1).Return(nil) + rollbacker.On("AppendAction", action.NewServiceStopAction(svc)).Times(1).Return() + rollbacker.On("Backup").Times(1).Return(nil) + installer.On("Install", rollbacker).Times(1).Return(nil) + monitor.On("MonitorForSuccess", mock.Anything, packagestate.CollectorPackageName).Times(1).Return(err) + monitor.On("SetState", packagestate.CollectorPackageName, protobufs.PackageStatus_InstallFailed, err).Times(1).Return(errors.New("insufficient permissions")) + rollbacker.On("Rollback").Times(1).Return() + + err = updater.Update() + require.ErrorContains(t, err, "failed while monitoring for success") + }) + + t.Run("Monitor for success finds error in package statuses", func(t *testing.T) { + installDir := t.TempDir() + + installer := install_mocks.NewInstaller(t) + svc := service_mocks.NewService(t) + rollbacker := rollback_mocks.NewRollbacker(t) + monitor := state_mocks.NewMockMonitor(t) + + updater := &Updater{ + installDir: installDir, + installer: installer, + svc: svc, + rollbacker: rollbacker, + monitor: monitor, + logger: zaptest.NewLogger(t), + } + + svc.On("Stop").Times(1).Return(nil) + rollbacker.On("AppendAction", action.NewServiceStopAction(svc)).Times(1).Return() + rollbacker.On("Backup").Times(1).Return(nil) + installer.On("Install", rollbacker).Times(1).Return(nil) + monitor.On("MonitorForSuccess", mock.Anything, packagestate.CollectorPackageName).Times(1).Return(state.ErrFailedStatus) + rollbacker.On("Rollback").Times(1).Return() + + err := updater.Update() + require.ErrorContains(t, err, "failed while monitoring for success") + }) +} diff --git a/updater/internal/version/version.go b/updater/internal/version/version.go new file mode 100644 index 000000000..beef0ef2e --- /dev/null +++ b/updater/internal/version/version.go @@ -0,0 +1,37 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package version + +// these will be replaced at link time by make. +var ( + version = "latest" // Semantic version, or "latest" by default + gitHash = "unknown" // Commit hash from which this build was generated + date = "unknown" // Date the build was generated +) + +// Version returns the version of the collector. +func Version() string { + return version +} + +// GitHash returns the githash associated with the collector's version. +func GitHash() string { + return gitHash +} + +// Date returns the publish date associated with the collector's version. +func Date() string { + return date +} diff --git a/updater/internal/version/version_test.go b/updater/internal/version/version_test.go new file mode 100644 index 000000000..2c7956d5d --- /dev/null +++ b/updater/internal/version/version_test.go @@ -0,0 +1,27 @@ +// Copyright observIQ, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package version + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestDefaults(t *testing.T) { + require.Equal(t, version, Version()) + require.Equal(t, gitHash, GitHash()) + require.Equal(t, date, Date()) +} diff --git a/windows/scripts/build-msi.sh b/windows/scripts/build-msi.sh index 3c740c72a..a6d1c7fe1 100755 --- a/windows/scripts/build-msi.sh +++ b/windows/scripts/build-msi.sh @@ -22,6 +22,7 @@ mkdir -p storage touch storage/.include cp "$PROJECT_BASE/dist/collector_windows_amd64.exe" "observiq-otel-collector.exe" +cp "$PROJECT_BASE/dist/updater_windows_amd64.exe" "updater.exe" vagrant winrm -c \ "cd C:/vagrant; go-msi.exe make -m observiq-otel-collector.msi --version v0.0.1 --arch amd64" diff --git a/windows/wix.json b/windows/wix.json index f79a7fbde..b29305224 100644 --- a/windows/wix.json +++ b/windows/wix.json @@ -18,6 +18,9 @@ "arguments": "--config "[INSTALLDIR]config.yaml" --logging "[INSTALLDIR]logging.yaml" --manager "[INSTALLDIR]manager.yaml"" } }, + { + "path": "updater.exe" + }, { "path": "config.yaml", "never_overwrite": true