Skip to content

Commit

Permalink
Tighter ruff rules
Browse files Browse the repository at this point in the history
  • Loading branch information
andreasjansson committed Aug 27, 2024
1 parent 03e8395 commit 40db045
Show file tree
Hide file tree
Showing 23 changed files with 296 additions and 126 deletions.
11 changes: 4 additions & 7 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,11 @@ jobs:
- name: Install dependencies
run: |
pip install -r requirements-test.txt
pip install .
- name: Run ruff
run: |
ruff check --ignore=F403,F405
- name: Run black
- name: Lint
run: |
black --check .
./script/lint
unit-test:
runs-on: ubuntu-latest
Expand All @@ -48,4 +45,4 @@ jobs:
- name: Run pytest
run: |
pytest test/
./script/unit-test
19 changes: 13 additions & 6 deletions cog_safe_push/ai.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import base64
import json
import mimetypes
from pathlib import Path
import os
import json
from pathlib import Path
from typing import cast

import anthropic

from . import log
from .exceptions import AIError
from .retry import retry
from . import log


@retry(3)
Expand Down Expand Up @@ -47,7 +49,9 @@ def call(system_prompt: str, prompt: str, files: list[Path] | None = None) -> st
content = prompt
log.vvv(f"Claude prompt: {prompt}")

messages = [{"role": "user", "content": content}]
messages: list[anthropic.types.MessageParam] = [
{"role": "user", "content": content}
]

response = client.messages.create(
model=model,
Expand All @@ -57,12 +61,15 @@ def call(system_prompt: str, prompt: str, files: list[Path] | None = None) -> st
stream=False,
temperature=0.9,
)
output = response.content[0].text
content = cast(anthropic.types.TextBlock, response.content[0])
output = content.text
log.vvv(f"Claude response: {output}")
return output


def create_content_list(files: list[Path]):
def create_content_list(
files: list[Path],
) -> list[anthropic.types.ImageBlockParam | anthropic.types.TextBlockParam]:
content = []
for path in files:
with path.open("rb") as f:
Expand Down
8 changes: 5 additions & 3 deletions cog_safe_push/cog.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import subprocess
import re
import replicate
import subprocess

from replicate.model import Model

from . import log


def push(model: replicate.model.Model) -> str:
def push(model: Model) -> str:
url = f"r8.im/{model.owner}/{model.name}"
log.info(f"Pushing to {url}")
process = subprocess.Popen(
Expand All @@ -16,6 +17,7 @@ def push(model: replicate.model.Model) -> str:
)

sha256_id = None
assert process.stdout
for line in process.stdout:
log.v(line.rstrip()) # Print output in real-time
if "latest: digest: sha256:" in line:
Expand Down
8 changes: 4 additions & 4 deletions cog_safe_push/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,23 @@ class SchemaLintError(Exception):
pass


class IncompatibleSchema(Exception):
class IncompatibleSchemaError(Exception):
pass


class OutputsDontMatch(Exception):
class OutputsDontMatchError(Exception):
pass


class FuzzError(Exception):
pass


class PredictionTimeout(Exception):
class PredictionTimeoutError(Exception):
pass


class PredictionFailed(Exception):
class PredictionFailedError(Exception):
pass


Expand Down
17 changes: 10 additions & 7 deletions cog_safe_push/lint.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from pathlib import Path
import subprocess
from pathlib import Path
from typing import Any

import yaml

from .exceptions import CodeLintError


def lint_predict():
with open("cog.yaml", "r") as f:
cog_config = yaml.safe_load(f)

cog_config = load_cog_config()
predict_config = cog_config.get("predict", "")
predict_filename = predict_config.split(":")[0]

Expand All @@ -19,9 +19,7 @@ def lint_predict():


def lint_train():
with open("cog.yaml", "r") as f:
cog_config = yaml.safe_load(f)

cog_config = load_cog_config()
train_config = cog_config.get("train", "")
train_filename = train_config.split(":")[0]

Expand All @@ -31,6 +29,11 @@ def lint_train():
lint_file(train_filename)


def load_cog_config() -> dict[str, Any]:
with Path("cog.yaml").open() as f:
return yaml.safe_load(f)


def lint_file(filename: str):
if not Path(filename).exists():
raise CodeLintError(f"{filename} doesn't exist")
Expand Down
19 changes: 10 additions & 9 deletions cog_safe_push/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from collections import defaultdict
import re
import argparse
import re
from collections import defaultdict

import replicate
from replicate.exceptions import ReplicateException
from replicate.exceptions import ReplicateError
from replicate.model import Model

from . import cog, lint, schema, predict, log
from . import cog, lint, log, predict, schema


def main():
Expand Down Expand Up @@ -244,8 +246,7 @@ def parse_inputs(inputs_list: list[str]) -> dict[str, list[predict.WeightedInput
except ValueError:
raise ValueError(f"Invalid input format: {input_str}")

inputs = make_weighted_inputs(input_values, input_weights)
return inputs
return make_weighted_inputs(input_values, input_weights)


def make_weighted_inputs(
Expand Down Expand Up @@ -310,7 +311,7 @@ def parse_input_weight_percent(value_str: str) -> tuple[str, float | None]:
return value_str, None


def get_or_create_model(model_owner, model_name, hardware) -> replicate.model.Model:
def get_or_create_model(model_owner, model_name, hardware) -> Model:
model = get_model(model_owner, model_name)

if not model:
Expand All @@ -329,10 +330,10 @@ def get_or_create_model(model_owner, model_name, hardware) -> replicate.model.Mo
return model


def get_model(owner, name) -> replicate.model.Model:
def get_model(owner, name) -> Model | None:
try:
model = replicate.models.get(f"{owner}/{name}")
except ReplicateException as e:
except ReplicateError as e:
if e.status == 404:
return None
raise
Expand Down
Loading

0 comments on commit 40db045

Please sign in to comment.