forked from bellwether/minerva
added dask clustering support
This commit is contained in:
parent
f4dd130266
commit
5eb2e39c69
7 changed files with 625 additions and 62 deletions
79
test.py
79
test.py
|
|
@ -1,39 +1,56 @@
|
|||
import minerva
|
||||
import pprint
|
||||
import dask.distributed as distributed
|
||||
import dask_cloudprovider.aws as aws
|
||||
import configparser
|
||||
import os
|
||||
import contextlib
|
||||
|
||||
pp = pprint.PrettyPrinter(indent=4)
|
||||
# altered /Users/ari/opt/miniconda3/envs/mamba_oa_env/lib/python3.10/site-packages/aiobotocore/endpoint.py:96
|
||||
|
||||
m = minerva.Minerva("hay")
|
||||
athena = m.athena("s3://haystac-pmo-athena/")
|
||||
# needs [default] AWS credential
|
||||
# `security = False` (can't use TLS because otherwise the UserData param is too long)
|
||||
|
||||
#query = athena.query(
|
||||
#"""SELECT *
|
||||
#FROM trajectories.kitware
|
||||
#WHERE ST_Disjoint(
|
||||
# ST_GeometryFromText('POLYGON((103.6 1.2151693, 103.6 1.5151693, 104.14797 1.5151693, 104.14797 1.2151693, 103.6 1.2151693))'),
|
||||
# ST_Point(longitude, latitude)
|
||||
#)
|
||||
#""")
|
||||
def aws_profile(profile):
|
||||
parser = configparser.RawConfigParser()
|
||||
parser.read(os.path.expanduser("~/.aws/credentials"))
|
||||
config = parser.items(profile)
|
||||
config = {key.upper(): value for key, value in [*config]}
|
||||
config['AWS_REGION'] = config.pop('REGION')
|
||||
return config
|
||||
|
||||
# Everything *needs* to have a column in order for parquet to work, so scalar
|
||||
# values have to be assigned something, so here we use `as count` to create
|
||||
# a temporary column called `count`
|
||||
#print(athena.query("select count(*) as count from trajectories.kitware").results().head(1))
|
||||
# Create a cluster
|
||||
cluster = aws.EC2Cluster(
|
||||
env_vars = aws_profile("hay"),
|
||||
key_name = "Ari-Brown-HAY",
|
||||
vpc = "vpc-0823964489ecc1e85",
|
||||
subnet_id = "subnet-05eb26d8649a093e1", # project-subnet-public1-us-east-1a
|
||||
n_workers = 2,
|
||||
region = "us-east-1",
|
||||
bootstrap = True,
|
||||
|
||||
query = athena.query(
|
||||
"""
|
||||
select round(longitude, 3) as lon, count(*) as count
|
||||
from trajectories.kitware
|
||||
where agent = 4
|
||||
group by round(longitude, 3)
|
||||
order by count(*) desc
|
||||
"""
|
||||
)
|
||||
data = query.results()
|
||||
pp.pprint(data.head(10))
|
||||
print(query.runtime)
|
||||
security_groups = ["sg-0f9e555954e863954", # ssh
|
||||
"sg-0b34a3f7398076545", # default
|
||||
"sg-04cd2626d91ac093c"], # dask (8786, 8787)
|
||||
#worker_module = "dask_cuda.cli.dask_cuda_worker", # for running GPU clusters
|
||||
|
||||
#import IPython
|
||||
#IPython.embed()
|
||||
#iam_instance_profile = "S3+SSM+CloudWatch+ECR", # this is actually a dict? what contents???
|
||||
worker_instance_type = "t3.small",
|
||||
ami = "ami-0b0cd81283738558a", # ubuntu 22.04 x86
|
||||
|
||||
security = False)
|
||||
print(cluster)
|
||||
exit()
|
||||
|
||||
# Connect to the cluster
|
||||
client = distributed.Client(cluster)
|
||||
print(client)
|
||||
|
||||
# Practice with a big array
|
||||
import numpy as np
|
||||
import dask.array as da
|
||||
|
||||
large_array = np.random.rand(1000000, 1000000)
|
||||
dask_array = da.from_array(large_array, chunks=(1000, 1000))
|
||||
dask_array = dask_array.persist() # non-blocking
|
||||
|
||||
mean = dask_array.mean().compute()
|
||||
print(mean)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue