added support for distributed dataframes from athena queries

This commit is contained in:
Ari Brown 2023-10-10 21:21:22 -04:00
parent 27a1d75bb3
commit bfb5dda6d9
3 changed files with 30 additions and 8 deletions

View file

@ -6,6 +6,7 @@ import pyarrow as pa
import pyarrow.dataset
import pprint
import datetime
import dask.dataframe as dd
from minerva import parallel_map
pp = pprint.PrettyPrinter(indent=4)
@ -16,12 +17,14 @@ class Athena:
self.client = handler.session.client("athena")
self.output = output
# For when you want to receive the results of something
def query(self, sql, params=[]):
q = Query(self, sql, params)
q.run()
return q
# For when you want to send a query to run, but there aren't results
# (like a DML query for creating databases and tables etc)
def execute(self, sql, params=[]):
@ -29,6 +32,7 @@ class Athena:
e.run()
return e
class Execute:
"""
Execute will not return results, but will execute the SQL and return the final state.
@ -44,10 +48,12 @@ class Execute:
self.temps = []
self.ds = None
# The string of the query
def query(self):
return self.sql
# Send the SQL to Athena for running
def run(self):
if self.__class__ == Query:
@ -66,16 +72,19 @@ class Execute:
self.query_id = resp['QueryExecutionId']
return resp
# The status of the SQL (running, queued, succeeded, etc.)
def status(self):
return self.info()['Status']['State']
# Get the basic information on the SQL
def info(self):
res = self.client.get_query_execution(QueryExecutionId=self.query_id)
self.info_cache = res['QueryExecution']
return self.info_cache
# Block until the SQL has finished running
def finish(self):
stat = self.status()
@ -104,6 +113,7 @@ class Query(Execute):
f"with (format = '{self.DATA_STYLE}')"
return query
# Gets the files that are listed in the manifest (from the UNLOAD part of
# the statement)
# Blocks until the query has finished (because it calls `self.finish()`)
@ -128,6 +138,7 @@ class Query(Execute):
raise Exception(self.info_cache['Status']['AthenaError']['ErrorMessage'])
#return status # canceled or error
# Blocks until the query has finished running and then returns you a pyarrow
# dataset of the results.
# Calls `self.manifest_files()` which blocks via `self.finish()`
@ -140,17 +151,34 @@ class Query(Execute):
self.ds = pa.dataset.dataset(self.temps)
return self.ds
def distribute_results(self, client):
if self.ds:
return self.ds
futures = []
for fn in self.manifest_files():
df = pd.read_csv(fn)
future = client.scatter(df)
futures.append(future)
return dd.from_delayed(futures, meta=df)
# Return scalar results
# Abstracts away a bunch of keystrokes
def scalar(self):
return self.results().head(1)[0][0].as_py()
def __enter__(self):
return self
def __exit__(self, exception_type, exception_value, exception_traceback):
self.close()
def close(self):
if self.temps:
for file in self.temps:

View file

@ -40,9 +40,6 @@ class Cluster:
self.create()
self.login()
self.start_dask()
#self.connect()
#return self.client
# Begin the startup process in the background
@ -69,11 +66,6 @@ class Cluster:
w.cmd(f"dask worker {self.scheduler.private_ip}:8786", disown=True)
def connect(self):
self.client = Client(self.location)
return self.client
def terminate(self):
self.scheduler.terminate()
for w in self.workers:

View file

@ -75,6 +75,8 @@ class Machine:
self.thread.join()
def wait(self, n):
time.sleep(n) # give time for AWS to register that the instance has been created
i = 0
# Time for the server to show as "running"
# and time for the server to finish getting daemons running