forked from bellwether/minerva
189 lines
6.4 KiB
Python
189 lines
6.4 KiB
Python
import time
|
|
import math
|
|
import datetime
|
|
#from pexpect import pxssh
|
|
from fabric import Connection
|
|
import paramiko.ssh_exception
|
|
import shlex
|
|
import threading
|
|
import os
|
|
import minerva
|
|
import select
|
|
|
|
class Machine(minerva.Remote):
|
|
def __init__(self,
|
|
pier,
|
|
ami = "ami-0a538467cc9da9bb2", # ubuntu 22
|
|
instance_type = "t2.micro",
|
|
variables = {},
|
|
username = None,
|
|
key_pair = (None, None),
|
|
name = "Minerva Instance",
|
|
public = True,
|
|
disk_size = 8):
|
|
|
|
super().__init__(None, username, key_pair[1], name)
|
|
|
|
self.pier = pier
|
|
self.ami = ami
|
|
self.instance_type = instance_type
|
|
self.username = username
|
|
self.key_pair = key_pair
|
|
self.variables = variables
|
|
self.name = name
|
|
self.instance_id = None
|
|
self.ready = False
|
|
self.info = None
|
|
self.ssh = None
|
|
self.started = False
|
|
self.terminated = False
|
|
self.public = public
|
|
self.disk_size = disk_size
|
|
self.ip = None # tracking which IP we're using for our connection
|
|
|
|
|
|
def create(self):
|
|
if self.info:
|
|
return
|
|
|
|
iam = {'Name': self.pier.iam} if self.pier.iam else {}
|
|
res = self.pier.ec2.run_instances(
|
|
ImageId = self.ami,
|
|
InstanceType = self.instance_type,
|
|
KeyName = self.key_pair[0] or self.pier.key_pair_name,
|
|
MinCount = 1,
|
|
MaxCount = 1,
|
|
TagSpecifications = [{'ResourceType': 'instance',
|
|
'Tags': [{'Key': 'Name', 'Value': self.name}]}],
|
|
NetworkInterfaces = [{'AssociatePublicIpAddress': self.public,
|
|
'SubnetId': self.pier.subnet_id,
|
|
'Groups': self.pier.groups,
|
|
'DeviceIndex': 0}],
|
|
BlockDeviceMappings = [{'DeviceName': '/dev/sda1',
|
|
'Ebs': {'VolumeSize': self.disk_size,
|
|
'DeleteOnTermination': True}}],
|
|
IamInstanceProfile = iam,
|
|
Monitoring = {'Enabled': True}
|
|
)
|
|
|
|
self.info = res['Instances'][0]
|
|
self.private_ip = self.info['NetworkInterfaces'][0]['PrivateIpAddress']
|
|
self.instance_id = self.info['InstanceId']
|
|
|
|
# TODO there should be a check here in case some instances fail to
|
|
# start up in a timely manner
|
|
# Start a countdown in the background
|
|
# to give time for the instance to start up
|
|
wait_time = 180
|
|
self.thread = threading.Thread(target = self.wait,
|
|
args = (wait_time,),
|
|
daemon = True)
|
|
self.thread.start()
|
|
|
|
return self # allows chaining
|
|
|
|
|
|
def status(self):
|
|
resp = self.pier.ec2.describe_instance_status(InstanceIds=[self.info['InstanceId']],
|
|
IncludeAllInstances=True)
|
|
return resp['InstanceStatuses'][0]['InstanceState']['Name']
|
|
|
|
|
|
# Only used for joining the initial startup thread
|
|
def join(self):
|
|
self.thread.join()
|
|
|
|
|
|
# Wait until the machine is ready (max 180 seconds)
|
|
def wait(self, n):
|
|
i = 0
|
|
# Time for the server to show as "running"
|
|
# and time for the server to finish getting daemons running
|
|
while self.status() != "running":
|
|
time.sleep(10)
|
|
i += 1
|
|
|
|
if i > (n / 10):
|
|
reason = f"{self.info['InstanceId']} took too long to start ({i} attempts)"
|
|
raise Exception(reason)
|
|
|
|
self.started = datetime.datetime.now()
|
|
|
|
|
|
# alternatively, could maybe implement this with SSM so that we can access
|
|
# private subnets? TODO
|
|
def login(self):
|
|
if self.ssh:
|
|
return True
|
|
|
|
if not self.public:
|
|
raise Exception("Can only log into server that has a public IP")
|
|
|
|
# Machine must be running first, so we need to wait for the countdown to finish
|
|
self.join()
|
|
|
|
resp = self.pier.ec2.describe_instances(InstanceIds=[self.info['InstanceId']])
|
|
self.description = resp['Reservations'][0]['Instances'][0]
|
|
self.public_ip = self.description['PublicIpAddress']
|
|
|
|
print(f"\t{self.name} ({self.info['InstanceId']}\t- {self.instance_type}) => {self.public_ip} ({self.private_ip})")
|
|
|
|
self.ip = self.public_ip or self.private_ip
|
|
self.ssh = Connection(self.ip,
|
|
self.username,
|
|
connect_kwargs = {
|
|
"key_filename": self.key_pair[1] #self.pier.key_path
|
|
}
|
|
)
|
|
|
|
i = 0
|
|
max_wait = 120
|
|
# Time for the server to get SSH up and running
|
|
while True:
|
|
try:
|
|
self.ssh.open()
|
|
break
|
|
|
|
except paramiko.ssh_exception.NoValidConnectionsError:
|
|
time.sleep(10)
|
|
i += 1
|
|
|
|
if i > (max_wait / 10):
|
|
reason = f"{self.info['InstanceId']} took too long to start ssh ({i} attempts)"
|
|
raise Exception(reason)
|
|
|
|
return True
|
|
|
|
|
|
def terminate(self):
|
|
if self.terminated:
|
|
return
|
|
|
|
self.pier.ec2.terminate_instances(
|
|
InstanceIds=[self.info['InstanceId']],
|
|
DryRun=False
|
|
)
|
|
print(f"terminated {self.name} ({self.info['InstanceId']})")
|
|
self.terminated = datetime.datetime.now()
|
|
|
|
|
|
def run_time(self):
|
|
now = datetime.datetime.now()
|
|
start_time = self.started or now # what if AWS hasn't made our start time available?
|
|
end_time = self.terminated or now # what if we're still running?
|
|
return end_time - start_time
|
|
|
|
|
|
def cost(self):
|
|
minutes = math.ceil(self.run_time().seconds / 60)
|
|
|
|
instance = list(filter(lambda x: x['Instance'] == self.instance_type,
|
|
minerva.AWS_INSTANCES))[0]
|
|
|
|
per_hour = instance['Price']
|
|
if per_hour == 'unavailable':
|
|
return None
|
|
per_hour = float(per_hour[1:]) # strip the leading $
|
|
|
|
return (minutes / 60) * per_hour
|
|
|