Program Listing for File MessageBucketDevice.cuh
↰ Return to documentation for file (include/flamegpu/runtime/messaging/MessageBucket/MessageBucketDevice.cuh
)
#ifndef INCLUDE_FLAMEGPU_RUNTIME_MESSAGING_MESSAGEBUCKET_MESSAGEBUCKETDEVICE_CUH_
#define INCLUDE_FLAMEGPU_RUNTIME_MESSAGING_MESSAGEBUCKET_MESSAGEBUCKETDEVICE_CUH_
#include "flamegpu/runtime/messaging/MessageBucket.h"
#include "flamegpu/runtime/messaging/MessageBruteForce/MessageBruteForceDevice.cuh"
namespace flamegpu {
class MessageBucket::In {
public:
class Filter {
friend class Message;
public:
class Message {
const Filter &_parent;
IntT cell_index;
public:
__device__ Message(const Filter &parent, const IntT &_cell_index)
: _parent(parent)
, cell_index(_cell_index) { }
__device__ bool operator==(const Message& rhs) const {
return this->cell_index == rhs.cell_index;
}
__device__ bool operator!=(const Message& rhs) const { return !(*this == rhs); }
__device__ Message& operator++() { ++cell_index; return *this; }
template<typename T, size_type N>
__device__ T getVariable(const char(&variable_name)[N]) const;
template<typename T, flamegpu::size_type N, unsigned int M> __device__
T getVariable(const char(&variable_name)[M], unsigned int index) const;
};
class iterator {
Message _message;
public:
__device__ iterator(const Filter &parent, const IntT &cell_index)
: _message(parent, cell_index) {
// Increment to find first message
++_message;
}
__device__ iterator& operator++() { ++_message; return *this; }
__device__ iterator operator++(int) {
iterator temp = *this;
++*this;
return temp;
}
__device__ bool operator==(const iterator& rhs) const { return _message == rhs._message; }
__device__ bool operator!=(const iterator& rhs) const { return _message != rhs._message; }
__device__ Message& operator*() { return _message; }
__device__ Message* operator->() { return &_message; }
};
inline __device__ Filter(const MetaData *_metadata, const IntT &beginKey, const IntT &endKey);
#if !defined(FLAMEGPU_SEATBELTS) || FLAMEGPU_SEATBELTS
inline __device__ Filter();
#endif
inline __device__ iterator begin(void) const {
// Bin before initial bin, as the constructor calls increment operator
return iterator(*this, bucket_begin - 1);
}
inline __device__ iterator end(void) const {
// Final bin, as the constructor calls increment operator
return iterator(*this, bucket_end - 1);
}
inline __device__ IntT size(void) const {
return bucket_end - bucket_begin;
}
private:
IntT bucket_begin, bucket_end;
const MetaData *metadata;
};
__device__ In(const void *_metadata)
: metadata(reinterpret_cast<const MetaData*>(_metadata))
{ }
inline __device__ Filter operator() (const IntT &key) const {
#if !defined(FLAMEGPU_SEATBELTS) || FLAMEGPU_SEATBELTS
{
if (key < metadata->min) {
DTHROW("Bucket messaging iterator key %d is lower than minimum key (%d).\n", key, metadata->min);
return Filter();
} else if (key >= metadata->max) {
DTHROW("Bucket messaging iterator key %d is higher than maximum key (%d).\n", key, metadata->max - 1);
return Filter();
}
}
#endif
return Filter(metadata, key, key + 1);
}
inline __device__ Filter operator() (const IntT &beginKey, const IntT &endKey) const {
#if !defined(FLAMEGPU_SEATBELTS) || FLAMEGPU_SEATBELTS
{
if (beginKey < metadata->min) {
DTHROW("Bucket messaging iterator begin key %d is lower than minimum key (%d).\n", beginKey, metadata->min);
return Filter();
} else if (endKey > metadata->max) {
DTHROW("Bucket messaging iterator end key %d is higher than maximum key + 1 (%d).\n", endKey, metadata->max);
return Filter();
} else if (endKey <= beginKey) {
DTHROW("Bucket messaging iterator begin key must be lower than end key (%d !< %d).\n", beginKey, endKey);
return Filter();
}
}
#endif
return Filter(metadata, beginKey, endKey);
}
private:
const MetaData *metadata;
};
class MessageBucket::Out : public MessageBruteForce::Out {
public:
__device__ Out(const void *_metadata, unsigned int *scan_flag_messageOutput)
: MessageBruteForce::Out(nullptr, scan_flag_messageOutput)
#if !defined(FLAMEGPU_SEATBELTS) || FLAMEGPU_SEATBELTS
, metadata(reinterpret_cast<const MetaData*>(_metadata))
#else
, metadata(nullptr)
#endif
{ }
inline __device__ void setKey(const IntT &key) const;
const MetaData * const metadata;
};
__device__ MessageBucket::In::Filter::Filter(const MetaData* _metadata, const IntT& beginKey, const IntT& endKey)
: bucket_begin(0)
, bucket_end(0)
, metadata(_metadata) {
// If key is in bounds
if (beginKey >= metadata->min && endKey < metadata->max && beginKey <= endKey) {
bucket_begin = metadata->PBM[beginKey - metadata->min];
bucket_end = metadata->PBM[endKey - metadata->min];
}
}
#if !defined(FLAMEGPU_SEATBELTS) || FLAMEGPU_SEATBELTS
__device__ MessageBucket::In::Filter::Filter()
: bucket_begin(0)
, bucket_end(0)
, metadata(nullptr) { }
#endif
__device__ void MessageBucket::Out::setKey(const IntT &key) const {
unsigned int index = (blockDim.x * blockIdx.x) + threadIdx.x; // + d_message_count;
#if !defined(FLAMEGPU_SEATBELTS) || FLAMEGPU_SEATBELTS
if (key < metadata->min || key >= metadata->max) {
DTHROW("MessageArray key %u is out of range [%d, %d).\n", key, metadata->min, metadata->max);
return;
}
#endif
// set the variables using curve
detail::curve::DeviceCurve::setMessageVariable<IntT>("_key", key, index);
// Set scan flag incase the message is optional
this->scan_flag[index] = 1;
}
template<typename T, unsigned int N>
__device__ T MessageBucket::In::Filter::Message::getVariable(const char(&variable_name)[N]) const {
#if !defined(FLAMEGPU_SEATBELTS) || FLAMEGPU_SEATBELTS
// Ensure that the message is within bounds.
if (cell_index >= _parent.bucket_end) {
DTHROW("Bucket message index exceeds bin length, unable to get variable '%s'.\n", variable_name);
return static_cast<T>(0);
}
#endif
// get the value from curve using the message index.
T value = detail::curve::DeviceCurve::getMessageVariable<T>(variable_name, cell_index);
return value;
}
template<typename T, flamegpu::size_type N, unsigned int M> __device__
T MessageBucket::In::Filter::Message::getVariable(const char(&variable_name)[M], const unsigned int array_index) const {
#if !defined(FLAMEGPU_SEATBELTS) || FLAMEGPU_SEATBELTS
// Ensure that the message is within bounds.
if (cell_index >= _parent.bucket_end) {
DTHROW("Bucket message index exceeds bin length, unable to get variable '%s'.\n", variable_name);
return {};
}
#endif
// get the value from curve using the message index.
T value = detail::curve::DeviceCurve::getMessageArrayVariable<T, N>(variable_name, cell_index, array_index);
return value;
}
} // namespace flamegpu
#endif // INCLUDE_FLAMEGPU_RUNTIME_MESSAGING_MESSAGEBUCKET_MESSAGEBUCKETDEVICE_CUH_