Files
infmap/app/app.py
j 95c20df073
All checks were successful
Build-Publish / build (linux/amd64) (push) Successful in 4s
Build-Publish / build (linux/arm64) (push) Successful in 12s
Build-Publish / create-manifest (push) Successful in 2s
Build-Publish / publish-template (push) Successful in 9s
Clean GPU names in dashboard and remove stale servers from DB
2026-03-08 13:29:37 +13:00

426 lines
13 KiB
Python

import os
import re
import time
import logging
import threading
from datetime import datetime, timezone
from concurrent.futures import ThreadPoolExecutor, as_completed
import paramiko
from flask import Flask, render_template, jsonify
from flask_sqlalchemy import SQLAlchemy
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
logger = logging.getLogger(__name__)
DB_PATH = '/app/data/infmap.db'
app = Flask(__name__)
app.config['SQLALCHEMY_DATABASE_URI'] = f'sqlite:///{DB_PATH}'
db = SQLAlchemy(app)
COLLECTION_INTERVAL = int(os.environ.get('COLLECTION_INTERVAL', 300))
MAX_CONCURRENT_SSH = int(os.environ.get('MAX_CONCURRENT_SSH', 5))
SSH_KEY_PATH = '/app/ssh_key'
INFRA_CONF_PATH = '/app/infrastructure.conf'
# --- Database Model ---
class Server(db.Model):
__tablename__ = 'servers'
id = db.Column(db.Integer, primary_key=True)
group_name = db.Column(db.String(255), nullable=False)
username = db.Column(db.String(255), nullable=False)
hostname = db.Column(db.String(255), nullable=False)
primary_ip = db.Column(db.String(45), default='')
url = db.Column(db.String(1024), default='')
is_online = db.Column(db.Boolean, default=False)
last_collected = db.Column(db.DateTime, nullable=True)
details = db.Column(db.JSON, nullable=True)
notes = db.Column(db.Text, default='')
__table_args__ = (db.UniqueConstraint('username', 'hostname', name='uq_user_host'),)
# --- Config Parsing ---
def parse_infrastructure_conf():
servers = []
current_group = None
try:
with open(INFRA_CONF_PATH) as f:
for line in f:
line = line.rstrip('\n')
if not line.strip() or line.strip().startswith('#'):
continue
if line[0] not in (' ', '\t'):
current_group = line.strip()
else:
parts = line.strip().split(None, 1)
entry = parts[0] if parts else ''
url = parts[1] if len(parts) > 1 else ''
if '@' in entry:
user, host = entry.split('@', 1)
else:
user, host = 'infmap', entry
if host:
servers.append({
'group': current_group or 'Default',
'username': user.strip(),
'hostname': host.strip(),
'url': url.strip(),
})
except FileNotFoundError:
logger.error("infrastructure.conf not found at %s", INFRA_CONF_PATH)
return servers
# --- SSH Collection ---
def load_ssh_key():
for key_class in [paramiko.Ed25519Key, paramiko.RSAKey, paramiko.ECDSAKey]:
try:
return key_class.from_private_key_file(SSH_KEY_PATH)
except Exception:
continue
raise RuntimeError(f"Could not load SSH key from {SSH_KEY_PATH}")
def collect_one(entry, ssh_key):
"""SSH into a single server, run the gather script, return parsed data."""
try:
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh.connect(
entry['hostname'],
username=entry['username'],
pkey=ssh_key,
timeout=15,
banner_timeout=15,
auth_timeout=15,
)
with open('/app/gather_info.sh') as f:
script = f.read()
stdin, stdout, stderr = ssh.exec_command('bash -s', timeout=60)
stdin.write(script)
stdin.channel.shutdown_write()
output = stdout.read().decode('utf-8', errors='replace')
ssh.close()
data = parse_gather_output(output)
data['is_online'] = True
return data
except Exception as e:
logger.warning("Failed to collect from %s@%s: %s", entry['username'], entry['hostname'], e)
return {'is_online': False, 'error': str(e)}
def parse_gather_output(output):
"""Parse the [section] key=value output from gather_info.sh."""
data = {}
current_section = None
for line in output.split('\n'):
line = line.strip()
if not line:
continue
# Section header: [name] or [name:id]
m = re.match(r'^\[(.+)\]$', line)
if m:
section = m.group(1)
if section == 'end':
break
if ':' in section:
base, name = section.split(':', 1)
if base not in data:
data[base] = []
item = {'_name': name}
data[base].append(item)
current_section = ('list', item)
else:
if section not in data:
data[section] = {}
current_section = ('dict', data[section])
continue
# Key=value
if '=' in line and current_section:
key, _, value = line.partition('=')
key = key.strip()
value = value.strip()
if current_section[0] == 'dict':
section_data = current_section[1]
# Handle repeated keys (e.g., dns server=)
if key in section_data:
if not isinstance(section_data[key], list):
section_data[key] = [section_data[key]]
section_data[key].append(value)
else:
section_data[key] = value
elif current_section[0] == 'list':
current_section[1][key] = value
return data
# --- Collection Loop ---
def collect_all():
entries = parse_infrastructure_conf()
if not entries:
logger.info("No servers configured in infrastructure.conf")
return
try:
ssh_key = load_ssh_key()
except Exception as e:
logger.error("SSH key error: %s", e)
return
logger.info("Collecting from %d servers (max %d concurrent)", len(entries), MAX_CONCURRENT_SSH)
results = {}
with ThreadPoolExecutor(max_workers=MAX_CONCURRENT_SSH) as pool:
futures = {pool.submit(collect_one, e, ssh_key): e for e in entries}
for future in as_completed(futures):
entry = futures[future]
key = f"{entry['username']}@{entry['hostname']}"
try:
results[key] = (entry, future.result(timeout=90))
except Exception as e:
results[key] = (entry, {'is_online': False, 'error': str(e)})
# Update database (all in main collector thread)
with app.app_context():
# Remove servers no longer in config
config_keys = {(e['username'], e['hostname']) for e in entries}
for server in Server.query.all():
if (server.username, server.hostname) not in config_keys:
db.session.delete(server)
for key, (entry, result) in results.items():
server = Server.query.filter_by(
username=entry['username'],
hostname=entry['hostname'],
).first()
if not server:
server = Server(
group_name=entry['group'],
username=entry['username'],
hostname=entry['hostname'],
)
db.session.add(server)
server.group_name = entry['group']
server.url = entry.get('url', '')
server.is_online = result.get('is_online', False)
server.last_collected = datetime.now(timezone.utc)
server.details = result
# Extract primary IP: prefer the interface carrying the default route
default_iface = ''
routing = result.get('routing', {})
if isinstance(routing, dict):
default_iface = routing.get('interface', '')
primary_ip = ''
for iface in result.get('net', []):
ipv4 = iface.get('ipv4', '')
if not ipv4 or ipv4.startswith('127.'):
continue
iface_name = iface.get('name', '') or iface.get('_name', '')
if iface_name == default_iface:
primary_ip = ipv4
break
if not primary_ip:
primary_ip = ipv4
server.primary_ip = primary_ip
db.session.commit()
logger.info("Collection complete, updated %d servers", len(results))
def collector_loop():
time.sleep(10) # Let the app start up
while True:
try:
collect_all()
except Exception as e:
logger.error("Collection loop error: %s", e)
time.sleep(COLLECTION_INTERVAL)
# --- Web Routes ---
@app.route('/')
def index():
servers = Server.query.order_by(Server.group_name, Server.primary_ip).all()
groups = {}
for s in servers:
g = s.group_name or 'Default'
if g not in groups:
groups[g] = []
groups[g].append(s)
# Sort servers within each group by IP (numerically)
for g in groups:
groups[g].sort(key=lambda s: _ip_sort_key(s.primary_ip))
return render_template('index.html', groups=groups)
@app.route('/api/servers')
def api_servers():
servers = Server.query.all()
result = []
for s in servers:
result.append({
'id': s.id,
'group_name': s.group_name,
'username': s.username,
'hostname': s.hostname,
'primary_ip': s.primary_ip,
'url': s.url,
'is_online': s.is_online,
'last_collected': s.last_collected.isoformat() if s.last_collected else None,
'notes': s.notes,
'details': s.details,
})
return jsonify(result)
@app.route('/api/servers/<int:server_id>/notes', methods=['PUT'])
def api_update_notes(server_id):
from flask import request
server = Server.query.get_or_404(server_id)
data = request.get_json()
server.notes = data.get('notes', '')
db.session.commit()
return jsonify({'ok': True})
def _ip_sort_key(ip_str):
if not ip_str:
return [999, 999, 999, 999]
try:
return [int(x) for x in ip_str.split('.')]
except (ValueError, AttributeError):
return [999, 999, 999, 999]
# --- Jinja2 Helpers ---
@app.template_filter('format_bytes')
def format_bytes(value):
try:
b = int(value)
except (TypeError, ValueError):
return value
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
if abs(b) < 1024:
return f"{b:.1f} {unit}"
b /= 1024
return f"{b:.1f} PB"
@app.template_filter('format_mb')
def format_mb(value):
try:
mb = int(value)
except (TypeError, ValueError):
return value
if mb >= 1024:
return f"{mb / 1024:.1f} GB"
return f"{mb} MB"
@app.template_filter('format_uptime')
def format_uptime(seconds):
try:
s = int(seconds)
except (TypeError, ValueError):
return 'Unknown'
days = s // 86400
hours = (s % 86400) // 3600
if days > 0:
return f"{days}d {hours}h"
minutes = (s % 3600) // 60
return f"{hours}h {minutes}m"
@app.template_filter('temp_color')
def temp_color(temp_c):
try:
t = float(temp_c)
except (TypeError, ValueError):
return '#64748b'
if t >= 90:
return '#ef4444'
if t >= 75:
return '#f97316'
if t >= 60:
return '#eab308'
return '#22c55e'
@app.template_filter('clean_gpu')
def clean_gpu(description):
if not description:
return '-'
s = str(description)
# Strip PCI address prefix (e.g. "01:00.0 ")
import re
s = re.sub(r'^[0-9a-f:.]+\s+', '', s, flags=re.IGNORECASE)
# Strip type prefix
for prefix in ['VGA compatible controller: ', '3D controller: ', 'Display controller: ']:
if s.startswith(prefix):
s = s[len(prefix):]
# Strip common vendor prefixes
for vendor in ['NVIDIA Corporation ', 'Advanced Micro Devices, Inc. ', 'AMD ', 'Intel Corporation ',
'Advanced Micro Devices Inc. ', 'Matrox Electronics Systems Ltd. ']:
if s.startswith(vendor):
s = s[len(vendor):]
# Strip revision suffix
s = re.sub(r'\s*\(rev [0-9a-f]+\)\s*$', '', s, flags=re.IGNORECASE)
# Prefer bracketed name if present (e.g. "GA106 [GeForce RTX 3060]" -> "GeForce RTX 3060")
bracket = re.search(r'\[(.+)\]', s)
if bracket:
s = bracket.group(1)
return s.strip()
@app.template_filter('usage_color')
def usage_color(percent):
try:
p = float(percent)
except (TypeError, ValueError):
return '#64748b'
if p >= 90:
return '#ef4444'
if p >= 75:
return '#f97316'
if p >= 60:
return '#eab308'
return '#22c55e'
# --- Main ---
if __name__ == '__main__':
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
with app.app_context():
db.create_all()
logger.info("Database ready at %s", DB_PATH)
# Start collector thread
collector_thread = threading.Thread(target=collector_loop, daemon=True)
collector_thread.start()
app.run(host='0.0.0.0', port=5000, threaded=True)