
#include "pch.h"

#include <boost/bind.hpp>
#include <boost/tokenizer.hpp>
#include <boost/regex.hpp>

#ifdef WIN32
#include <ws2tcpip.h>
#else
#include <sys/socket.h>
#include <sys/types.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <fcntl.h>
#include <netdb.h>
#endif

#include "usb_server.h"
#include "usb_criteria.h"
#include "usb_debug.h"
#include "debug.h"

#ifndef SEM_FAILED
#define SEM_FAILED NULL
#endif

#ifdef WIN32

#include "../Private/oem_private_data.h"
#include "../Private/oem_private_key.h"
#include "../Private/private_api.h"

static void generate_random_buffer(unsigned char *buffer, unsigned int length)
{
    srand(::GetTickCount());

    for (unsigned int i = 0; i < length; ++i) {
        buffer[i] = rand() % 255;
    }
}

#else

// Return the semaphore file name used to communicate with the usbrdr daemon.
static const char* get_sem_name()
{
	static char sem_name[64];
	sprintf(sem_name, "/usbrdr-%d", getpid());
	return sem_name;
}

#endif // WIN32

static bool do_usb_authorization()
{
    bool result = false;
#ifdef WIN32
    unsigned char data_in[total_buf_len];
    unsigned char data_out[total_buf_len];

    generate_random_buffer(data_in, sizeof(data_in));
    generate_random_buffer(data_out, sizeof(data_out));

    if (UsbRedirectorServiceCall(GetEstablishData, data_in, total_buf_len)) {

        USBRDRPRIVATE *usb_private = reinterpret_cast<USBRDRPRIVATE *>(data_in);

        data_out[usb_private->mul_offset] = data_in[usb_private->mul_offset] << 1;
        data_out[usb_private->div_offset] = data_in[usb_private->div_offset] >> 1;

        for (int i = 0; i < oem_key_len; ++i) {
            data_out[usb_private->key_offset + (i * usb_private->step)] = 
                data_in[usb_private->key_offset + (i * usb_private->step)] ^ oem_key[i];
        }

        if (UsbRedirectorServiceCall(EstablishConnection, data_out, total_buf_len)) {
            result = true;
        } else {
            log_usb_error("USB service call EstablishConnection failed");
        }
    } else {
        log_usb_error("USB service call GetEstablishData failed");
    }
#else
    // No authorization is required on the Linux version for now.
    result = true;
#endif // WIN32
    return result;
}

UsbServer::UsbServer(const UsbServer::handler& handler_func) :
    _handler_func(handler_func),
    _usb_notify_thread(NULL),
    _connection(0),
    _client(0),
    _usb_event(SEM_FAILED),
    _client_ip(0),
    _auto_sharing(false),
    _stop_notify_thread(false)
{
//    LOGGER_SECTION("spice.usbserver");
}

UsbServer::~UsbServer()
{
    if (_usb_notify_thread != NULL) {
        stop();
    }
}

