import os import re import time import logging import threading from datetime import datetime, timezone from concurrent.futures import ThreadPoolExecutor, as_completed import paramiko BUILD_DATE = '__BUILD_DATE__' from flask import Flask, render_template, jsonify from flask_sqlalchemy import SQLAlchemy logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') logger = logging.getLogger(__name__) DB_PATH = '/app/data/infmap.db' app = Flask(__name__) app.config['SQLALCHEMY_DATABASE_URI'] = f'sqlite:///{DB_PATH}' db = SQLAlchemy(app) COLLECTION_INTERVAL = int(os.environ.get('COLLECTION_INTERVAL', 300)) MAX_CONCURRENT_SSH = int(os.environ.get('MAX_CONCURRENT_SSH', 5)) SSH_KEY_PATH = '/app/ssh_key' INFRA_CONF_PATH = '/app/infrastructure.conf' # --- Database Model --- class Server(db.Model): __tablename__ = 'servers' id = db.Column(db.Integer, primary_key=True) group_name = db.Column(db.String(255), nullable=False) username = db.Column(db.String(255), nullable=False) hostname = db.Column(db.String(255), nullable=False) primary_ip = db.Column(db.String(45), default='') url = db.Column(db.String(1024), default='') is_online = db.Column(db.Boolean, default=False) last_collected = db.Column(db.DateTime, nullable=True) details = db.Column(db.JSON, nullable=True) notes = db.Column(db.Text, default='') __table_args__ = (db.UniqueConstraint('username', 'hostname', name='uq_user_host'),) # --- Config Parsing --- def parse_infrastructure_conf(): servers = [] current_group = None try: with open(INFRA_CONF_PATH) as f: for line in f: line = line.rstrip('\n') if not line.strip() or line.strip().startswith('#'): continue if line[0] not in (' ', '\t'): current_group = line.strip() else: parts = line.strip().split(None, 1) entry = parts[0] if parts else '' url = parts[1] if len(parts) > 1 else '' if '@' in entry: user, host = entry.split('@', 1) else: user, host = 'infmap', entry if host: servers.append({ 'group': current_group or 'Default', 'username': user.strip(), 'hostname': host.strip(), 'url': url.strip(), }) except FileNotFoundError: logger.error("infrastructure.conf not found at %s", INFRA_CONF_PATH) return servers # --- SSH Collection --- def load_ssh_key(): for key_class in [paramiko.Ed25519Key, paramiko.RSAKey, paramiko.ECDSAKey]: try: return key_class.from_private_key_file(SSH_KEY_PATH) except Exception: continue raise RuntimeError(f"Could not load SSH key from {SSH_KEY_PATH}") def collect_one(entry, ssh_key): """SSH into a single server, run the gather script, return parsed data.""" try: ssh = paramiko.SSHClient() ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) ssh.connect( entry['hostname'], username=entry['username'], pkey=ssh_key, timeout=15, banner_timeout=15, auth_timeout=15, ) with open('/app/gather_info.sh') as f: script = f.read() stdin, stdout, stderr = ssh.exec_command('bash -s', timeout=60) stdin.write(script) stdin.channel.shutdown_write() output = stdout.read().decode('utf-8', errors='replace') ssh.close() data = parse_gather_output(output) data['is_online'] = True return data except Exception as e: logger.warning("Failed to collect from %s@%s: %s", entry['username'], entry['hostname'], e) return {'is_online': False, 'error': str(e)} def parse_gather_output(output): """Parse the [section] key=value output from gather_info.sh.""" data = {} current_section = None for line in output.split('\n'): line = line.strip() if not line: continue # Section header: [name] or [name:id] m = re.match(r'^\[(.+)\]$', line) if m: section = m.group(1) if section == 'end': break if ':' in section: base, name = section.split(':', 1) if base not in data: data[base] = [] item = {'_name': name} data[base].append(item) current_section = ('list', item) else: if section not in data: data[section] = {} current_section = ('dict', data[section]) continue # Key=value if '=' in line and current_section: key, _, value = line.partition('=') key = key.strip() value = value.strip() if current_section[0] == 'dict': section_data = current_section[1] # Handle repeated keys (e.g., dns server=) if key in section_data: if not isinstance(section_data[key], list): section_data[key] = [section_data[key]] section_data[key].append(value) else: section_data[key] = value elif current_section[0] == 'list': current_section[1][key] = value return data # --- Collection Loop --- def collect_all(): entries = parse_infrastructure_conf() if not entries: logger.info("No servers configured in infrastructure.conf") return try: ssh_key = load_ssh_key() except Exception as e: logger.error("SSH key error: %s", e) return logger.info("Collecting from %d servers (max %d concurrent)", len(entries), MAX_CONCURRENT_SSH) results = {} with ThreadPoolExecutor(max_workers=MAX_CONCURRENT_SSH) as pool: futures = {pool.submit(collect_one, e, ssh_key): e for e in entries} for future in as_completed(futures): entry = futures[future] key = f"{entry['username']}@{entry['hostname']}" try: results[key] = (entry, future.result(timeout=90)) except Exception as e: results[key] = (entry, {'is_online': False, 'error': str(e)}) # Update database (all in main collector thread) with app.app_context(): # Remove servers no longer in config config_keys = {(e['username'], e['hostname']) for e in entries} for server in Server.query.all(): if (server.username, server.hostname) not in config_keys: db.session.delete(server) for key, (entry, result) in results.items(): server = Server.query.filter_by( username=entry['username'], hostname=entry['hostname'], ).first() if not server: server = Server( group_name=entry['group'], username=entry['username'], hostname=entry['hostname'], ) db.session.add(server) server.group_name = entry['group'] server.url = entry.get('url', '') server.is_online = result.get('is_online', False) server.last_collected = datetime.now(timezone.utc) server.details = result # Extract primary IP: prefer the interface carrying the default route default_iface = '' routing = result.get('routing', {}) if isinstance(routing, dict): default_iface = routing.get('interface', '') primary_ip = '' for iface in result.get('net', []): ipv4 = iface.get('ipv4', '') if not ipv4 or ipv4.startswith('127.'): continue iface_name = iface.get('name', '') or iface.get('_name', '') if iface_name == default_iface: primary_ip = ipv4 break if not primary_ip: primary_ip = ipv4 server.primary_ip = primary_ip db.session.commit() logger.info("Collection complete, updated %d servers", len(results)) def collector_loop(): time.sleep(10) # Let the app start up while True: try: collect_all() except Exception as e: logger.error("Collection loop error: %s", e) time.sleep(COLLECTION_INTERVAL) # --- Web Routes --- @app.route('/') def index(): servers = Server.query.order_by(Server.group_name, Server.primary_ip).all() groups = {} for s in servers: g = s.group_name or 'Default' if g not in groups: groups[g] = [] groups[g].append(s) # Sort servers within each group by IP (numerically) for g in groups: groups[g].sort(key=lambda s: _ip_sort_key(s.primary_ip)) return render_template('index.html', groups=groups, build_date=BUILD_DATE) @app.route('/api/servers') def api_servers(): servers = Server.query.all() result = [] for s in servers: result.append({ 'id': s.id, 'group_name': s.group_name, 'username': s.username, 'hostname': s.hostname, 'primary_ip': s.primary_ip, 'url': s.url, 'is_online': s.is_online, 'last_collected': s.last_collected.isoformat() if s.last_collected else None, 'notes': s.notes, 'details': s.details, }) return jsonify(result) @app.route('/api/servers//notes', methods=['PUT']) def api_update_notes(server_id): from flask import request server = Server.query.get_or_404(server_id) data = request.get_json() server.notes = data.get('notes', '') db.session.commit() return jsonify({'ok': True}) def _ip_sort_key(ip_str): if not ip_str: return [999, 999, 999, 999] try: return [int(x) for x in ip_str.split('.')] except (ValueError, AttributeError): return [999, 999, 999, 999] # --- Jinja2 Helpers --- @app.template_filter('format_bytes') def format_bytes(value): try: b = int(value) except (TypeError, ValueError): return value for unit in ['B', 'KB', 'MB', 'GB', 'TB']: if abs(b) < 1024: return f"{b:.1f} {unit}" b /= 1024 return f"{b:.1f} PB" @app.template_filter('format_mb') def format_mb(value): try: mb = int(value) except (TypeError, ValueError): return value if mb >= 1024: return f"{mb / 1024:.1f} GB" return f"{mb} MB" @app.template_filter('format_uptime') def format_uptime(seconds): try: s = int(seconds) except (TypeError, ValueError): return 'Unknown' days = s // 86400 hours = (s % 86400) // 3600 if days > 0: return f"{days}d {hours}h" minutes = (s % 3600) // 60 return f"{hours}h {minutes}m" @app.template_filter('temp_color') def temp_color(temp_c): try: t = float(temp_c) except (TypeError, ValueError): return '#64748b' if t >= 90: return '#ef4444' if t >= 75: return '#f97316' if t >= 60: return '#eab308' return '#22c55e' @app.template_filter('clean_gpu') def clean_gpu(description): if not description: return '-' s = str(description) # Strip PCI address prefix (e.g. "01:00.0 ") import re s = re.sub(r'^[0-9a-f:.]+\s+', '', s, flags=re.IGNORECASE) # Strip type prefix for prefix in ['VGA compatible controller: ', '3D controller: ', 'Display controller: ']: if s.startswith(prefix): s = s[len(prefix):] # Strip common vendor prefixes for vendor in ['NVIDIA Corporation ', 'Advanced Micro Devices, Inc. ', 'AMD ', 'Intel Corporation ', 'Advanced Micro Devices Inc. ', 'Matrox Electronics Systems Ltd. ']: if s.startswith(vendor): s = s[len(vendor):] # Strip revision suffix s = re.sub(r'\s*\(rev [0-9a-f]+\)\s*$', '', s, flags=re.IGNORECASE) # Prefer bracketed name if present (e.g. "GA106 [GeForce RTX 3060]" -> "GeForce RTX 3060") bracket = re.search(r'\[(.+)\]', s) if bracket: s = bracket.group(1) return s.strip() @app.template_filter('usage_color') def usage_color(percent): try: p = float(percent) except (TypeError, ValueError): return '#64748b' if p >= 90: return '#ef4444' if p >= 75: return '#f97316' if p >= 60: return '#eab308' return '#22c55e' # --- Main --- def migrate_db(): """Add any missing columns to existing tables.""" import sqlite3 conn = sqlite3.connect(DB_PATH) cursor = conn.cursor() cursor.execute("PRAGMA table_info(servers)") existing = {row[1] for row in cursor.fetchall()} migrations = { 'url': "ALTER TABLE servers ADD COLUMN url VARCHAR(1024) DEFAULT ''", 'notes': "ALTER TABLE servers ADD COLUMN notes TEXT DEFAULT ''", } for col, sql in migrations.items(): if col not in existing: cursor.execute(sql) logger.info("Added column '%s' to servers table", col) conn.commit() conn.close() if __name__ == '__main__': os.makedirs(os.path.dirname(DB_PATH), exist_ok=True) with app.app_context(): db.create_all() migrate_db() logger.info("Database ready at %s", DB_PATH) # Start collector thread collector_thread = threading.Thread(target=collector_loop, daemon=True) collector_thread.start() app.run(host='0.0.0.0', port=5000, threaded=True)