forked from bellwether/minerva
187 lines
5.5 KiB
Python
187 lines
5.5 KiB
Python
import boto3
|
|
import os
|
|
import random
|
|
import time
|
|
import pyarrow as pa
|
|
import pyarrow.dataset
|
|
import pprint
|
|
import datetime
|
|
import json
|
|
from minerva import parallel_map
|
|
|
|
pp = pprint.PrettyPrinter(indent=4)
|
|
|
|
class Redshift:
|
|
def __init__(self, handler, output, db=None, cluster=None, workgroup=None, secret=None):
|
|
self.handler = handler
|
|
self.client = handler.session.client("redshift-data")
|
|
self.output = output
|
|
self.database = db
|
|
self.cluster = cluster
|
|
self.workgroup = workgroup
|
|
self.secret = secret
|
|
|
|
if workgroup:
|
|
cli = handler.session.client("redshift-serverless")
|
|
wg = cli.get_workgroup(workgroupName=workgroup)
|
|
self.rpus = wg['workgroup']['baseCapacity']
|
|
|
|
def query(self, sql):
|
|
q = Query(self, sql)
|
|
q.run()
|
|
return q
|
|
|
|
def execute(self, sql):
|
|
e = Execute(self, sql)
|
|
e.run()
|
|
return e
|
|
|
|
|
|
class Execute:
|
|
"""
|
|
Execute will not return results, but will execute the SQL and return the final state.
|
|
Execute is meant to be used for DML statements such as CREATE DATABASE/TABLE
|
|
"""
|
|
def __init__(self, redshift, sql):
|
|
self.redshift = redshift
|
|
self.handler = redshift.handler
|
|
self.client = redshift.client
|
|
self.sql = sql
|
|
self.info_cache = None
|
|
self.status_cache = None
|
|
self.ds = None
|
|
self.files = None
|
|
self.temps = []
|
|
|
|
def query(self):
|
|
return self.sql
|
|
|
|
|
|
def run(self):
|
|
if self.redshift.cluster:
|
|
resp = self.client.execute_statement(Sql = self.query(),
|
|
Database = self.redshift.database,
|
|
ClusterIdentifier = self.redshift.cluster)
|
|
else:
|
|
params = {"WorkgroupName": self.redshift.workgroup}
|
|
if self.redshift.secret:
|
|
params['SecretArn'] = self.redshift.secret
|
|
|
|
resp = self.client.execute_statement(Sql=self.query(),
|
|
Database=self.redshift.database,
|
|
**params)
|
|
|
|
self.query_id = resp['Id']
|
|
return resp
|
|
|
|
|
|
def status(self):
|
|
return self.info()['Status']
|
|
|
|
|
|
def info(self):
|
|
if self.status_cache in ['FINISHED', 'ABORTED', 'FAILED']:
|
|
return self.info_cache
|
|
|
|
res = self.client.describe_statement(Id=self.query_id)
|
|
self.info_cache = res
|
|
self.status_cache = res['Status']
|
|
|
|
return self.info_cache
|
|
|
|
|
|
# Block until the SQL has finished running
|
|
def finish(self):
|
|
stat = self.status()
|
|
while stat in ['SUBMITTED', 'PICKED', 'STARTED']:
|
|
time.sleep(5)
|
|
stat = self.status()
|
|
|
|
self.runtime = self.info_cache['UpdatedAt'] - self.info_cache['CreatedAt']
|
|
|
|
if self.redshift.rpus:
|
|
# $0.36 / RPU-hour
|
|
self.cost = 0.36 * self.redshift.rpus * self.runtime.seconds / 3600.0
|
|
|
|
return stat # finalized state
|
|
|
|
|
|
class Query(Execute):
|
|
DATA_STYLE = 'parquet'
|
|
|
|
def query(self):
|
|
self.out = os.path.join(self.redshift.output,
|
|
"results",
|
|
str(random.random()) + ".")
|
|
#query = f"unload ({repr(self.sql)}) to {repr(self.out)} " + \
|
|
# f"iam_role default " + \
|
|
# f"format as {self.DATA_STYLE} " + \
|
|
# f"manifest"
|
|
|
|
query = f"""
|
|
create temp table temp_data as {self.sql};
|
|
unload ('select * from temp_data') to {repr(self.out)}
|
|
iam_role default
|
|
format as {self.DATA_STYLE}
|
|
manifest;
|
|
drop table temp_data;
|
|
"""
|
|
print(query)
|
|
return query
|
|
|
|
|
|
def manifest_files(self):
|
|
if self.files:
|
|
return self.files
|
|
|
|
status = self.finish()
|
|
|
|
if status == "FINISHED":
|
|
if self.info_cache['ResultRows'] != 0:
|
|
# Because we're using `UNLOAD`, we get a manifest of the files
|
|
# that make up our data.
|
|
manif = self.out + "manifest"
|
|
tmp = self.handler.s3.download(manif)
|
|
with open(tmp, "r") as f:
|
|
js = json.load(f)
|
|
|
|
# Filter empty strings
|
|
self.files = [e['url'].strip() for e in js['entries'] if e['url'].strip()]
|
|
else: # no results returned, so no manifest file was created
|
|
self.files = []
|
|
|
|
return self.files
|
|
else:
|
|
return status # canceled or error
|
|
|
|
|
|
def results(self):
|
|
# if it's not a list, then we've failed
|
|
if type(self.manifest_files()) != type([]):
|
|
raise Exception(f"""Query has status {self.status()} did not complete and
|
|
thus has no results""")
|
|
|
|
self.temps = [self.handler.s3.download(f) for f in self.manifest_files()]
|
|
#local = parallel_map(self.handler.s3.download, self.manifest_files())
|
|
self.ds = pa.dataset.dataset(self.temps)
|
|
return self.ds
|
|
|
|
|
|
# Return scalar results
|
|
# Abstracts away a bunch of keystrokes
|
|
def scalar(self):
|
|
return self.results().head(1)[0][0].as_py()
|
|
|
|
|
|
def __enter__(self):
|
|
return self
|
|
|
|
|
|
def __exit__(self, exception_type, exception_value, exception_traceback):
|
|
self.close()
|
|
|
|
|
|
def close(self):
|
|
if self.temps:
|
|
for file in self.temps:
|
|
os.remove(file)
|