forked from bellwether/minerva
added support for distributed dataframes from athena queries
This commit is contained in:
parent
27a1d75bb3
commit
bfb5dda6d9
3 changed files with 30 additions and 8 deletions
|
|
@ -6,6 +6,7 @@ import pyarrow as pa
|
|||
import pyarrow.dataset
|
||||
import pprint
|
||||
import datetime
|
||||
import dask.dataframe as dd
|
||||
from minerva import parallel_map
|
||||
|
||||
pp = pprint.PrettyPrinter(indent=4)
|
||||
|
|
@ -16,12 +17,14 @@ class Athena:
|
|||
self.client = handler.session.client("athena")
|
||||
self.output = output
|
||||
|
||||
|
||||
# For when you want to receive the results of something
|
||||
def query(self, sql, params=[]):
|
||||
q = Query(self, sql, params)
|
||||
q.run()
|
||||
return q
|
||||
|
||||
|
||||
# For when you want to send a query to run, but there aren't results
|
||||
# (like a DML query for creating databases and tables etc)
|
||||
def execute(self, sql, params=[]):
|
||||
|
|
@ -29,6 +32,7 @@ class Athena:
|
|||
e.run()
|
||||
return e
|
||||
|
||||
|
||||
class Execute:
|
||||
"""
|
||||
Execute will not return results, but will execute the SQL and return the final state.
|
||||
|
|
@ -44,10 +48,12 @@ class Execute:
|
|||
self.temps = []
|
||||
self.ds = None
|
||||
|
||||
|
||||
# The string of the query
|
||||
def query(self):
|
||||
return self.sql
|
||||
|
||||
|
||||
# Send the SQL to Athena for running
|
||||
def run(self):
|
||||
if self.__class__ == Query:
|
||||
|
|
@ -66,16 +72,19 @@ class Execute:
|
|||
self.query_id = resp['QueryExecutionId']
|
||||
return resp
|
||||
|
||||
|
||||
# The status of the SQL (running, queued, succeeded, etc.)
|
||||
def status(self):
|
||||
return self.info()['Status']['State']
|
||||
|
||||
|
||||
# Get the basic information on the SQL
|
||||
def info(self):
|
||||
res = self.client.get_query_execution(QueryExecutionId=self.query_id)
|
||||
self.info_cache = res['QueryExecution']
|
||||
return self.info_cache
|
||||
|
||||
|
||||
# Block until the SQL has finished running
|
||||
def finish(self):
|
||||
stat = self.status()
|
||||
|
|
@ -104,6 +113,7 @@ class Query(Execute):
|
|||
f"with (format = '{self.DATA_STYLE}')"
|
||||
return query
|
||||
|
||||
|
||||
# Gets the files that are listed in the manifest (from the UNLOAD part of
|
||||
# the statement)
|
||||
# Blocks until the query has finished (because it calls `self.finish()`)
|
||||
|
|
@ -128,6 +138,7 @@ class Query(Execute):
|
|||
raise Exception(self.info_cache['Status']['AthenaError']['ErrorMessage'])
|
||||
#return status # canceled or error
|
||||
|
||||
|
||||
# Blocks until the query has finished running and then returns you a pyarrow
|
||||
# dataset of the results.
|
||||
# Calls `self.manifest_files()` which blocks via `self.finish()`
|
||||
|
|
@ -140,17 +151,34 @@ class Query(Execute):
|
|||
self.ds = pa.dataset.dataset(self.temps)
|
||||
return self.ds
|
||||
|
||||
|
||||
def distribute_results(self, client):
|
||||
if self.ds:
|
||||
return self.ds
|
||||
|
||||
futures = []
|
||||
for fn in self.manifest_files():
|
||||
df = pd.read_csv(fn)
|
||||
future = client.scatter(df)
|
||||
futures.append(future)
|
||||
|
||||
return dd.from_delayed(futures, meta=df)
|
||||
|
||||
|
||||
# Return scalar results
|
||||
# Abstracts away a bunch of keystrokes
|
||||
def scalar(self):
|
||||
return self.results().head(1)[0][0].as_py()
|
||||
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
|
||||
def __exit__(self, exception_type, exception_value, exception_traceback):
|
||||
self.close()
|
||||
|
||||
|
||||
def close(self):
|
||||
if self.temps:
|
||||
for file in self.temps:
|
||||
|
|
|
|||
|
|
@ -40,9 +40,6 @@ class Cluster:
|
|||
self.create()
|
||||
self.login()
|
||||
self.start_dask()
|
||||
#self.connect()
|
||||
|
||||
#return self.client
|
||||
|
||||
|
||||
# Begin the startup process in the background
|
||||
|
|
@ -69,11 +66,6 @@ class Cluster:
|
|||
w.cmd(f"dask worker {self.scheduler.private_ip}:8786", disown=True)
|
||||
|
||||
|
||||
def connect(self):
|
||||
self.client = Client(self.location)
|
||||
return self.client
|
||||
|
||||
|
||||
def terminate(self):
|
||||
self.scheduler.terminate()
|
||||
for w in self.workers:
|
||||
|
|
|
|||
|
|
@ -75,6 +75,8 @@ class Machine:
|
|||
self.thread.join()
|
||||
|
||||
def wait(self, n):
|
||||
time.sleep(n) # give time for AWS to register that the instance has been created
|
||||
|
||||
i = 0
|
||||
# Time for the server to show as "running"
|
||||
# and time for the server to finish getting daemons running
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue