forked from bellwether/minerva
299 lines
8.6 KiB
Python
299 lines
8.6 KiB
Python
import boto3
|
|
import math
|
|
import os
|
|
import random
|
|
import time
|
|
import pyarrow as pa
|
|
import pyarrow.dataset
|
|
import pprint
|
|
import datetime
|
|
import dask.dataframe as dd
|
|
from minerva import parallel_map, load_template
|
|
|
|
pp = pprint.PrettyPrinter(indent=4)
|
|
|
|
def local(path):
|
|
return os.path.join(os.path.dirname(os.path.abspath(__file__)), path)
|
|
|
|
class Athena:
|
|
def __init__(self, handler, output=None):
|
|
self.handler = handler
|
|
self.client = handler.session.client("athena")
|
|
self.output = output
|
|
|
|
|
|
# For when you want to receive the results of something
|
|
def query(self, sql, params=[], format='parquet'):
|
|
q = Query(self, sql, params, format)
|
|
q.run()
|
|
return q
|
|
|
|
|
|
# For when you want to send a query to run, but there aren't results
|
|
# (like a DML query for creating databases and tables etc)
|
|
def execute(self, sql, params=[], format=None):
|
|
e = Execute(self, sql, params, format)
|
|
e.run()
|
|
return e
|
|
|
|
def parallelize(self, *args, **kwargs):
|
|
p = Parallelize(self, *args, **kwargs)
|
|
return p
|
|
|
|
def delete_table(self, table, join=True):
|
|
e = Execute(self, f"drop table {table}")
|
|
e.run()
|
|
|
|
if join:
|
|
e.finish()
|
|
|
|
s3_uri = os.path.join(self.output, table, "")
|
|
self.handler.s3.rm(s3_uri)
|
|
|
|
return e
|
|
|
|
def cancel(self, query_id):
|
|
return self.client.stop_query_execution(QueryExecutionId=query_id)
|
|
|
|
|
|
class Execute:
|
|
"""
|
|
Execute will not return results, but will execute the SQL and return the final state.
|
|
Execute is meant to be used for DML statements such as CREATE DATABASE/TABLE
|
|
"""
|
|
def __init__(self, athena, sql, params=[], format='parquet'):
|
|
self.athena = athena
|
|
self.handler = athena.handler
|
|
self.client = athena.client
|
|
self.sql = sql
|
|
self.params = [str(p) for p in params]
|
|
self.info_cache = None
|
|
self.temps = []
|
|
self.ds = None
|
|
self.files = None
|
|
self.format = format
|
|
self.finished = None
|
|
|
|
|
|
# The string of the query
|
|
def query(self):
|
|
return self.sql
|
|
|
|
|
|
# Send the SQL to Athena for running
|
|
def run(self):
|
|
config = {"OutputLocation": os.path.join(self.athena.output, "results")}
|
|
|
|
if self.params:
|
|
resp = self.client.start_query_execution(QueryString=self.query(),
|
|
ResultConfiguration=config,
|
|
ExecutionParameters=self.params)
|
|
else:
|
|
resp = self.client.start_query_execution(QueryString=self.query(),
|
|
ResultConfiguration=config)
|
|
|
|
self.query_id = resp['QueryExecutionId']
|
|
return resp
|
|
|
|
|
|
# The status of the SQL (running, queued, succeeded, etc.)
|
|
def status(self):
|
|
return self.info()['Status']['State']
|
|
|
|
|
|
# Get the basic information on the SQL
|
|
def info(self):
|
|
res = self.client.get_query_execution(QueryExecutionId=self.query_id)
|
|
self.info_cache = res['QueryExecution']
|
|
return self.info_cache
|
|
|
|
|
|
# Block until the SQL has finished running
|
|
def finish(self):
|
|
if self.finished:
|
|
return self.finished
|
|
|
|
stat = self.status()
|
|
while stat in ['QUEUED', 'RUNNING']:
|
|
time.sleep(5)
|
|
stat = self.status()
|
|
|
|
ms = self.info_cache['Statistics']['TotalExecutionTimeInMillis']
|
|
self.runtime = datetime.timedelta(seconds=ms / 1000)
|
|
|
|
scanned = self.info_cache['Statistics']['DataScannedInBytes']
|
|
self.cost = 5.0 * scanned / (1024 ** 4) # $5/TB scanned
|
|
|
|
self.finished = stat
|
|
|
|
return stat # finalized state
|
|
|
|
|
|
class Query(Execute):
|
|
|
|
# Automatically includes unloading the results to Parquet format
|
|
def query(self):
|
|
out = os.path.join(self.athena.output,
|
|
"results",
|
|
str(random.random()))
|
|
query = f"unload ({self.sql}) to {repr(out)} " + \
|
|
f"with (format = '{self.format}', compression ='zstd', compression_level = 4)"
|
|
return query
|
|
|
|
|
|
# Gets the files that are listed in the manifest (from the UNLOAD part of
|
|
# the statement)
|
|
# Blocks until the query has finished (because it calls `self.finish()`)
|
|
def manifest_files(self):
|
|
if self.files:
|
|
return self.files
|
|
|
|
status = self.finish()
|
|
|
|
if status == "SUCCEEDED":
|
|
# Because we're using `UNLOAD`, we get a manifest of the files
|
|
# that make up our data.
|
|
manif = self.info_cache['Statistics']['DataManifestLocation']
|
|
files = self.handler.s3.read(manif).split("\n")
|
|
files = [f.strip() for f in files if f.strip()] # filter empty
|
|
|
|
self.files = files
|
|
return files
|
|
else:
|
|
print("Error")
|
|
print(self.info_cache['Status']['AthenaError']['ErrorMessage'])
|
|
raise Exception(self.info_cache['Status']['AthenaError']['ErrorMessage'])
|
|
#return status # canceled or error
|
|
|
|
|
|
# Blocks until the query has finished running and then returns you a pyarrow
|
|
# dataset of the results.
|
|
# Calls `self.manifest_files()` which blocks via `self.finish()`
|
|
def results(self):
|
|
if self.ds:
|
|
return self.ds
|
|
|
|
self.temps = [self.handler.s3.download(f) for f in self.manifest_files()]
|
|
#local = parallel_map(self.handler.s3.download, self.manifest_files())
|
|
self.ds = pa.dataset.dataset(self.temps)
|
|
return self.ds
|
|
|
|
|
|
def distribute_results(self, client, size=10000):
|
|
import dask.dataframe as dd
|
|
import pandas as pd
|
|
|
|
if self.ds:
|
|
return self.ds
|
|
|
|
futures = []
|
|
print(f"{len(self.manifest_files())} files in manifest")
|
|
for fn in self.manifest_files():
|
|
print(f"reading {fn}...")
|
|
df = dd.from_pandas(pd.read_parquet(fn), chunksize=100000)
|
|
print(df._meta)
|
|
print("\tloaded")
|
|
future = client.scatter(df)
|
|
print("\tscattered")
|
|
futures.append(future)
|
|
|
|
return dd.from_delayed(futures, meta=df)
|
|
|
|
|
|
# Return scalar results
|
|
# Abstracts away a bunch of keystrokes
|
|
def scalar(self):
|
|
return self.results().head(1)[0][0].as_py()
|
|
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
|
|
def __exit__(self, exception_type, exception_value, exception_traceback):
|
|
self.close()
|
|
|
|
|
|
def close(self):
|
|
if self.temps:
|
|
for file in self.temps:
|
|
os.remove(file)
|
|
|
|
|
|
class Parallelize:
|
|
UNION_TABLES = local("athena/union_tables.sql")
|
|
|
|
def __init__(self, athena, dest=None, data=[], n=1):
|
|
self.athena = athena
|
|
self.dest = dest
|
|
self.data = list(data)
|
|
self.n = n
|
|
self.tables = []
|
|
self.queries = []
|
|
self.runtime = None
|
|
self.cost = None
|
|
|
|
|
|
def __lshift__(self, res):
|
|
self.queries.append(res)
|
|
|
|
|
|
def __iter__(self):
|
|
n = min(self.n, len(self.data))
|
|
size = math.ceil(len(self.data) / n)
|
|
|
|
self.groups = [self.data[i:i + size] for i in range(0, len(self.data), size)]
|
|
self.current = 0
|
|
|
|
return self
|
|
|
|
|
|
def __next__(self):
|
|
if self.current >= len(self.groups):
|
|
if self.dest:
|
|
self.union_tables(self.dest).finish()
|
|
self.clear_temp_tables()
|
|
raise StopIteration
|
|
|
|
# temp table name, in case it's needed
|
|
tmp = "temp_" + str(round(random.random() * 10_000_000))
|
|
self.tables.append(tmp)
|
|
|
|
obj = self.groups[self.current]
|
|
self.current += 1
|
|
|
|
return tmp, obj
|
|
|
|
|
|
# runs the SQL and removes the S3 data
|
|
def clear_temp_tables(self):
|
|
qs = [self.athena.delete_table(table, join=False) for table in self.tables]
|
|
for q in qs:
|
|
q.finish()
|
|
|
|
|
|
def finish(self):
|
|
for q in self.queries:
|
|
q.finish()
|
|
|
|
self.cost = sum([q.cost for q in self.queries])
|
|
self.runtime = max([q.runtime for q in self.queries])
|
|
|
|
def results(self):
|
|
self.finish()
|
|
|
|
return pa.dataset.dataset([q.results().files for q in self.queries])
|
|
|
|
|
|
def union_tables(self, dest):
|
|
self.finish()
|
|
|
|
lines = [f"select * from {table}" for table in self.tables]
|
|
tables = ' union all '.join(lines)
|
|
|
|
out = os.path.join(self.athena.output, dest)
|
|
sql = load_sql(self.UNION_TABLES, dest = dest,
|
|
output = out,
|
|
tables = tables)
|
|
return self.athena.execute(sql)
|
|
|