From adf909608d5589ac2c04c73b714f98f87c24aa84 Mon Sep 17 00:00:00 2001 From: Ari Brown Date: Tue, 1 Aug 2023 12:38:16 -0400 Subject: [PATCH] added support for redshift --- blueshift.py | 97 ++++++++++++++++++++++++++++++++++++++++++++++++++++ test.py | 19 +++++----- test2.py | 13 +++++++ 3 files changed, 120 insertions(+), 9 deletions(-) create mode 100644 blueshift.py create mode 100644 test2.py diff --git a/blueshift.py b/blueshift.py new file mode 100644 index 0000000..ce1f26a --- /dev/null +++ b/blueshift.py @@ -0,0 +1,97 @@ +import boto3 +import os +import random +import time +import pyarrow as pa +import pyarrow.dataset +import pprint +import json + +pp = pprint.PrettyPrinter(indent=4) + +class Redshift: + def __init__(self, profile, output, db=None, cluster=None): + self.session = boto3.session.Session(profile_name=profile) + self.redshift = self.session.client("redshift-data") + self.output = output + self.database = db + self.cluster = cluster + + def query(self, sql): + q = Query(self, sql) + q.run() + return q + + def download(self, s3): + bucket = s3.split("/")[2] + file = os.path.join(*s3.split("/")[3:]) + tmp = f"/tmp/{random.random()}.bin" + self.session.client('s3').download_file(bucket, file, tmp) + + return tmp + +class Query: + DATA_STYLE = 'parquet' + + def __init__(self, handler, sql): + self.handler = handler + self.redshift = handler.redshift + self.sql = sql + self.out = None + + def run(self): + self.out = os.path.join(self.handler.output, + str(random.random())) + query = f"unload ({repr(self.sql)}) to {repr(self.out)} " + \ + f"iam_role default " + \ + f"format as {self.DATA_STYLE} " + \ + f"manifest" + + resp = self.redshift.execute_statement(Sql=query, + Database=self.handler.database, + ClusterIdentifier=self.handler.cluster) + self.query_id = resp['Id'] + return resp + + def status(self): + return self.info()['Status'] + + def info(self): + res = self.redshift.describe_statement(Id=self.query_id) + return res + + def results(self): + tiedot = self.info() + status = tiedot['Status'] + + while status in ['SUBMITTED', 'PICKED', 'STARTED']: + time.sleep(5) + tiedot = self.info() + status = tiedot['Status'] + + if status == "FINISHED": + # Because we're using `UNLOAD`, we get a manifest of the files + # that make up our data. + files = self.manifest(tiedot) + files = [f.strip() for f in files if f.strip()] # filter empty + + # TODO parallelize this + local = [self.handler.download(f) for f in files] + self.ds = pa.dataset.dataset(local) + + return self.ds + else: + print("Error:") + pp.pprint(tiedot) + return status # canceled or error + + def manifest(self, tiedot): + manif = self.out + "manifest" + tmp = self.handler.download(manif) + with open(tmp, "r") as f: + js = json.load(f) + + return [e['url'] for e in js['entries']] + + + diff --git a/test.py b/test.py index dffa287..2a63c5a 100644 --- a/test.py +++ b/test.py @@ -4,16 +4,17 @@ import pprint pp = pprint.PrettyPrinter(indent=4) athena = m.Athena("hay", "s3://haystac-pmo-athena/") -query = athena.query( -"""SELECT * -FROM trajectories.kitware -WHERE ST_Disjoint( - ST_GeometryFromText('POLYGON((103.6 1.2151693, 103.6 1.5151693, 104.14797 1.5151693, 104.14797 1.2151693, 103.6 1.2151693))'), - ST_Point(longitude, latitude) -) -""") +#query = athena.query( +#"""SELECT * +#FROM trajectories.kitware +#WHERE ST_Disjoint( +# ST_GeometryFromText('POLYGON((103.6 1.2151693, 103.6 1.5151693, 104.14797 1.5151693, 104.14797 1.2151693, 103.6 1.2151693))'), +# ST_Point(longitude, latitude) +#) +#""") +query = athena.query("select count(*) as count from trajectories.kitware") data = query.results() -print(data.head(10)) +pp.pprint(query.info()['Statistics']) # Everything *needs* to have a column in order for parquet to work, so scalar # values have to be assigned something, so here we use `as count` to create diff --git a/test2.py b/test2.py new file mode 100644 index 0000000..2f0d8af --- /dev/null +++ b/test2.py @@ -0,0 +1,13 @@ +import blueshift as b +import pprint + +pp = pprint.PrettyPrinter(indent=4) + +red = b.Redshift("hay", "s3://haystac-pmo-athena/", + db="dev", + cluster="redshift-cluster-1") +query = red.query("select count(*) from myspectrum_schema.kitware") +res = query.results() +pp.pprint(res.head(10)) +pp.pprint(query.info()) +