bool UsbServer::start(const std::string& client_addr, uint16_t client_port /*=32023*/)
{
    ASSERT(_usb_event == SEM_FAILED);

    uint32_t client_ip = inet_addr(client_addr.c_str());

    if (client_ip == INADDR_NONE)
    {
        struct addrinfo *result = NULL;

        if ((getaddrinfo(client_addr.c_str(), NULL, NULL, &result) == 0) && (result != NULL))
        {
            struct sockaddr_in *addr = (sockaddr_in *)(result->ai_addr);
            client_ip = addr->sin_addr.s_addr;

            freeaddrinfo(result);
        }
        else
        {
            LOG_ERROR("Failed to resolve client address: " << client_addr.c_str());
        }
    }

    if (client_ip == INADDR_NONE)
    {
        LOG_ERROR("USB server can't be started without a client ip.");
        return false;
    }

    if (InitUsbRedirectorApi() && do_usb_authorization()) {
        if (set_notification_event()) {
            _usb_notify_thread = new boost::thread(boost::bind(&UsbServer::usb_notify, this));
        }

        if (_usb_notify_thread != NULL) {
#ifdef WIN32
            unsigned long ulConnectionCount = 20;
            UsbRedirectorServiceCall(SetNumberOfAllowedDeviceConnectionsOnServer,
                &ulConnectionCount, sizeof(ulConnectionCount));
#endif

            // allow all classes, vendors, products and revisions.
            set_filter("-1,-1,-1,-1,1");

            _client_ip = client_ip;

            in_addr addr;
            addr.s_addr = _client_ip;

            // Re-using an existing callback connection caused the function
            // EnumConnectedClients not to work (might be related to a guest
            // that didn't do the USB authorization). Better close and create
            // a new connection.
            close_existing_callback_connection(inet_ntoa(addr), client_port);

            LOG_DEBUG("Creating a callback connection " << inet_ntoa(addr) <<
                ":" << client_port);

            if (!CreateCallBackConnectionToUsbClient(inet_ntoa(addr), client_port, &_connection)) {
                log_usb_error("CreateCallBackConnectionToUsbClient error");
            }

            if (_connection == 0)
            {
                LOG_DEBUG("No callback connection was established.");
                stop();
            }
            else
            {
                LOG_DEBUG("Callback connection is " << _connection);
                enum_local_usb_devices();
            }
        }

    } else {
        log_usb_error("USB Redirector init error");
    }

    return (_usb_notify_thread != NULL);
}

void UsbServer::stop()
{
    if (_usb_notify_thread != NULL) {
        LOG_INFO("Shutting down USB server.");

        unsigned long connection = _connection;

        // Assume that the client is no longer connected.
        _client = 0;
        _connection = 0;

        stop_waiting();

        _usb_notify_thread->join();
        delete _usb_notify_thread;
        _usb_notify_thread = NULL;

        if (connection) {
            CloseCallBackConnectionToUsbClient(connection);
        }
        
        DeInitUsbRedirectorApi();

        unset_notification_event();
        clear_criterias();

        _client_ip = 0;
    }
}

void UsbServer::set_filter(const std::string& filter)
{
    std::string filter_str = filter;

    clear_criterias();

    typedef boost::char_separator<char> separator;
    typedef boost::tokenizer<separator, std::string::const_iterator, std::string> tokenizer;
    
    separator sep("|");
    tokenizer criterias(filter_str, sep);

    for (tokenizer::iterator iter = criterias.begin(); iter != criterias.end(); ++iter) {
        try {
            _criterias.push_back(new UsbCriteria(iter->c_str()));
        } catch (...) {
            // ignore bad criterias.
        }
    }

    LOG_DEBUG(filter);
}

void UsbServer::usb_notify()
{
    while (!_stop_notify_thread) {

        wait_for_notification();

        if ((_connection != 0) && (_client == 0)) {
            find_connected_client();
        }

        enum_local_usb_devices();

        _handler_func();
    }
}

void UsbServer::enum_local_usb_devices()
{
    boost::mutex::scoped_lock lock(_protect_devices);
    uint32 size;
    uint32 count = 0;

    if (EnumLocalUSBDevices(NULL, &size)) {
        count = size / sizeof(DEVICE_DESCRIPTOR);
    }

    _devices.resize(count);

    if (count > 0) {
        EnumLocalUSBDevices(&_devices[0], &size);
    }

    trace_usb_devices();

    if (_client != 0) {
        auto_share_devices();
        connect_shared_devices();
        update_connected_devices();
    }

    stop_sharing_unplugged();
    remove_ignored_unplugged();
}

void UsbServer::connect_shared_devices()
{
    ASSERT(_client != 0);

    std::list<unsigned long>::iterator iter = _share_devices.begin();
    while (iter != _share_devices.end()) {
        std::vector<DEVICE_DESCRIPTOR>::const_iterator device = find_device(*iter);
        if ((device != _devices.end()) && (device->ulDeviceStatus == UsbDeviceAvailable)) {
            connect_shared_device(*iter);
            iter = _share_devices.erase(iter);
        }
        else {
            ++iter;
        }
    }
}

