Commit 1c2e9025 authored by Filip Hauzvic's avatar Filip Hauzvic
Browse files

RPC System

parent beec9644
Loading
Loading
Loading
Loading
+5 −0
Original line number Diff line number Diff line
@@ -36,12 +36,17 @@ namespace net
        void register_on_peer_connected(const std::function<void(ConnectionId)>& callback);
        void register_on_peer_disconnected(std::function<void(ConnectionId)> callback);

        SocketMode mode() const { return socket->get_mode(); }
        ConnectionStatus status() const { return socket->get_status(); }

    private:
        void initialize() override;

        void add_connection(ConnectionId connection);
        void remove_connection(ConnectionId connection);

        void handle_packet(ConnectionId from, NetworkMessage* message);

        std::unique_ptr<ISocket> socket;
        std::unordered_map<ConnectionId, Connection> connections;
    };
+5 −0
Original line number Diff line number Diff line
@@ -22,6 +22,11 @@ namespace net
        int channel;
    };

    enum class PacketType : uint8_t
    {
        RPC
    };

    enum class SocketMode : uint8_t
    {
        SERVER,
+85 −0
Original line number Diff line number Diff line
#pragma once
#include <string>
#include <vector>
#include <com/context.hpp>
#include <net/index.hpp>
#include <net/network_types.hpp>

#define RPC_RELIABLE(function, peer, ...) \
    net::RPCSystem::call_remote(this, #function, peer, net::ReliabilityMode::RELIABLE, __VA_ARGS__)
#define RPC_UNRELIABLE(function, peer, ...) \
    net::RPCSystem::call_remote(this, #function, peer, net::ReliabilityMode::UNRELIABLE, __VA_ARGS__)

namespace net
{
    struct RPCSystem
    {
        static void process_packet(ConnectionId from, void* deserializer);

        template<typename... Args>
        static void call_remote(
            com::ContextItem* context_item,
            const std::string& method_name,
            ConnectionId target,
            ReliabilityMode reliability,
            Args&&... args)
        {

            std::vector<com::Reflection::ValuePtr> params;
            (params.push_back(make_reflection_value(std::forward<Args>(args))), ...); // Convert each argument to a ValuePtr and add to params vector

            int written_bytes = 0;
            auto buffer = create_packet(context_item, method_name, params, &written_bytes);
            std::span<const uint8_t> buffer_span(buffer.data(), written_bytes);
            driver()->send(target, buffer_span, reliability);
        }

    private:
        static std::vector<uint8_t> create_packet(
            com::ContextItem* context_item,
            std::string const& function_name,
            const std::vector<std::shared_ptr<com::Reflection::Value>>& params,
            int* written_bytes = nullptr);

        static void call_function_local(
            std::vector<std::string> const& context_item_path,
            std::string const& function_name,
            void* deserializer,
            ConnectionId from_id);

        template<typename T>
        static com::Reflection::ValuePtr make_reflection_value(const T& val)
        {
            if constexpr (std::is_same_v<T, int>)
            {
                return com::make_value<com::Reflection::ValueINT>(val);
            } else if constexpr (std::is_same_v<T, float>)
            {
                return com::make_value<com::Reflection::ValueFLOAT>(val);
            } else if constexpr (std::is_same_v<T, bool>)
            {
                return com::make_value<com::Reflection::ValueBOOL>(val);
            } else if constexpr (std::is_same_v<T, std::string>)
            {
                return com::make_value<com::Reflection::ValueSTRING>(val);
            } else if constexpr (std::is_same_v<std::decay_t<T>, const char*> ||
                                 std::is_same_v<std::decay_t<T>, char*> ||
                                 std::is_array_v<std::remove_reference_t<T>>)
            {
                return com::make_value<com::Reflection::ValueSTRING>(std::string(val));
            } else if constexpr (std::is_same_v<T, vec2>)
            {
                return com::make_value<com::Reflection::ValueVEC2>(val);
            } else if constexpr (std::is_same_v<T, vec3>)
            {
                return com::make_value<com::Reflection::ValueVEC3>(val);
            } else if constexpr (std::is_same_v<T, vec2i>)
            {
                return com::make_value<com::Reflection::ValueVEC2I>(val);
            } else
            {
                throw std::runtime_error("Unsupported type for reflection value creation");
            }
        }
    };
}
+22 −2
Original line number Diff line number Diff line
#include <net/net_driver.hpp>
#include <net/socket/gns_socket.hpp>
#include <ranges>
#include <net/rpc_system.hpp>
#include <osi/index.hpp>
#include <utils/serialization.hpp>

using namespace std;
using namespace net;
@@ -25,7 +27,6 @@ void NetDriver::initialize()

bool NetDriver::start_server(int port, int max_connections)
{
    if (socket->get_status() != ConnectionStatus::DISCONNECTED)
    if (!socket->start_server(port, max_connections))
    {
        return false;
@@ -34,7 +35,6 @@ bool NetDriver::start_server(int port, int max_connections)
    socket->register_on_peer_connected([this](ConnectionId id)
    {
        LOG(LSL_DEBUG, "[NetDriver] New connection added with ID " << id);
        std::cout << "New connection added with ID " << id << std::endl;
        add_connection(id);
    });
    socket->register_on_peer_disconnected([this](ConnectionId id)
@@ -126,8 +126,28 @@ void NetDriver::next_round()
    socket->update();
    while (auto msg = socket->receive())
    {
        handle_packet(msg->sender, msg);
    }
}

void NetDriver::handle_packet(ConnectionId from, NetworkMessage* message)
{
    auto deserializer = create_deserializer(message->data);

    PacketType packet_type;
    deserializer.value1b(packet_type);

    switch (packet_type)
    {
        case PacketType::RPC:
            RPCSystem::process_packet(from, &deserializer);
            break;

        default:
            LOG(LSL_WARNING, "[NetDriver] Received packet with unknown type from connection " << from);
            break;
    }
}


net/src/rpc_system.cpp

0 → 100644
+99 −0
Original line number Diff line number Diff line
#include <net/rpc_system.hpp>
#include <utils/serialization.hpp>

namespace net
{
    void RPCSystem::process_packet(ConnectionId from, void* deserializer)
    {
        auto des = static_cast<Deserializer*>(deserializer);

        std::vector<std::string> context_item_path;
        des->container(context_item_path, 32, [](auto& s, std::string& str) {
            s.text1b(str, 250);
        });

        std::string function_name;
        des->text1b(function_name, 250);

        std::vector<uint8_t> param_data;
        call_function_local(context_item_path, function_name, des, from);
    }

    std::vector<uint8_t> RPCSystem::create_packet(com::ContextItem* context_item, std::string const& function_name, const std::vector<std::shared_ptr<com::Reflection::Value>>& params, int* written_bytes)
    {
        Buffer buffer(1024);
        auto serializer = create_serializer(buffer);
        serializer.value1b(static_cast<uint8_t>(PacketType::RPC));

        const std::vector<std::string> context_item_path = context_item->path(root());
        serializer.container(context_item_path, 32, [](auto& s, std::string& str) {
            s.text1b(str, 250);
        });

        serializer.text1b(function_name, 250);

        if (!context_item->reflection()->functions().contains(function_name))
        {
            LOG(LSL_ERROR, "[RPCSystem] Function " << function_name << " not found in context item " << context_item->name());
            return {};
        }
        const auto function = context_item->reflection()->functions().at(function_name);

        if (function.param_types.size() != params.size())
        {
            LOG(LSL_ERROR, "[RPCSystem] Parameter count mismatch for function " << function_name << " in context item " << context_item->name());
            return {};
        }

        for (int i = 0; i < params.size(); i++)
        {
            if (params[i]->typeID() != function.param_types[i])
            {
                LOG(LSL_ERROR, "[RPCSystem] Parameter type mismatch for parameter " << i << " of function " << function_name << " in context item " << context_item->name());
                return {};
            }
            serialize_reflection_value(serializer, params[i]);
        }

        if (written_bytes)
        {
            *written_bytes = static_cast<int>(serializer.adapter().writtenBytesCount());
        }

        return buffer;
    }

    void RPCSystem::call_function_local(std::vector<std::string> const& context_item_path, std::string const& function_name, void* deserializer, ConnectionId from_id)
    {
        auto des = static_cast<Deserializer*>(deserializer);

        const auto context_item = root()->locate<com::ContextItem>(context_item_path);
        const auto function = context_item->reflection()->functions().at(function_name);

        std::vector<com::Reflection::ValuePtr> params;
        params.push_back(com::make_value<com::Reflection::ValueINT>(from_id)); // First parameter is the sender's connection ID

        for (int i = 0; i < function.param_types.size(); i++)
        {
            params.push_back(deserialize_reflection_value(*des));
        }

        if (function.param_types.size() != params.size() - 1)
        {
            LOG(LSL_ERROR, "[RPCSystem] Parameter count mismatch when calling function " << function_name << " in context item " << context_item->name());
            return;
        }

        for (int i = 0; i < function.param_types.size(); i++)
        {
            if (params[i + 1]->typeID() != function.param_types[i])
            {
                LOG(LSL_ERROR, "[RPCSystem] Parameter type mismatch for parameter " << i << " when calling function " << function_name << " in context item " << context_item->name());
                return;
            }
        }

        function.code(params);
    }

}
 No newline at end of file