Skip to content

Commit

Permalink
feat(tables): update samples to show explainability [(#2523)](GoogleC…
Browse files Browse the repository at this point in the history
…loudPlatform/python-docs-samples#2523)

* show xai

* local feature importance

* use updated client

* use fixed library

* use new model
  • Loading branch information
sirtorry authored Dec 18, 2019
1 parent 61b7d39 commit 4904476
Show file tree
Hide file tree
Showing 7 changed files with 267 additions and 171 deletions.
223 changes: 129 additions & 94 deletions samples/tables/automl_tables_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,23 +79,38 @@ def list_datasets(project_id, compute_region, filter_=None):
print("Dataset id: {}".format(dataset.name.split("/")[-1]))
print("Dataset display name: {}".format(dataset.display_name))
metadata = dataset.tables_dataset_metadata
print("Dataset primary table spec id: {}".format(
metadata.primary_table_spec_id))
print("Dataset target column spec id: {}".format(
metadata.target_column_spec_id))
print("Dataset target column spec id: {}".format(
metadata.target_column_spec_id))
print("Dataset weight column spec id: {}".format(
metadata.weight_column_spec_id))
print("Dataset ml use column spec id: {}".format(
metadata.ml_use_column_spec_id))
print(
"Dataset primary table spec id: {}".format(
metadata.primary_table_spec_id
)
)
print(
"Dataset target column spec id: {}".format(
metadata.target_column_spec_id
)
)
print(
"Dataset target column spec id: {}".format(
metadata.target_column_spec_id
)
)
print(
"Dataset weight column spec id: {}".format(
metadata.weight_column_spec_id
)
)
print(
"Dataset ml use column spec id: {}".format(
metadata.ml_use_column_spec_id
)
)
print("Dataset example count: {}".format(dataset.example_count))
print("Dataset create time:")
print("\tseconds: {}".format(dataset.create_time.seconds))
print("\tnanos: {}".format(dataset.create_time.nanos))
print("\n")

# [END automl_tables_list_datasets]
# [END automl_tables_list_datasets]
result.append(dataset)

return result
Expand All @@ -119,28 +134,31 @@ def list_table_specs(

# List all the table specs in the dataset by applying filter.
response = client.list_table_specs(
dataset_display_name=dataset_display_name, filter_=filter_)
dataset_display_name=dataset_display_name, filter_=filter_
)

print("List of table specs:")
for table_spec in response:
# Display the table_spec information.
print("Table spec name: {}".format(table_spec.name))
print("Table spec id: {}".format(table_spec.name.split("/")[-1]))
print("Table spec time column spec id: {}".format(
table_spec.time_column_spec_id))
print(
"Table spec time column spec id: {}".format(
table_spec.time_column_spec_id
)
)
print("Table spec row count: {}".format(table_spec.row_count))
print("Table spec column count: {}".format(table_spec.column_count))

# [END automl_tables_list_specs]
# [END automl_tables_list_specs]
result.append(table_spec)

return result


def list_column_specs(project_id,
compute_region,
dataset_display_name,
filter_=None):
def list_column_specs(
project_id, compute_region, dataset_display_name, filter_=None
):
"""List all column specs."""
result = []
# [START automl_tables_list_column_specs]
Expand All @@ -156,7 +174,8 @@ def list_column_specs(project_id,

# List all the table specs in the dataset by applying filter.
response = client.list_column_specs(
dataset_display_name=dataset_display_name, filter_=filter_)
dataset_display_name=dataset_display_name, filter_=filter_
)

print("List of column specs:")
for column_spec in response:
Expand All @@ -166,7 +185,7 @@ def list_column_specs(project_id,
print("Column spec display name: {}".format(column_spec.display_name))
print("Column spec data type: {}".format(column_spec.data_type))

# [END automl_tables_list_column_specs]
# [END automl_tables_list_column_specs]
result.append(column_spec)

return result
Expand Down Expand Up @@ -227,19 +246,20 @@ def get_table_spec(project_id, compute_region, dataset_id, table_spec_id):
# Display the table spec information.
print("Table spec name: {}".format(table_spec.name))
print("Table spec id: {}".format(table_spec.name.split("/")[-1]))
print("Table spec time column spec id: {}".format(
table_spec.time_column_spec_id))
print(
"Table spec time column spec id: {}".format(
table_spec.time_column_spec_id
)
)
print("Table spec row count: {}".format(table_spec.row_count))
print("Table spec column count: {}".format(table_spec.column_count))

# [END automl_tables_get_table_spec]


def get_column_spec(project_id,
compute_region,
dataset_id,
table_spec_id,
column_spec_id):
def get_column_spec(
project_id, compute_region, dataset_id, table_spec_id, column_spec_id
):
"""Get the column spec."""
# [START automl_tables_get_column_spec]
# TODO(developer): Uncomment and set the following variables
Expand Down Expand Up @@ -288,7 +308,7 @@ def import_data(project_id, compute_region, dataset_display_name, path):
client = automl.TablesClient(project=project_id, region=compute_region)

response = None
if path.startswith('bq'):
if path.startswith("bq"):
response = client.import_data(
dataset_display_name=dataset_display_name, bigquery_input_uri=path
)
Expand All @@ -297,7 +317,7 @@ def import_data(project_id, compute_region, dataset_display_name, path):
input_uris = path.split(",")
response = client.import_data(
dataset_display_name=dataset_display_name,
gcs_input_uris=input_uris
gcs_input_uris=input_uris,
)

print("Processing import...")
Expand All @@ -321,8 +341,10 @@ def export_data(project_id, compute_region, dataset_display_name, gcs_uri):
client = automl.TablesClient(project=project_id, region=compute_region)

# Export the dataset to the output URI.
response = client.export_data(dataset_display_name=dataset_display_name,
gcs_output_uri_prefix=gcs_uri)
response = client.export_data(
dataset_display_name=dataset_display_name,
gcs_output_uri_prefix=gcs_uri,
)

print("Processing export...")
# synchronous check of operation status.
Expand All @@ -331,12 +353,14 @@ def export_data(project_id, compute_region, dataset_display_name, gcs_uri):
# [END automl_tables_export_data]


def update_dataset(project_id,
compute_region,
dataset_display_name,
target_column_spec_name=None,
weight_column_spec_name=None,
test_train_column_spec_name=None):
def update_dataset(
project_id,
compute_region,
dataset_display_name,
target_column_spec_name=None,
weight_column_spec_name=None,
test_train_column_spec_name=None,
):
"""Update dataset."""
# [START automl_tables_update_dataset]
# TODO(developer): Uncomment and set the following variables
Expand All @@ -354,29 +378,31 @@ def update_dataset(project_id,
if target_column_spec_name is not None:
response = client.set_target_column(
dataset_display_name=dataset_display_name,
column_spec_display_name=target_column_spec_name
column_spec_display_name=target_column_spec_name,
)
print("Target column updated. {}".format(response))
if weight_column_spec_name is not None:
response = client.set_weight_column(
dataset_display_name=dataset_display_name,
column_spec_display_name=weight_column_spec_name
column_spec_display_name=weight_column_spec_name,
)
print("Weight column updated. {}".format(response))
if test_train_column_spec_name is not None:
response = client.set_test_train_column(
dataset_display_name=dataset_display_name,
column_spec_display_name=test_train_column_spec_name
column_spec_display_name=test_train_column_spec_name,
)
print("Test/train column updated. {}".format(response))

# [END automl_tables_update_dataset]


def update_table_spec(project_id,
compute_region,
dataset_display_name,
time_column_spec_display_name):
def update_table_spec(
project_id,
compute_region,
dataset_display_name,
time_column_spec_display_name,
):
"""Update table spec."""
# [START automl_tables_update_table_spec]
# TODO(developer): Uncomment and set the following variables
Expand All @@ -391,20 +417,22 @@ def update_table_spec(project_id,

response = client.set_time_column(
dataset_display_name=dataset_display_name,
column_spec_display_name=time_column_spec_display_name
column_spec_display_name=time_column_spec_display_name,
)

# synchronous check of operation status.
print("Table spec updated. {}".format(response))
# [END automl_tables_update_table_spec]


def update_column_spec(project_id,
compute_region,
dataset_display_name,
column_spec_display_name,
type_code,
nullable=None):
def update_column_spec(
project_id,
compute_region,
dataset_display_name,
column_spec_display_name,
type_code,
nullable=None,
):
"""Update column spec."""
# [START automl_tables_update_column_spec]
# TODO(developer): Uncomment and set the following variables
Expand All @@ -423,7 +451,8 @@ def update_column_spec(project_id,
response = client.update_column_spec(
dataset_display_name=dataset_display_name,
column_spec_display_name=column_spec_display_name,
type_code=type_code, nullable=nullable
type_code=type_code,
nullable=nullable,
)

# synchronous check of operation status.
Expand Down Expand Up @@ -546,56 +575,62 @@ def delete_dataset(project_id, compute_region, dataset_display_name):
if args.command == "list_datasets":
list_datasets(project_id, compute_region, args.filter_)
if args.command == "list_table_specs":
list_table_specs(project_id,
compute_region,
args.dataset_display_name,
args.filter_)
list_table_specs(
project_id, compute_region, args.dataset_display_name, args.filter_
)
if args.command == "list_column_specs":
list_column_specs(project_id,
compute_region,
args.dataset_display_name,
args.filter_)
list_column_specs(
project_id, compute_region, args.dataset_display_name, args.filter_
)
if args.command == "get_dataset":
get_dataset(project_id, compute_region, args.dataset_display_name)
if args.command == "get_table_spec":
get_table_spec(project_id,
compute_region,
args.dataset_display_name,
args.table_spec_id)
get_table_spec(
project_id,
compute_region,
args.dataset_display_name,
args.table_spec_id,
)
if args.command == "get_column_spec":
get_column_spec(project_id,
compute_region,
args.dataset_display_name,
args.table_spec_id,
args.column_spec_id)
get_column_spec(
project_id,
compute_region,
args.dataset_display_name,
args.table_spec_id,
args.column_spec_id,
)
if args.command == "import_data":
import_data(project_id,
compute_region,
args.dataset_display_name,
args.path)
import_data(
project_id, compute_region, args.dataset_display_name, args.path
)
if args.command == "export_data":
export_data(project_id,
compute_region,
args.dataset_display_name,
args.gcs_uri)
export_data(
project_id, compute_region, args.dataset_display_name, args.gcs_uri
)
if args.command == "update_dataset":
update_dataset(project_id,
compute_region,
args.dataset_display_name,
args.target_column_spec_name,
args.weight_column_spec_name,
args.ml_use_column_spec_name)
update_dataset(
project_id,
compute_region,
args.dataset_display_name,
args.target_column_spec_name,
args.weight_column_spec_name,
args.ml_use_column_spec_name,
)
if args.command == "update_table_spec":
update_table_spec(project_id,
compute_region,
args.dataset_display_name,
args.time_column_spec_display_name)
update_table_spec(
project_id,
compute_region,
args.dataset_display_name,
args.time_column_spec_display_name,
)
if args.command == "update_column_spec":
update_column_spec(project_id,
compute_region,
args.dataset_display_name,
args.column_spec_display_name,
args.type_code,
args.nullable)
update_column_spec(
project_id,
compute_region,
args.dataset_display_name,
args.column_spec_display_name,
args.type_code,
args.nullable,
)
if args.command == "delete_dataset":
delete_dataset(project_id, compute_region, args.dataset_display_name)
Loading

0 comments on commit 4904476

Please sign in to comment.