Skip to content

Commit

Permalink
Merge branch 'poc/vizro_ai_charts_refactor' into docs/vizro_ai_docs
Browse files Browse the repository at this point in the history
  • Loading branch information
maxschulz-COL committed Sep 6, 2024
2 parents f8cf486 + 99a27c3 commit f94fe52
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 51 deletions.
4 changes: 2 additions & 2 deletions vizro-ai/src/vizro_ai/_vizro_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,12 @@ def plot(
df: The dataframe to be analyzed.
user_input: User questions or descriptions of the desired visual.
max_debug_retry: Maximum number of retries to debug errors. Defaults to `1`.
return_elements: Flag to return ChartPlanStatic pydantic model that includes all
return_elements: Flag to return ChartPlan pydantic model that includes all
possible elements generated. Defaults to `False`.
validate_code: Flag if produced code should be executed to validate it. Defaults to `True`.
Returns:
go.Figure or ChartPlanStatic pydantic model
go.Figure or ChartPlan pydantic model
"""
response_model = ChartPlanFactory(data_frame=df) if validate_code else ChartPlan
Expand Down
47 changes: 0 additions & 47 deletions vizro-ai/src/vizro_ai/plot/_response_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from typing import List, Optional, Union

import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

from vizro_ai.plot._utils._safeguard import _safeguard_check
Expand Down Expand Up @@ -186,49 +185,3 @@ def _test_execute_chart_code(v):
},
__base__=ChartPlan,
)


if __name__ == "__main__":
# Write docs
# Formulate long term todos

from dotenv import find_dotenv, load_dotenv

load_dotenv(find_dotenv(usecwd=True))
from vizro_ai import VizroAI
from vizro_ai._llm_models import _get_llm_model

# df = px.data.iris()
df = px.data.gapminder()

model = _get_llm_model()

query = "the trend of gdp over years in the US"
# query = "show me the geo distribution of life expectancy and set year as animation "
# query = "describe the composition of gdp in continents in 2007, and add horizontal line for avg gdp in 2007"
# query = """plot a bubble chart to shows the changes in life expectancy gdp per capita over time.
# Animate the chart by year"""

# print(df.sample(10).to_string())

# res = _get_pydantic_model(
# query=query, llm_model=model, response_model=dummy_model, df_info=df.sample(10).to_string())
# code = res._get_complete_code(lint=True)
# fig = res._get_fig_object(data_frame=df)
# fig.show()

# llm = ChatAnthropic(
# model="claude-3-5-sonnet-20240620",
# temperature=0,
# max_tokens=1024,
# timeout=None,
# max_retries=2,
# api_key = os.environ.get("ANTHROPIC_API_KEY"),
# base_url= os.environ.get("ANTHROPIC_API_BASE")
# )

############################################################################################################

vizro_ai = VizroAI(model=model)
res2 = vizro_ai.plot(df=df, user_input=query, return_elements=True)
res2.get_fig_object(data_frame=df).show()
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import re

import pytest
import vizro.models as vm
from vizro_ai.dashboard._response_models.components import ComponentPlan

import pytest


class TestComponentCreate:
"""Tests component creation."""
Expand Down

0 comments on commit f94fe52

Please sign in to comment.