minerva/minerva/athena.py

118 lines
3.6 KiB
Python

import boto3
import os
import random
import time
import pyarrow as pa
import pyarrow.dataset
import pprint
import datetime
from minerva import parallel_map
pp = pprint.PrettyPrinter(indent=4)
class Athena:
def __init__(self, handler, output):
self.handler = handler
self.client = handler.session.client("athena")
self.output = output
def query(self, sql, params=[]):
q = Query(self, sql, params)
q.run()
return q
def execute(self, sql, params=[]):
e = Execute(self, sql, params)
e.run()
return e
class Execute:
"""
Execute will not return results, but will execute the SQL and return the final state.
Execute is meant to be used for DML statements such as CREATE DATABASE/TABLE
"""
def __init__(self, athena, sql, params=[]):
self.athena = athena
self.handler = athena.handler
self.client = athena.client
self.sql = sql
self.params = [str(p) for p in params]
self.info_cache = None
def query(self):
return self.sql
def run(self):
config = {"OutputLocation": self.athena.output}
if self.params:
resp = self.client.start_query_execution(QueryString=self.query(),
ResultConfiguration=config,
ExecutionParameters=self.params)
else:
resp = self.client.start_query_execution(QueryString=self.query(),
ResultConfiguration=config)
self.query_id = resp['QueryExecutionId']
return resp
def status(self):
return self.info()['Status']['State']
def info(self):
res = self.client.get_query_execution(QueryExecutionId=self.query_id)
self.info_cache = res['QueryExecution']
return self.info_cache
def finish(self):
stat = self.status()
while stat in ['QUEUED', 'RUNNING']:
time.sleep(5)
stat = self.status()
ms = self.info_cache['Statistics']['TotalExecutionTimeInMillis']
self.runtime = datetime.timedelta(seconds=ms / 1000)
return stat # finalized state
class Query(Execute):
DATA_STYLE = 'parquet'
def query(self):
out = os.path.join(self.athena.output,
str(random.random()))
query = f"unload ({self.sql}) to {repr(out)} " + \
f"with (format = '{self.DATA_STYLE}')"
return query
def manifest_files(self):
status = self.finish()
if status == "SUCCEEDED":
# Track the runtime
ms = self.info_cache['Statistics']['TotalExecutionTimeInMillis']
self.runtime = datetime.timedelta(seconds=ms / 1000)
# Because we're using `UNLOAD`, we get a manifest of the files
# that make up our data.
manif = self.info_cache['Statistics']['DataManifestLocation']
tmp = self.handler.s3.download(manif)
with open(tmp, "r") as f:
txt = f.read()
files = txt.strip().split("\n")
files = [f.strip() for f in files if f.strip()] # filter empty
return files
else:
print("Error")
print(self.info_cache['Status']['AthenaError']['ErrorMessage'])
raise
#return status # canceled or error
def results(self):
#local = [self.handler.s3.download(f) for f in self.manifest_files()]
local = parallel_map(self.handler.s3.download, self.manifest_files())
self.ds = pa.dataset.dataset(local)
return self.ds