Skip to content

Commit

Permalink
Record username when flagging (#4135)
Browse files Browse the repository at this point in the history
* record username

* fix

* changelog fix

* format

* fix hf saver

* fix deserialization

* fixes
  • Loading branch information
abidlabs authored and dawoodkhan82 committed Jun 2, 2023
1 parent a839912 commit a17d286
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 19 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

## Bug Fixes:

- Records username when flagging by [@abidlabs](https://github.com/abidlabs) in [PR 4135](https://github.com/gradio-app/gradio/pull/4135)
- Fix website build issue by [@aliabd](https://github.com/aliabd) in [PR 4142](https://github.com/gradio-app/gradio/pull/4142)

## Documentation Changes:
Expand Down
39 changes: 29 additions & 10 deletions gradio/flagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import uuid
import warnings
from abc import ABC, abstractmethod
from collections import OrderedDict
from distutils.version import StrictVersion
from pathlib import Path
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -113,7 +114,7 @@ def flag(
writer.writerow(utils.sanitize_list_for_csv(csv_data))

with open(log_filepath) as csvfile:
line_count = len([None for row in csv.reader(csvfile)]) - 1
line_count = len(list(csv.reader(csvfile))) - 1
return line_count


Expand Down Expand Up @@ -187,7 +188,7 @@ def flag(
writer.writerow(utils.sanitize_list_for_csv(csv_data))

with open(log_filepath, encoding="utf-8") as csvfile:
line_count = len([None for row in csv.reader(csvfile)]) - 1
line_count = len(list(csv.reader(csvfile))) - 1
return line_count


Expand Down Expand Up @@ -286,7 +287,12 @@ def setup(self, components: list[IOComponent], flagging_dir: str):
except huggingface_hub.utils.EntryNotFoundError:
pass

def flag(self, flag_data: list[Any], flag_option: str = "") -> int:
def flag(
self,
flag_data: list[Any],
flag_option: str = "",
username: str | None = None,
) -> int:
if self.separate_dirs:
# JSONL files to support dataset preview on the Hub
unique_id = str(uuid.uuid4())
Expand All @@ -305,6 +311,7 @@ def flag(self, flag_data: list[Any], flag_option: str = "") -> int:
path_in_repo=path_in_repo,
flag_data=flag_data,
flag_option=flag_option,
username=username or "",
)

def _flag_in_dir(
Expand All @@ -314,10 +321,11 @@ def _flag_in_dir(
path_in_repo: str | None,
flag_data: list[Any],
flag_option: str = "",
username: str = "",
) -> int:
# Deserialize components (write images/audio to files)
features, row = self._deserialize_components(
components_dir, flag_data, flag_option
components_dir, flag_data, flag_option, username
)

# Write generic info to dataset_infos.json + upload
Expand Down Expand Up @@ -394,18 +402,21 @@ def _save_as_jsonl(data_file: Path, headers: list[str], row: list[Any]) -> str:
return data_file.parent.name

def _deserialize_components(
self, data_dir: Path, flag_data: list[Any], flag_option: str = ""
self,
data_dir: Path,
flag_data: list[Any],
flag_option: str = "",
username: str = "",
) -> tuple[dict[Any, Any], list[Any]]:
"""Deserialize components and return the corresponding row for the flagged sample.
Images/audio are saved to disk as individual files.
"""
# Components that can have a preview on dataset repos
# NOTE: not at root level to avoid circular imports
file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"}

# Generate the row corresponding to the flagged sample
features = {}
features = OrderedDict()
row = []
for component, sample in zip(self.components, flag_data):
# Get deserialized object (will save sample to disk if applicable -file, audio, image,...-)
Expand All @@ -415,7 +426,11 @@ def _deserialize_components(

# Add deserialized object to row
features[label] = {"dtype": "string", "_type": "Value"}
row.append(Path(deserialized).name)
try:
assert Path(deserialized).exists()
row.append(Path(deserialized).name)
except (AssertionError, TypeError, ValueError):
row.append(str(deserialized))

# If component is eligible for a preview, add the URL of the file
if isinstance(component, tuple(file_preview_types)): # type: ignore
Expand All @@ -436,7 +451,9 @@ def _deserialize_components(
)
)
features["flag"] = {"dtype": "string", "_type": "Value"}
features["username"] = {"dtype": "string", "_type": "Value"}
row.append(flag_option)
row.append(username)
return features, row


Expand Down Expand Up @@ -483,9 +500,11 @@ def __init__(
self.__name__ = "Flag"
self.visual_feedback = visual_feedback

def __call__(self, *flag_data):
def __call__(self, request: gr.Request, *flag_data):
try:
self.flagging_callback.flag(list(flag_data), flag_option=self.value)
self.flagging_callback.flag(
list(flag_data), flag_option=self.value, username=request.username
)
except Exception as e:
print(f"Error while flagging: {e}")
if self.visual_feedback:
Expand Down
10 changes: 5 additions & 5 deletions gradio/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,7 @@ def special_args(
updated inputs, progress index, event data index.
"""
signature = inspect.signature(fn)
type_hints = utils.get_type_hints(fn)
positional_args = []
for param in signature.parameters.values():
if param.kind not in (param.POSITIONAL_ONLY, param.POSITIONAL_OR_KEYWORD):
Expand All @@ -619,19 +620,18 @@ def special_args(
progress_index = None
event_data_index = None
for i, param in enumerate(positional_args):
type_hint = type_hints.get(param.name)
if isinstance(param.default, Progress):
progress_index = i
if inputs is not None:
inputs.insert(i, param.default)
elif param.annotation == routes.Request:
elif type_hint == routes.Request:
if inputs is not None:
inputs.insert(i, request)
elif isinstance(param.annotation, type) and issubclass(
param.annotation, EventData
):
elif type_hint and issubclass(type_hint, EventData):
event_data_index = i
if inputs is not None and event_data is not None:
inputs.insert(i, param.annotation(event_data.target, event_data._data))
inputs.insert(i, type_hint(event_data.target, event_data._data))
elif (
param.default is not param.empty and inputs is not None and len(inputs) <= i
):
Expand Down
7 changes: 3 additions & 4 deletions gradio/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,10 +634,9 @@ async def fn(*args):

extra_output = [submit_btn, stop_btn]

cleanup = lambda: [
Button.update(visible=True),
Button.update(visible=False),
]
def cleanup():
return [Button.update(visible=True), Button.update(visible=False)]

for i, trigger in enumerate(triggers):
predict_event = trigger(
lambda: (
Expand Down

0 comments on commit a17d286

Please sign in to comment.