adding parallelization helpers and query cancelation

This commit is contained in:
Ari Brown 2024-01-18 12:38:42 -05:00
parent 5eb8471081
commit 9442c33d14
4 changed files with 150 additions and 9 deletions

17
examples/cancel_query.py Normal file
View file

@ -0,0 +1,17 @@
import minerva
m = minerva.Minerva("hay-te")
athena = m.athena("s3://haystac-te-athena/")
file = "/tmp/queries.txt"
with open(file, 'r') as f:
txt = f.read()
for line in txt.split("\n"):
if not line:
continue
print(line)
athena.cancel(line)

View file

@ -1,4 +1,5 @@
import boto3
import math
import os
import random
import time
@ -8,9 +9,20 @@ import pprint
import datetime
import dask.dataframe as dd
from minerva import parallel_map
from mako.template import Template
pp = pprint.PrettyPrinter(indent=4)
# Get full path of fname
def local(fname):
return os.path.join(os.path.abspath(os.path.dirname(__file__)), fname)
def load_sql(path, **params):
with open(path, 'r') as f:
query = f.read()
return Template(query).render(**params)
class Athena:
def __init__(self, handler, output=None):
self.handler = handler
@ -32,6 +44,25 @@ class Athena:
e.run()
return e
def parallelize(self, *args, **kwargs):
p = Parallelize(self, *args, **kwargs)
return p
def delete_table(self, table, join=True):
e = Execute(self, f"drop table {table}")
e.run()
if join:
e.finish()
s3_uri = os.path.join(self.output, table, "")
self.handler.s3.rm(s3_uri)
return e
def cancel(self, query_id):
return self.client.stop_query_execution(QueryExecutionId=query_id)
class Execute:
"""
@ -39,16 +70,17 @@ class Execute:
Execute is meant to be used for DML statements such as CREATE DATABASE/TABLE
"""
def __init__(self, athena, sql, params=[], format='parquet'):
self.athena = athena
self.handler = athena.handler
self.client = athena.client
self.sql = sql
self.params = [str(p) for p in params]
self.athena = athena
self.handler = athena.handler
self.client = athena.client
self.sql = sql
self.params = [str(p) for p in params]
self.info_cache = None
self.temps = []
self.ds = None
self.files = None
self.format = format
self.temps = []
self.ds = None
self.files = None
self.format = format
self.finished = None
# The string of the query
@ -86,6 +118,9 @@ class Execute:
# Block until the SQL has finished running
def finish(self):
if self.finished:
return self.finished
stat = self.status()
while stat in ['QUEUED', 'RUNNING']:
time.sleep(5)
@ -97,6 +132,8 @@ class Execute:
scanned = self.info_cache['Statistics']['DataScannedInBytes']
self.cost = 5.0 * scanned / (1024 ** 4) # $5/TB scanned
self.finished = stat
return stat # finalized state
@ -190,3 +227,81 @@ class Query(Execute):
for file in self.temps:
os.remove(file)
class Parallelize:
UNION_TABLES = local("athena/union_tables.sql")
def __init__(self, athena, dest=None, data=[], n=1):
self.athena = athena
self.dest = dest
self.data = list(data)
self.n = n
self.tables = []
self.queries = []
self.runtime = None
self.cost = None
def __lshift__(self, res):
self.queries.append(res)
def __iter__(self):
n = min(self.n, len(self.data))
size = math.ceil(len(self.data) / n)
self.groups = [self.data[i:i + size] for i in range(0, len(self.data), size)]
self.current = 0
return self
def __next__(self):
if self.current >= len(self.groups):
if self.dest:
self.union_tables(self.dest).finish()
self.clear_temp_tables()
raise StopIteration
# temp table name, in case it's needed
tmp = "temp_" + str(round(random.random() * 10_000_000))
self.tables.append(tmp)
obj = self.groups[self.current]
self.current += 1
return tmp, obj
# runs the SQL and removes the S3 data
def clear_temp_tables(self):
qs = [self.athena.delete_table(table, join=False) for table in self.tables]
for q in qs:
q.finish()
def finish(self):
for q in self.queries:
q.finish()
self.cost = sum([q.cost for q in self.queries])
self.runtime = max([q.runtime for q in self.queries])
def results(self):
self.finish()
return pa.dataset.dataset([q.results().files for q in self.queries])
def union_tables(self, dest):
self.finish()
lines = [f"select * from {table}" for table in self.tables]
tables = ' union all '.join(lines)
out = os.path.join(self.athena.output, dest)
sql = load_sql(self.UNION_TABLES, dest = dest,
output = out,
tables = tables)
return self.athena.execute(sql)

View file

@ -18,3 +18,7 @@ class Minerva:
def pier(self, *args, **kwargs):
return m.Pier(self, *args, **kwargs)
def s3(self, *args, **kwargs):
return m.S3(self, *args, **kwargs)

View file

@ -36,3 +36,8 @@ class S3:
#self.s3.list_objects_v2(Bucket=bucket, Prefix=key)
return self.resource.Bucket(bucket).objects.filter(Prefix=key, **kwargs)
def rm(self, uri, **kwargs):
fs = self.ls(uri)
fs.delete()
return fs