Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow specifying ETag value while setting ACLs #38

Merged
merged 1 commit into from
Dec 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 62 additions & 30 deletions tailscale/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ func WithBaseURL(baseURL string) ClientOption {
}
}

func (c *Client) buildRequest(ctx context.Context, method, uri string, body interface{}) (*http.Request, error) {
// TODO: consider setting `headers` and `body` via opts to decrease the number of arguments.
func (c *Client) buildRequest(ctx context.Context, method, uri string, headers map[string]string, body interface{}) (*http.Request, error) {
knyar marked this conversation as resolved.
Show resolved Hide resolved
u, err := c.baseURL.Parse(uri)
if err != nil {
return nil, err
Expand All @@ -101,6 +102,10 @@ func (c *Client) buildRequest(ctx context.Context, method, uri string, body inte
return nil, err
}

for k, v := range headers {
req.Header.Set(k, v)
}

switch {
case body == nil:
req.Header.Set("Accept", contentType)
Expand Down Expand Up @@ -164,7 +169,7 @@ func (err APIError) Error() string {
func (c *Client) SetDNSSearchPaths(ctx context.Context, searchPaths []string) error {
const uriFmt = "/api/v2/tailnet/%v/dns/searchpaths"

req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), map[string][]string{
req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), nil, map[string][]string{
"searchPaths": searchPaths,
})
if err != nil {
Expand All @@ -178,7 +183,7 @@ func (c *Client) SetDNSSearchPaths(ctx context.Context, searchPaths []string) er
func (c *Client) DNSSearchPaths(ctx context.Context) ([]string, error) {
const uriFmt = "/api/v2/tailnet/%v/dns/searchpaths"

req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet), nil)
req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet), nil, nil)
if err != nil {
return nil, err
}
Expand All @@ -196,7 +201,7 @@ func (c *Client) DNSSearchPaths(ctx context.Context) ([]string, error) {
func (c *Client) SetDNSNameservers(ctx context.Context, dns []string) error {
const uriFmt = "/api/v2/tailnet/%v/dns/nameservers"

req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), map[string][]string{
req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), nil, map[string][]string{
"dns": dns,
})
if err != nil {
Expand All @@ -210,7 +215,7 @@ func (c *Client) SetDNSNameservers(ctx context.Context, dns []string) error {
func (c *Client) DNSNameservers(ctx context.Context) ([]string, error) {
const uriFmt = "/api/v2/tailnet/%v/dns/nameservers"

req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet), nil)
req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet), nil, nil)
if err != nil {
return nil, err
}
Expand All @@ -225,14 +230,18 @@ func (c *Client) DNSNameservers(ctx context.Context) ([]string, error) {

type (
ACL struct {
ACLs []ACLEntry `json:"acls" hujson:"ACLs,omitempty"`
AutoApprovers *ACLAutoApprovers `json:"autoapprovers,omitempty" hujson:"AutoApprovers,omitempty"`
Groups map[string][]string `json:"groups,omitempty" hujson:"Groups,omitempty"`
Hosts map[string]string `json:"hosts,omitempty" hujson:"Hosts,omitempty"`
TagOwners map[string][]string `json:"tagowners,omitempty" hujson:"TagOwners,omitempty"`
DERPMap *ACLDERPMap `json:"derpMap,omitempty" hujson:"DerpMap,omitempty"`
Tests []ACLTest `json:"tests,omitempty" hujson:"Tests,omitempty"`
SSH []ACLSSH `json:"ssh,omitempty" hujson:"SSH,omitempty"`
ACLs []ACLEntry `json:"acls" hujson:"ACLs,omitempty"`
AutoApprovers *ACLAutoApprovers `json:"autoapprovers,omitempty" hujson:"AutoApprovers,omitempty"`
Groups map[string][]string `json:"groups,omitempty" hujson:"Groups,omitempty"`
Hosts map[string]string `json:"hosts,omitempty" hujson:"Hosts,omitempty"`
TagOwners map[string][]string `json:"tagowners,omitempty" hujson:"TagOwners,omitempty"`
DERPMap *ACLDERPMap `json:"derpMap,omitempty" hujson:"DerpMap,omitempty"`
Tests []ACLTest `json:"tests,omitempty" hujson:"Tests,omitempty"`
SSH []ACLSSH `json:"ssh,omitempty" hujson:"SSH,omitempty"`
NodeAttrs []NodeAttrGrant `json:"nodeAttrs,omitempty" hujson:"NodeAttrs,omitempty"`
DisableIPv4 bool `json:"disableIPv4,omitempty" hujson:"DisableIPv4,omitempty"`
OneCGNATRoute string `json:"oneCGNATRoute,omitempty" hujson:"OneCGNATRoute,omitempty"`
RandomizeClientPort bool `json:"randomizeClientPort,omitempty" hujson:"RandomizeClientPort,omitempty"`
}

ACLAutoApprovers struct {
Expand Down Expand Up @@ -291,13 +300,18 @@ type (
Destination []string `json:"dst" hujson:"Dst"`
CheckPeriod Duration `json:"checkPeriod" hujson:"CheckPeriod"`
}

NodeAttrGrant struct {
Target []string `json:"target,omitempty" hujson:"Target,omitempty"`
Attr []string `json:"attr,omitempty" hujson:"Attr,omitempty"`
}
)

// ACL retrieves the ACL that is currently set for the given tailnet.
func (c *Client) ACL(ctx context.Context) (*ACL, error) {
const uriFmt = "/api/v2/tailnet/%s/acl"

req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet), nil)
req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet), nil, nil)
if err != nil {
return nil, err
}
Expand All @@ -310,11 +324,29 @@ func (c *Client) ACL(ctx context.Context) (*ACL, error) {
return &resp, nil
}

type setACLParams struct {
headers map[string]string
}
type SetACLOption func(p *setACLParams)

// WithETag allows passing an ETag value with Set ACL API call that
// will be used in the `If-Match` HTTP request header.
func WithETag(etag string) SetACLOption {
return func(p *setACLParams) {
p.headers["If-Match"] = fmt.Sprintf("%q", etag)
}
}

// SetACL sets the ACL for the given tailnet.
func (c *Client) SetACL(ctx context.Context, acl ACL) error {
func (c *Client) SetACL(ctx context.Context, acl ACL, opts ...SetACLOption) error {
const uriFmt = "/api/v2/tailnet/%s/acl"

req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), acl)
p := &setACLParams{headers: make(map[string]string)}
for _, opt := range opts {
opt(p)
}

req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), p.headers, acl)
if err != nil {
return err
}
Expand All @@ -326,7 +358,7 @@ func (c *Client) SetACL(ctx context.Context, acl ACL) error {
func (c *Client) ValidateACL(ctx context.Context, acl ACL) error {
const uriFmt = "/api/v2/tailnet/%s/acl/validate"

req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), acl)
req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), nil, acl)
if err != nil {
return err
}
Expand All @@ -343,7 +375,7 @@ type DNSPreferences struct {
func (c *Client) DNSPreferences(ctx context.Context) (*DNSPreferences, error) {
const uriFmt = "/api/v2/tailnet/%s/dns/preferences"

req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet), nil)
req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet), nil, nil)
if err != nil {
return nil, err
}
Expand All @@ -361,7 +393,7 @@ func (c *Client) DNSPreferences(ctx context.Context) (*DNSPreferences, error) {
func (c *Client) SetDNSPreferences(ctx context.Context, preferences DNSPreferences) error {
const uriFmt = "/api/v2/tailnet/%s/dns/preferences"

req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), preferences)
req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), nil, preferences)
if err != nil {
return nil
}
Expand All @@ -381,7 +413,7 @@ type (
func (c *Client) SetDeviceSubnetRoutes(ctx context.Context, deviceID string, routes []string) error {
const uriFmt = "/api/v2/device/%s/routes"

req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, deviceID), map[string][]string{
req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, deviceID), nil, map[string][]string{
"routes": routes,
})
if err != nil {
Expand All @@ -397,7 +429,7 @@ func (c *Client) SetDeviceSubnetRoutes(ctx context.Context, deviceID string, rou
func (c *Client) DeviceSubnetRoutes(ctx context.Context, deviceID string) (*DeviceRoutes, error) {
const uriFmt = "/api/v2/device/%s/routes"

req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, deviceID), nil)
req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, deviceID), nil, nil)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -459,7 +491,7 @@ type Device struct {
func (c *Client) Devices(ctx context.Context) ([]Device, error) {
const uriFmt = "/api/v2/tailnet/%s/devices"

req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet), nil)
req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet), nil, nil)
if err != nil {
return nil, err
}
Expand All @@ -476,7 +508,7 @@ func (c *Client) Devices(ctx context.Context) ([]Device, error) {
func (c *Client) AuthorizeDevice(ctx context.Context, deviceID string) error {
const uriFmt = "/api/v2/device/%s/authorized"

req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, deviceID), map[string]bool{
req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, deviceID), nil, map[string]bool{
"authorized": true,
})
if err != nil {
Expand All @@ -489,7 +521,7 @@ func (c *Client) AuthorizeDevice(ctx context.Context, deviceID string) error {
// DeleteDevice deletes the device given its deviceID.
func (c *Client) DeleteDevice(ctx context.Context, deviceID string) error {
const uriFmt = "/api/v2/device/%s"
req, err := c.buildRequest(ctx, http.MethodDelete, fmt.Sprintf(uriFmt, deviceID), nil)
req, err := c.buildRequest(ctx, http.MethodDelete, fmt.Sprintf(uriFmt, deviceID), nil, nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -552,7 +584,7 @@ func (c *Client) CreateKey(ctx context.Context, capabilities KeyCapabilities, op
}
}

req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), ckr)
req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), nil, ckr)
if err != nil {
return Key{}, err
}
Expand All @@ -566,7 +598,7 @@ func (c *Client) CreateKey(ctx context.Context, capabilities KeyCapabilities, op
func (c *Client) GetKey(ctx context.Context, id string) (Key, error) {
const uriFmt = "/api/v2/tailnet/%s/keys/%s"

req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet, id), nil)
req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet, id), nil, nil)
if err != nil {
return Key{}, err
}
Expand All @@ -580,7 +612,7 @@ func (c *Client) GetKey(ctx context.Context, id string) (Key, error) {
func (c *Client) Keys(ctx context.Context) ([]Key, error) {
const uriFmt = "/api/v2/tailnet/%s/keys"

req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet), nil)
req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet), nil, nil)
if err != nil {
return nil, err
}
Expand All @@ -597,7 +629,7 @@ func (c *Client) Keys(ctx context.Context) ([]Key, error) {
func (c *Client) DeleteKey(ctx context.Context, id string) error {
const uriFmt = "/api/v2/tailnet/%s/keys/%s"

req, err := c.buildRequest(ctx, http.MethodDelete, fmt.Sprintf(uriFmt, c.tailnet, id), nil)
req, err := c.buildRequest(ctx, http.MethodDelete, fmt.Sprintf(uriFmt, c.tailnet, id), nil, nil)
if err != nil {
return err
}
Expand All @@ -609,7 +641,7 @@ func (c *Client) DeleteKey(ctx context.Context, id string) error {
func (c *Client) SetDeviceTags(ctx context.Context, deviceID string, tags []string) error {
const uriFmt = "/api/v2/device/%s/tags"

req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, deviceID), map[string][]string{
req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, deviceID), nil, map[string][]string{
"tags": tags,
})
if err != nil {
Expand All @@ -631,7 +663,7 @@ type (
func (c *Client) SetDeviceKey(ctx context.Context, deviceID string, key DeviceKey) error {
const uriFmt = "/api/v2/device/%s/key"

req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, deviceID), key)
req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, deviceID), nil, key)
if err != nil {
return err
}
Expand Down
26 changes: 26 additions & 0 deletions tailscale/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,32 @@ func TestClient_SetACL(t *testing.T) {
assert.NoError(t, client.SetACL(context.Background(), expectedACL))
assert.Equal(t, http.MethodPost, server.Method)
assert.Equal(t, "/api/v2/tailnet/example.com/acl", server.Path)
assert.Equal(t, "", server.Header.Get("If-Match"))

var actualACL tailscale.ACL
assert.NoError(t, json.Unmarshal(server.Body.Bytes(), &actualACL))
assert.EqualValues(t, expectedACL, actualACL)
}

func TestClient_SetACLWithETag(t *testing.T) {
t.Parallel()

client, server := NewTestHarness(t)
server.ResponseCode = http.StatusOK
expectedACL := tailscale.ACL{
ACLs: []tailscale.ACLEntry{
{
Action: "accept",
Ports: []string{"*:*"},
Users: []string{"*"},
},
},
}

assert.NoError(t, client.SetACL(context.Background(), expectedACL, tailscale.WithETag("test-etag")))
assert.Equal(t, http.MethodPost, server.Method)
assert.Equal(t, "/api/v2/tailnet/example.com/acl", server.Path)
assert.Equal(t, `"test-etag"`, server.Header.Get("If-Match"))

var actualACL tailscale.ACL
assert.NoError(t, json.Unmarshal(server.Body.Bytes(), &actualACL))
Expand Down
2 changes: 2 additions & 0 deletions tailscale/tailscale_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ type TestServer struct {
Method string
Path string
Body *bytes.Buffer
Header http.Header

ResponseCode int
ResponseBody interface{}
Expand Down Expand Up @@ -61,6 +62,7 @@ func NewTestHarness(t *testing.T) (*tailscale.Client, *TestServer) {
func (t *TestServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
t.Method = r.Method
t.Path = r.URL.Path
t.Header = r.Header

t.Body = bytes.NewBuffer([]byte{})
_, err := io.Copy(t.Body, r.Body)
Expand Down