void UsbServer::connect_shared_device(unsigned long hdev)
{
    if (!ConnectUSBDeviceToClient(hdev, _client)) {
        log_usb_error("ConnectUSBDeviceToClient");
        stop_sharing_device(hdev);
    } else {
        LOG_DEBUG(fmt("Device 0x%08p is about to be connected to client.") % hdev);
    }

}

void UsbServer::stop_sharing_unplugged()
{
    std::vector<DEVICE_DESCRIPTOR>::const_iterator iter;

    for (iter = _devices.begin(); iter != _devices.end(); ++iter) {
        const DEVICE_DESCRIPTOR& desc = *iter;
        // Stop sharing if device was physically removed from computer or if
        // "safely remove hardware" was used on the guest side.
        if ((desc.ulDeviceStatus == UsbDeviceNotPlugged) ||
            (is_disconnected_from_remote(desc) == true))
        {
            remove_connected_device(desc.hDevice);
            update_ignored_devices(desc);
            if (StopSharingUSBDevice(desc.hDevice)) {
                LOG_DEBUG(fmt("Stopped sharing unplugged device 0x%08p.") % desc.hDevice);
            } else {
                log_usb_error("StopSharingUSBDevice (unplugged)");
            }
        }
    }
}

int UsbServer::find_criteria_action(const DEVICE_DESCRIPTOR& desc)
{
    int action = UsbCriteria::BLOCK;

    std::vector<UsbCriteria*>::iterator iter;
    for (iter = _criterias.begin(); iter != _criterias.end(); ++iter) {
        UsbCriteria *criteria = *iter;
        ASSERT(criteria != NULL);
        action = criteria->match(desc);
        if (action != UsbCriteria::MISMATCH) {
            break;
        }
    }

    return action;
}

void UsbServer::ignore_existing_devices()
{
    boost::mutex::scoped_lock lock(_protect_devices);
    std::vector<DEVICE_DESCRIPTOR>::const_iterator iter;

    for (iter = _devices.begin(); iter != _devices.end(); ++iter) {
        const DEVICE_DESCRIPTOR& desc = *iter;
        update_ignored_devices(desc);
    }
}

void UsbServer::update_ignored_devices(const DEVICE_DESCRIPTOR& desc)
{
    if (!_auto_sharing) {
        return;
    }

    if (desc.ulDeviceStatus == UsbDeviceNotPlugged) {
        std::list<unsigned long>::iterator iter = 
            find(_ignored_devices.begin(), _ignored_devices.end(), desc.hDevice);
        if (iter != _ignored_devices.end()) {
            _ignored_devices.erase(iter);
            LOG_TRACE(fmt("Device 0x%08p was removed from the ignored devices list.") % desc.hDevice);
        }
    } else {
        _ignored_devices.push_back(desc.hDevice);
        LOG_TRACE(fmt("Device 0x%08p was added to the ignored devices list.") % desc.hDevice);
    }
}

void UsbServer::auto_share_devices()
{
    if (!_auto_sharing) {
        return;
    }

    std::vector<DEVICE_DESCRIPTOR>::const_iterator iter;

    for (iter = _devices.begin(); iter != _devices.end(); ++iter) {
        const DEVICE_DESCRIPTOR& desc = *iter;
        if (!is_shared_device(desc) && is_share_allowed(desc)) {
            std::list<unsigned long>::const_iterator iter = 
                find(_ignored_devices.begin(), _ignored_devices.end(), desc.hDevice);
            if (iter == _ignored_devices.end()) {
                share_device(desc);
            } else {
                LOG_TRACE(fmt("Device 0x%08p was found in the ignored devices list.") % desc.hDevice);
            }
        }
    }
}

void UsbServer::trace_usb_devices()
{
    std::vector<DEVICE_DESCRIPTOR>::const_iterator iter;

    for (iter = _devices.begin(); iter != _devices.end(); ++iter) {
        const DEVICE_DESCRIPTOR& desc = *iter;
        usb_device_descriptor(desc);
    }
}

