Skip to content

Commit

Permalink
Allow specifying ETag value while setting ACLs
Browse files Browse the repository at this point in the history
This adds a `WithETag` option that allows specifying an ETag value while
writing ACL contents.

Also, a few missing ACL file fields have been added based on
https://tailscale.com/kb/1018/acls/

Signed-off-by: Anton Tolchanov <anton@tailscale.com>
  • Loading branch information
knyar committed Nov 29, 2022
1 parent 84f6c03 commit 717b0b8
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 30 deletions.
91 changes: 61 additions & 30 deletions tailscale/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func WithBaseURL(baseURL string) ClientOption {
}
}

func (c *Client) buildRequest(ctx context.Context, method, uri string, body interface{}) (*http.Request, error) {
func (c *Client) buildRequest(ctx context.Context, method, uri string, headers map[string]string, body interface{}) (*http.Request, error) {
u, err := c.baseURL.Parse(uri)
if err != nil {
return nil, err
Expand All @@ -108,6 +108,10 @@ func (c *Client) buildRequest(ctx context.Context, method, uri string, body inte
req.Header.Set("Content-Type", contentType)
}

for k, v := range headers {
req.Header.Set(k, v)
}

req.SetBasicAuth(c.apiKey, "")
return req, nil
}
Expand Down Expand Up @@ -164,7 +168,7 @@ func (err APIError) Error() string {
func (c *Client) SetDNSSearchPaths(ctx context.Context, searchPaths []string) error {
const uriFmt = "/api/v2/tailnet/%v/dns/searchpaths"

req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), map[string][]string{
req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), nil, map[string][]string{
"searchPaths": searchPaths,
})
if err != nil {
Expand All @@ -178,7 +182,7 @@ func (c *Client) SetDNSSearchPaths(ctx context.Context, searchPaths []string) er
func (c *Client) DNSSearchPaths(ctx context.Context) ([]string, error) {
const uriFmt = "/api/v2/tailnet/%v/dns/searchpaths"

req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet), nil)
req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet), nil, nil)
if err != nil {
return nil, err
}
Expand All @@ -196,7 +200,7 @@ func (c *Client) DNSSearchPaths(ctx context.Context) ([]string, error) {
func (c *Client) SetDNSNameservers(ctx context.Context, dns []string) error {
const uriFmt = "/api/v2/tailnet/%v/dns/nameservers"

req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), map[string][]string{
req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), nil, map[string][]string{
"dns": dns,
})
if err != nil {
Expand All @@ -210,7 +214,7 @@ func (c *Client) SetDNSNameservers(ctx context.Context, dns []string) error {
func (c *Client) DNSNameservers(ctx context.Context) ([]string, error) {
const uriFmt = "/api/v2/tailnet/%v/dns/nameservers"

req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet), nil)
req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet), nil, nil)
if err != nil {
return nil, err
}
Expand All @@ -225,14 +229,18 @@ func (c *Client) DNSNameservers(ctx context.Context) ([]string, error) {

type (
ACL struct {
ACLs []ACLEntry `json:"acls" hujson:"ACLs,omitempty"`
AutoApprovers *ACLAutoApprovers `json:"autoapprovers,omitempty" hujson:"AutoApprovers,omitempty"`
Groups map[string][]string `json:"groups,omitempty" hujson:"Groups,omitempty"`
Hosts map[string]string `json:"hosts,omitempty" hujson:"Hosts,omitempty"`
TagOwners map[string][]string `json:"tagowners,omitempty" hujson:"TagOwners,omitempty"`
DERPMap *ACLDERPMap `json:"derpMap,omitempty" hujson:"DerpMap,omitempty"`
Tests []ACLTest `json:"tests,omitempty" hujson:"Tests,omitempty"`
SSH []ACLSSH `json:"ssh,omitempty" hujson:"SSH,omitempty"`
ACLs []ACLEntry `json:"acls" hujson:"ACLs,omitempty"`
AutoApprovers *ACLAutoApprovers `json:"autoapprovers,omitempty" hujson:"AutoApprovers,omitempty"`
Groups map[string][]string `json:"groups,omitempty" hujson:"Groups,omitempty"`
Hosts map[string]string `json:"hosts,omitempty" hujson:"Hosts,omitempty"`
TagOwners map[string][]string `json:"tagowners,omitempty" hujson:"TagOwners,omitempty"`
DERPMap *ACLDERPMap `json:"derpMap,omitempty" hujson:"DerpMap,omitempty"`
Tests []ACLTest `json:"tests,omitempty" hujson:"Tests,omitempty"`
SSH []ACLSSH `json:"ssh,omitempty" hujson:"SSH,omitempty"`
NodeAttrs []NodeAttrGrant `json:"nodeAttrs,omitempty" hujson:"NodeAttrs,omitempty"`
DisableIPv4 bool `json:"disableIPv4,omitempty" hujson:"DisableIPv4,omitempty"`
OneCGNATRoute string `json:"oneCGNATRoute,omitempty" hujson:"OneCGNATRoute,omitempty"`
RandomizeClientPort bool `json:"randomizeClientPort,omitempty" hujson:"RandomizeClientPort,omitempty"`
}

ACLAutoApprovers struct {
Expand Down Expand Up @@ -291,13 +299,18 @@ type (
Destination []string `json:"dst" hujson:"Dst"`
CheckPeriod Duration `json:"checkPeriod" hujson:"CheckPeriod"`
}

NodeAttrGrant struct {
Target []string `json:"target,omitempty" hujson:"Target,omitempty"`
Attr []string `json:"attr,omitempty" hujson:"Attr,omitempty"`
}
)

// ACL retrieves the ACL that is currently set for the given tailnet.
func (c *Client) ACL(ctx context.Context) (*ACL, error) {
const uriFmt = "/api/v2/tailnet/%s/acl"

req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet), nil)
req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet), nil, nil)
if err != nil {
return nil, err
}
Expand All @@ -310,11 +323,29 @@ func (c *Client) ACL(ctx context.Context) (*ACL, error) {
return &resp, nil
}

type setACLParams struct {
headers map[string]string
}
type SetACLOption func(p *setACLParams)

// WithETag allows passing an ETag value with Set ACL API call that
// will be used in the `If-Match` HTTP request header.
func WithETag(etag string) SetACLOption {
return func(p *setACLParams) {
p.headers["If-Match"] = fmt.Sprintf(`"%s"`, etag)
}
}

// SetACL sets the ACL for the given tailnet.
func (c *Client) SetACL(ctx context.Context, acl ACL) error {
func (c *Client) SetACL(ctx context.Context, acl ACL, opts ...SetACLOption) error {
const uriFmt = "/api/v2/tailnet/%s/acl"

req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), acl)
p := &setACLParams{headers: make(map[string]string)}
for _, opt := range opts {
opt(p)
}

req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), p.headers, acl)
if err != nil {
return err
}
Expand All @@ -326,7 +357,7 @@ func (c *Client) SetACL(ctx context.Context, acl ACL) error {
func (c *Client) ValidateACL(ctx context.Context, acl ACL) error {
const uriFmt = "/api/v2/tailnet/%s/acl/validate"

req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), acl)
req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), nil, acl)
if err != nil {
return err
}
Expand All @@ -343,7 +374,7 @@ type DNSPreferences struct {
func (c *Client) DNSPreferences(ctx context.Context) (*DNSPreferences, error) {
const uriFmt = "/api/v2/tailnet/%s/dns/preferences"

req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet), nil)
req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet), nil, nil)
if err != nil {
return nil, err
}
Expand All @@ -361,7 +392,7 @@ func (c *Client) DNSPreferences(ctx context.Context) (*DNSPreferences, error) {
func (c *Client) SetDNSPreferences(ctx context.Context, preferences DNSPreferences) error {
const uriFmt = "/api/v2/tailnet/%s/dns/preferences"

req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), preferences)
req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), nil, preferences)
if err != nil {
return nil
}
Expand All @@ -381,7 +412,7 @@ type (
func (c *Client) SetDeviceSubnetRoutes(ctx context.Context, deviceID string, routes []string) error {
const uriFmt = "/api/v2/device/%s/routes"

req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, deviceID), map[string][]string{
req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, deviceID), nil, map[string][]string{
"routes": routes,
})
if err != nil {
Expand All @@ -397,7 +428,7 @@ func (c *Client) SetDeviceSubnetRoutes(ctx context.Context, deviceID string, rou
func (c *Client) DeviceSubnetRoutes(ctx context.Context, deviceID string) (*DeviceRoutes, error) {
const uriFmt = "/api/v2/device/%s/routes"

req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, deviceID), nil)
req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, deviceID), nil, nil)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -459,7 +490,7 @@ type Device struct {
func (c *Client) Devices(ctx context.Context) ([]Device, error) {
const uriFmt = "/api/v2/tailnet/%s/devices"

req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet), nil)
req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet), nil, nil)
if err != nil {
return nil, err
}
Expand All @@ -476,7 +507,7 @@ func (c *Client) Devices(ctx context.Context) ([]Device, error) {
func (c *Client) AuthorizeDevice(ctx context.Context, deviceID string) error {
const uriFmt = "/api/v2/device/%s/authorized"

req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, deviceID), map[string]bool{
req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, deviceID), nil, map[string]bool{
"authorized": true,
})
if err != nil {
Expand All @@ -489,7 +520,7 @@ func (c *Client) AuthorizeDevice(ctx context.Context, deviceID string) error {
// DeleteDevice deletes the device given its deviceID.
func (c *Client) DeleteDevice(ctx context.Context, deviceID string) error {
const uriFmt = "/api/v2/device/%s"
req, err := c.buildRequest(ctx, http.MethodDelete, fmt.Sprintf(uriFmt, deviceID), nil)
req, err := c.buildRequest(ctx, http.MethodDelete, fmt.Sprintf(uriFmt, deviceID), nil, nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -526,7 +557,7 @@ type (
func (c *Client) CreateKey(ctx context.Context, capabilities KeyCapabilities) (Key, error) {
const uriFmt = "/api/v2/tailnet/%s/keys"

req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), map[string]KeyCapabilities{
req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, c.tailnet), nil, map[string]KeyCapabilities{
"capabilities": capabilities,
})
if err != nil {
Expand All @@ -542,7 +573,7 @@ func (c *Client) CreateKey(ctx context.Context, capabilities KeyCapabilities) (K
func (c *Client) GetKey(ctx context.Context, id string) (Key, error) {
const uriFmt = "/api/v2/tailnet/%s/keys/%s"

req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet, id), nil)
req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet, id), nil, nil)
if err != nil {
return Key{}, err
}
Expand All @@ -556,7 +587,7 @@ func (c *Client) GetKey(ctx context.Context, id string) (Key, error) {
func (c *Client) Keys(ctx context.Context) ([]Key, error) {
const uriFmt = "/api/v2/tailnet/%s/keys"

req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet), nil)
req, err := c.buildRequest(ctx, http.MethodGet, fmt.Sprintf(uriFmt, c.tailnet), nil, nil)
if err != nil {
return nil, err
}
Expand All @@ -573,7 +604,7 @@ func (c *Client) Keys(ctx context.Context) ([]Key, error) {
func (c *Client) DeleteKey(ctx context.Context, id string) error {
const uriFmt = "/api/v2/tailnet/%s/keys/%s"

req, err := c.buildRequest(ctx, http.MethodDelete, fmt.Sprintf(uriFmt, c.tailnet, id), nil)
req, err := c.buildRequest(ctx, http.MethodDelete, fmt.Sprintf(uriFmt, c.tailnet, id), nil, nil)
if err != nil {
return err
}
Expand All @@ -585,7 +616,7 @@ func (c *Client) DeleteKey(ctx context.Context, id string) error {
func (c *Client) SetDeviceTags(ctx context.Context, deviceID string, tags []string) error {
const uriFmt = "/api/v2/device/%s/tags"

req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, deviceID), map[string][]string{
req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, deviceID), nil, map[string][]string{
"tags": tags,
})
if err != nil {
Expand All @@ -607,7 +638,7 @@ type (
func (c *Client) SetDeviceKey(ctx context.Context, deviceID string, key DeviceKey) error {
const uriFmt = "/api/v2/device/%s/key"

req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, deviceID), key)
req, err := c.buildRequest(ctx, http.MethodPost, fmt.Sprintf(uriFmt, deviceID), nil, key)
if err != nil {
return err
}
Expand Down
26 changes: 26 additions & 0 deletions tailscale/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,32 @@ func TestClient_SetACL(t *testing.T) {
assert.NoError(t, client.SetACL(context.Background(), expectedACL))
assert.Equal(t, http.MethodPost, server.Method)
assert.Equal(t, "/api/v2/tailnet/example.com/acl", server.Path)
assert.Equal(t, "", server.Header.Get("If-Match"))

var actualACL tailscale.ACL
assert.NoError(t, json.Unmarshal(server.Body.Bytes(), &actualACL))
assert.EqualValues(t, expectedACL, actualACL)
}

func TestClient_SetACLWithETag(t *testing.T) {
t.Parallel()

client, server := NewTestHarness(t)
server.ResponseCode = http.StatusOK
expectedACL := tailscale.ACL{
ACLs: []tailscale.ACLEntry{
{
Action: "accept",
Ports: []string{"*:*"},
Users: []string{"*"},
},
},
}

assert.NoError(t, client.SetACL(context.Background(), expectedACL, tailscale.WithETag("test-etag")))
assert.Equal(t, http.MethodPost, server.Method)
assert.Equal(t, "/api/v2/tailnet/example.com/acl", server.Path)
assert.Equal(t, `"test-etag"`, server.Header.Get("If-Match"))

var actualACL tailscale.ACL
assert.NoError(t, json.Unmarshal(server.Body.Bytes(), &actualACL))
Expand Down
2 changes: 2 additions & 0 deletions tailscale/tailscale_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ type TestServer struct {
Method string
Path string
Body *bytes.Buffer
Header http.Header

ResponseCode int
ResponseBody interface{}
Expand Down Expand Up @@ -61,6 +62,7 @@ func NewTestHarness(t *testing.T) (*tailscale.Client, *TestServer) {
func (t *TestServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
t.Method = r.Method
t.Path = r.URL.Path
t.Header = r.Header

t.Body = bytes.NewBuffer([]byte{})
_, err := io.Copy(t.Body, r.Body)
Expand Down

0 comments on commit 717b0b8

Please sign in to comment.