added support for distributed dataframes from athena queries

This commit is contained in:
Ari Brown 2023-10-10 21:21:22 -04:00
parent 27a1d75bb3
commit bfb5dda6d9
3 changed files with 30 additions and 8 deletions

View file

@ -6,6 +6,7 @@ import pyarrow as pa
import pyarrow.dataset import pyarrow.dataset
import pprint import pprint
import datetime import datetime
import dask.dataframe as dd
from minerva import parallel_map from minerva import parallel_map
pp = pprint.PrettyPrinter(indent=4) pp = pprint.PrettyPrinter(indent=4)
@ -16,12 +17,14 @@ class Athena:
self.client = handler.session.client("athena") self.client = handler.session.client("athena")
self.output = output self.output = output
# For when you want to receive the results of something # For when you want to receive the results of something
def query(self, sql, params=[]): def query(self, sql, params=[]):
q = Query(self, sql, params) q = Query(self, sql, params)
q.run() q.run()
return q return q
# For when you want to send a query to run, but there aren't results # For when you want to send a query to run, but there aren't results
# (like a DML query for creating databases and tables etc) # (like a DML query for creating databases and tables etc)
def execute(self, sql, params=[]): def execute(self, sql, params=[]):
@ -29,6 +32,7 @@ class Athena:
e.run() e.run()
return e return e
class Execute: class Execute:
""" """
Execute will not return results, but will execute the SQL and return the final state. Execute will not return results, but will execute the SQL and return the final state.
@ -44,10 +48,12 @@ class Execute:
self.temps = [] self.temps = []
self.ds = None self.ds = None
# The string of the query # The string of the query
def query(self): def query(self):
return self.sql return self.sql
# Send the SQL to Athena for running # Send the SQL to Athena for running
def run(self): def run(self):
if self.__class__ == Query: if self.__class__ == Query:
@ -66,16 +72,19 @@ class Execute:
self.query_id = resp['QueryExecutionId'] self.query_id = resp['QueryExecutionId']
return resp return resp
# The status of the SQL (running, queued, succeeded, etc.) # The status of the SQL (running, queued, succeeded, etc.)
def status(self): def status(self):
return self.info()['Status']['State'] return self.info()['Status']['State']
# Get the basic information on the SQL # Get the basic information on the SQL
def info(self): def info(self):
res = self.client.get_query_execution(QueryExecutionId=self.query_id) res = self.client.get_query_execution(QueryExecutionId=self.query_id)
self.info_cache = res['QueryExecution'] self.info_cache = res['QueryExecution']
return self.info_cache return self.info_cache
# Block until the SQL has finished running # Block until the SQL has finished running
def finish(self): def finish(self):
stat = self.status() stat = self.status()
@ -104,6 +113,7 @@ class Query(Execute):
f"with (format = '{self.DATA_STYLE}')" f"with (format = '{self.DATA_STYLE}')"
return query return query
# Gets the files that are listed in the manifest (from the UNLOAD part of # Gets the files that are listed in the manifest (from the UNLOAD part of
# the statement) # the statement)
# Blocks until the query has finished (because it calls `self.finish()`) # Blocks until the query has finished (because it calls `self.finish()`)
@ -128,6 +138,7 @@ class Query(Execute):
raise Exception(self.info_cache['Status']['AthenaError']['ErrorMessage']) raise Exception(self.info_cache['Status']['AthenaError']['ErrorMessage'])
#return status # canceled or error #return status # canceled or error
# Blocks until the query has finished running and then returns you a pyarrow # Blocks until the query has finished running and then returns you a pyarrow
# dataset of the results. # dataset of the results.
# Calls `self.manifest_files()` which blocks via `self.finish()` # Calls `self.manifest_files()` which blocks via `self.finish()`
@ -140,17 +151,34 @@ class Query(Execute):
self.ds = pa.dataset.dataset(self.temps) self.ds = pa.dataset.dataset(self.temps)
return self.ds return self.ds
def distribute_results(self, client):
if self.ds:
return self.ds
futures = []
for fn in self.manifest_files():
df = pd.read_csv(fn)
future = client.scatter(df)
futures.append(future)
return dd.from_delayed(futures, meta=df)
# Return scalar results # Return scalar results
# Abstracts away a bunch of keystrokes # Abstracts away a bunch of keystrokes
def scalar(self): def scalar(self):
return self.results().head(1)[0][0].as_py() return self.results().head(1)[0][0].as_py()
def __enter__(self): def __enter__(self):
return self return self
def __exit__(self, exception_type, exception_value, exception_traceback): def __exit__(self, exception_type, exception_value, exception_traceback):
self.close() self.close()
def close(self): def close(self):
if self.temps: if self.temps:
for file in self.temps: for file in self.temps:

View file

@ -40,9 +40,6 @@ class Cluster:
self.create() self.create()
self.login() self.login()
self.start_dask() self.start_dask()
#self.connect()
#return self.client
# Begin the startup process in the background # Begin the startup process in the background
@ -69,11 +66,6 @@ class Cluster:
w.cmd(f"dask worker {self.scheduler.private_ip}:8786", disown=True) w.cmd(f"dask worker {self.scheduler.private_ip}:8786", disown=True)
def connect(self):
self.client = Client(self.location)
return self.client
def terminate(self): def terminate(self):
self.scheduler.terminate() self.scheduler.terminate()
for w in self.workers: for w in self.workers:

View file

@ -75,6 +75,8 @@ class Machine:
self.thread.join() self.thread.join()
def wait(self, n): def wait(self, n):
time.sleep(n) # give time for AWS to register that the instance has been created
i = 0 i = 0
# Time for the server to show as "running" # Time for the server to show as "running"
# and time for the server to finish getting daemons running # and time for the server to finish getting daemons running