minerva/blueshift.py
2023-08-01 12:38:16 -04:00

97 lines
2.8 KiB
Python

import boto3
import os
import random
import time
import pyarrow as pa
import pyarrow.dataset
import pprint
import json
pp = pprint.PrettyPrinter(indent=4)
class Redshift:
def __init__(self, profile, output, db=None, cluster=None):
self.session = boto3.session.Session(profile_name=profile)
self.redshift = self.session.client("redshift-data")
self.output = output
self.database = db
self.cluster = cluster
def query(self, sql):
q = Query(self, sql)
q.run()
return q
def download(self, s3):
bucket = s3.split("/")[2]
file = os.path.join(*s3.split("/")[3:])
tmp = f"/tmp/{random.random()}.bin"
self.session.client('s3').download_file(bucket, file, tmp)
return tmp
class Query:
DATA_STYLE = 'parquet'
def __init__(self, handler, sql):
self.handler = handler
self.redshift = handler.redshift
self.sql = sql
self.out = None
def run(self):
self.out = os.path.join(self.handler.output,
str(random.random()))
query = f"unload ({repr(self.sql)}) to {repr(self.out)} " + \
f"iam_role default " + \
f"format as {self.DATA_STYLE} " + \
f"manifest"
resp = self.redshift.execute_statement(Sql=query,
Database=self.handler.database,
ClusterIdentifier=self.handler.cluster)
self.query_id = resp['Id']
return resp
def status(self):
return self.info()['Status']
def info(self):
res = self.redshift.describe_statement(Id=self.query_id)
return res
def results(self):
tiedot = self.info()
status = tiedot['Status']
while status in ['SUBMITTED', 'PICKED', 'STARTED']:
time.sleep(5)
tiedot = self.info()
status = tiedot['Status']
if status == "FINISHED":
# Because we're using `UNLOAD`, we get a manifest of the files
# that make up our data.
files = self.manifest(tiedot)
files = [f.strip() for f in files if f.strip()] # filter empty
# TODO parallelize this
local = [self.handler.download(f) for f in files]
self.ds = pa.dataset.dataset(local)
return self.ds
else:
print("Error:")
pp.pprint(tiedot)
return status # canceled or error
def manifest(self, tiedot):
manif = self.out + "manifest"
tmp = self.handler.download(manif)
with open(tmp, "r") as f:
js = json.load(f)
return [e['url'] for e in js['entries']]