Skip to content

Commit

Permalink
Allow local admins and individual users to delete scoped media
Browse files Browse the repository at this point in the history
  • Loading branch information
turt2live committed Sep 5, 2019
1 parent f7ac22a commit f53e853
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 5 deletions.
60 changes: 57 additions & 3 deletions api/custom/purge.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
package custom

import (
"database/sql"
"net/http"
"strconv"

"github.com/gorilla/mux"
"github.com/sirupsen/logrus"
"github.com/turt2live/matrix-media-repo/api"
"github.com/turt2live/matrix-media-repo/common"
"github.com/turt2live/matrix-media-repo/controllers/maintenance_controller"
"github.com/turt2live/matrix-media-repo/matrix"
"github.com/turt2live/matrix-media-repo/storage"
"github.com/turt2live/matrix-media-repo/types"
"github.com/turt2live/matrix-media-repo/util"
)

type MediaPurgedResponse struct {
Expand Down Expand Up @@ -39,7 +45,8 @@ func PurgeRemoteMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) int
}

func PurgeIndividualRecord(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} {
// TODO: Allow non-repo-admins to delete things
isGlobalAdmin, isLocalAdmin := getPurgeRequestInfo(r, log, user)
localServerName := r.Host

params := mux.Vars(r)

Expand All @@ -51,7 +58,32 @@ func PurgeIndividualRecord(r *http.Request, log *logrus.Entry, user api.UserInfo
"mediaId": mediaId,
})

// If the user is NOT a global admin, ensure they are speaking to the right server
if !isGlobalAdmin {
if server != localServerName {
return api.AuthFailed()
}
// If the user is NOT a local admin, ensure they uploaded the content in the first place
if !isLocalAdmin {
db := storage.GetDatabase().GetMediaStore(r.Context(), log)
m, err := db.Get(server, mediaId)
if err == sql.ErrNoRows {
return api.NotFoundError()
}
if err != nil {
log.Error("Error checking ownership of media: " + err.Error())
return api.InternalServerError("error checking media ownership")
}
if m.UserId != user.UserId {
return api.AuthFailed()
}
}
}

err := maintenance_controller.PurgeMedia(server, mediaId, r.Context(), log)
if err == sql.ErrNoRows || err == common.ErrMediaNotFound {
return api.NotFoundError()
}
if err != nil {
log.Error("Error purging media: " + err.Error())
return api.InternalServerError("error purging media")
Expand All @@ -61,9 +93,20 @@ func PurgeIndividualRecord(r *http.Request, log *logrus.Entry, user api.UserInfo
}

func PurgeQurantined(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} {
// TODO: Allow non-repo-admins to delete things
isGlobalAdmin, isLocalAdmin := getPurgeRequestInfo(r, log, user)
localServerName := r.Host

var affected []*types.Media
var err error

if isGlobalAdmin {
affected, err = maintenance_controller.PurgeQuarantined(r.Context(), log)
} else if isLocalAdmin {
affected, err = maintenance_controller.PurgeQuarantinedFor(localServerName, r.Context(), log)
} else {
return api.AuthFailed()
}

affected, err := maintenance_controller.PurgeQuarantined(r.Context(), log)
if err != nil {
log.Error("Error purging media: " + err.Error())
return api.InternalServerError("error purging media")
Expand All @@ -76,3 +119,14 @@ func PurgeQurantined(r *http.Request, log *logrus.Entry, user api.UserInfo) inte

return &api.DoNotCacheResponse{Payload: map[string]interface{}{"purged": true, "affected": mxcs}}
}

func getPurgeRequestInfo(r *http.Request, log *logrus.Entry, user api.UserInfo) (bool, bool) {
isGlobalAdmin := util.IsGlobalAdmin(user.UserId)
isLocalAdmin, err := matrix.IsUserAdmin(r.Context(), r.Host, user.AccessToken, r.RemoteAddr)
if err != nil {
log.Error("Error verifying local admin: " + err.Error())
return isGlobalAdmin, false
}

return isGlobalAdmin, isLocalAdmin
}
4 changes: 2 additions & 2 deletions api/webserver/webserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ func Init() {
previewUrlHandler := handler{api.AccessTokenRequiredRoute(r0.PreviewUrl), "url_preview", counter, false}
identiconHandler := handler{api.AccessTokenOptionalRoute(r0.Identicon), "identicon", counter, false}
purgeRemote := handler{api.RepoAdminRoute(custom.PurgeRemoteMedia), "purge_remote_media", counter, false}
purgeOneHandler := handler{api.RepoAdminRoute(custom.PurgeIndividualRecord), "purge_individual_media", counter, false}
purgeQuarantinedHandler := handler{api.RepoAdminRoute(custom.PurgeQurantined), "purge_quarantined", counter, false}
purgeOneHandler := handler{api.AccessTokenRequiredRoute(custom.PurgeIndividualRecord), "purge_individual_media", counter, false}
purgeQuarantinedHandler := handler{api.AccessTokenRequiredRoute(custom.PurgeQurantined), "purge_quarantined", counter, false}
quarantineHandler := handler{api.AccessTokenRequiredRoute(custom.QuarantineMedia), "quarantine_media", counter, false}
quarantineRoomHandler := handler{api.AccessTokenRequiredRoute(custom.QuarantineRoomMedia), "quarantine_room", counter, false}
localCopyHandler := handler{api.AccessTokenRequiredRoute(unstable.LocalCopy), "local_copy", counter, false}
Expand Down
17 changes: 17 additions & 0 deletions controllers/maintenance_controller/maintainance_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,23 @@ func PurgeQuarantined(ctx context.Context, log *logrus.Entry) ([]*types.Media, e
return records, nil
}

func PurgeQuarantinedFor(serverName string, ctx context.Context, log *logrus.Entry) ([]*types.Media, error) {
mediaDb := storage.GetDatabase().GetMediaStore(ctx, log)
records, err := mediaDb.GetQuarantinedMediaFor(serverName)
if err != nil {
return nil, err
}

for _, r := range records {
err = doPurge(r, ctx, log)
if err != nil {
return nil, err
}
}

return records, nil
}

func PurgeMedia(origin string, mediaId string, ctx context.Context, log *logrus.Entry) error {
media, err := download_controller.FindMediaRecord(origin, mediaId, false, ctx, log)
if err != nil {
Expand Down
36 changes: 36 additions & 0 deletions storage/stores/media_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ const selectAllMediaForServer = "SELECT origin, media_id, upload_name, content_t
const selectAllMediaForServerUsers = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, datastore_id, location, creation_ts, quarantined FROM media WHERE origin = $1 AND user_id = ANY($2)"
const selectAllMediaForServerIds = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, datastore_id, location, creation_ts, quarantined FROM media WHERE origin = $1 AND media_id = ANY($2)"
const selectQuarantinedMedia = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, datastore_id, location, creation_ts, quarantined FROM media WHERE quarantined = true;"
const selectServerQuarantinedMedia = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, datastore_id, location, creation_ts, quarantined FROM media WHERE quarantined = true AND origin = $1;"

var dsCacheByPath = sync.Map{} // [string] => Datastore
var dsCacheById = sync.Map{} // [string] => Datastore
Expand All @@ -50,6 +51,7 @@ type mediaStoreStatements struct {
selectAllMediaForServerUsers *sql.Stmt
selectAllMediaForServerIds *sql.Stmt
selectQuarantinedMedia *sql.Stmt
selectServerQuarantinedMedia *sql.Stmt
}

type MediaStoreFactory struct {
Expand Down Expand Up @@ -121,6 +123,9 @@ func InitMediaStore(sqlDb *sql.DB) (*MediaStoreFactory, error) {
if store.stmts.selectQuarantinedMedia, err = store.sqlDb.Prepare(selectQuarantinedMedia); err != nil {
return nil, err
}
if store.stmts.selectServerQuarantinedMedia, err = store.sqlDb.Prepare(selectServerQuarantinedMedia); err != nil {
return nil, err
}

return &store, nil
}
Expand Down Expand Up @@ -525,3 +530,34 @@ func (s *MediaStore) GetAllQuarantinedMedia() ([]*types.Media, error) {

return results, nil
}

func (s *MediaStore) GetQuarantinedMediaFor(serverName string) ([]*types.Media, error) {
rows, err := s.statements.selectServerQuarantinedMedia.QueryContext(s.ctx, serverName)
if err != nil {
return nil, err
}

var results []*types.Media
for rows.Next() {
obj := &types.Media{}
err = rows.Scan(
&obj.Origin,
&obj.MediaId,
&obj.UploadName,
&obj.ContentType,
&obj.UserId,
&obj.Sha256Hash,
&obj.SizeBytes,
&obj.DatastoreId,
&obj.Location,
&obj.CreationTs,
&obj.Quarantined,
)
if err != nil {
return nil, err
}
results = append(results, obj)
}

return results, nil
}

0 comments on commit f53e853

Please sign in to comment.