#include "runner.h"
#include <cstdlib>
#include <sstream>
#include <iostream>
#include <unistd.h>
#include <sys/wait.h>
#include <fcntl.h>
#include <libssh/libssh.h>
#include <libssh/callbacks.h>
#include <termios.h>
#include <sys/select.h>

namespace runner {

namespace {

// Add working_dir to the forward declaration
ssh_session ssh_connect_and_auth(const sSSHInfo* sshinfo, const std::map<std::string, std::string>& env, std::string* error);
std::string ssh_build_remote_command(const std::string& command, const std::vector<std::string>& args, const std::string& working_dir, const std::map<std::string, std::string>& env);
std::string escape_shell_arg(const std::string& arg);
int ssh_interactive_shell_session(ssh_session session, ssh_channel channel, const std::string& remote_cmd_str, const std::string& command, std::string* output);
int ssh_exec_command(ssh_session session, ssh_channel channel, const std::string& remote_cmd_str, bool silent, std::string* output, const std::map<std::string, std::string>& env, const std::string& working_dir);
int local_execute_cmd(const std::string& command, const std::vector<std::string>& args, const std::string& working_dir, const std::map<std::string, std::string>& env, bool silent, bool interactive, std::string* output);

ssh_session ssh_connect_and_auth(const sSSHInfo* sshinfo, const std::map<std::string, std::string>& env, std::string* error) {
    ssh_session session = ssh_new();
    if (!session) {
        if (error) *error = "Failed to create SSH session.";
        return nullptr;
    }
    ssh_options_set(session, SSH_OPTIONS_HOST, sshinfo->host.c_str());
    if (!sshinfo->port.empty()) {
        int port = std::stoi(sshinfo->port);
        ssh_options_set(session, SSH_OPTIONS_PORT, &port);
    }
    if (!sshinfo->user.empty()) {
        ssh_options_set(session, SSH_OPTIONS_USER, sshinfo->user.c_str());
    }
    int rc = ssh_connect(session);
    if (rc != SSH_OK) {
        if (error) *error = std::string("SSH connection failed: ") + ssh_get_error(session);
        ssh_free(session);
        return nullptr;
    }
    rc = ssh_userauth_publickey_auto(session, nullptr, nullptr);
    if (rc != SSH_AUTH_SUCCESS) {
        auto it = env.find("SSHPASS");
        if (it != env.end()) {
            rc = ssh_userauth_password(session, nullptr, it->second.c_str());
        }
    }
    if (rc != SSH_AUTH_SUCCESS) {
        if (error) *error = std::string("SSH authentication failed: ") + ssh_get_error(session);
        ssh_disconnect(session);
        ssh_free(session);
        return nullptr;
    }
    return session;
}

std::string ssh_build_remote_command(const std::string& command, const std::vector<std::string>& args, const std::string& working_dir, const std::map<std::string, std::string>& env) {
    std::ostringstream remote_cmd;
    for (const auto& kv : env) {
        if (kv.first == "SSHPASS") continue;
        remote_cmd << kv.first << "='" << kv.second << "' ";
    }
    if (!working_dir.empty()) {
        remote_cmd << "cd '" << working_dir << "' && ";
    }
    remote_cmd << command;
    for (const auto& arg : args) {
        remote_cmd << " '" << arg << "'";
    }
    return remote_cmd.str();
}

// Utility function to escape special shell characters
std::string escape_shell_arg(const std::string& arg) {
    std::ostringstream escaped;
    escaped << '"';
    for (char c : arg) {
        if (c == '"' || c == '\\' || c == '$' || c == '`') {
            escaped << '\\';
        }
        escaped << c;
    }
    escaped << '"';
    return escaped.str();
}

// For non-interactive SSH, just build the command with args
std::string ssh_build_command_only(const std::string& command, const std::vector<std::string>& args) {
    std::ostringstream remote_cmd;
    remote_cmd << command;
    for (const auto& arg : args) {
        remote_cmd << " " << escape_shell_arg(arg);
    }
    return remote_cmd.str();
}

int ssh_interactive_shell_session(ssh_session session, ssh_channel channel, const std::string& remote_cmd_str, const std::string& command, std::string* output) {
    int rc = ssh_channel_request_pty(channel);
    if (rc != SSH_OK) {
        if (output) *output = std::string("Failed to request pty: ") + ssh_get_error(session);
        return -1;
    }
    rc = ssh_channel_request_shell(channel);
    if (rc != SSH_OK) {
        if (output) *output = std::string("Failed to request shell: ") + ssh_get_error(session);
        return -1;
    }
    struct termios orig_termios, raw_termios;
    tcgetattr(STDIN_FILENO, &orig_termios);
    raw_termios = orig_termios;
    cfmakeraw(&raw_termios);
    tcsetattr(STDIN_FILENO, TCSANOW, &raw_termios);
    if (!command.empty()) {
        ssh_channel_write(channel, remote_cmd_str.c_str(), remote_cmd_str.size());
        ssh_channel_write(channel, "\n", 1);
    }
    int maxfd = STDIN_FILENO > STDOUT_FILENO ? STDIN_FILENO : STDOUT_FILENO;
    maxfd = maxfd > ssh_get_fd(session) ? maxfd : ssh_get_fd(session);
    char buffer[4096];
    bool done = false;
    while (!done) {
        fd_set fds_read;
        FD_ZERO(&fds_read);
        FD_SET(STDIN_FILENO, &fds_read);
        FD_SET(ssh_get_fd(session), &fds_read);
        int ret = select(maxfd + 1, &fds_read, nullptr, nullptr, nullptr);
        if (ret < 0) break;
        if (FD_ISSET(STDIN_FILENO, &fds_read)) {
            ssize_t n = read(STDIN_FILENO, buffer, sizeof(buffer));
            if (n > 0) {
                ssh_channel_write(channel, buffer, n);
            } else {
                ssh_channel_send_eof(channel);
                done = true;
            }
        }
        if (FD_ISSET(ssh_get_fd(session), &fds_read)) {
            int n = ssh_channel_read(channel, buffer, sizeof(buffer), 0);
            if (n > 0) {
                write(STDOUT_FILENO, buffer, n);
            } else if (n == 0) {
                done = true;
            }
        }
        if (ssh_channel_is_closed(channel) || ssh_channel_is_eof(channel)) {
            done = true;
        }
    }
    tcsetattr(STDIN_FILENO, TCSANOW, &orig_termios);
    return 0;
}

int ssh_exec_command(ssh_session session, ssh_channel channel, const std::string& remote_cmd_str, bool silent, std::string* output, const std::map<std::string, std::string>& env, const std::string& working_dir) {
    // Build complete command with env, working_dir, and the command itself
    std::ostringstream cmd_with_env;
    
    // Create a simple, flat command that will work reliably
    // Format: env VAR=value bash -c 'cd /path && command args'
    
    // Start with env variables
    if (!env.empty()) {
        cmd_with_env << "env ";
        for (const auto& kv : env) {
            if (kv.first == "SSHPASS") continue;
            cmd_with_env << kv.first << "='" << kv.second << "' ";
        }
    }
    
    // Use a single bash -c with the entire command inside single quotes
    cmd_with_env << "bash -c '";
    
    // Add cd if working directory specified
    if (!working_dir.empty()) {
        cmd_with_env << "cd " << working_dir << " && ";
    }
    
    // Add the command, but replace any single quotes with '\''
    std::string escaped_cmd = remote_cmd_str;
    size_t pos = 0;
    while ((pos = escaped_cmd.find('\'', pos)) != std::string::npos) {
        escaped_cmd.replace(pos, 1, "'\\''");
        pos += 4; // Length of "'\\''"
    }
    cmd_with_env << escaped_cmd;
    
    // Close the single quote
    cmd_with_env << "'";
    
    std::string final_cmd = cmd_with_env.str();
    
    // Debug: Show the command being executed
    std::cerr << "SSH exec command: " << final_cmd << std::endl;
    
    int rc = ssh_channel_request_exec(channel, final_cmd.c_str());
    if (rc != SSH_OK) {
        std::string error = std::string("Failed to exec remote command: ") + ssh_get_error(session);
        std::cerr << "SSH exec error: " << error << std::endl;
        if (output) *output = error;
        return -1;
    }
    
    if (output) {
        std::ostringstream oss;
        char buffer[4096];
        int nbytes;
        
        // Read from stdout
        while ((nbytes = ssh_channel_read(channel, buffer, sizeof(buffer), 0)) > 0) {
            std::cerr << "Read " << nbytes << " bytes from stdout" << std::endl;
            oss.write(buffer, nbytes);
        }
        if (nbytes < 0) {
            std::cerr << "Error reading from stdout" << std::endl;
        }
        
        // Read from stderr
        while ((nbytes = ssh_channel_read(channel, buffer, sizeof(buffer), 1)) > 0) {
            std::cerr << "Read " << nbytes << " bytes from stderr" << std::endl;
            oss.write(buffer, nbytes);
        }
        if (nbytes < 0) {
            std::cerr << "Error reading from stderr" << std::endl;
        }
        
        *output = oss.str();
    } else if (!silent) {
        char buffer[4096];
        int nbytes;
        
        // Read from stdout
        while ((nbytes = ssh_channel_read(channel, buffer, sizeof(buffer), 0)) > 0) {
            std::cerr << "Read " << nbytes << " bytes from stdout (writing to fd 1)" << std::endl;
            write(1, buffer, nbytes);
        }
        if (nbytes < 0) {
            std::cerr << "Error reading from stdout" << std::endl;
        }
        
        // Read from stderr
        while ((nbytes = ssh_channel_read(channel, buffer, sizeof(buffer), 1)) > 0) {
            std::cerr << "Read " << nbytes << " bytes from stderr (writing to fd 2)" << std::endl;
            write(2, buffer, nbytes);
        }
        if (nbytes < 0) {
            std::cerr << "Error reading from stderr" << std::endl;
        }
    }
    
    return 0;
}

int local_execute_cmd(
    const std::string& command,
    const std::vector<std::string>& args,
    const std::string& working_dir,
    const std::map<std::string, std::string>& env,
    bool silent,
    bool interactive,
    std::string* output
) {
    int pipefd[2];
    bool use_pipe = output && !interactive;
    if (use_pipe && pipe(pipefd) == -1) {
        perror("pipe");
        return -1;
    }
    pid_t pid = fork();
    if (pid == -1) {
        perror("fork");
        return -1;
    }
    if (pid == 0) {
        if (!working_dir.empty()) {
            if (chdir(working_dir.c_str()) != 0) {
                perror("chdir");
                exit(-1);
            }
        }
        for (const auto& kv : env) {
            setenv(kv.first.c_str(), kv.second.c_str(), 1);
        }
        if (use_pipe) {
            close(pipefd[0]);
            dup2(pipefd[1], STDOUT_FILENO);
            dup2(pipefd[1], STDERR_FILENO);
            close(pipefd[1]);
        } else if (silent && !interactive) {
            int devnull = open("/dev/null", O_WRONLY);
            dup2(devnull, STDOUT_FILENO);
            dup2(devnull, STDERR_FILENO);
            close(devnull);
        }
        if (!interactive) {
            setsid();
        }
        std::vector<char*> argv;
        argv.push_back(const_cast<char*>(command.c_str()));
        for (const auto& arg : args) {
            argv.push_back(const_cast<char*>(arg.c_str()));
        }
        argv.push_back(nullptr);
        execvp(command.c_str(), argv.data());
        perror("execvp");
        exit(-1);
    } else {
        if (use_pipe) {
            close(pipefd[1]);
            std::ostringstream oss;
            char buf[4096];
            ssize_t n;
            while ((n = read(pipefd[0], buf, sizeof(buf))) > 0) {
                oss.write(buf, n);
            }
            close(pipefd[0]);
            *output = oss.str();
        }
        int status = 0;
        waitpid(pid, &status, 0);
        if (WIFEXITED(status)) {
            return WEXITSTATUS(status);
        } else {
            return -1;
        }
    }
}

} // anonymous namespace

int execute_cmd(
    const std::string& command,
    const std::vector<std::string>& args,
    const std::string& working_dir,
    const std::map<std::string, std::string>& env,
    bool silent,
    bool interactive,
    sSSHInfo* sshinfo,
    std::string* output
) {
    if (sshinfo) {
        std::string error;
        ssh_session session = ssh_connect_and_auth(sshinfo, env, &error);
        if (!session) {
            if (output) *output = error;
            return -1;
        }
        ssh_channel channel = ssh_channel_new(session);
        if (!channel) {
            if (output) *output = "Failed to create SSH channel.";
            ssh_disconnect(session);
            ssh_free(session);
            return -1;
        }
        int rc = ssh_channel_open_session(channel);
        if (rc != SSH_OK) {
            if (output) *output = std::string("Failed to open SSH channel: ") + ssh_get_error(session);
            ssh_channel_free(channel);
            ssh_disconnect(session);
            ssh_free(session);
            return -1;
        }
        
        int ret = 0;
        if (interactive) {
            std::string remote_cmd_str = ssh_build_remote_command(command, args, working_dir, {});
            ret = ssh_interactive_shell_session(session, channel, remote_cmd_str, command, output);
        } else {
            // For non-interactive, handle working directory in ssh_exec_command
            std::string remote_cmd_str = ssh_build_command_only(command, args);
            ret = ssh_exec_command(session, channel, remote_cmd_str, silent, output, env, working_dir);
        }
        
        ssh_channel_send_eof(channel);
        ssh_channel_close(channel);
        ssh_channel_free(channel);
        ssh_disconnect(session);
        ssh_free(session);
        return ret;
    } else {
        return local_execute_cmd(command, args, working_dir, env, silent, interactive, output);
    }
}

} // namespace runner