forked from bellwether/minerva
118 lines
3.6 KiB
Python
118 lines
3.6 KiB
Python
import boto3
|
|
import os
|
|
import random
|
|
import time
|
|
import pyarrow as pa
|
|
import pyarrow.dataset
|
|
import pprint
|
|
import datetime
|
|
from minerva import parallel_map
|
|
|
|
pp = pprint.PrettyPrinter(indent=4)
|
|
|
|
class Athena:
|
|
def __init__(self, handler, output):
|
|
self.handler = handler
|
|
self.client = handler.session.client("athena")
|
|
self.output = output
|
|
|
|
def query(self, sql, params=[]):
|
|
q = Query(self, sql, params)
|
|
q.run()
|
|
return q
|
|
|
|
def execute(self, sql, params=[]):
|
|
e = Execute(self, sql, params)
|
|
e.run()
|
|
return e
|
|
|
|
|
|
class Execute:
|
|
"""
|
|
Execute will not return results, but will execute the SQL and return the final state.
|
|
Execute is meant to be used for DML statements such as CREATE DATABASE/TABLE
|
|
"""
|
|
def __init__(self, athena, sql, params=[]):
|
|
self.athena = athena
|
|
self.handler = athena.handler
|
|
self.client = athena.client
|
|
self.sql = sql
|
|
self.params = [str(p) for p in params]
|
|
self.info_cache = None
|
|
|
|
def query(self):
|
|
return self.sql
|
|
|
|
def run(self):
|
|
config = {"OutputLocation": self.athena.output}
|
|
if self.params:
|
|
resp = self.client.start_query_execution(QueryString=self.query(),
|
|
ResultConfiguration=config,
|
|
ExecutionParameters=self.params)
|
|
else:
|
|
resp = self.client.start_query_execution(QueryString=self.query(),
|
|
ResultConfiguration=config)
|
|
|
|
self.query_id = resp['QueryExecutionId']
|
|
return resp
|
|
|
|
def status(self):
|
|
return self.info()['Status']['State']
|
|
|
|
def info(self):
|
|
res = self.client.get_query_execution(QueryExecutionId=self.query_id)
|
|
self.info_cache = res['QueryExecution']
|
|
return self.info_cache
|
|
|
|
def finish(self):
|
|
stat = self.status()
|
|
while stat in ['QUEUED', 'RUNNING']:
|
|
time.sleep(5)
|
|
stat = self.status()
|
|
|
|
ms = self.info_cache['Statistics']['TotalExecutionTimeInMillis']
|
|
self.runtime = datetime.timedelta(seconds=ms / 1000)
|
|
return stat # finalized state
|
|
|
|
|
|
class Query(Execute):
|
|
DATA_STYLE = 'parquet'
|
|
|
|
def query(self):
|
|
out = os.path.join(self.athena.output,
|
|
str(random.random()))
|
|
query = f"unload ({self.sql}) to {repr(out)} " + \
|
|
f"with (format = '{self.DATA_STYLE}')"
|
|
return query
|
|
|
|
def manifest_files(self):
|
|
status = self.finish()
|
|
|
|
if status == "SUCCEEDED":
|
|
# Track the runtime
|
|
ms = self.info_cache['Statistics']['TotalExecutionTimeInMillis']
|
|
self.runtime = datetime.timedelta(seconds=ms / 1000)
|
|
|
|
# Because we're using `UNLOAD`, we get a manifest of the files
|
|
# that make up our data.
|
|
manif = self.info_cache['Statistics']['DataManifestLocation']
|
|
tmp = self.handler.s3.download(manif)
|
|
with open(tmp, "r") as f:
|
|
txt = f.read()
|
|
|
|
files = txt.strip().split("\n")
|
|
files = [f.strip() for f in files if f.strip()] # filter empty
|
|
|
|
return files
|
|
else:
|
|
print("Error")
|
|
print(self.info_cache['Status']['AthenaError']['ErrorMessage'])
|
|
raise
|
|
#return status # canceled or error
|
|
|
|
def results(self):
|
|
#local = [self.handler.s3.download(f) for f in self.manifest_files()]
|
|
local = parallel_map(self.handler.s3.download, self.manifest_files())
|
|
self.ds = pa.dataset.dataset(local)
|
|
return self.ds
|
|
|