minerva/minerva/athena.py
2023-08-18 13:44:41 -04:00

144 lines
4.7 KiB
Python

import boto3
import os
import random
import time
import pyarrow as pa
import pyarrow.dataset
import pprint
import datetime
from minerva import parallel_map
pp = pprint.PrettyPrinter(indent=4)
class Athena:
def __init__(self, handler, output=None):
self.handler = handler
self.client = handler.session.client("athena")
self.output = output
# For when you want to receive the results of something
def query(self, sql, params=[]):
q = Query(self, sql, params)
q.run()
return q
# For when you want to send a query to run, but there aren't results
# (like a DML query for creating databases and tables etc)
def execute(self, sql, params=[]):
e = Execute(self, sql, params)
e.run()
return e
class Execute:
"""
Execute will not return results, but will execute the SQL and return the final state.
Execute is meant to be used for DML statements such as CREATE DATABASE/TABLE
"""
def __init__(self, athena, sql, params=[]):
self.athena = athena
self.handler = athena.handler
self.client = athena.client
self.sql = sql
self.params = [str(p) for p in params]
self.info_cache = None
self.files = []
# The string of the query
def query(self):
return self.sql
# Send the SQL to Athena for running
def run(self):
config = {"OutputLocation": self.athena.output}
if self.params:
resp = self.client.start_query_execution(QueryString=self.query(),
ResultConfiguration=config,
ExecutionParameters=self.params)
else:
resp = self.client.start_query_execution(QueryString=self.query(),
ResultConfiguration=config)
self.query_id = resp['QueryExecutionId']
return resp
# The status of the SQL (running, queued, succeeded, etc.)
def status(self):
return self.info()['Status']['State']
# Get the basic information on the SQL
def info(self):
res = self.client.get_query_execution(QueryExecutionId=self.query_id)
self.info_cache = res['QueryExecution']
return self.info_cache
# Block until the SQL has finished running
def finish(self):
stat = self.status()
while stat in ['QUEUED', 'RUNNING']:
time.sleep(5)
stat = self.status()
ms = self.info_cache['Statistics']['TotalExecutionTimeInMillis']
self.runtime = datetime.timedelta(seconds=ms / 1000)
return stat # finalized state
class Query(Execute):
DATA_STYLE = 'parquet'
# Automatically includes unloading the results to Parquet format
def query(self):
out = os.path.join(self.athena.output,
str(random.random()))
query = f"unload ({self.sql}) to {repr(out)} " + \
f"with (format = '{self.DATA_STYLE}')"
return query
# Gets the files that are listed in the manifest (from the UNLOAD part of
# the statement)
# Blocks until the query has finished (because it calls `self.finish()`)
def manifest_files(self):
status = self.finish()
if status == "SUCCEEDED":
# Track the runtime
ms = self.info_cache['Statistics']['TotalExecutionTimeInMillis']
self.runtime = datetime.timedelta(seconds=ms / 1000)
# Because we're using `UNLOAD`, we get a manifest of the files
# that make up our data.
manif = self.info_cache['Statistics']['DataManifestLocation']
files = self.handler.s3.read(manif).split("\n")
files = [f.strip() for f in files if f.strip()] # filter empty
self.files = files
return files
else:
print("Error")
print(self.info_cache['Status']['AthenaError']['ErrorMessage'])
raise
#return status # canceled or error
# Blocks until the query has finished running and then returns you a pyarrow
# dataset of the results.
# Calls `self.manifest_files()` which blocks via `self.finish()`
def results(self):
start = time.time()
local = [self.handler.s3.download(f) for f in self.manifest_files()]
#local = parallel_map(self.handler.s3.download, self.manifest_files())
print(time.time() - start)
self.ds = pa.dataset.dataset(local)
return self.ds
def __enter__(self):
return self
def __exit__(self, exception_type, exception_value, exception_traceback):
self.close()
def close(self):
for file in self.files:
os.remove(file)