import os import re import time import logging import threading from datetime import datetime, timezone from concurrent.futures import ThreadPoolExecutor, as_completed import paramiko BUILD_DATE = '__BUILD_DATE__' from flask import Flask, render_template, jsonify, Response, stream_with_context from flask_sqlalchemy import SQLAlchemy logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') logger = logging.getLogger(__name__) DB_PATH = '/app/data/infmap.db' app = Flask(__name__) app.config['SQLALCHEMY_DATABASE_URI'] = f'sqlite:///{DB_PATH}' db = SQLAlchemy(app) COLLECTION_INTERVAL = int(os.environ.get('COLLECTION_INTERVAL', 300)) MAX_CONCURRENT_SSH = int(os.environ.get('MAX_CONCURRENT_SSH', 5)) SSH_KEY_PATH = '/app/ssh_key' INFRA_CONF_PATH = '/app/infrastructure.conf' # --- Database Model --- class Server(db.Model): __tablename__ = 'servers' id = db.Column(db.Integer, primary_key=True) group_name = db.Column(db.String(255), nullable=False) username = db.Column(db.String(255), nullable=False) hostname = db.Column(db.String(255), nullable=False) primary_ip = db.Column(db.String(45), default='') url = db.Column(db.String(1024), default='') is_online = db.Column(db.Boolean, default=False) last_collected = db.Column(db.DateTime, nullable=True) details = db.Column(db.JSON, nullable=True) notes = db.Column(db.Text, default='') links = db.Column(db.JSON, default=list) # [{"label": "...", "url": "..."}] parent_hostname = db.Column(db.String(255), default='') __table_args__ = (db.UniqueConstraint('username', 'hostname', name='uq_user_host'),) # --- Config Parsing --- def parse_infrastructure_conf(): servers = [] current_group = None current_host = None # track the last host for nesting try: with open(INFRA_CONF_PATH) as f: for line in f: line = line.rstrip('\n') stripped = line.strip() if not stripped or stripped.startswith('#'): continue # Server entry: starts with - or -- (with optional leading whitespace) if stripped.startswith('--'): is_child = True rest = stripped[2:].strip() elif stripped.startswith('-'): is_child = False rest = stripped[1:].strip() elif line[0] not in (' ', '\t') and not stripped.startswith('-'): # Group header: no leading whitespace, no dash prefix current_group = stripped current_host = None continue else: # Legacy: indented without dash, treat as host is_child = False rest = stripped parts = rest.split(None, 1) entry = parts[0] if parts else '' url = parts[1] if len(parts) > 1 else '' if not entry: continue if '@' in entry: user, host = entry.split('@', 1) else: user, host = 'infmap', entry if host: parent = '' if is_child and current_host: parent = current_host else: current_host = host.strip() servers.append({ 'group': current_group or 'Default', 'username': user.strip(), 'hostname': host.strip(), 'url': url.strip(), 'parent_hostname': parent, }) except FileNotFoundError: logger.error("infrastructure.conf not found at %s", INFRA_CONF_PATH) return servers # --- SSH Collection --- def load_ssh_key(): for key_class in [paramiko.Ed25519Key, paramiko.RSAKey, paramiko.ECDSAKey]: try: return key_class.from_private_key_file(SSH_KEY_PATH) except Exception: continue raise RuntimeError(f"Could not load SSH key from {SSH_KEY_PATH}") def collect_one(entry, ssh_key, progress_cb=None): """SSH into a single server, run the gather script, return parsed data.""" try: ssh = paramiko.SSHClient() ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) ssh.connect( entry['hostname'], username=entry['username'], pkey=ssh_key, timeout=15, banner_timeout=15, auth_timeout=15, ) with open('/app/gather_info.sh') as f: script = f.read() stdin, stdout, stderr = ssh.exec_command('bash -s', timeout=60) stdin.write(script) stdin.channel.shutdown_write() if progress_cb: # Stream line-by-line and report section headers as progress output_lines = [] for raw_line in stdout: line = raw_line.rstrip('\n') output_lines.append(line) stripped = line.strip() if stripped.startswith('[') and stripped.endswith(']') and stripped != '[end]': section = stripped[1:-1].split(':')[0] progress_cb(section) output = '\n'.join(output_lines) else: output = stdout.read().decode('utf-8', errors='replace') ssh.close() data = parse_gather_output(output) data['is_online'] = True return data except Exception as e: logger.warning("Failed to collect from %s@%s: %s", entry['username'], entry['hostname'], e) return {'is_online': False, 'error': str(e)} def parse_gather_output(output): """Parse the [section] key=value output from gather_info.sh.""" data = {} current_section = None for line in output.split('\n'): line = line.strip() if not line: continue # Section header: [name] or [name:id] m = re.match(r'^\[(.+)\]$', line) if m: section = m.group(1) if section == 'end': break if ':' in section: base, name = section.split(':', 1) if base not in data: data[base] = [] item = {'_name': name} data[base].append(item) current_section = ('list', item) else: if section not in data: data[section] = {} current_section = ('dict', data[section]) continue # Key=value if '=' in line and current_section: key, _, value = line.partition('=') key = key.strip() value = value.strip() if current_section[0] == 'dict': section_data = current_section[1] # Handle repeated keys (e.g., dns server=) if key in section_data: if not isinstance(section_data[key], list): section_data[key] = [section_data[key]] section_data[key].append(value) else: section_data[key] = value elif current_section[0] == 'list': current_section[1][key] = value return data # --- Collection Loop --- def collect_all(): entries = parse_infrastructure_conf() if not entries: logger.info("No servers configured in infrastructure.conf") return try: ssh_key = load_ssh_key() except Exception as e: logger.error("SSH key error: %s", e) return logger.info("Collecting from %d servers (max %d concurrent)", len(entries), MAX_CONCURRENT_SSH) results = {} with ThreadPoolExecutor(max_workers=MAX_CONCURRENT_SSH) as pool: futures = {pool.submit(collect_one, e, ssh_key): e for e in entries} for future in as_completed(futures): entry = futures[future] key = f"{entry['username']}@{entry['hostname']}" try: results[key] = (entry, future.result(timeout=90)) except Exception as e: results[key] = (entry, {'is_online': False, 'error': str(e)}) # Update database (all in main collector thread) with app.app_context(): for key, (entry, result) in results.items(): server = Server.query.filter_by( username=entry['username'], hostname=entry['hostname'], ).first() if not server: server = Server( group_name=entry['group'], username=entry['username'], hostname=entry['hostname'], ) db.session.add(server) server.group_name = entry['group'] server.parent_hostname = entry.get('parent_hostname', '') _update_server_from_result(server, entry, result) db.session.commit() logger.info("Collection complete, updated %d servers", len(results)) _collect_event = threading.Event() def trigger_collect(): """Wake up the collector loop to run immediately.""" _collect_event.set() def collector_loop(): time.sleep(2) # Brief pause to let Flask start while True: try: collect_all() except Exception as e: logger.error("Collection loop error: %s", e) _collect_event.wait(timeout=COLLECTION_INTERVAL) _collect_event.clear() # --- Web Routes --- @app.route('/') def index(): # Get group order from config file and filter to configured servers only config_entries = parse_infrastructure_conf() config_keys = {(e['username'], e['hostname']) for e in config_entries} group_order = [] for e in config_entries: g = e['group'] if g not in group_order: group_order.append(g) all_servers = Server.query.order_by(Server.group_name, Server.primary_ip).all() # Separate parents and children (only servers in current config) children_map = {} # parent_hostname -> [child_servers] parents = [] for s in all_servers: if (s.username, s.hostname) not in config_keys: continue if s.parent_hostname: children_map.setdefault(s.parent_hostname, []).append(s) else: parents.append(s) # Sort children by IP for hostname in children_map: children_map[hostname].sort(key=lambda s: _ip_sort_key(s.primary_ip)) # Group parents, preserving config order groups = {} for s in parents: g = s.group_name or 'Default' if g not in groups: groups[g] = [] groups[g].append(s) # Re-order groups to match config ordered_groups = {} for g in group_order: if g in groups: ordered_groups[g] = groups.pop(g) # Append any remaining groups not in config for g in groups: ordered_groups[g] = groups[g] groups = ordered_groups # Sort servers within each group by IP (numerically) for g in groups: groups[g].sort(key=lambda s: _ip_sort_key(s.primary_ip)) return render_template('index.html', groups=groups, children_map=children_map, build_date=BUILD_DATE) @app.route('/api/version') def api_version(): return jsonify({'build_date': BUILD_DATE}) @app.route('/api/servers') def api_servers(): servers = Server.query.all() result = [] for s in servers: result.append({ 'id': s.id, 'group_name': s.group_name, 'username': s.username, 'hostname': s.hostname, 'primary_ip': s.primary_ip, 'url': s.url, 'is_online': s.is_online, 'last_collected': s.last_collected.isoformat() if s.last_collected else None, 'notes': s.notes, 'links': s.links or [], 'parent_hostname': s.parent_hostname, 'details': s.details, }) return jsonify(result) def _update_server_from_result(server, entry, result): """Apply collection result to a server record.""" server.is_online = result.get('is_online', False) server.last_collected = datetime.now(timezone.utc) server.details = result server.url = entry.get('url', server.url) default_iface = '' routing = result.get('routing', {}) if isinstance(routing, dict): default_iface = routing.get('interface', '') primary_ip = '' for iface in result.get('net', []): ipv4 = iface.get('ipv4', '') if not ipv4 or ipv4.startswith('127.'): continue iface_name = iface.get('name', '') or iface.get('_name', '') if iface_name == default_iface: primary_ip = ipv4 break if not primary_ip: primary_ip = ipv4 server.primary_ip = primary_ip @app.route('/api/refresh', methods=['POST']) def api_refresh(): trigger_collect() return jsonify({'ok': True, 'message': 'Collection triggered'}) @app.route('/api/refresh/stream') def api_refresh_stream(): """SSE endpoint: collect all servers with progress updates.""" def generate(): entries = parse_infrastructure_conf() if not entries: yield f"data: No servers configured\n\n" yield f"data: [DONE]\n\n" return try: ssh_key = load_ssh_key() except Exception as e: yield f"data: SSH key error: {e}\n\n" yield f"data: [DONE]\n\n" return yield f"data: Collecting from {len(entries)} servers...\n\n" with ThreadPoolExecutor(max_workers=MAX_CONCURRENT_SSH) as pool: futures = {pool.submit(collect_one, e, ssh_key): e for e in entries} results = {} for future in as_completed(futures): entry = futures[future] key = f"{entry['username']}@{entry['hostname']}" try: result = future.result(timeout=90) results[key] = (entry, result) status = 'online' if result.get('is_online') else 'offline' yield f"data: {entry['hostname']} - {status}\n\n" except Exception as e: results[key] = (entry, {'is_online': False, 'error': str(e)}) yield f"data: {entry['hostname']} - error: {e}\n\n" # Update database with app.app_context(): for key, (entry, result) in results.items(): server = Server.query.filter_by( username=entry['username'], hostname=entry['hostname'], ).first() if not server: server = Server( group_name=entry['group'], username=entry['username'], hostname=entry['hostname'], ) db.session.add(server) server.group_name = entry['group'] server.parent_hostname = entry.get('parent_hostname', '') _update_server_from_result(server, entry, result) db.session.commit() yield f"data: Collection complete - {len(results)} servers updated\n\n" yield f"data: [DONE]\n\n" return Response(stream_with_context(generate()), mimetype='text/event-stream', headers={'Cache-Control': 'no-cache', 'X-Accel-Buffering': 'no'}) @app.route('/api/servers//refresh/stream') def api_refresh_one_stream(server_id): """SSE endpoint: collect a single server (and its child VMs) with progress updates.""" def generate(): with app.app_context(): server = Server.query.get(server_id) if not server: yield f"data: Server not found\n\n" yield f"data: [DONE]\n\n" return hostname = server.hostname entry = { 'group': server.group_name, 'username': server.username, 'hostname': server.hostname, 'url': server.url, } # Find child VMs configured under this host children = Server.query.filter_by(parent_hostname=hostname).all() child_entries = [{ 'id': c.id, 'group': c.group_name, 'username': c.username, 'hostname': c.hostname, 'url': c.url, } for c in children] try: ssh_key = load_ssh_key() except Exception as e: yield f"data: SSH key error: {e}\n\n" yield f"data: [DONE]\n\n" return # Helper to collect one server with streaming progress def collect_with_progress(srv_entry, srv_id): progress_msgs = [] last_section = [None] def on_progress(section): if section != last_section[0]: last_section[0] = section progress_msgs.append(section) result = collect_one(srv_entry, ssh_key, progress_cb=on_progress) with app.app_context(): srv = Server.query.get(srv_id) _update_server_from_result(srv, srv_entry, result) db.session.commit() return result, progress_msgs # Collect the host import queue as _queue progress_q = _queue.Queue() host_done = threading.Event() def _collect_host(): last_reported = [None] def on_progress(section): if section != last_reported[0]: last_reported[0] = section progress_q.put(('progress', f"{hostname}: {section}")) try: result = collect_one(entry, ssh_key, progress_cb=on_progress) progress_q.put(('result', result)) except Exception as e: progress_q.put(('result', {'is_online': False, 'error': str(e)})) host_done.set() t = threading.Thread(target=_collect_host) t.start() # Stream progress while collecting result = None while not host_done.is_set() or not progress_q.empty(): try: kind, val = progress_q.get(timeout=0.3) if kind == 'progress': yield f"data: {val}\n\n" elif kind == 'result': result = val except _queue.Empty: continue t.join() # Drain anything remaining while not progress_q.empty(): kind, val = progress_q.get_nowait() if kind == 'progress': yield f"data: {val}\n\n" elif kind == 'result': result = val if result is None: result = {'is_online': False, 'error': 'unknown'} with app.app_context(): server = Server.query.get(server_id) _update_server_from_result(server, entry, result) db.session.commit() if result.get('is_online'): sys_info = result.get('system', {}) ct_count = len(result.get('container', [])) msg = f"{hostname} - online" if sys_info.get('os_pretty'): msg += f" ({sys_info['os_pretty']})" if ct_count: msg += f", {ct_count} containers" yield f"data: {msg}\n\n" else: yield f"data: {hostname} - offline: {result.get('error', 'unknown')}\n\n" # Collect child VMs with progress for child_entry in child_entries: child_host = child_entry['hostname'] child_q = _queue.Queue() child_done = threading.Event() def _collect_child(ce=child_entry, cq=child_q, cd=child_done): last_reported = [None] def on_progress(section): if section != last_reported[0]: last_reported[0] = section cq.put(('progress', f"{ce['hostname']}: {section}")) try: r = collect_one(ce, ssh_key, progress_cb=on_progress) cq.put(('result', r)) except Exception as e: cq.put(('result', {'is_online': False, 'error': str(e)})) cd.set() ct = threading.Thread(target=_collect_child) ct.start() child_result = None while not child_done.is_set() or not child_q.empty(): try: kind, val = child_q.get(timeout=0.3) if kind == 'progress': yield f"data: {val}\n\n" elif kind == 'result': child_result = val except _queue.Empty: continue ct.join() while not child_q.empty(): kind, val = child_q.get_nowait() if kind == 'progress': yield f"data: {val}\n\n" elif kind == 'result': child_result = val if child_result is None: child_result = {'is_online': False, 'error': 'unknown'} with app.app_context(): child_server = Server.query.get(child_entry['id']) _update_server_from_result(child_server, child_entry, child_result) db.session.commit() if child_result.get('is_online'): child_ct = len(child_result.get('container', [])) msg = f"{child_host} - online" if child_ct: msg += f", {child_ct} containers" yield f"data: {msg}\n\n" else: yield f"data: {child_host} - offline: {child_result.get('error', 'unknown')}\n\n" yield f"data: [DONE]\n\n" return Response(stream_with_context(generate()), mimetype='text/event-stream', headers={'Cache-Control': 'no-cache', 'X-Accel-Buffering': 'no'}) @app.route('/api/servers//notes', methods=['PUT']) def api_update_notes(server_id): from flask import request server = Server.query.get_or_404(server_id) data = request.get_json() server.notes = data.get('notes', '') db.session.commit() return jsonify({'ok': True}) @app.route('/api/servers//links', methods=['PUT']) def api_update_links(server_id): from flask import request server = Server.query.get_or_404(server_id) data = request.get_json() server.links = data.get('links', []) db.session.commit() return jsonify({'ok': True}) @app.route('/api/all-notes') def api_all_notes(): servers = Server.query.filter(Server.notes != '', Server.notes != None).all() return jsonify([{ 'id': s.id, 'hostname': s.hostname, 'group_name': s.group_name, 'notes': s.notes, } for s in servers]) @app.route('/api/all-links') def api_all_links(): servers = Server.query.filter(Server.links != None, Server.links != '[]').all() result = [] for s in servers: links = s.links or [] if links: result.append({ 'id': s.id, 'hostname': s.hostname, 'group_name': s.group_name, 'links': links, }) return jsonify(result) def _ip_sort_key(ip_str): if not ip_str: return [999, 999, 999, 999] try: return [int(x) for x in ip_str.split('.')] except (ValueError, AttributeError): return [999, 999, 999, 999] # --- Jinja2 Helpers --- @app.template_filter('format_bytes') def format_bytes(value): try: b = int(value) except (TypeError, ValueError): return value for unit in ['B', 'KB', 'MB', 'GB', 'TB']: if abs(b) < 1024: return f"{b:.1f} {unit}" b /= 1024 return f"{b:.1f} PB" @app.template_filter('format_mb') def format_mb(value): try: mb = int(value) except (TypeError, ValueError): return value if mb >= 1024: return f"{mb / 1024:.1f} GB" return f"{mb} MB" @app.template_filter('format_uptime') def format_uptime(seconds): try: s = int(seconds) except (TypeError, ValueError): return 'Unknown' days = s // 86400 hours = (s % 86400) // 3600 if days > 0: return f"{days}d {hours}h" minutes = (s % 3600) // 60 return f"{hours}h {minutes}m" @app.template_filter('temp_color') def temp_color(temp_c): try: t = float(temp_c) except (TypeError, ValueError): return '#64748b' if t >= 90: return '#ef4444' if t >= 75: return '#f97316' if t >= 60: return '#eab308' return '#22c55e' @app.template_filter('clean_gpu') def clean_gpu(description): if not description: return '-' s = str(description) import re # Strip PCI address prefix (e.g. "01:00.0 ") s = re.sub(r'^[0-9a-f:.]+\s+', '', s, flags=re.IGNORECASE) # Strip type prefix for prefix in ['VGA compatible controller: ', '3D controller: ', 'Display controller: ']: if s.startswith(prefix): s = s[len(prefix):] # Strip revision suffix s = re.sub(r'\s*\(rev [0-9a-f]+\)\s*$', '', s, flags=re.IGNORECASE) # Detect manufacturer manufacturer = '' s_lower = s.lower() if 'nvidia' in s_lower: manufacturer = 'Nvidia' elif 'advanced micro' in s_lower or 'amd' in s_lower or 'radeon' in s_lower: manufacturer = 'AMD' elif 'intel' in s_lower: manufacturer = 'Intel' elif 'matrox' in s_lower: manufacturer = 'Matrox' # Strip vendor name from string for vendor in ['NVIDIA Corporation ', 'Advanced Micro Devices, Inc. ', 'AMD ', 'Intel Corporation ', 'Advanced Micro Devices Inc. ', 'Matrox Electronics Systems Ltd. ']: if s.startswith(vendor): s = s[len(vendor):] # Prefer bracketed name if present (e.g. "GA106 [GeForce RTX 3060]" -> "GeForce RTX 3060") bracket = re.search(r'\[(.+)\]', s) if bracket: s = bracket.group(1) # Clean up model string # Remove "Lite Hash Rate" and similar marketing suffixes s = re.sub(r'\s+Lite Hash Rate', '', s) # Remove slash alternatives (e.g. "RX 6800/6800 XT / 6900 XT" -> "RX 6800") s = re.sub(r'\s*/[\w\s/]+$', '', s) # Remove trailing whitespace s = s.strip() # Don't duplicate manufacturer if already in the model name if manufacturer and s: s_check = s.lower() if manufacturer.lower() in s_check: return s return f"{manufacturer} {s}" return s or '-' @app.template_filter('usage_color') def usage_color(percent): try: p = float(percent) except (TypeError, ValueError): return '#64748b' if p >= 90: return '#ef4444' if p >= 75: return '#f97316' if p >= 60: return '#eab308' return '#22c55e' # --- Main --- def migrate_db(): """Add any missing columns to existing tables.""" import sqlite3 conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() cursor.execute("PRAGMA table_info(servers)") existing = {row[1] for row in cursor.fetchall()} migrations = { 'url': "ALTER TABLE servers ADD COLUMN url VARCHAR(1024) DEFAULT ''", 'notes': "ALTER TABLE servers ADD COLUMN notes TEXT DEFAULT ''", 'parent_hostname': "ALTER TABLE servers ADD COLUMN parent_hostname VARCHAR(255) DEFAULT ''", 'links': "ALTER TABLE servers ADD COLUMN links JSON DEFAULT '[]'", } for col, sql in migrations.items(): if col not in existing: cursor.execute(sql) logger.info("Added column '%s' to servers table", col) conn.commit() conn.close() if __name__ == '__main__': os.makedirs(os.path.dirname(DB_PATH), exist_ok=True) with app.app_context(): db.create_all() migrate_db() logger.info("Database ready at %s", DB_PATH) # Start collector thread collector_thread = threading.Thread(target=collector_loop, daemon=True) collector_thread.start() app.run(host='0.0.0.0', port=5000, threaded=True)