diff --git a/api/custom/purge.go b/api/custom/purge.go index 9b5f40a5..5bb507d8 100644 --- a/api/custom/purge.go +++ b/api/custom/purge.go @@ -1,13 +1,19 @@ package custom import ( + "database/sql" "net/http" "strconv" "github.com/gorilla/mux" "github.com/sirupsen/logrus" "github.com/turt2live/matrix-media-repo/api" + "github.com/turt2live/matrix-media-repo/common" "github.com/turt2live/matrix-media-repo/controllers/maintenance_controller" + "github.com/turt2live/matrix-media-repo/matrix" + "github.com/turt2live/matrix-media-repo/storage" + "github.com/turt2live/matrix-media-repo/types" + "github.com/turt2live/matrix-media-repo/util" ) type MediaPurgedResponse struct { @@ -39,7 +45,8 @@ func PurgeRemoteMedia(r *http.Request, log *logrus.Entry, user api.UserInfo) int } func PurgeIndividualRecord(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { - // TODO: Allow non-repo-admins to delete things + isGlobalAdmin, isLocalAdmin := getPurgeRequestInfo(r, log, user) + localServerName := r.Host params := mux.Vars(r) @@ -51,7 +58,32 @@ func PurgeIndividualRecord(r *http.Request, log *logrus.Entry, user api.UserInfo "mediaId": mediaId, }) + // If the user is NOT a global admin, ensure they are speaking to the right server + if !isGlobalAdmin { + if server != localServerName { + return api.AuthFailed() + } + // If the user is NOT a local admin, ensure they uploaded the content in the first place + if !isLocalAdmin { + db := storage.GetDatabase().GetMediaStore(r.Context(), log) + m, err := db.Get(server, mediaId) + if err == sql.ErrNoRows { + return api.NotFoundError() + } + if err != nil { + log.Error("Error checking ownership of media: " + err.Error()) + return api.InternalServerError("error checking media ownership") + } + if m.UserId != user.UserId { + return api.AuthFailed() + } + } + } + err := maintenance_controller.PurgeMedia(server, mediaId, r.Context(), log) + if err == sql.ErrNoRows || err == common.ErrMediaNotFound { + return api.NotFoundError() + } if err != nil { log.Error("Error purging media: " + err.Error()) return api.InternalServerError("error purging media") @@ -61,9 +93,20 @@ func PurgeIndividualRecord(r *http.Request, log *logrus.Entry, user api.UserInfo } func PurgeQurantined(r *http.Request, log *logrus.Entry, user api.UserInfo) interface{} { - // TODO: Allow non-repo-admins to delete things + isGlobalAdmin, isLocalAdmin := getPurgeRequestInfo(r, log, user) + localServerName := r.Host + + var affected []*types.Media + var err error + + if isGlobalAdmin { + affected, err = maintenance_controller.PurgeQuarantined(r.Context(), log) + } else if isLocalAdmin { + affected, err = maintenance_controller.PurgeQuarantinedFor(localServerName, r.Context(), log) + } else { + return api.AuthFailed() + } - affected, err := maintenance_controller.PurgeQuarantined(r.Context(), log) if err != nil { log.Error("Error purging media: " + err.Error()) return api.InternalServerError("error purging media") @@ -76,3 +119,14 @@ func PurgeQurantined(r *http.Request, log *logrus.Entry, user api.UserInfo) inte return &api.DoNotCacheResponse{Payload: map[string]interface{}{"purged": true, "affected": mxcs}} } + +func getPurgeRequestInfo(r *http.Request, log *logrus.Entry, user api.UserInfo) (bool, bool) { + isGlobalAdmin := util.IsGlobalAdmin(user.UserId) + isLocalAdmin, err := matrix.IsUserAdmin(r.Context(), r.Host, user.AccessToken, r.RemoteAddr) + if err != nil { + log.Error("Error verifying local admin: " + err.Error()) + return isGlobalAdmin, false + } + + return isGlobalAdmin, isLocalAdmin +} diff --git a/api/webserver/webserver.go b/api/webserver/webserver.go index 7274ee27..c3cf298c 100644 --- a/api/webserver/webserver.go +++ b/api/webserver/webserver.go @@ -32,8 +32,8 @@ func Init() { previewUrlHandler := handler{api.AccessTokenRequiredRoute(r0.PreviewUrl), "url_preview", counter, false} identiconHandler := handler{api.AccessTokenOptionalRoute(r0.Identicon), "identicon", counter, false} purgeRemote := handler{api.RepoAdminRoute(custom.PurgeRemoteMedia), "purge_remote_media", counter, false} - purgeOneHandler := handler{api.RepoAdminRoute(custom.PurgeIndividualRecord), "purge_individual_media", counter, false} - purgeQuarantinedHandler := handler{api.RepoAdminRoute(custom.PurgeQurantined), "purge_quarantined", counter, false} + purgeOneHandler := handler{api.AccessTokenRequiredRoute(custom.PurgeIndividualRecord), "purge_individual_media", counter, false} + purgeQuarantinedHandler := handler{api.AccessTokenRequiredRoute(custom.PurgeQurantined), "purge_quarantined", counter, false} quarantineHandler := handler{api.AccessTokenRequiredRoute(custom.QuarantineMedia), "quarantine_media", counter, false} quarantineRoomHandler := handler{api.AccessTokenRequiredRoute(custom.QuarantineRoomMedia), "quarantine_room", counter, false} localCopyHandler := handler{api.AccessTokenRequiredRoute(unstable.LocalCopy), "local_copy", counter, false} diff --git a/controllers/maintenance_controller/maintainance_controller.go b/controllers/maintenance_controller/maintainance_controller.go index c9b701e3..70e1579a 100644 --- a/controllers/maintenance_controller/maintainance_controller.go +++ b/controllers/maintenance_controller/maintainance_controller.go @@ -245,6 +245,23 @@ func PurgeQuarantined(ctx context.Context, log *logrus.Entry) ([]*types.Media, e return records, nil } +func PurgeQuarantinedFor(serverName string, ctx context.Context, log *logrus.Entry) ([]*types.Media, error) { + mediaDb := storage.GetDatabase().GetMediaStore(ctx, log) + records, err := mediaDb.GetQuarantinedMediaFor(serverName) + if err != nil { + return nil, err + } + + for _, r := range records { + err = doPurge(r, ctx, log) + if err != nil { + return nil, err + } + } + + return records, nil +} + func PurgeMedia(origin string, mediaId string, ctx context.Context, log *logrus.Entry) error { media, err := download_controller.FindMediaRecord(origin, mediaId, false, ctx, log) if err != nil { diff --git a/storage/stores/media_store.go b/storage/stores/media_store.go index e13b2a6a..2f88bd0f 100644 --- a/storage/stores/media_store.go +++ b/storage/stores/media_store.go @@ -27,6 +27,7 @@ const selectAllMediaForServer = "SELECT origin, media_id, upload_name, content_t const selectAllMediaForServerUsers = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, datastore_id, location, creation_ts, quarantined FROM media WHERE origin = $1 AND user_id = ANY($2)" const selectAllMediaForServerIds = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, datastore_id, location, creation_ts, quarantined FROM media WHERE origin = $1 AND media_id = ANY($2)" const selectQuarantinedMedia = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, datastore_id, location, creation_ts, quarantined FROM media WHERE quarantined = true;" +const selectServerQuarantinedMedia = "SELECT origin, media_id, upload_name, content_type, user_id, sha256_hash, size_bytes, datastore_id, location, creation_ts, quarantined FROM media WHERE quarantined = true AND origin = $1;" var dsCacheByPath = sync.Map{} // [string] => Datastore var dsCacheById = sync.Map{} // [string] => Datastore @@ -50,6 +51,7 @@ type mediaStoreStatements struct { selectAllMediaForServerUsers *sql.Stmt selectAllMediaForServerIds *sql.Stmt selectQuarantinedMedia *sql.Stmt + selectServerQuarantinedMedia *sql.Stmt } type MediaStoreFactory struct { @@ -121,6 +123,9 @@ func InitMediaStore(sqlDb *sql.DB) (*MediaStoreFactory, error) { if store.stmts.selectQuarantinedMedia, err = store.sqlDb.Prepare(selectQuarantinedMedia); err != nil { return nil, err } + if store.stmts.selectServerQuarantinedMedia, err = store.sqlDb.Prepare(selectServerQuarantinedMedia); err != nil { + return nil, err + } return &store, nil } @@ -525,3 +530,34 @@ func (s *MediaStore) GetAllQuarantinedMedia() ([]*types.Media, error) { return results, nil } + +func (s *MediaStore) GetQuarantinedMediaFor(serverName string) ([]*types.Media, error) { + rows, err := s.statements.selectServerQuarantinedMedia.QueryContext(s.ctx, serverName) + if err != nil { + return nil, err + } + + var results []*types.Media + for rows.Next() { + obj := &types.Media{} + err = rows.Scan( + &obj.Origin, + &obj.MediaId, + &obj.UploadName, + &obj.ContentType, + &obj.UserId, + &obj.Sha256Hash, + &obj.SizeBytes, + &obj.DatastoreId, + &obj.Location, + &obj.CreationTs, + &obj.Quarantined, + ) + if err != nil { + return nil, err + } + results = append(results, obj) + } + + return results, nil +}