Files
infmap/app/app.py
j 2c7f2e2c7c
All checks were successful
Build-Publish / build (linux/amd64) (push) Successful in 4s
Build-Publish / build (linux/arm64) (push) Successful in 12s
Build-Publish / create-manifest (push) Successful in 2s
Build-Publish / publish-template (push) Successful in 8s
Stream real-time section progress during single-host refresh
2026-03-09 18:39:54 +13:00

829 lines
27 KiB
Python

import os
import re
import time
import logging
import threading
from datetime import datetime, timezone
from concurrent.futures import ThreadPoolExecutor, as_completed
import paramiko
BUILD_DATE = '__BUILD_DATE__'
from flask import Flask, render_template, jsonify, Response, stream_with_context
from flask_sqlalchemy import SQLAlchemy
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
logger = logging.getLogger(__name__)
DB_PATH = '/app/data/infmap.db'
app = Flask(__name__)
app.config['SQLALCHEMY_DATABASE_URI'] = f'sqlite:///{DB_PATH}'
db = SQLAlchemy(app)
COLLECTION_INTERVAL = int(os.environ.get('COLLECTION_INTERVAL', 300))
MAX_CONCURRENT_SSH = int(os.environ.get('MAX_CONCURRENT_SSH', 5))
SSH_KEY_PATH = '/app/ssh_key'
INFRA_CONF_PATH = '/app/infrastructure.conf'
# --- Database Model ---
class Server(db.Model):
__tablename__ = 'servers'
id = db.Column(db.Integer, primary_key=True)
group_name = db.Column(db.String(255), nullable=False)
username = db.Column(db.String(255), nullable=False)
hostname = db.Column(db.String(255), nullable=False)
primary_ip = db.Column(db.String(45), default='')
url = db.Column(db.String(1024), default='')
is_online = db.Column(db.Boolean, default=False)
last_collected = db.Column(db.DateTime, nullable=True)
details = db.Column(db.JSON, nullable=True)
notes = db.Column(db.Text, default='')
links = db.Column(db.JSON, default=list) # [{"label": "...", "url": "..."}]
parent_hostname = db.Column(db.String(255), default='')
__table_args__ = (db.UniqueConstraint('username', 'hostname', name='uq_user_host'),)
# --- Config Parsing ---
def parse_infrastructure_conf():
servers = []
current_group = None
current_host = None # track the last host for nesting
try:
with open(INFRA_CONF_PATH) as f:
for line in f:
line = line.rstrip('\n')
stripped = line.strip()
if not stripped or stripped.startswith('#'):
continue
# Server entry: starts with - or -- (with optional leading whitespace)
if stripped.startswith('--'):
is_child = True
rest = stripped[2:].strip()
elif stripped.startswith('-'):
is_child = False
rest = stripped[1:].strip()
elif line[0] not in (' ', '\t') and not stripped.startswith('-'):
# Group header: no leading whitespace, no dash prefix
current_group = stripped
current_host = None
continue
else:
# Legacy: indented without dash, treat as host
is_child = False
rest = stripped
parts = rest.split(None, 1)
entry = parts[0] if parts else ''
url = parts[1] if len(parts) > 1 else ''
if not entry:
continue
if '@' in entry:
user, host = entry.split('@', 1)
else:
user, host = 'infmap', entry
if host:
parent = ''
if is_child and current_host:
parent = current_host
else:
current_host = host.strip()
servers.append({
'group': current_group or 'Default',
'username': user.strip(),
'hostname': host.strip(),
'url': url.strip(),
'parent_hostname': parent,
})
except FileNotFoundError:
logger.error("infrastructure.conf not found at %s", INFRA_CONF_PATH)
return servers
# --- SSH Collection ---
def load_ssh_key():
for key_class in [paramiko.Ed25519Key, paramiko.RSAKey, paramiko.ECDSAKey]:
try:
return key_class.from_private_key_file(SSH_KEY_PATH)
except Exception:
continue
raise RuntimeError(f"Could not load SSH key from {SSH_KEY_PATH}")
def collect_one(entry, ssh_key, progress_cb=None):
"""SSH into a single server, run the gather script, return parsed data."""
try:
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh.connect(
entry['hostname'],
username=entry['username'],
pkey=ssh_key,
timeout=15,
banner_timeout=15,
auth_timeout=15,
)
with open('/app/gather_info.sh') as f:
script = f.read()
stdin, stdout, stderr = ssh.exec_command('bash -s', timeout=60)
stdin.write(script)
stdin.channel.shutdown_write()
if progress_cb:
# Stream line-by-line and report section headers as progress
output_lines = []
for raw_line in stdout:
line = raw_line.rstrip('\n')
output_lines.append(line)
stripped = line.strip()
if stripped.startswith('[') and stripped.endswith(']') and stripped != '[end]':
section = stripped[1:-1].split(':')[0]
progress_cb(section)
output = '\n'.join(output_lines)
else:
output = stdout.read().decode('utf-8', errors='replace')
ssh.close()
data = parse_gather_output(output)
data['is_online'] = True
return data
except Exception as e:
logger.warning("Failed to collect from %s@%s: %s", entry['username'], entry['hostname'], e)
return {'is_online': False, 'error': str(e)}
def parse_gather_output(output):
"""Parse the [section] key=value output from gather_info.sh."""
data = {}
current_section = None
for line in output.split('\n'):
line = line.strip()
if not line:
continue
# Section header: [name] or [name:id]
m = re.match(r'^\[(.+)\]$', line)
if m:
section = m.group(1)
if section == 'end':
break
if ':' in section:
base, name = section.split(':', 1)
if base not in data:
data[base] = []
item = {'_name': name}
data[base].append(item)
current_section = ('list', item)
else:
if section not in data:
data[section] = {}
current_section = ('dict', data[section])
continue
# Key=value
if '=' in line and current_section:
key, _, value = line.partition('=')
key = key.strip()
value = value.strip()
if current_section[0] == 'dict':
section_data = current_section[1]
# Handle repeated keys (e.g., dns server=)
if key in section_data:
if not isinstance(section_data[key], list):
section_data[key] = [section_data[key]]
section_data[key].append(value)
else:
section_data[key] = value
elif current_section[0] == 'list':
current_section[1][key] = value
return data
# --- Collection Loop ---
def collect_all():
entries = parse_infrastructure_conf()
if not entries:
logger.info("No servers configured in infrastructure.conf")
return
try:
ssh_key = load_ssh_key()
except Exception as e:
logger.error("SSH key error: %s", e)
return
logger.info("Collecting from %d servers (max %d concurrent)", len(entries), MAX_CONCURRENT_SSH)
results = {}
with ThreadPoolExecutor(max_workers=MAX_CONCURRENT_SSH) as pool:
futures = {pool.submit(collect_one, e, ssh_key): e for e in entries}
for future in as_completed(futures):
entry = futures[future]
key = f"{entry['username']}@{entry['hostname']}"
try:
results[key] = (entry, future.result(timeout=90))
except Exception as e:
results[key] = (entry, {'is_online': False, 'error': str(e)})
# Update database (all in main collector thread)
with app.app_context():
# Remove servers no longer in config
config_keys = {(e['username'], e['hostname']) for e in entries}
for server in Server.query.all():
if (server.username, server.hostname) not in config_keys:
db.session.delete(server)
for key, (entry, result) in results.items():
server = Server.query.filter_by(
username=entry['username'],
hostname=entry['hostname'],
).first()
if not server:
server = Server(
group_name=entry['group'],
username=entry['username'],
hostname=entry['hostname'],
)
db.session.add(server)
server.group_name = entry['group']
server.parent_hostname = entry.get('parent_hostname', '')
_update_server_from_result(server, entry, result)
db.session.commit()
logger.info("Collection complete, updated %d servers", len(results))
_collect_event = threading.Event()
def trigger_collect():
"""Wake up the collector loop to run immediately."""
_collect_event.set()
def collector_loop():
time.sleep(2) # Brief pause to let Flask start
while True:
try:
collect_all()
except Exception as e:
logger.error("Collection loop error: %s", e)
_collect_event.wait(timeout=COLLECTION_INTERVAL)
_collect_event.clear()
# --- Web Routes ---
@app.route('/')
def index():
all_servers = Server.query.order_by(Server.group_name, Server.primary_ip).all()
# Separate parents and children
children_map = {} # parent_hostname -> [child_servers]
parents = []
for s in all_servers:
if s.parent_hostname:
children_map.setdefault(s.parent_hostname, []).append(s)
else:
parents.append(s)
# Sort children by IP
for hostname in children_map:
children_map[hostname].sort(key=lambda s: _ip_sort_key(s.primary_ip))
# Get group order from config file
config_entries = parse_infrastructure_conf()
group_order = []
for e in config_entries:
g = e['group']
if g not in group_order:
group_order.append(g)
# Group parents, preserving config order
groups = {}
for s in parents:
g = s.group_name or 'Default'
if g not in groups:
groups[g] = []
groups[g].append(s)
# Re-order groups to match config
ordered_groups = {}
for g in group_order:
if g in groups:
ordered_groups[g] = groups.pop(g)
# Append any remaining groups not in config
for g in groups:
ordered_groups[g] = groups[g]
groups = ordered_groups
# Sort servers within each group by IP (numerically)
for g in groups:
groups[g].sort(key=lambda s: _ip_sort_key(s.primary_ip))
return render_template('index.html', groups=groups, children_map=children_map, build_date=BUILD_DATE)
@app.route('/api/version')
def api_version():
return jsonify({'build_date': BUILD_DATE})
@app.route('/api/servers')
def api_servers():
servers = Server.query.all()
result = []
for s in servers:
result.append({
'id': s.id,
'group_name': s.group_name,
'username': s.username,
'hostname': s.hostname,
'primary_ip': s.primary_ip,
'url': s.url,
'is_online': s.is_online,
'last_collected': s.last_collected.isoformat() if s.last_collected else None,
'notes': s.notes,
'links': s.links or [],
'parent_hostname': s.parent_hostname,
'details': s.details,
})
return jsonify(result)
def _update_server_from_result(server, entry, result):
"""Apply collection result to a server record."""
server.is_online = result.get('is_online', False)
server.last_collected = datetime.now(timezone.utc)
server.details = result
server.url = entry.get('url', server.url)
default_iface = ''
routing = result.get('routing', {})
if isinstance(routing, dict):
default_iface = routing.get('interface', '')
primary_ip = ''
for iface in result.get('net', []):
ipv4 = iface.get('ipv4', '')
if not ipv4 or ipv4.startswith('127.'):
continue
iface_name = iface.get('name', '') or iface.get('_name', '')
if iface_name == default_iface:
primary_ip = ipv4
break
if not primary_ip:
primary_ip = ipv4
server.primary_ip = primary_ip
@app.route('/api/refresh', methods=['POST'])
def api_refresh():
trigger_collect()
return jsonify({'ok': True, 'message': 'Collection triggered'})
@app.route('/api/refresh/stream')
def api_refresh_stream():
"""SSE endpoint: collect all servers with progress updates."""
def generate():
entries = parse_infrastructure_conf()
if not entries:
yield f"data: No servers configured\n\n"
yield f"data: [DONE]\n\n"
return
try:
ssh_key = load_ssh_key()
except Exception as e:
yield f"data: SSH key error: {e}\n\n"
yield f"data: [DONE]\n\n"
return
yield f"data: Collecting from {len(entries)} servers...\n\n"
with ThreadPoolExecutor(max_workers=MAX_CONCURRENT_SSH) as pool:
futures = {pool.submit(collect_one, e, ssh_key): e for e in entries}
results = {}
for future in as_completed(futures):
entry = futures[future]
key = f"{entry['username']}@{entry['hostname']}"
try:
result = future.result(timeout=90)
results[key] = (entry, result)
status = 'online' if result.get('is_online') else 'offline'
yield f"data: {entry['hostname']} - {status}\n\n"
except Exception as e:
results[key] = (entry, {'is_online': False, 'error': str(e)})
yield f"data: {entry['hostname']} - error: {e}\n\n"
# Update database
with app.app_context():
config_keys = {(e['username'], e['hostname']) for e in entries}
for server in Server.query.all():
if (server.username, server.hostname) not in config_keys:
db.session.delete(server)
for key, (entry, result) in results.items():
server = Server.query.filter_by(
username=entry['username'], hostname=entry['hostname'],
).first()
if not server:
server = Server(
group_name=entry['group'],
username=entry['username'],
hostname=entry['hostname'],
)
db.session.add(server)
server.group_name = entry['group']
server.parent_hostname = entry.get('parent_hostname', '')
_update_server_from_result(server, entry, result)
db.session.commit()
yield f"data: Collection complete - {len(results)} servers updated\n\n"
yield f"data: [DONE]\n\n"
return Response(stream_with_context(generate()),
mimetype='text/event-stream',
headers={'Cache-Control': 'no-cache', 'X-Accel-Buffering': 'no'})
@app.route('/api/servers/<int:server_id>/refresh/stream')
def api_refresh_one_stream(server_id):
"""SSE endpoint: collect a single server (and its child VMs) with progress updates."""
def generate():
with app.app_context():
server = Server.query.get(server_id)
if not server:
yield f"data: Server not found\n\n"
yield f"data: [DONE]\n\n"
return
hostname = server.hostname
entry = {
'group': server.group_name,
'username': server.username,
'hostname': server.hostname,
'url': server.url,
}
# Find child VMs configured under this host
children = Server.query.filter_by(parent_hostname=hostname).all()
child_entries = [{
'id': c.id,
'group': c.group_name,
'username': c.username,
'hostname': c.hostname,
'url': c.url,
} for c in children]
try:
ssh_key = load_ssh_key()
except Exception as e:
yield f"data: SSH key error: {e}\n\n"
yield f"data: [DONE]\n\n"
return
# Helper to collect one server with streaming progress
def collect_with_progress(srv_entry, srv_id):
progress_msgs = []
last_section = [None]
def on_progress(section):
if section != last_section[0]:
last_section[0] = section
progress_msgs.append(section)
result = collect_one(srv_entry, ssh_key, progress_cb=on_progress)
with app.app_context():
srv = Server.query.get(srv_id)
_update_server_from_result(srv, srv_entry, result)
db.session.commit()
return result, progress_msgs
# Collect the host
import queue as _queue
progress_q = _queue.Queue()
host_done = threading.Event()
def _collect_host():
last_reported = [None]
def on_progress(section):
if section != last_reported[0]:
last_reported[0] = section
progress_q.put(('progress', f"{hostname}: {section}"))
try:
result = collect_one(entry, ssh_key, progress_cb=on_progress)
progress_q.put(('result', result))
except Exception as e:
progress_q.put(('result', {'is_online': False, 'error': str(e)}))
host_done.set()
t = threading.Thread(target=_collect_host)
t.start()
# Stream progress while collecting
while not host_done.is_set():
try:
kind, val = progress_q.get(timeout=0.3)
if kind == 'progress':
yield f"data: {val}\n\n"
elif kind == 'result':
break
except _queue.Empty:
continue
t.join()
# Drain remaining messages
result = None
while not progress_q.empty():
kind, val = progress_q.get_nowait()
if kind == 'progress':
yield f"data: {val}\n\n"
elif kind == 'result':
result = val
if result is None:
result = {'is_online': False, 'error': 'unknown'}
with app.app_context():
server = Server.query.get(server_id)
_update_server_from_result(server, entry, result)
db.session.commit()
if result.get('is_online'):
sys_info = result.get('system', {})
ct_count = len(result.get('container', []))
msg = f"{hostname} - online"
if sys_info.get('os_pretty'):
msg += f" ({sys_info['os_pretty']})"
if ct_count:
msg += f", {ct_count} containers"
yield f"data: {msg}\n\n"
else:
yield f"data: {hostname} - offline: {result.get('error', 'unknown')}\n\n"
# Collect child VMs with progress
for child_entry in child_entries:
child_host = child_entry['hostname']
child_q = _queue.Queue()
child_done = threading.Event()
def _collect_child(ce=child_entry, cq=child_q, cd=child_done):
last_reported = [None]
def on_progress(section):
if section != last_reported[0]:
last_reported[0] = section
cq.put(('progress', f"{ce['hostname']}: {section}"))
try:
r = collect_one(ce, ssh_key, progress_cb=on_progress)
cq.put(('result', r))
except Exception as e:
cq.put(('result', {'is_online': False, 'error': str(e)}))
cd.set()
ct = threading.Thread(target=_collect_child)
ct.start()
while not child_done.is_set():
try:
kind, val = child_q.get(timeout=0.3)
if kind == 'progress':
yield f"data: {val}\n\n"
elif kind == 'result':
break
except _queue.Empty:
continue
ct.join()
child_result = None
while not child_q.empty():
kind, val = child_q.get_nowait()
if kind == 'progress':
yield f"data: {val}\n\n"
elif kind == 'result':
child_result = val
if child_result is None:
child_result = {'is_online': False, 'error': 'unknown'}
with app.app_context():
child_server = Server.query.get(child_entry['id'])
_update_server_from_result(child_server, child_entry, child_result)
db.session.commit()
if child_result.get('is_online'):
child_ct = len(child_result.get('container', []))
msg = f"{child_host} - online"
if child_ct:
msg += f", {child_ct} containers"
yield f"data: {msg}\n\n"
else:
yield f"data: {child_host} - offline: {child_result.get('error', 'unknown')}\n\n"
yield f"data: [DONE]\n\n"
return Response(stream_with_context(generate()),
mimetype='text/event-stream',
headers={'Cache-Control': 'no-cache', 'X-Accel-Buffering': 'no'})
@app.route('/api/servers/<int:server_id>/notes', methods=['PUT'])
def api_update_notes(server_id):
from flask import request
server = Server.query.get_or_404(server_id)
data = request.get_json()
server.notes = data.get('notes', '')
db.session.commit()
return jsonify({'ok': True})
@app.route('/api/servers/<int:server_id>/links', methods=['PUT'])
def api_update_links(server_id):
from flask import request
server = Server.query.get_or_404(server_id)
data = request.get_json()
server.links = data.get('links', [])
db.session.commit()
return jsonify({'ok': True})
def _ip_sort_key(ip_str):
if not ip_str:
return [999, 999, 999, 999]
try:
return [int(x) for x in ip_str.split('.')]
except (ValueError, AttributeError):
return [999, 999, 999, 999]
# --- Jinja2 Helpers ---
@app.template_filter('format_bytes')
def format_bytes(value):
try:
b = int(value)
except (TypeError, ValueError):
return value
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
if abs(b) < 1024:
return f"{b:.1f} {unit}"
b /= 1024
return f"{b:.1f} PB"
@app.template_filter('format_mb')
def format_mb(value):
try:
mb = int(value)
except (TypeError, ValueError):
return value
if mb >= 1024:
return f"{mb / 1024:.1f} GB"
return f"{mb} MB"
@app.template_filter('format_uptime')
def format_uptime(seconds):
try:
s = int(seconds)
except (TypeError, ValueError):
return 'Unknown'
days = s // 86400
hours = (s % 86400) // 3600
if days > 0:
return f"{days}d {hours}h"
minutes = (s % 3600) // 60
return f"{hours}h {minutes}m"
@app.template_filter('temp_color')
def temp_color(temp_c):
try:
t = float(temp_c)
except (TypeError, ValueError):
return '#64748b'
if t >= 90:
return '#ef4444'
if t >= 75:
return '#f97316'
if t >= 60:
return '#eab308'
return '#22c55e'
@app.template_filter('clean_gpu')
def clean_gpu(description):
if not description:
return '-'
s = str(description)
import re
# Strip PCI address prefix (e.g. "01:00.0 ")
s = re.sub(r'^[0-9a-f:.]+\s+', '', s, flags=re.IGNORECASE)
# Strip type prefix
for prefix in ['VGA compatible controller: ', '3D controller: ', 'Display controller: ']:
if s.startswith(prefix):
s = s[len(prefix):]
# Strip revision suffix
s = re.sub(r'\s*\(rev [0-9a-f]+\)\s*$', '', s, flags=re.IGNORECASE)
# Detect manufacturer
manufacturer = ''
s_lower = s.lower()
if 'nvidia' in s_lower:
manufacturer = 'Nvidia'
elif 'advanced micro' in s_lower or 'amd' in s_lower or 'radeon' in s_lower:
manufacturer = 'AMD'
elif 'intel' in s_lower:
manufacturer = 'Intel'
elif 'matrox' in s_lower:
manufacturer = 'Matrox'
# Strip vendor name from string
for vendor in ['NVIDIA Corporation ', 'Advanced Micro Devices, Inc. ', 'AMD ',
'Intel Corporation ', 'Advanced Micro Devices Inc. ',
'Matrox Electronics Systems Ltd. ']:
if s.startswith(vendor):
s = s[len(vendor):]
# Prefer bracketed name if present (e.g. "GA106 [GeForce RTX 3060]" -> "GeForce RTX 3060")
bracket = re.search(r'\[(.+)\]', s)
if bracket:
s = bracket.group(1)
# Clean up model string
# Remove "Lite Hash Rate" and similar marketing suffixes
s = re.sub(r'\s+Lite Hash Rate', '', s)
# Remove slash alternatives (e.g. "RX 6800/6800 XT / 6900 XT" -> "RX 6800")
s = re.sub(r'\s*/[\w\s/]+$', '', s)
# Remove trailing whitespace
s = s.strip()
# Don't duplicate manufacturer if already in the model name
if manufacturer and s:
s_check = s.lower()
if manufacturer.lower() in s_check:
return s
return f"{manufacturer} {s}"
return s or '-'
@app.template_filter('usage_color')
def usage_color(percent):
try:
p = float(percent)
except (TypeError, ValueError):
return '#64748b'
if p >= 90:
return '#ef4444'
if p >= 75:
return '#f97316'
if p >= 60:
return '#eab308'
return '#22c55e'
# --- Main ---
def migrate_db():
"""Add any missing columns to existing tables."""
import sqlite3
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute("PRAGMA table_info(servers)")
existing = {row[1] for row in cursor.fetchall()}
migrations = {
'url': "ALTER TABLE servers ADD COLUMN url VARCHAR(1024) DEFAULT ''",
'notes': "ALTER TABLE servers ADD COLUMN notes TEXT DEFAULT ''",
'parent_hostname': "ALTER TABLE servers ADD COLUMN parent_hostname VARCHAR(255) DEFAULT ''",
'links': "ALTER TABLE servers ADD COLUMN links JSON DEFAULT '[]'",
}
for col, sql in migrations.items():
if col not in existing:
cursor.execute(sql)
logger.info("Added column '%s' to servers table", col)
conn.commit()
conn.close()
if __name__ == '__main__':
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
with app.app_context():
db.create_all()
migrate_db()
logger.info("Database ready at %s", DB_PATH)
# Start collector thread
collector_thread = threading.Thread(target=collector_loop, daemon=True)
collector_thread.start()
app.run(host='0.0.0.0', port=5000, threaded=True)