479 lines
15 KiB
Python
479 lines
15 KiB
Python
import os
|
|
import re
|
|
import time
|
|
import logging
|
|
import threading
|
|
from datetime import datetime, timezone
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
|
import paramiko
|
|
|
|
BUILD_DATE = '__BUILD_DATE__'
|
|
|
|
from flask import Flask, render_template, jsonify
|
|
from flask_sqlalchemy import SQLAlchemy
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
|
|
logger = logging.getLogger(__name__)
|
|
|
|
DB_PATH = '/app/data/infmap.db'
|
|
|
|
app = Flask(__name__)
|
|
app.config['SQLALCHEMY_DATABASE_URI'] = f'sqlite:///{DB_PATH}'
|
|
db = SQLAlchemy(app)
|
|
|
|
COLLECTION_INTERVAL = int(os.environ.get('COLLECTION_INTERVAL', 300))
|
|
MAX_CONCURRENT_SSH = int(os.environ.get('MAX_CONCURRENT_SSH', 5))
|
|
SSH_KEY_PATH = '/app/ssh_key'
|
|
INFRA_CONF_PATH = '/app/infrastructure.conf'
|
|
|
|
|
|
# --- Database Model ---
|
|
|
|
class Server(db.Model):
|
|
__tablename__ = 'servers'
|
|
id = db.Column(db.Integer, primary_key=True)
|
|
group_name = db.Column(db.String(255), nullable=False)
|
|
username = db.Column(db.String(255), nullable=False)
|
|
hostname = db.Column(db.String(255), nullable=False)
|
|
primary_ip = db.Column(db.String(45), default='')
|
|
url = db.Column(db.String(1024), default='')
|
|
is_online = db.Column(db.Boolean, default=False)
|
|
last_collected = db.Column(db.DateTime, nullable=True)
|
|
details = db.Column(db.JSON, nullable=True)
|
|
notes = db.Column(db.Text, default='')
|
|
__table_args__ = (db.UniqueConstraint('username', 'hostname', name='uq_user_host'),)
|
|
|
|
|
|
# --- Config Parsing ---
|
|
|
|
def parse_infrastructure_conf():
|
|
servers = []
|
|
current_group = None
|
|
try:
|
|
with open(INFRA_CONF_PATH) as f:
|
|
for line in f:
|
|
line = line.rstrip('\n')
|
|
if not line.strip() or line.strip().startswith('#'):
|
|
continue
|
|
if line[0] not in (' ', '\t'):
|
|
current_group = line.strip()
|
|
else:
|
|
parts = line.strip().split(None, 1)
|
|
entry = parts[0] if parts else ''
|
|
url = parts[1] if len(parts) > 1 else ''
|
|
if '@' in entry:
|
|
user, host = entry.split('@', 1)
|
|
else:
|
|
user, host = 'infmap', entry
|
|
if host:
|
|
servers.append({
|
|
'group': current_group or 'Default',
|
|
'username': user.strip(),
|
|
'hostname': host.strip(),
|
|
'url': url.strip(),
|
|
})
|
|
except FileNotFoundError:
|
|
logger.error("infrastructure.conf not found at %s", INFRA_CONF_PATH)
|
|
return servers
|
|
|
|
|
|
# --- SSH Collection ---
|
|
|
|
def load_ssh_key():
|
|
for key_class in [paramiko.Ed25519Key, paramiko.RSAKey, paramiko.ECDSAKey]:
|
|
try:
|
|
return key_class.from_private_key_file(SSH_KEY_PATH)
|
|
except Exception:
|
|
continue
|
|
raise RuntimeError(f"Could not load SSH key from {SSH_KEY_PATH}")
|
|
|
|
|
|
def collect_one(entry, ssh_key):
|
|
"""SSH into a single server, run the gather script, return parsed data."""
|
|
try:
|
|
ssh = paramiko.SSHClient()
|
|
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
|
|
ssh.connect(
|
|
entry['hostname'],
|
|
username=entry['username'],
|
|
pkey=ssh_key,
|
|
timeout=15,
|
|
banner_timeout=15,
|
|
auth_timeout=15,
|
|
)
|
|
|
|
with open('/app/gather_info.sh') as f:
|
|
script = f.read()
|
|
|
|
stdin, stdout, stderr = ssh.exec_command('bash -s', timeout=60)
|
|
stdin.write(script)
|
|
stdin.channel.shutdown_write()
|
|
output = stdout.read().decode('utf-8', errors='replace')
|
|
ssh.close()
|
|
|
|
data = parse_gather_output(output)
|
|
data['is_online'] = True
|
|
return data
|
|
except Exception as e:
|
|
logger.warning("Failed to collect from %s@%s: %s", entry['username'], entry['hostname'], e)
|
|
return {'is_online': False, 'error': str(e)}
|
|
|
|
|
|
def parse_gather_output(output):
|
|
"""Parse the [section] key=value output from gather_info.sh."""
|
|
data = {}
|
|
current_section = None
|
|
|
|
for line in output.split('\n'):
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
|
|
# Section header: [name] or [name:id]
|
|
m = re.match(r'^\[(.+)\]$', line)
|
|
if m:
|
|
section = m.group(1)
|
|
if section == 'end':
|
|
break
|
|
if ':' in section:
|
|
base, name = section.split(':', 1)
|
|
if base not in data:
|
|
data[base] = []
|
|
item = {'_name': name}
|
|
data[base].append(item)
|
|
current_section = ('list', item)
|
|
else:
|
|
if section not in data:
|
|
data[section] = {}
|
|
current_section = ('dict', data[section])
|
|
continue
|
|
|
|
# Key=value
|
|
if '=' in line and current_section:
|
|
key, _, value = line.partition('=')
|
|
key = key.strip()
|
|
value = value.strip()
|
|
|
|
if current_section[0] == 'dict':
|
|
section_data = current_section[1]
|
|
# Handle repeated keys (e.g., dns server=)
|
|
if key in section_data:
|
|
if not isinstance(section_data[key], list):
|
|
section_data[key] = [section_data[key]]
|
|
section_data[key].append(value)
|
|
else:
|
|
section_data[key] = value
|
|
elif current_section[0] == 'list':
|
|
current_section[1][key] = value
|
|
|
|
return data
|
|
|
|
|
|
# --- Collection Loop ---
|
|
|
|
def collect_all():
|
|
entries = parse_infrastructure_conf()
|
|
if not entries:
|
|
logger.info("No servers configured in infrastructure.conf")
|
|
return
|
|
|
|
try:
|
|
ssh_key = load_ssh_key()
|
|
except Exception as e:
|
|
logger.error("SSH key error: %s", e)
|
|
return
|
|
|
|
logger.info("Collecting from %d servers (max %d concurrent)", len(entries), MAX_CONCURRENT_SSH)
|
|
results = {}
|
|
|
|
with ThreadPoolExecutor(max_workers=MAX_CONCURRENT_SSH) as pool:
|
|
futures = {pool.submit(collect_one, e, ssh_key): e for e in entries}
|
|
for future in as_completed(futures):
|
|
entry = futures[future]
|
|
key = f"{entry['username']}@{entry['hostname']}"
|
|
try:
|
|
results[key] = (entry, future.result(timeout=90))
|
|
except Exception as e:
|
|
results[key] = (entry, {'is_online': False, 'error': str(e)})
|
|
|
|
# Update database (all in main collector thread)
|
|
with app.app_context():
|
|
# Remove servers no longer in config
|
|
config_keys = {(e['username'], e['hostname']) for e in entries}
|
|
for server in Server.query.all():
|
|
if (server.username, server.hostname) not in config_keys:
|
|
db.session.delete(server)
|
|
|
|
for key, (entry, result) in results.items():
|
|
server = Server.query.filter_by(
|
|
username=entry['username'],
|
|
hostname=entry['hostname'],
|
|
).first()
|
|
|
|
if not server:
|
|
server = Server(
|
|
group_name=entry['group'],
|
|
username=entry['username'],
|
|
hostname=entry['hostname'],
|
|
)
|
|
db.session.add(server)
|
|
|
|
server.group_name = entry['group']
|
|
server.url = entry.get('url', '')
|
|
server.is_online = result.get('is_online', False)
|
|
server.last_collected = datetime.now(timezone.utc)
|
|
server.details = result
|
|
|
|
# Extract primary IP: prefer the interface carrying the default route
|
|
default_iface = ''
|
|
routing = result.get('routing', {})
|
|
if isinstance(routing, dict):
|
|
default_iface = routing.get('interface', '')
|
|
|
|
primary_ip = ''
|
|
for iface in result.get('net', []):
|
|
ipv4 = iface.get('ipv4', '')
|
|
if not ipv4 or ipv4.startswith('127.'):
|
|
continue
|
|
iface_name = iface.get('name', '') or iface.get('_name', '')
|
|
if iface_name == default_iface:
|
|
primary_ip = ipv4
|
|
break
|
|
if not primary_ip:
|
|
primary_ip = ipv4
|
|
server.primary_ip = primary_ip
|
|
|
|
db.session.commit()
|
|
logger.info("Collection complete, updated %d servers", len(results))
|
|
|
|
|
|
def collector_loop():
|
|
time.sleep(2) # Brief pause to let Flask start
|
|
while True:
|
|
try:
|
|
collect_all()
|
|
except Exception as e:
|
|
logger.error("Collection loop error: %s", e)
|
|
time.sleep(COLLECTION_INTERVAL)
|
|
|
|
|
|
# --- Web Routes ---
|
|
|
|
@app.route('/')
|
|
def index():
|
|
servers = Server.query.order_by(Server.group_name, Server.primary_ip).all()
|
|
groups = {}
|
|
for s in servers:
|
|
g = s.group_name or 'Default'
|
|
if g not in groups:
|
|
groups[g] = []
|
|
groups[g].append(s)
|
|
|
|
# Sort servers within each group by IP (numerically)
|
|
for g in groups:
|
|
groups[g].sort(key=lambda s: _ip_sort_key(s.primary_ip))
|
|
|
|
return render_template('index.html', groups=groups, build_date=BUILD_DATE)
|
|
|
|
|
|
@app.route('/api/servers')
|
|
def api_servers():
|
|
servers = Server.query.all()
|
|
result = []
|
|
for s in servers:
|
|
result.append({
|
|
'id': s.id,
|
|
'group_name': s.group_name,
|
|
'username': s.username,
|
|
'hostname': s.hostname,
|
|
'primary_ip': s.primary_ip,
|
|
'url': s.url,
|
|
'is_online': s.is_online,
|
|
'last_collected': s.last_collected.isoformat() if s.last_collected else None,
|
|
'notes': s.notes,
|
|
'details': s.details,
|
|
})
|
|
return jsonify(result)
|
|
|
|
|
|
@app.route('/api/servers/<int:server_id>/notes', methods=['PUT'])
|
|
def api_update_notes(server_id):
|
|
from flask import request
|
|
server = Server.query.get_or_404(server_id)
|
|
data = request.get_json()
|
|
server.notes = data.get('notes', '')
|
|
db.session.commit()
|
|
return jsonify({'ok': True})
|
|
|
|
|
|
def _ip_sort_key(ip_str):
|
|
if not ip_str:
|
|
return [999, 999, 999, 999]
|
|
try:
|
|
return [int(x) for x in ip_str.split('.')]
|
|
except (ValueError, AttributeError):
|
|
return [999, 999, 999, 999]
|
|
|
|
|
|
# --- Jinja2 Helpers ---
|
|
|
|
@app.template_filter('format_bytes')
|
|
def format_bytes(value):
|
|
try:
|
|
b = int(value)
|
|
except (TypeError, ValueError):
|
|
return value
|
|
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
|
|
if abs(b) < 1024:
|
|
return f"{b:.1f} {unit}"
|
|
b /= 1024
|
|
return f"{b:.1f} PB"
|
|
|
|
|
|
@app.template_filter('format_mb')
|
|
def format_mb(value):
|
|
try:
|
|
mb = int(value)
|
|
except (TypeError, ValueError):
|
|
return value
|
|
if mb >= 1024:
|
|
return f"{mb / 1024:.1f} GB"
|
|
return f"{mb} MB"
|
|
|
|
|
|
@app.template_filter('format_uptime')
|
|
def format_uptime(seconds):
|
|
try:
|
|
s = int(seconds)
|
|
except (TypeError, ValueError):
|
|
return 'Unknown'
|
|
days = s // 86400
|
|
hours = (s % 86400) // 3600
|
|
if days > 0:
|
|
return f"{days}d {hours}h"
|
|
minutes = (s % 3600) // 60
|
|
return f"{hours}h {minutes}m"
|
|
|
|
|
|
@app.template_filter('temp_color')
|
|
def temp_color(temp_c):
|
|
try:
|
|
t = float(temp_c)
|
|
except (TypeError, ValueError):
|
|
return '#64748b'
|
|
if t >= 90:
|
|
return '#ef4444'
|
|
if t >= 75:
|
|
return '#f97316'
|
|
if t >= 60:
|
|
return '#eab308'
|
|
return '#22c55e'
|
|
|
|
|
|
@app.template_filter('clean_gpu')
|
|
def clean_gpu(description):
|
|
if not description:
|
|
return '-'
|
|
s = str(description)
|
|
import re
|
|
|
|
# Strip PCI address prefix (e.g. "01:00.0 ")
|
|
s = re.sub(r'^[0-9a-f:.]+\s+', '', s, flags=re.IGNORECASE)
|
|
# Strip type prefix
|
|
for prefix in ['VGA compatible controller: ', '3D controller: ', 'Display controller: ']:
|
|
if s.startswith(prefix):
|
|
s = s[len(prefix):]
|
|
# Strip revision suffix
|
|
s = re.sub(r'\s*\(rev [0-9a-f]+\)\s*$', '', s, flags=re.IGNORECASE)
|
|
|
|
# Detect manufacturer
|
|
manufacturer = ''
|
|
s_lower = s.lower()
|
|
if 'nvidia' in s_lower:
|
|
manufacturer = 'Nvidia'
|
|
elif 'advanced micro' in s_lower or 'amd' in s_lower or 'radeon' in s_lower:
|
|
manufacturer = 'AMD'
|
|
elif 'intel' in s_lower:
|
|
manufacturer = 'Intel'
|
|
elif 'matrox' in s_lower:
|
|
manufacturer = 'Matrox'
|
|
|
|
# Strip vendor name from string
|
|
for vendor in ['NVIDIA Corporation ', 'Advanced Micro Devices, Inc. ', 'AMD ',
|
|
'Intel Corporation ', 'Advanced Micro Devices Inc. ',
|
|
'Matrox Electronics Systems Ltd. ']:
|
|
if s.startswith(vendor):
|
|
s = s[len(vendor):]
|
|
|
|
# Prefer bracketed name if present (e.g. "GA106 [GeForce RTX 3060]" -> "GeForce RTX 3060")
|
|
bracket = re.search(r'\[(.+)\]', s)
|
|
if bracket:
|
|
s = bracket.group(1)
|
|
|
|
# Clean up model string
|
|
# Remove "Lite Hash Rate" and similar marketing suffixes
|
|
s = re.sub(r'\s+Lite Hash Rate', '', s)
|
|
# Remove slash alternatives (e.g. "RX 6800/6800 XT / 6900 XT" -> "RX 6800")
|
|
s = re.sub(r'\s*/[\w\s/]+$', '', s)
|
|
# Remove trailing whitespace
|
|
s = s.strip()
|
|
|
|
# Don't duplicate manufacturer if already in the model name
|
|
if manufacturer and s:
|
|
s_check = s.lower()
|
|
if manufacturer.lower() in s_check:
|
|
return s
|
|
return f"{manufacturer} {s}"
|
|
return s or '-'
|
|
|
|
|
|
@app.template_filter('usage_color')
|
|
def usage_color(percent):
|
|
try:
|
|
p = float(percent)
|
|
except (TypeError, ValueError):
|
|
return '#64748b'
|
|
if p >= 90:
|
|
return '#ef4444'
|
|
if p >= 75:
|
|
return '#f97316'
|
|
if p >= 60:
|
|
return '#eab308'
|
|
return '#22c55e'
|
|
|
|
|
|
# --- Main ---
|
|
|
|
def migrate_db():
|
|
"""Add any missing columns to existing tables."""
|
|
import sqlite3
|
|
conn = sqlite3.connect(DB_PATH)
|
|
cursor = conn.cursor()
|
|
cursor.execute("PRAGMA table_info(servers)")
|
|
existing = {row[1] for row in cursor.fetchall()}
|
|
migrations = {
|
|
'url': "ALTER TABLE servers ADD COLUMN url VARCHAR(1024) DEFAULT ''",
|
|
'notes': "ALTER TABLE servers ADD COLUMN notes TEXT DEFAULT ''",
|
|
}
|
|
for col, sql in migrations.items():
|
|
if col not in existing:
|
|
cursor.execute(sql)
|
|
logger.info("Added column '%s' to servers table", col)
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
|
|
with app.app_context():
|
|
db.create_all()
|
|
migrate_db()
|
|
logger.info("Database ready at %s", DB_PATH)
|
|
|
|
# Start collector thread
|
|
collector_thread = threading.Thread(target=collector_loop, daemon=True)
|
|
collector_thread.start()
|
|
|
|
app.run(host='0.0.0.0', port=5000, threaded=True)
|