Program Listing for File MessageBucketDevice.cuh

Return to documentation for file (include/flamegpu/runtime/messaging/MessageBucket/MessageBucketDevice.cuh)

#ifndef INCLUDE_FLAMEGPU_RUNTIME_MESSAGING_MESSAGEBUCKET_MESSAGEBUCKETDEVICE_CUH_
#define INCLUDE_FLAMEGPU_RUNTIME_MESSAGING_MESSAGEBUCKET_MESSAGEBUCKETDEVICE_CUH_

#include "flamegpu/runtime/messaging/MessageBucket.h"
#include "flamegpu/runtime/messaging/MessageBruteForce/MessageBruteForceDevice.cuh"

namespace flamegpu {

class MessageBucket::In {
 public:
    class Filter {
        friend class Message;

     public:
        class Message {
            const Filter &_parent;
            IntT cell_index;

         public:
            __device__ Message(const Filter &parent, const IntT &_cell_index)
                : _parent(parent)
                , cell_index(_cell_index) { }
            __device__ bool operator==(const Message& rhs) const {
                return this->cell_index == rhs.cell_index;
            }
            __device__ bool operator!=(const Message& rhs) const { return !(*this == rhs); }
            __device__ Message& operator++() { ++cell_index; return *this; }
            template<typename T, size_type N>
            __device__ T getVariable(const char(&variable_name)[N]) const;
            template<typename T, flamegpu::size_type N, unsigned int M> __device__
            T getVariable(const char(&variable_name)[M], unsigned int index) const;
        };
        class iterator {
            Message _message;

         public:
            __device__ iterator(const Filter &parent, const IntT &cell_index)
                : _message(parent, cell_index) {
                // Increment to find first message
                ++_message;
            }
            __device__ iterator& operator++() { ++_message;  return *this; }
            __device__ iterator operator++(int) {
                iterator temp = *this;
                ++*this;
                return temp;
            }
            __device__ bool operator==(const iterator& rhs) const { return  _message == rhs._message; }
            __device__ bool operator!=(const iterator& rhs) const { return  _message != rhs._message; }
            __device__ Message& operator*() { return _message; }
            __device__ Message* operator->() { return &_message; }
        };
        inline __device__ Filter(const MetaData *_metadata, const IntT &beginKey, const IntT &endKey);
#if !defined(FLAMEGPU_SEATBELTS) || FLAMEGPU_SEATBELTS
        inline __device__ Filter();
#endif
        inline __device__ iterator begin(void) const {
            // Bin before initial bin, as the constructor calls increment operator
            return iterator(*this, bucket_begin - 1);
        }
        inline __device__ iterator end(void) const {
            // Final bin, as the constructor calls increment operator
            return iterator(*this, bucket_end - 1);
        }
        inline __device__ IntT size(void) const {
            return bucket_end - bucket_begin;
        }

     private:
        IntT bucket_begin, bucket_end;
        const MetaData *metadata;
    };
    __device__ In(const void *_metadata)
        : metadata(reinterpret_cast<const MetaData*>(_metadata))
    { }
    inline __device__ Filter operator() (const IntT &key) const {
#if !defined(FLAMEGPU_SEATBELTS) || FLAMEGPU_SEATBELTS
        {
            if (key < metadata->min) {
                DTHROW("Bucket messaging iterator key %d is lower than minimum key (%d).\n", key, metadata->min);
                return Filter();
            } else if (key >= metadata->max) {
                DTHROW("Bucket messaging iterator key %d is higher than maximum key (%d).\n", key, metadata->max - 1);
                return Filter();
            }
        }
#endif
        return Filter(metadata, key, key + 1);
    }
    inline __device__ Filter operator() (const IntT &beginKey, const IntT &endKey) const {
#if !defined(FLAMEGPU_SEATBELTS) || FLAMEGPU_SEATBELTS
        {
            if (beginKey < metadata->min) {
                DTHROW("Bucket messaging iterator begin key %d is lower than minimum key (%d).\n", beginKey, metadata->min);
                return Filter();
            } else if (endKey > metadata->max) {
                DTHROW("Bucket messaging iterator end key %d is higher than maximum key + 1 (%d).\n", endKey, metadata->max);
                return Filter();
            } else if (endKey <= beginKey) {
                DTHROW("Bucket messaging iterator begin key must be lower than end key (%d !< %d).\n", beginKey, endKey);
                return Filter();
            }
        }
#endif
        return Filter(metadata, beginKey, endKey);
    }

