Skip to content

Commit

Permalink
Merge pull request #23 from moj-analytical-services/add_temp_default_…
Browse files Browse the repository at this point in the history
…schema

Add temp default schema
  • Loading branch information
pjrh-moj authored Apr 18, 2023
2 parents 4ac6257 + 42b0960 commit e48da61
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 36 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: Rdbtools
Title: Connects the MoJ Analytical Platform to Athena
Version: 0.3.0
Version: 0.4.0
Authors@R:
person(given = "First",
family = "Last",
Expand Down
37 changes: 32 additions & 5 deletions R/athena_connection.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,16 @@ setClass(
#' methods from noctua's AthenaConnection class, which in turn are DBI
#' methods.
#' In general the expected usage is to run the function with no arguments to
#' get a standard database connection, which should work for most purposes.
#' get a standard database connection, which should work for most basic data
#' access purposes.
#'
#' @param aws_region This is the region where the database is held. If unset or NULL then will default to the AP's region.
#' @param staging_dir This the s3 location where outputs of queries can be held. If unset or NULL then will default to a session specific temporary dir.
#' @param rstudio_conn_tab Set this to true to show this connection in you RStudio connections frame (warning: this takes a long time to load because of the number of databases in the AP's Athena)
#' @param session_duration The number of seconds which the session should last before needing new authentication. Minimum of 900.
#' @param role_session_name This is a parameter for authentication, and should be left to NULL in normal operation.
#' @param schema_name This is the default database that tables not specifying a database will be looked in. If this is set to the string `__temp__` then it will use (and create if required) the temporary database based on your username - this is useful for using dbplyr which does not understand the `__temp__` keyword, alongside the DBI commands.
#' @param ... Other agruments passed to `dbConnect`
#'
#' @examples
#' con <- connect_athena() # creates a connection with sensible defaults
Expand All @@ -42,6 +45,7 @@ connect_athena <- function(aws_region = NULL,
rstudio_conn_tab = FALSE,
session_duration = 3600,
role_session_name = NULL,
schema_name = "default",
...
) {

Expand Down Expand Up @@ -97,6 +101,15 @@ connect_athena <- function(aws_region = NULL,
staging_dir = get_staging_dir_from_userid(user_id)
}

# this works out the temp db name from the user id
temp_db_name <- get_database_name_from_userid(user_id)

if (schema_name == "__temp__") {
schema_name_set <- temp_db_name
} else {
schema_name_set <- schema_name
}

# connect to athena
# returns an AthenaConnection object, see noctua docs for details
con <- dbConnect(noctua::athena(),
Expand All @@ -106,6 +119,7 @@ connect_athena <- function(aws_region = NULL,
aws_access_key_id = credentials$AccessKeyId,
aws_secret_access_key = credentials$SecretAccessKey,
aws_session_token = credentials$SessionToken,
schema_name = schema_name_set,
...)
} else {

Expand All @@ -121,18 +135,26 @@ connect_athena <- function(aws_region = NULL,
staging_dir = get_staging_dir_from_userid(user_id)
}

# this works out the temp db name from the user id
temp_db_name <- get_database_name_from_userid(user_id)

if (schema_name == "__temp__") {
schema_name_set <- temp_db_name
} else {
schema_name_set <- schema_name
}

# connect to athena
# returns an AthenaConnection object, see noctua docs for details
con <- dbConnect(noctua::athena(),
region_name = aws_region,
s3_staging_dir = staging_dir,
rstudio_conn_tab = rstudio_conn_tab)
rstudio_conn_tab = rstudio_conn_tab,
schema_name = schema_name,
...)

}

# this works out the temp db name from the user id
temp_db_name <- get_database_name_from_userid(user_id)

# coerce the AthenaConnection object to be a MoJAthenaConnection object
# this just adds the slot MoJdetails, as defined in setClass above
con <- as(con,"MoJAthenaConnection")
Expand All @@ -146,6 +168,11 @@ connect_athena <- function(aws_region = NULL,
con@MoJdetails$temp_db_name <- temp_db_name
con@MoJdetails$temp_db_exists <- NA # Don't know if the temp db exists yet

# this checks that the temp database exists if it is set as the default db
if (schema_name == "__temp__") {
result <- athena_temp_db(con, check_exists = TRUE)
}

return(con)

}
Expand Down
80 changes: 51 additions & 29 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ You can use the same command to update the package, if it is changed on Github l

## How to use

### Connecting a session and querying
### Basic connecting a session and querying

#### With SQL commands (using DBI)

See https://dyfanjones.github.io/noctua/reference/index.html for the full list of functions you can call to interact with Athena.

Expand All @@ -41,10 +43,45 @@ data <- dbGetQuery(con, "SELECT * FROM database.table") # queries and puts data
dbDisconnect(con) # disconnects the connection
```

#### Using dbplyr

See https://dbplyr.tidyverse.org/index.html

As an example:
```
library(tidyverse)
library(dbplyr)
library(Rdbtools)
con <- connect_athena()
datadb <- tbl(con, sql("select * from database.name")) # create the dbplyr link
# use dplyr as usual on this dataframe link
datadb %>%
filter(size < 10) %>%
group_by() %>%
summarise(n = n(),
total = sum(total))
dbDisconnect(con) # disconnects the connection
```

Note that if you need any function within dbplyr which does a copy (e.g. joining a local table to a remote table)
then you need to ensure you have the right permissions for the staging directory you are using.
See the help page for `dbWriteTable` by running `?dbWriteTable` in the console.

### The temporary database

Wherever you put the special string `__temp__` then this will refer to a database which is specific to your user and where you can write temporary tables before you read them out.
This works with both the noctua functions (which are updated in this package for connections made via `connect_athena()`) and the convenience functions (e.g. `read_sql()`).
Each user can have a database which can store temporary tables.

Note that the tables created here will have their underlying data stored in the default staging directory
(which is different for each new connection) or that specified by the staging directory argument
(which will remain the same for each new connection).
The permissions of the staging directory will determine who can access the data in the temporary tables.

#### With SQL commands (using DBI)

Wherever you put the special string `__temp__` in SQL commands then this will refer to a database which is specific to your user and where you can write temporary tables before you read them out.
This works with the DBI functions (which are updated in this package for connections made via `connect_athena()`) and the convenience functions (e.g. `read_sql()`).

```
library(Rdbtools)
Expand All @@ -67,7 +104,16 @@ The `__temp__` string substitution is implemented for:

If there are further noctua/DBI function where the `__temp__` string substitution would be useful then open up an issue or pull request and the Rdbtools community can try and arrange an implementation.

Additionally, the `athena_temp_db` function will return a string with the name of the temporary database if required to create specific SQL commands, or in use in other functions not listed above.
#### Using dbplyr (or other packages)

The `__temp__` string is not understood by dbplyr functions, so to use the temporary database for this or other packages you have two options:

+ When creating the connection, you can specify the temporary database as the default schema: `connect_athena(schema_name = "__temp__")`. In this case dbplyr commands which do not specify a database will default to the temporary database (e.g. then `compute("temp_tbl"))` at the end of a dbplyr chain will create a table in the temporary database with the name "temp_tbl").
+ Alternatively, the `athena_temp_db` function will return a string with the name of the temporary database if required to manually create specific SQL commands, or in use in other functions not listed above.

The temporary database is the same each way, so you can mix dbplyr, DBI, and other packages in the same code.

## Advanced use

### The connection object

Expand All @@ -76,31 +122,7 @@ By default the authenticated session will last for one hour, after which you wil
For most purposes creating a new connection will be sufficient, however you will lose access to any tables created in the `__temp__` database (as these are only accessible under the same session).
To refresh a connection, please use the `refresh_athena_connection()` function, or in a long script the `refresh_if_expired()` function may also be useful (see the help pages in RStudio for further details of these functions).

### Using dbplyr

See https://dbplyr.tidyverse.org/index.html

As an example:
```
library(tidyverse)
library(dbplyr)
library(Rdbtools)

con <- connect_athena()
datadb <- tbl(con, sql("select * from database.name")) # create the dbplyr link
# use dplyr as usual on this dataframe link
datadb %>%
filter(size < 10) %>%
group_by() %>%
summarise(n = n(),
total = sum(total))
dbDisconnect(con) # disconnects the connection
```

Note that if you need any function within dbplyr which does a copy (e.g. joining a local table to a remote table)
then you need to ensure you have the right permissions for the staging directory you are using.
See the help page for `dbWriteTable` by running `?dbWriteTable` in the console.

### The region argument when creating connection object

Expand All @@ -118,7 +140,7 @@ othewise use `eu-west-1` as the default

In most cases, you do not need to worry about the region, the default region (`AWS_DEFAULT_REGION` and `AWS_REGION`) should be the one for running query and the one where your staging dir is. When there is cross-region situation in your runnning environment and you want to save the time for passing the region every time when creating connection, you can use the `AWS_ATHENA_QUERY_REGION` to specify it.

### Single queries (deprecated)
## Single queries (deprecated)

The function `read_sql` is provided which replicates the same function from `dbtools` - this is kept for backwards compatibility only.
This creates a database connection, reads the data and then closes the connection every call.
Expand Down
8 changes: 7 additions & 1 deletion man/connect_athena.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit e48da61

Please sign in to comment.