forked from bellwether/minerva
adding parallelization helpers and query cancelation
This commit is contained in:
parent
5eb8471081
commit
9442c33d14
4 changed files with 150 additions and 9 deletions
|
|
@ -1,4 +1,5 @@
|
|||
import boto3
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
|
|
@ -8,9 +9,20 @@ import pprint
|
|||
import datetime
|
||||
import dask.dataframe as dd
|
||||
from minerva import parallel_map
|
||||
from mako.template import Template
|
||||
|
||||
pp = pprint.PrettyPrinter(indent=4)
|
||||
|
||||
# Get full path of fname
|
||||
def local(fname):
|
||||
return os.path.join(os.path.abspath(os.path.dirname(__file__)), fname)
|
||||
|
||||
def load_sql(path, **params):
|
||||
with open(path, 'r') as f:
|
||||
query = f.read()
|
||||
|
||||
return Template(query).render(**params)
|
||||
|
||||
class Athena:
|
||||
def __init__(self, handler, output=None):
|
||||
self.handler = handler
|
||||
|
|
@ -32,6 +44,25 @@ class Athena:
|
|||
e.run()
|
||||
return e
|
||||
|
||||
def parallelize(self, *args, **kwargs):
|
||||
p = Parallelize(self, *args, **kwargs)
|
||||
return p
|
||||
|
||||
def delete_table(self, table, join=True):
|
||||
e = Execute(self, f"drop table {table}")
|
||||
e.run()
|
||||
|
||||
if join:
|
||||
e.finish()
|
||||
|
||||
s3_uri = os.path.join(self.output, table, "")
|
||||
self.handler.s3.rm(s3_uri)
|
||||
|
||||
return e
|
||||
|
||||
def cancel(self, query_id):
|
||||
return self.client.stop_query_execution(QueryExecutionId=query_id)
|
||||
|
||||
|
||||
class Execute:
|
||||
"""
|
||||
|
|
@ -39,16 +70,17 @@ class Execute:
|
|||
Execute is meant to be used for DML statements such as CREATE DATABASE/TABLE
|
||||
"""
|
||||
def __init__(self, athena, sql, params=[], format='parquet'):
|
||||
self.athena = athena
|
||||
self.handler = athena.handler
|
||||
self.client = athena.client
|
||||
self.sql = sql
|
||||
self.params = [str(p) for p in params]
|
||||
self.athena = athena
|
||||
self.handler = athena.handler
|
||||
self.client = athena.client
|
||||
self.sql = sql
|
||||
self.params = [str(p) for p in params]
|
||||
self.info_cache = None
|
||||
self.temps = []
|
||||
self.ds = None
|
||||
self.files = None
|
||||
self.format = format
|
||||
self.temps = []
|
||||
self.ds = None
|
||||
self.files = None
|
||||
self.format = format
|
||||
self.finished = None
|
||||
|
||||
|
||||
# The string of the query
|
||||
|
|
@ -86,6 +118,9 @@ class Execute:
|
|||
|
||||
# Block until the SQL has finished running
|
||||
def finish(self):
|
||||
if self.finished:
|
||||
return self.finished
|
||||
|
||||
stat = self.status()
|
||||
while stat in ['QUEUED', 'RUNNING']:
|
||||
time.sleep(5)
|
||||
|
|
@ -97,6 +132,8 @@ class Execute:
|
|||
scanned = self.info_cache['Statistics']['DataScannedInBytes']
|
||||
self.cost = 5.0 * scanned / (1024 ** 4) # $5/TB scanned
|
||||
|
||||
self.finished = stat
|
||||
|
||||
return stat # finalized state
|
||||
|
||||
|
||||
|
|
@ -190,3 +227,81 @@ class Query(Execute):
|
|||
for file in self.temps:
|
||||
os.remove(file)
|
||||
|
||||
|
||||
class Parallelize:
|
||||
UNION_TABLES = local("athena/union_tables.sql")
|
||||
|
||||
def __init__(self, athena, dest=None, data=[], n=1):
|
||||
self.athena = athena
|
||||
self.dest = dest
|
||||
self.data = list(data)
|
||||
self.n = n
|
||||
self.tables = []
|
||||
self.queries = []
|
||||
self.runtime = None
|
||||
self.cost = None
|
||||
|
||||
|
||||
def __lshift__(self, res):
|
||||
self.queries.append(res)
|
||||
|
||||
|
||||
def __iter__(self):
|
||||
n = min(self.n, len(self.data))
|
||||
size = math.ceil(len(self.data) / n)
|
||||
|
||||
self.groups = [self.data[i:i + size] for i in range(0, len(self.data), size)]
|
||||
self.current = 0
|
||||
|
||||
return self
|
||||
|
||||
|
||||
def __next__(self):
|
||||
if self.current >= len(self.groups):
|
||||
if self.dest:
|
||||
self.union_tables(self.dest).finish()
|
||||
self.clear_temp_tables()
|
||||
raise StopIteration
|
||||
|
||||
# temp table name, in case it's needed
|
||||
tmp = "temp_" + str(round(random.random() * 10_000_000))
|
||||
self.tables.append(tmp)
|
||||
|
||||
obj = self.groups[self.current]
|
||||
self.current += 1
|
||||
|
||||
return tmp, obj
|
||||
|
||||
|
||||
# runs the SQL and removes the S3 data
|
||||
def clear_temp_tables(self):
|
||||
qs = [self.athena.delete_table(table, join=False) for table in self.tables]
|
||||
for q in qs:
|
||||
q.finish()
|
||||
|
||||
|
||||
def finish(self):
|
||||
for q in self.queries:
|
||||
q.finish()
|
||||
|
||||
self.cost = sum([q.cost for q in self.queries])
|
||||
self.runtime = max([q.runtime for q in self.queries])
|
||||
|
||||
def results(self):
|
||||
self.finish()
|
||||
|
||||
return pa.dataset.dataset([q.results().files for q in self.queries])
|
||||
|
||||
|
||||
def union_tables(self, dest):
|
||||
self.finish()
|
||||
|
||||
lines = [f"select * from {table}" for table in self.tables]
|
||||
tables = ' union all '.join(lines)
|
||||
|
||||
out = os.path.join(self.athena.output, dest)
|
||||
sql = load_sql(self.UNION_TABLES, dest = dest,
|
||||
output = out,
|
||||
tables = tables)
|
||||
return self.athena.execute(sql)
|
||||
|
||||
|
|
|
|||
|
|
@ -18,3 +18,7 @@ class Minerva:
|
|||
def pier(self, *args, **kwargs):
|
||||
return m.Pier(self, *args, **kwargs)
|
||||
|
||||
|
||||
def s3(self, *args, **kwargs):
|
||||
return m.S3(self, *args, **kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -36,3 +36,8 @@ class S3:
|
|||
#self.s3.list_objects_v2(Bucket=bucket, Prefix=key)
|
||||
return self.resource.Bucket(bucket).objects.filter(Prefix=key, **kwargs)
|
||||
|
||||
def rm(self, uri, **kwargs):
|
||||
fs = self.ls(uri)
|
||||
fs.delete()
|
||||
return fs
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue