Skip to content

Commit

Permalink
format: Added formatter and linter (#529)
Browse files Browse the repository at this point in the history
* format: Format py files

* format: Update formating integ tests.

---------

Co-authored-by: Bigad Soleiman <bigadsoleiman@gmail.com>
  • Loading branch information
charles-marion and bigadsoleiman authored Aug 12, 2024
1 parent 917f838 commit cf795a3
Show file tree
Hide file tree
Showing 80 changed files with 338 additions and 280 deletions.
10 changes: 10 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
[flake8]
max-line-length = 88
# Based on https://black.readthedocs.io/en/stable/guides/using_black_with_other_tools.html#flake8
# E711 and E711 flags the assert in the tests
extend-ignore = E203,E701,E711,E712
exclude =
lib/user-interface/react-app/node_modules/
cdk.out/
dist/
node_modules/
1 change: 1 addition & 0 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ jobs:
# Suppression of pip audit failure until langchain is upgraded.
run: |
pip install -r pytest_requirements.txt
flake8 .
bandit -r .
pip-audit -r pytest_requirements.txt || true
pip-audit -r lib/shared/web-crawler-batch-job/requirements.txt || true
Expand Down
2 changes: 1 addition & 1 deletion integtests/chatbot-api/cross_encoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ def test_ranking(client: AppSyncClient, config):

assert len(result) == 2
assert result[0].get("score") > result[1].get("score")
assert result[0].get("passage") == "A cat is an animal."
assert result[0].get("passage") == "A cat is an animal."
7 changes: 5 additions & 2 deletions integtests/chatbot-api/opensearch_workspace_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def test_add_text(client: AppSyncClient):
"content": "The Integ Test flower is green.",
}
)
# This test can take several minutes because it's waiting for AWSBatch to start a host
# This test can take several minutes because it's waiting for
# AWSBatch to start a host
ready = False
retries = 0
while not ready and retries < 50:
Expand Down Expand Up @@ -100,7 +101,9 @@ def test_search_document(client: AppSyncClient):
if len(result.get("items")) == 1:
ready = True
assert result.get("engine") == "opensearch"
assert result.get("items")[0].get("documentId") == pytest.document.get("documentId")
assert result.get("items")[0].get("documentId") == pytest.document.get(
"documentId"
)
assert ready == True


Expand Down
40 changes: 21 additions & 19 deletions integtests/chatbot-api/session_test.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,29 @@
import json
import json
import uuid
import time

import pytest


def test_create_session(client, default_model, default_provider, session_id):
request = {
"action": "run",
"modelInterface": "langchain",
"data": {
"action": "run",
"modelInterface": "langchain",
"data": {
"mode": "chain",
"text": "test",
"files": [],
"modelName": default_model,
"provider": default_provider,
"sessionId": session_id,
},
}
},
}

client.send_query(json.dumps(request))
# Need a second sessions to verify the delete all
request["data"]["sessionId"] = str(uuid.uuid4())
client.send_query(json.dumps(request))

found = False
sessionFound = None
retries = 0
Expand All @@ -39,27 +40,30 @@ def test_create_session(client, default_model, default_provider, session_id):
assert found == True
assert sessionFound.get("title") == request.get("data").get("text")

def test_get_session(client, session_id, default_model):

def test_get_session(client, session_id, default_model):
session = client.get_session(session_id)
assert session.get("id") == session_id
assert session.get("title") == "test"
assert len(session.get("history")) == 2
assert session.get("history")[0].get("type") == "human"
assert session.get("history")[1].get("type") == "ai"

def test_delete_session(client, session_id):


def test_delete_session(client, session_id):
session = client.delete_session(session_id)
assert session.get("id") == session_id
assert session.get("deleted") == True

session = client.get_session(session_id)
assert session == None

def test_delete_user_sessions(client):


def test_delete_user_sessions(client):
sessions = client.delete_user_sessions()
assert len(sessions) > 0
assert sessions[0].get("deleted") == True

sessions = client.list_sessions()
retries = 0
while True:
Expand All @@ -69,10 +73,8 @@ def test_delete_user_sessions(client):
pytest.fail()
elif len(client.list_sessions()) == 0:
break



@pytest.fixture(scope="package")
def session_id():
return str(uuid.uuid4())



45 changes: 14 additions & 31 deletions integtests/clients/appsync_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from gql import Client
from gql.transport.aiohttp import AIOHTTPTransport
from gql.dsl import DSLMutation, DSLSchema, DSLQuery, dsl_gql
from graphql import print_ast


class AppSyncClient:
Expand Down Expand Up @@ -122,18 +121,7 @@ def create_opensearch_workspace(self, input):
)
)
return self.client.execute(query).get("createOpenSearchWorkspace")

def list_workspaces(self):
query = dsl_gql(
DSLQuery(
self.schema.Query.listWorkspaces.select(
self.schema.Workspace.id,
self.schema.Workspace.name,
)
)
)
return self.client.execute(query).get("listWorkspaces")


def list_workspaces(self):
query = dsl_gql(
DSLQuery(
Expand All @@ -145,7 +133,7 @@ def list_workspaces(self):
)
)
return self.client.execute(query).get("listWorkspaces")

def get_workspace(self, id):
query = dsl_gql(
DSLQuery(
Expand All @@ -157,39 +145,36 @@ def get_workspace(self, id):
)
)
return self.client.execute(query).get("getWorkspace")

def delete_workspace(self, id):
query = dsl_gql(
DSLMutation(
self.schema.Mutation.deleteWorkspace.args(workspaceId=id)
)
DSLMutation(self.schema.Mutation.deleteWorkspace.args(workspaceId=id))
)
return self.client.execute(query)

