Program Listing for File MessageBruteForceDevice.cuh

Return to documentation for file (include/flamegpu/runtime/messaging/MessageBruteForce/MessageBruteForceDevice.cuh)

#ifndef INCLUDE_FLAMEGPU_RUNTIME_MESSAGING_MESSAGEBRUTEFORCE_MESSAGEBRUTEFORCEDEVICE_CUH_
#define INCLUDE_FLAMEGPU_RUNTIME_MESSAGING_MESSAGEBRUTEFORCE_MESSAGEBRUTEFORCEDEVICE_CUH_

#include "flamegpu/defines.h"
#include "flamegpu/runtime/messaging/MessageNone.h"
#include "flamegpu/runtime/messaging/MessageBruteForce.h"
#ifndef __CUDACC_RTC__
#include "flamegpu/runtime/detail/curve/DeviceCurve.cuh"
#endif  // __CUDACC_RTC__

struct ModelData;

namespace flamegpu {

class MessageBruteForce::In {
 public:
    class Message;      // Forward declare inner classes
    class iterator;     // Forward declare inner classes

    __device__ In(const void *metadata)
        : len(reinterpret_cast<const MetaData*>(metadata)->length)
    { }
    __device__ size_type size(void) const {
        return len;
    }
    __device__ iterator begin(void) const {  // const
        return iterator(*this, 0);
    }
    __device__ iterator end(void) const  {  // const
        // If there can be many begin, each with diff end, we need a middle layer to host the iterator/s
        return iterator(*this, len);
    }

    class Message {
        const MessageBruteForce::In &_parent;
        size_type index;

     public:
        __device__ Message(const MessageBruteForce::In &parent) : _parent(parent), index(0) {}
        __device__ Message(const MessageBruteForce::In &parent, size_type index) : _parent(parent), index(index) {}
        __host__ __device__ bool operator==(const Message& rhs) const { return  this->getIndex() == rhs.getIndex(); }
        __host__ __device__ bool operator!=(const Message& rhs) const { return  this->getIndex() != rhs.getIndex(); }
        __host__ __device__ Message& operator++() { ++index;  return *this; }
        __host__ __device__ size_type getIndex() const { return this->index; }
        template<typename T, unsigned int N> __device__
        T getVariable(const char(&variable_name)[N]) const;
        template<typename T, flamegpu::size_type N, unsigned int M> __device__
        T getVariable(const char(&variable_name)[M], unsigned int index) const;
    };

    class iterator {
         Message _message;

     public:
        __device__ iterator(const In &parent, size_type index) : _message(parent, index) {}
        __device__ iterator& operator++() { ++_message;  return *this; }
        __device__ bool operator==(const iterator& rhs) const { return  _message == rhs._message; }
        __device__ bool operator!=(const iterator& rhs) const { return  _message != rhs._message; }
        __device__  Message& operator*() { return _message; }
    };

 private:
    size_type len;
};



class MessageBruteForce::Out {
 public:
    __device__ Out(const void *, unsigned int *scan_flag_messageOutput)
        : scan_flag(scan_flag_messageOutput)
    { }
    template<typename T, unsigned int N>
    __device__ void setVariable(const char(&variable_name)[N], T value) const;
    template<typename T, unsigned int N, unsigned int M>
    __device__ void setVariable(const char(&variable_name)[M], unsigned int index, T value) const;

 protected:
    unsigned int *scan_flag;
};

template<typename T, unsigned int N>
__device__ T MessageBruteForce::In::Message::getVariable(const char(&variable_name)[N]) const {
#if !defined(FLAMEGPU_SEATBELTS) || FLAMEGPU_SEATBELTS
    // Ensure that the message is within bounds.
    if (index >= this->_parent.len) {
        DTHROW("Brute force message index exceeds messagelist length, unable to get variable '%s'.\n", variable_name);
        return static_cast<T>(0);
    }
#endif
    // get the value from curve using the message index.
#ifdef FLAMEGPU_USE_GLM
    T value = detail::curve::DeviceCurve::getMessageVariable<T>(variable_name, index);
#else
    T value = detail::curve::DeviceCurve::getMessageVariable_ldg<T>(variable_name, index);
#endif
    return value;
}
template<typename T, flamegpu::size_type N, unsigned int M> __device__
T MessageBruteForce::In::Message::getVariable(const char(&variable_name)[M], const unsigned int array_index) const {
    // simple indexing assumes index is the thread number (this may change later)
    const unsigned int index = (blockDim.x * blockIdx.x) + threadIdx.x;
#if !defined(FLAMEGPU_SEATBELTS) || FLAMEGPU_SEATBELTS
    // Ensure that the message is within bounds.
    if (index >= this->_parent.len) {
        DTHROW("Brute force message index exceeds messagelist length, unable to get variable '%s'.\n", variable_name);
        return static_cast<T>(0);
    }
#endif
    // get the value from curve using the message index.
#ifdef FLAMEGPU_USE_GLM
    T value = detail::curve::DeviceCurve::getMessageArrayVariable<T, N>(variable_name, index, array_index);
#else
    T value = detail::curve::DeviceCurve::getMessageArrayVariable_ldg<T, N>(variable_name, index, array_index);
#endif
    return value;
}

template<typename T, unsigned int N>
__device__ void MessageBruteForce::Out::setVariable(const char(&variable_name)[N], T value) const {  // message name or variable name
    if (variable_name[0] == '_') {
#if !defined(FLAMEGPU_SEATBELTS) || FLAMEGPU_SEATBELTS
        DTHROW("Variable names starting with '_' are reserved for internal use, with '%s', in MessageBruteForce::Out::setVariable().\n", variable_name);
#endif
        return;  // Fail silently
    }
    unsigned int index = (blockDim.x * blockIdx.x) + threadIdx.x;  // + d_message_count;

    // Todo: checking if the output message type is single or optional?  (d_message_type)

    // set the variable using curve
    detail::curve::DeviceCurve::setMessageVariable<T>(variable_name, value, index);

    // Set scan flag incase the message is optional
    this->scan_flag[index] = 1;
}
template<typename T, unsigned int N, unsigned int M>
__device__ void MessageBruteForce::Out::setVariable(const char(&variable_name)[M], const unsigned int array_index, T value) const {
    if (variable_name[0] == '_') {
#if !defined(FLAMEGPU_SEATBELTS) || FLAMEGPU_SEATBELTS
        DTHROW("Variable names starting with '_' are reserved for internal use, with '%s', in MessageBruteForce::Out::setVariable().\n", variable_name);
#endif
        return;  // Fail silently
    }
    unsigned int index = (blockDim.x * blockIdx.x) + threadIdx.x;

    // Todo: checking if the output message type is single or optional?  (d_message_type)

    // set the variable using curve
    detail::curve::DeviceCurve::setMessageArrayVariable<T, N>(variable_name, value, index, array_index);

    // Set scan flag incase the message is optional
    this->scan_flag[index] = 1;
}

}  // namespace flamegpu

#endif  // INCLUDE_FLAMEGPU_RUNTIME_MESSAGING_MESSAGEBRUTEFORCE_MESSAGEBRUTEFORCEDEVICE_CUH_