void UsbServer::close_existing_callback_connection(const char *host, unsigned short port)
{
    std::vector<SERVER_DESCRIPTOR> clients;
    uint32 size;
    uint32 count = 0;

    if (EnumCallBackConnectionsToUsbClients(NULL, &size)) {
        count = size / sizeof(SERVER_DESCRIPTOR);
    } else {
        log_usb_error("EnumCallBackConnectionsToUsbClients");
    }
    
    if (count > 0) {
        LOG_TRACE(count << " callback client(s) found.");

        clients.resize(count);
        EnumCallBackConnectionsToUsbClients(&clients[0], &size);
        
        std::vector<SERVER_DESCRIPTOR>::iterator iter;
        for (iter = clients.begin(); iter != clients.end(); ++iter) {
            SERVER_DESCRIPTOR& desc = *iter;
            LOG_TRACE("Callback client " << desc.cHostName << ":" << desc.tcpport);
            if ((strcmp(host, desc.cHostName) == 0) && (port == desc.tcpport)) {
                LOG_TRACE("Closing existing callback client " << desc.cHostName << ":" << desc.tcpport);
                CloseCallBackConnectionToUsbClient(desc.hServer);
                break;
            }
        }
    }
}

void UsbServer::find_connected_client()
{
    std::vector<SERVER_DESCRIPTOR> clients;
    uint32 size;
    uint32 count = 0;

    if (EnumConnectedClients(NULL, &size)) {
        count = size / sizeof(SERVER_DESCRIPTOR);
    }

    if (count > 0) {
        clients.resize(count);
        EnumConnectedClients(&clients[0], &size);
        find_connected_client(clients);
    } else {
        LOG_TRACE("No connected USB clients were found.");
    }
}

void UsbServer::find_connected_client(const std::vector<SERVER_DESCRIPTOR>& clients)
{
    std::vector<SERVER_DESCRIPTOR>::const_iterator iter;

    for (iter = clients.begin(); iter != clients.end(); ++iter) {
        const SERVER_DESCRIPTOR& client = *iter;
        if (_client_ip == extract_host_ip(client.cHostName)) {
            LOG_INFO("Found our USB client (" << client.cHostName << ")");
            _client = client.hServer;
            break;
        } else {
            LOG_TRACE(client.cHostName << " is not our USB client.");
        }
    }
}

uint32 UsbServer::extract_host_ip(const char* hostname)
{
    // It's enough in order to  find something which looks like an IP address
    // and let inet_addr() do the hard work.
    static const boost::regex ip_re("((\\d+\\.){3}\\d+)");

    uint32 host_ip = INADDR_NONE;
    
    boost::cmatch ip_str;
    if (regex_search(hostname, ip_str, ip_re)) {
        std::string ip(ip_str[0].first, ip_str[0].second);
        host_ip = inet_addr(ip.c_str());
    }

    return host_ip;
}

void UsbServer::clear_criterias()
{
    std::vector<UsbCriteria*>::iterator iter;

    for (iter = _criterias.begin(); iter != _criterias.end(); ++iter) {
        UsbCriteria *criteria = *iter;
        delete criteria;
    }

    _criterias.clear();
}

bool UsbServer::get_devices_desc(devices_desc& devices)
{
    boost::mutex::scoped_lock lock(_protect_devices);

    if (_devices.size() == 0) {
        return false;
    }

    devices.reserve(_devices.size());

    std::vector<DEVICE_DESCRIPTOR>::const_iterator iter;
    for (iter = _devices.begin(); iter != _devices.end(); ++iter) {
        const DEVICE_DESCRIPTOR& desc = *iter;
        std::string str;
        
        // cLocationInformation and cDeviceDescription might be empty strings.
        // IncentivesPro claims that it is probably a problem with the Windows
        // USB hub device driver.

        if (strlen(desc.cLocationInformation) > 0) {
            str = desc.cLocationInformation;
        } else {
            str =  "USB Device";
        }

        if (strlen(desc.cDeviceDescription) > 0) {
            str += " - ";
            str += desc.cDeviceDescription;
        }

        devices.push_back(UsbServer::device_desc(desc.hDevice, str));
    }

    return true;
}

bool UsbServer::is_running()
{
    return (_usb_notify_thread != NULL);
}

bool UsbServer::is_shared_device(unsigned long hdev)
{
    boost::mutex::scoped_lock lock(_protect_devices);
    bool shared = false;

    std::vector<DEVICE_DESCRIPTOR>::const_iterator iter = find_device(hdev);
    if (iter != _devices.end()) {
        shared = is_shared_device(*iter);
    }

    return shared;
}

bool UsbServer::is_share_allowed(unsigned long hdev)
{
    boost::mutex::scoped_lock lock(_protect_devices);
    bool allow = false;

    std::vector<DEVICE_DESCRIPTOR>::const_iterator iter = find_device(hdev);
    if (iter != _devices.end()) {
        allow = is_share_allowed(*iter);
    }

    return allow;
}

void UsbServer::share_device(unsigned long hdev)
{
    ASSERT(_client != 0);
    boost::mutex::scoped_lock lock(_protect_devices);

    std::vector<DEVICE_DESCRIPTOR>::const_iterator iter = find_device(hdev);
    if (iter != _devices.end()) {
        if (is_share_allowed(*iter)) {
            share_device(*iter);
        }
    }
}

void UsbServer::stop_sharing_device(unsigned long hdev)
{
    boost::mutex::scoped_lock lock(_protect_devices);

    std::vector<DEVICE_DESCRIPTOR>::const_iterator iter = find_device(hdev);
    if (iter != _devices.end()) {
        if (StopSharingUSBDevice(iter->hDevice)) {
            LOG_DEBUG(fmt("Stopped sharing device 0x%08p.") % iter->hDevice);
        } else {
            log_usb_error("StopSharingUSBDevice");
        }
        remove_connected_device(iter->hDevice);
    }
}

void UsbServer::auto_sharing(bool enable)
{
    _auto_sharing = enable;

    if (_auto_sharing == true) {
        ignore_existing_devices();
    } else {
        _ignored_devices.clear();
        LOG_TRACE("Ignored devices list was cleared.");
    }
}

bool UsbServer::is_shared_device(const DEVICE_DESCRIPTOR& desc)
{
    return ((desc.ulDeviceProperties & SharedDevice) == SharedDevice);
}

bool UsbServer::is_connected_device(const DEVICE_DESCRIPTOR& desc)
{
    return ((desc.ulIpClient == _client_ip) && (desc.ulDeviceStatus == UsbDeviceBusy));
}

bool UsbServer::is_share_allowed(const DEVICE_DESCRIPTOR& desc)
{
    int action = UsbCriteria::BLOCK;

    if ((_client != 0) &&
        ((desc.ulDeviceStatus == UsbDevice) || is_connected_device(desc))) {
            action = find_criteria_action(desc);
    }

    LOG_TRACE(fmt("Device 0x%08p Action: %s") % desc.hDevice % (
        (action == UsbCriteria::BLOCK) ? "Block" :
        (action == UsbCriteria::ALLOW) ? "Allow" : "Mismatch"));

    return (action == UsbCriteria::ALLOW);
}

void UsbServer::share_device(const DEVICE_DESCRIPTOR& desc)
{
    if (ShareUSBDevice(desc.hDevice)) {
        LOG_DEBUG(fmt("Adding device 0x%08p to the shared devices list.") % desc.hDevice);
        _share_devices.push_back(desc.hDevice);
    } else {
        log_usb_error("ShareUSBDevice");
    }
}

