minerva/minerva/redshift.py
2024-03-14 09:13:59 -04:00

211 lines
6.2 KiB
Python

import boto3
import os
import random
import time
import pyarrow as pa
import pyarrow.dataset
import pprint
import datetime
import json
from minerva import parallel_map
pp = pprint.PrettyPrinter(indent=4)
class Redshift:
def __init__(self, handler, output, db=None, cluster=None, workgroup=None, secret=None):
self.handler = handler
self.client = handler.session.client("redshift-data")
self.output = output
self.database = db
self.cluster = cluster
self.workgroup = workgroup
self.secret = secret
if workgroup:
cli = handler.session.client("redshift-serverless")
wg = cli.get_workgroup(workgroupName=workgroup)
self.rpus = wg['workgroup']['maxCapacity'] # provide an upper bound
def query(self, sql):
q = Query(self, sql)
q.run()
return q
def execute(self, sql):
e = Execute(self, sql)
e.run()
return e
def history(self, limit=100, status='FINISHED', next_token=''):
resp = self.client.list_statements(MaxResults = limit,
Status = status,
NextToken = next_token)
res = []
for statement in resp['Statements']:
q = Query(self, statement['QueryString'])
q.info_cache = statement
q.update_values()
res.append(q)
return res
class Execute:
"""
Execute will not return results, but will execute the SQL and return the final state.
Execute is meant to be used for DML statements such as CREATE DATABASE/TABLE
"""
def __init__(self, redshift, sql):
self.redshift = redshift
self.handler = redshift.handler
self.client = redshift.client
self.sql = sql
self.info_cache = None
self.status_cache = None
self.ds = None
self.files = None
self.temps = []
def query(self):
return self.sql
def run(self):
if self.redshift.cluster:
resp = self.client.execute_statement(Sql = self.query(),
Database = self.redshift.database,
ClusterIdentifier = self.redshift.cluster)
else:
params = {"WorkgroupName": self.redshift.workgroup}
if self.redshift.secret:
params['SecretArn'] = self.redshift.secret
resp = self.client.execute_statement(Sql=self.query(),
Database=self.redshift.database,
**params)
self.query_id = resp['Id']
return resp
def status(self):
return self.info()['Status']
def info(self):
if self.status_cache in ['FINISHED', 'ABORTED', 'FAILED']:
return self.info_cache
res = self.client.describe_statement(Id=self.query_id)
self.info_cache = res
self.status_cache = res['Status']
return self.info_cache
# Block until the SQL has finished running
def finish(self):
stat = self.status()
while stat in ['SUBMITTED', 'PICKED', 'STARTED']:
time.sleep(5)
stat = self.status()
self.update_values()
return self.status_cache
def update_values(self):
self.status_cache = self.info_cache['Status']
self.runtime = self.info_cache['UpdatedAt'] - self.info_cache['CreatedAt']
if self.redshift.rpus:
# $0.36 / RPU-hour
self.cost = 0.36 * self.redshift.rpus * self.runtime.seconds / 3600.0
class Query(Execute):
DATA_STYLE = 'parquet'
def query(self):
self.out = os.path.join(self.redshift.output,
"results",
str(random.random()) + ".")
#query = f"unload ({repr(self.sql)}) to {repr(self.out)} " + \
# f"iam_role default " + \
# f"format as {self.DATA_STYLE} " + \
# f"manifest"
query = f"""
create temp table temp_data as {self.sql};
unload ('select * from temp_data') to {repr(self.out)}
iam_role default
format as {self.DATA_STYLE}
parallel off
manifest;
drop table temp_data;
"""
print(query)
return query
def manifest_files(self):
if self.files:
return self.files
status = self.finish()
if status == "FINISHED":
# Because we're using `UNLOAD`, we get a manifest of the files
# that make up our data.
manif = self.out + "manifest"
# do we even have results? redshift doesn't create manifests when
# there are 0 results
if list(self.handler.s3.ls(manif)):
tmp = self.handler.s3.download(manif)
with open(tmp, "r") as f:
js = json.load(f)
# Filter empty strings
self.files = [e['url'].strip() for e in js['entries'] if e['url'].strip()]
else: # no results returned, so no manifest file was created
self.files = []
return self.files
else:
return status # canceled or error
def results(self):
# if it's not a list, then we've failed
if type(self.manifest_files()) != type([]):
raise Exception(f"""Query has status {self.status()} did not complete and
thus has no results. Error is {self.info_cache.get('Error', '')}""")
self.temps = [self.handler.s3.download(f) for f in self.manifest_files()]
#local = parallel_map(self.handler.s3.download, self.manifest_files())
self.ds = pa.dataset.dataset(self.temps)
return self.ds
# Return scalar results
# Abstracts away a bunch of keystrokes
def scalar(self):
return self.results().head(1)[0][0].as_py()
def __enter__(self):
return self
def __exit__(self, exception_type, exception_value, exception_traceback):
self.close()
def close(self):
if self.temps:
for file in self.temps:
os.remove(file)