minerva/minerva/athena.py
2024-01-25 13:42:24 -05:00

299 lines
8.6 KiB
Python

import boto3
import math
import os
import random
import time
import pyarrow as pa
import pyarrow.dataset
import pprint
import datetime
import dask.dataframe as dd
from minerva import parallel_map, load_template
pp = pprint.PrettyPrinter(indent=4)
def local(path):
return os.path.join(os.path.dirname(os.path.abspath(__file__)), path)
class Athena:
def __init__(self, handler, output=None):
self.handler = handler
self.client = handler.session.client("athena")
self.output = output
# For when you want to receive the results of something
def query(self, sql, params=[], format='parquet'):
q = Query(self, sql, params, format)
q.run()
return q
# For when you want to send a query to run, but there aren't results
# (like a DML query for creating databases and tables etc)
def execute(self, sql, params=[], format=None):
e = Execute(self, sql, params, format)
e.run()
return e
def parallelize(self, *args, **kwargs):
p = Parallelize(self, *args, **kwargs)
return p
def delete_table(self, table, join=True):
e = Execute(self, f"drop table {table}")
e.run()
if join:
e.finish()
s3_uri = os.path.join(self.output, table, "")
self.handler.s3.rm(s3_uri)
return e
def cancel(self, query_id):
return self.client.stop_query_execution(QueryExecutionId=query_id)
class Execute:
"""
Execute will not return results, but will execute the SQL and return the final state.
Execute is meant to be used for DML statements such as CREATE DATABASE/TABLE
"""
def __init__(self, athena, sql, params=[], format='parquet'):
self.athena = athena
self.handler = athena.handler
self.client = athena.client
self.sql = sql
self.params = [str(p) for p in params]
self.info_cache = None
self.temps = []
self.ds = None
self.files = None
self.format = format
self.finished = None
# The string of the query
def query(self):
return self.sql
# Send the SQL to Athena for running
def run(self):
config = {"OutputLocation": os.path.join(self.athena.output, "results")}
if self.params:
resp = self.client.start_query_execution(QueryString=self.query(),
ResultConfiguration=config,
ExecutionParameters=self.params)
else:
resp = self.client.start_query_execution(QueryString=self.query(),
ResultConfiguration=config)
self.query_id = resp['QueryExecutionId']
return resp
# The status of the SQL (running, queued, succeeded, etc.)
def status(self):
return self.info()['Status']['State']
# Get the basic information on the SQL
def info(self):
res = self.client.get_query_execution(QueryExecutionId=self.query_id)
self.info_cache = res['QueryExecution']
return self.info_cache
# Block until the SQL has finished running
def finish(self):
if self.finished:
return self.finished
stat = self.status()
while stat in ['QUEUED', 'RUNNING']:
time.sleep(5)
stat = self.status()
ms = self.info_cache['Statistics']['TotalExecutionTimeInMillis']
self.runtime = datetime.timedelta(seconds=ms / 1000)
scanned = self.info_cache['Statistics']['DataScannedInBytes']
self.cost = 5.0 * scanned / (1024 ** 4) # $5/TB scanned
self.finished = stat
return stat # finalized state
class Query(Execute):
# Automatically includes unloading the results to Parquet format
def query(self):
out = os.path.join(self.athena.output,
"results",
str(random.random()))
query = f"unload ({self.sql}) to {repr(out)} " + \
f"with (format = '{self.format}', compression ='zstd', compression_level = 4)"
return query
# Gets the files that are listed in the manifest (from the UNLOAD part of
# the statement)
# Blocks until the query has finished (because it calls `self.finish()`)
def manifest_files(self):
if self.files:
return self.files
status = self.finish()
if status == "SUCCEEDED":
# Because we're using `UNLOAD`, we get a manifest of the files
# that make up our data.
manif = self.info_cache['Statistics']['DataManifestLocation']
files = self.handler.s3.read(manif).split("\n")
files = [f.strip() for f in files if f.strip()] # filter empty
self.files = files
return files
else:
print("Error")
print(self.info_cache['Status']['AthenaError']['ErrorMessage'])
raise Exception(self.info_cache['Status']['AthenaError']['ErrorMessage'])
#return status # canceled or error
# Blocks until the query has finished running and then returns you a pyarrow
# dataset of the results.
# Calls `self.manifest_files()` which blocks via `self.finish()`
def results(self):
if self.ds:
return self.ds
self.temps = [self.handler.s3.download(f) for f in self.manifest_files()]
#local = parallel_map(self.handler.s3.download, self.manifest_files())
self.ds = pa.dataset.dataset(self.temps)
return self.ds
def distribute_results(self, client, size=10000):
import dask.dataframe as dd
import pandas as pd
if self.ds:
return self.ds
futures = []
print(f"{len(self.manifest_files())} files in manifest")
for fn in self.manifest_files():
print(f"reading {fn}...")
df = dd.from_pandas(pd.read_parquet(fn), chunksize=100000)
print(df._meta)
print("\tloaded")
future = client.scatter(df)
print("\tscattered")
futures.append(future)
return dd.from_delayed(futures, meta=df)
# Return scalar results
# Abstracts away a bunch of keystrokes
def scalar(self):
return self.results().head(1)[0][0].as_py()
def __enter__(self):
return self
def __exit__(self, exception_type, exception_value, exception_traceback):
self.close()
def close(self):
if self.temps:
for file in self.temps:
os.remove(file)
class Parallelize:
UNION_TABLES = local("athena/union_tables.sql")
def __init__(self, athena, dest=None, data=[], n=1):
self.athena = athena
self.dest = dest
self.data = list(data)
self.n = n
self.tables = []
self.queries = []
self.runtime = None
self.cost = None
def __lshift__(self, res):
self.queries.append(res)
def __iter__(self):
n = min(self.n, len(self.data))
size = math.ceil(len(self.data) / n)
self.groups = [self.data[i:i + size] for i in range(0, len(self.data), size)]
self.current = 0
return self
def __next__(self):
if self.current >= len(self.groups):
if self.dest:
self.union_tables(self.dest).finish()
self.clear_temp_tables()
raise StopIteration
# temp table name, in case it's needed
tmp = "temp_" + str(round(random.random() * 10_000_000))
self.tables.append(tmp)
obj = self.groups[self.current]
self.current += 1
return tmp, obj
# runs the SQL and removes the S3 data
def clear_temp_tables(self):
qs = [self.athena.delete_table(table, join=False) for table in self.tables]
for q in qs:
q.finish()
def finish(self):
for q in self.queries:
q.finish()
self.cost = sum([q.cost for q in self.queries])
self.runtime = max([q.runtime for q in self.queries])
def results(self):
self.finish()
return pa.dataset.dataset([q.results().files for q in self.queries])
def union_tables(self, dest):
self.finish()
lines = [f"select * from {table}" for table in self.tables]
tables = ' union all '.join(lines)
out = os.path.join(self.athena.output, dest)
sql = load_sql(self.UNION_TABLES, dest = dest,
output = out,
tables = tables)
return self.athena.execute(sql)