// Erase from the "ignored list" the devices which were physically removed.
void UsbServer::remove_ignored_unplugged()
{
    std::list<unsigned long>::iterator iter = _ignored_devices.begin();
    
    while (iter != _ignored_devices.end()) {
        std::vector<DEVICE_DESCRIPTOR>::const_iterator jter;

        for (jter = _devices.begin(); jter != _devices.end(); ++jter) {
            const DEVICE_DESCRIPTOR& desc = *jter;
            if (*iter == desc.hDevice) {
                break;
            }
        }

        if (jter == _devices.end()) {
            LOG_TRACE(fmt("Device 0x%08p was physically removed (removing from ignored devices list).") % *iter);
            iter = _ignored_devices.erase(iter);
        } else {
            ++iter;
        }
    }
}

bool UsbServer::set_notification_event()
{
    ASSERT(_usb_event == SEM_FAILED);

#ifdef WIN32
    // Create auto-reset notification event.
    _usb_event = ::CreateEvent(NULL, FALSE, FALSE, NULL);
#else
    _usb_event = sem_open(get_sem_name(), O_CREAT, S_IROTH|S_IWOTH, 0);
#endif

    if (_usb_event != SEM_FAILED) {
        // Pass event handle to USB Redirector DLL.
        SetNotificationEvent(_usb_event);
    }

    return (_usb_event != SEM_FAILED);
}

void UsbServer::unset_notification_event()
{
    if (_usb_event != SEM_FAILED) {
#ifdef WIN32
        ::CloseHandle(_usb_event);
#else
        sem_close(_usb_event);
		sem_unlink(get_sem_name());
#endif
        _usb_event = SEM_FAILED;
    }
}

void UsbServer::raise_notification_event()
{
    if (_usb_event != SEM_FAILED) {
#ifdef WIN32
    ::SetEvent(_usb_event);
#else
    sem_post(_usb_event);
#endif
    }
}

void UsbServer::wait_for_notification()
{
    ASSERT(_usb_event != SEM_FAILED);

#ifdef WIN32
    ::WaitForSingleObject(_usb_event, INFINITE);
#else
    sem_wait(_usb_event);
#endif
}

void UsbServer::stop_waiting()
{
    _stop_notify_thread = true;

    raise_notification_event();
}

std::vector<DEVICE_DESCRIPTOR>::const_iterator UsbServer::find_device(unsigned long hdev)
{
    std::vector<DEVICE_DESCRIPTOR>::const_iterator iter;

    for (iter = _devices.begin(); iter != _devices.end(); ++iter) {
        if (hdev == iter->hDevice) {
            break;
        }
    }

    return iter;
}

void UsbServer::remove_connected_device(unsigned long hdev)
{
    std::list<unsigned long>::iterator iter;

    for (iter = _connected_devices.begin(); iter != _connected_devices.end(); ++iter) {
        if (hdev == *iter) {
            _connected_devices.erase(iter);
            LOG_DEBUG(fmt("Device 0x%08p was removed from connected devices list.") % hdev);
            break;
        }
    }
}

bool UsbServer::is_disconnected_from_remote(const DEVICE_DESCRIPTOR& desc)
{
    std::list<unsigned long>::iterator iter;
    bool disconnected = false;    

    for (iter = _connected_devices.begin(); iter != _connected_devices.end(); ++iter) {
        if ((desc.hDevice == *iter) &&
            ((desc.ulDeviceStatus == UsbDeviceAvailable) || (_client == 0)))
        {
            disconnected = true;
            LOG_DEBUG(fmt("Device 0x%08p was disconnected from remote side.") % desc.hDevice);
            break;
        }
    }

    return disconnected;
}

void UsbServer::update_connected_devices()
{
    std::vector<DEVICE_DESCRIPTOR>::const_iterator iter;

    for (iter = _devices.begin(); iter != _devices.end(); ++iter) {
        const DEVICE_DESCRIPTOR& desc = *iter;
        if (is_connected_device(desc) &&
            (std::find(_connected_devices.begin(), _connected_devices.end(), desc.hDevice) == _connected_devices.end())) {
            LOG_DEBUG(fmt("Device 0x%08p is now connected to client.") % desc.hDevice);
            _connected_devices.push_back(desc.hDevice);
        }
    }
}