 private:
    const MetaData *metadata;
};

class MessageBucket::Out : public MessageBruteForce::Out {
 public:
    __device__ Out(const void *_metadata, unsigned int *scan_flag_messageOutput)
        : MessageBruteForce::Out(nullptr, scan_flag_messageOutput)
#if !defined(FLAMEGPU_SEATBELTS) || FLAMEGPU_SEATBELTS
        , metadata(reinterpret_cast<const MetaData*>(_metadata))
#else
        , metadata(nullptr)
#endif
    { }
    inline __device__ void setKey(const IntT &key) const;
    const MetaData * const metadata;
};

__device__ MessageBucket::In::Filter::Filter(const MetaData* _metadata, const IntT& beginKey, const IntT& endKey)
    : bucket_begin(0)
    , bucket_end(0)
    , metadata(_metadata) {
    // If key is in bounds
    if (beginKey >= metadata->min && endKey < metadata->max && beginKey <= endKey) {
        bucket_begin = metadata->PBM[beginKey - metadata->min];
        bucket_end = metadata->PBM[endKey - metadata->min];
    }
}
#if !defined(FLAMEGPU_SEATBELTS) || FLAMEGPU_SEATBELTS
__device__ MessageBucket::In::Filter::Filter()
    : bucket_begin(0)
    , bucket_end(0)
    , metadata(nullptr) { }
#endif

__device__ void MessageBucket::Out::setKey(const IntT &key) const {
    unsigned int index = (blockDim.x * blockIdx.x) + threadIdx.x;  // + d_message_count;

#if !defined(FLAMEGPU_SEATBELTS) || FLAMEGPU_SEATBELTS
    if (key < metadata->min || key >= metadata->max) {
        DTHROW("MessageArray key %u is out of range [%d, %d).\n", key, metadata->min, metadata->max);
        return;
    }
#endif
    // set the variables using curve
    detail::curve::DeviceCurve::setMessageVariable<IntT>("_key", key, index);

    // Set scan flag incase the message is optional
    this->scan_flag[index] = 1;
}

template<typename T, unsigned int N>
__device__ T MessageBucket::In::Filter::Message::getVariable(const char(&variable_name)[N]) const {
#if !defined(FLAMEGPU_SEATBELTS) || FLAMEGPU_SEATBELTS
    // Ensure that the message is within bounds.
    if (cell_index >= _parent.bucket_end) {
        DTHROW("Bucket message index exceeds bin length, unable to get variable '%s'.\n", variable_name);
        return static_cast<T>(0);
    }
#endif
    // get the value from curve using the message index.
    T value = detail::curve::DeviceCurve::getMessageVariable<T>(variable_name, cell_index);
    return value;
}
template<typename T, flamegpu::size_type N, unsigned int M> __device__
T MessageBucket::In::Filter::Message::getVariable(const char(&variable_name)[M], const unsigned int array_index) const {
#if !defined(FLAMEGPU_SEATBELTS) || FLAMEGPU_SEATBELTS
    // Ensure that the message is within bounds.
    if (cell_index >= _parent.bucket_end) {
        DTHROW("Bucket message index exceeds bin length, unable to get variable '%s'.\n", variable_name);
        return {};
    }
#endif
    // get the value from curve using the message index.
    T value = detail::curve::DeviceCurve::getMessageArrayVariable<T, N>(variable_name, cell_index, array_index);
    return value;
}
}  // namespace flamegpu


#endif  // INCLUDE_FLAMEGPU_RUNTIME_MESSAGING_MESSAGEBUCKET_MESSAGEBUCKETDEVICE_CUH_