Program Listing for File MessageBruteForceHost.h

Return to documentation for file (include/flamegpu/runtime/messaging/MessageBruteForce/MessageBruteForceHost.h)

#ifndef INCLUDE_FLAMEGPU_RUNTIME_MESSAGING_MESSAGEBRUTEFORCE_MESSAGEBRUTEFORCEHOST_H_
#define INCLUDE_FLAMEGPU_RUNTIME_MESSAGING_MESSAGEBRUTEFORCE_MESSAGEBRUTEFORCEHOST_H_

#include <typeindex>
#include <memory>
#include <unordered_map>
#include <string>
#include <vector>

#include "flamegpu/model/Variable.h"
#include "flamegpu/simulation/detail/CUDAErrorChecking.cuh"

#include "flamegpu/runtime/messaging/MessageNone/MessageNoneHost.h"
#include "flamegpu/runtime/messaging/MessageBruteForce.h"
#include "flamegpu/runtime/messaging/MessageSortingType.h"
#include "flamegpu/detail/type_decode.h"

namespace flamegpu {

class MessageBruteForce::CUDAModelHandler : public MessageSpecialisationHandler {
 public:
    explicit CUDAModelHandler(detail::CUDAMessage &a)
        : MessageSpecialisationHandler()
        , d_metadata(nullptr)
        , sim_message(a) { }

    ~CUDAModelHandler() { }
    void init(detail::CUDAScatter &scatter, unsigned int streamId, cudaStream_t stream) override;
    void buildIndex(detail::CUDAScatter &scatter, unsigned int streamId, cudaStream_t stream) override;
    void allocateMetaDataDevicePtr(cudaStream_t stream) override;
    void freeMetaDataDevicePtr() override;
    const void *getMetaDataDevicePtr() const override { return d_metadata; }

 private:
    MetaData hd_metadata;
    MetaData *d_metadata;
    detail::CUDAMessage &sim_message;
};

struct MessageBruteForce::Data {
    friend class ModelDescription;
    friend struct ModelData;

    virtual ~Data() = default;
    std::weak_ptr<const ModelData> model;
    VariableMap variables;
    std::string name;
    bool persistent;
    unsigned int optional_outputs;
    bool operator==(const Data& rhs) const;
    bool operator!=(const Data& rhs) const;
    Data(const Data &other) = delete;

    virtual std::unique_ptr<MessageSpecialisationHandler> getSpecialisationHander(detail::CUDAMessage &owner) const;

    virtual std::type_index getType() const;
    virtual flamegpu::MessageSortingType getSortingType() const;

 protected:
    virtual Data *clone(const std::shared_ptr<const ModelData> &newParent);
    Data(std::shared_ptr<const ModelData> model, const Data &other);
    Data(std::shared_ptr<const ModelData> model, const std::string &message_name);
};

class MessageBruteForce::CDescription {
    friend struct Data;
    friend class AgentFunctionDescription;
    // friend void AgentFunctionDescription::setMessageOutput(MessageBruteForce::Description&);
    // friend void AgentFunctionDescription::setMessageInput(MessageBruteForce::Description&);

 public:
    explicit CDescription(std::shared_ptr<Data> data);
    explicit CDescription(std::shared_ptr<const Data> data);
    CDescription(const CDescription& other_agent) = default;
    CDescription(CDescription&& other_agent) = default;
    CDescription& operator=(const CDescription& other_agent) = default;
    CDescription& operator=(CDescription&& other_agent) = default;
    bool operator==(const CDescription& rhs) const;
    bool operator!=(const CDescription& rhs) const;

    std::string getName() const;
    bool getPersistent() const;
    const std::type_index& getVariableType(const std::string& variable_name) const;
    size_t getVariableSize(const std::string& variable_name) const;
    size_type getVariableLength(const std::string& variable_name) const;
    size_type getVariablesCount() const;
    bool hasVariable(const std::string& variable_name) const;

#ifndef SWIG

 protected:
#endif

    void setPersistent(const bool persistent) {
        message->persistent = persistent;
    }
    template<typename T>
    void newVariable(const std::string& variable_name);
    template<typename T, flamegpu::size_type N>
    void newVariable(const std::string& variable_name);
#ifdef SWIG
    template<typename T>
    void newVariableArray(const std::string& variable_name, size_type length);
#endif

    std::shared_ptr<Data> message;
};

class MessageBruteForce::Description : public CDescription {
 public:
    explicit Description(std::shared_ptr<Data> data);
    Description(const Description& other_message) = default;
    Description(Description&& other_message) = default;
    Description& operator=(const Description& other_message) = default;
    Description& operator=(Description&& other_message) = default;

    using MessageBruteForce::CDescription::setPersistent;
    using MessageBruteForce::CDescription::newVariable;
#ifdef SWIG
    using MessageBruteForce::CDescription::newVariableArray;
#endif
};
template<typename T>
void MessageBruteForce::CDescription::newVariable(const std::string &variable_name) {
    newVariable<T, 1>(variable_name);
}
template<typename T, flamegpu::size_type N>
void MessageBruteForce::CDescription::newVariable(const std::string& variable_name) {
    if (!variable_name.empty() && variable_name[0] == '_') {
        THROW exception::ReservedName("Message variable names cannot begin with '_', this is reserved for internal usage, "
            "in MessageDescription::newVariable().");
    }
    // Array length 0 makes no sense
    static_assert(detail::type_decode<T>::len_t * N > 0, "A variable cannot have 0 elements.");
    if (message->variables.find(variable_name) == message->variables.end()) {
        message->variables.emplace(variable_name, Variable(std::array<typename detail::type_decode<T>::type_t, detail::type_decode<T>::len_t * N>{}));
        return;
    }
    THROW exception::InvalidMessageVar("Message ('%s') already contains variable '%s', "
        "in MessageDescription::newVariable().",
        message->name.c_str(), variable_name.c_str());
}
#ifdef SWIG
template<typename T>
void MessageBruteForce::CDescription::newVariableArray(const std::string& variable_name, const size_type length) {
    if (!variable_name.empty() && variable_name[0] == '_') {
        THROW exception::ReservedName("Message variable names cannot begin with '_', this is reserved for internal usage, "
            "in MessageDescription::newVariable().");
    }
    if (length == 0) {
        THROW exception::InvalidMessageVar("Message variable arrays must have a length greater than 0."
            "in MessageDescription::newVariable().");
    }
    if (message->variables.find(variable_name) == message->variables.end()) {
        std::vector<typename detail::type_decode<T>::type_t> temp(static_cast<size_t>(detail::type_decode<T>::len_t * length));
        message->variables.emplace(variable_name, Variable(detail::type_decode<T>::len_t * length, temp));
        return;
    }
    THROW exception::InvalidMessageVar("Message ('%s') already contains variable '%s', "
        "in MessageDescription::newVariable().",
        message->name.c_str(), variable_name.c_str());
}
#endif

}  // namespace flamegpu

#endif  // INCLUDE_FLAMEGPU_RUNTIME_MESSAGING_MESSAGEBRUTEFORCE_MESSAGEBRUTEFORCEHOST_H_