def add_text(self, input):
query = dsl_gql(
DSLMutation(
self.schema.Mutation.addTextDocument.args(input=input).select(
self.schema.DocumentResult.documentId,
self.schema.DocumentResult.status
self.schema.DocumentResult.status,
)
)
)
return self.client.execute(query).get("addTextDocument")



def get_document(self, input):
query = dsl_gql(
DSLQuery(
self.schema.Query.getDocument.args(input=input).select(
self.schema.Document.workspaceId,
self.schema.Document.id,
self.schema.Document.status
self.schema.Document.status,
)
)
)
return self.client.execute(query).get("getDocument")

def semantic_search(self, input):
query = dsl_gql(
DSLQuery(
Expand All @@ -199,13 +184,13 @@ def semantic_search(self, input):
self.schema.SemanticSearchResult.items.select(
self.schema.SemanticSearchItem.content,
self.schema.SemanticSearchItem.documentId,
self.schema.SemanticSearchItem.score
),
self.schema.SemanticSearchItem.score,
),
)
)
)
return self.client.execute(query).get("performSemanticSearch")

def delete_document(self, input):
query = dsl_gql(
DSLMutation(
Expand All @@ -216,7 +201,7 @@ def delete_document(self, input):
)
)
return self.client.execute(query)

def calculate_embeding(self, input):
query = dsl_gql(
DSLQuery(
Expand All @@ -227,8 +212,7 @@ def calculate_embeding(self, input):
)
)
return self.client.execute(query).get("calculateEmbeddings")



def rank_passages(self, input):
query = dsl_gql(
DSLQuery(
Expand All @@ -239,4 +223,3 @@ def rank_passages(self, input):
)
)
return self.client.execute(query).get("rankPassages")

39 changes: 15 additions & 24 deletions integtests/clients/cognito_client.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@

import boto3
import string
import random


class CognitoClient:
def __init__(self, region: str, user_pool_id: str, client_id: str) -> None:
self.user_pool_id = user_pool_id
self.client_id = client_id
self.cognito_idp_client = boto3.client('cognito-idp', region_name=region)
self.cognito_idp_client = boto3.client("cognito-idp", region_name=region)

def get_token(self, email: str) -> None:
try:
Expand All @@ -20,42 +20,33 @@ def get_token(self, email: str) -> None:
UserPoolId=self.user_pool_id,
Username=email,
UserAttributes=[
{
'Name': 'email',
'Value': email
},
{
'Name': 'email_verified',
'Value': 'True'
}
{"Name": "email", "Value": email},
{"Name": "email_verified", "Value": "True"},
],
MessageAction="SUPPRESS",
)

password = self.get_password()
self.cognito_idp_client.admin_set_user_password(
UserPoolId=self.user_pool_id,
Username=email,
Password=password,
Permanent=True
Permanent=True,
)

response = self.cognito_idp_client.admin_initiate_auth(
UserPoolId=self.user_pool_id,
ClientId=self.client_id,
AuthFlow="ADMIN_NO_SRP_AUTH",
AuthParameters={
"USERNAME": email,
"PASSWORD": password
}
AuthParameters={"USERNAME": email, "PASSWORD": password},
)

return response["AuthenticationResult"]["IdToken"]

def get_password(self):
return "".join(
random.choices(string.ascii_uppercase, k=10) +
random.choices(string.ascii_lowercase, k=10) +
random.choices(string.digits, k=5) +
random.choices(string.punctuation, k=3)
)
random.choices(string.ascii_uppercase, k=10)
+ random.choices(string.ascii_lowercase, k=10)
+ random.choices(string.digits, k=5)
+ random.choices(string.punctuation, k=3)
)
18 changes: 13 additions & 5 deletions integtests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,38 +8,46 @@
from clients.cognito_client import CognitoClient
from clients.appsync_client import AppSyncClient


@pytest.fixture(scope="session")
def client(config):
user_pool_id = config.get("aws_user_pools_id")
region = config.get("aws_cognito_region")
user_pool_client_id = config.get("aws_user_pools_web_client_id")
endpoint = config.get("aws_appsync_graphqlEndpoint")

cognito = CognitoClient(region=region, user_pool_id=user_pool_id, client_id=user_pool_client_id)

cognito = CognitoClient(
region=region, user_pool_id=user_pool_id, client_id=user_pool_client_id
)
email = "integ-test-user@example.local"

return AppSyncClient(endpoint=endpoint, id_token=cognito.get_token(email=email))


@pytest.fixture(scope="session")
def unauthenticated_client(config):
def unauthenticated_client(config):
endpoint = config.get("aws_appsync_graphqlEndpoint")
return AppSyncClient(endpoint=endpoint, id_token=None)


@pytest.fixture(scope="session")
def default_model():
return "anthropic.claude-instant-v1"


@pytest.fixture(scope="session")
def default_embed_model():
return "amazon.titan-embed-text-v1"


@pytest.fixture(scope="session")
def default_provider():
return "bedrock"


@pytest.fixture(scope="session")
def config():
if "REACT_APP_URL" not in os.environ:
raise IndexError("Please set the environment variable REACT_APP_URL")
response = urlopen(os.environ['REACT_APP_URL'] + "/aws-exports.json")
response = urlopen(os.environ["REACT_APP_URL"] + "/aws-exports.json")
return json.loads(response.read())
1 change: 1 addition & 0 deletions lib/chatbot-api/functions/api-handler/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from aws_lambda_powertools.event_handler import (
AppSyncResolver,
)

from routes.health import router as health_router
from routes.embeddings import router as embeddings_router
from routes.cross_encoders import router as cross_encoders_router
Expand Down
Loading

0 comments on commit cf795a3

Please sign in to comment.