Program Listing for File MessageSpatial3DDevice.cuh

Return to documentation for file (include/flamegpu/runtime/messaging/MessageSpatial3D/MessageSpatial3DDevice.cuh)

#ifndef INCLUDE_FLAMEGPU_RUNTIME_MESSAGING_MESSAGESPATIAL3D_MESSAGESPATIAL3DDEVICE_CUH_
#define INCLUDE_FLAMEGPU_RUNTIME_MESSAGING_MESSAGESPATIAL3D_MESSAGESPATIAL3DDEVICE_CUH_

#include "flamegpu/runtime/messaging/MessageSpatial3D.h"
#include "flamegpu/runtime/messaging/MessageSpatial2D/MessageSpatial2DDevice.cuh"
#include "flamegpu/runtime/messaging/MessageBruteForce/MessageBruteForceDevice.cuh"

namespace flamegpu {

class MessageSpatial3D::In {
 public:
    class Filter {
     public:
        class Message {
            const Filter &_parent;
            int relative_cell[2] = { -2, 1 };
            int cell_index_max = 0;
            int cell_index = 0;
            __device__ void nextStrip() {
                if (relative_cell[1] >= 1) {
                    relative_cell[1] = -1;
                    relative_cell[0]++;
                } else {
                    relative_cell[1]++;
                }
            }

         public:
            __device__ Message(const Filter &parent, const int relative_cell_y, const int relative_cell_z, const int _cell_index_max, const int _cell_index)
                : _parent(parent)
                , cell_index_max(_cell_index_max)
                , cell_index(_cell_index) {
                relative_cell[0] = relative_cell_y;
                relative_cell[1] = relative_cell_z;
            }
            __device__ Message(const Filter &parent)
                : _parent(parent) { }
            __device__ bool operator==(const Message &rhs) const {
                return this->relative_cell[0] == rhs.relative_cell[0]
                    && this->relative_cell[1] == rhs.relative_cell[1]
                    && this->cell_index_max == rhs.cell_index_max
                    && this->cell_index == rhs.cell_index;
            }
            __device__ bool operator!=(const Message&) const {
                // The incoming Message& is end(), so we don't care about that
                // We only care that the host object has reached end
                // When the strip number equals 2, it has exceeded the [-1, 1] range
                return !(this->relative_cell[0] >= 2);
            }
            __device__ Message& operator++();
            template<typename T, unsigned int N>
            __device__ T getVariable(const char(&variable_name)[N]) const;
            template<typename T, flamegpu::size_type N, unsigned int M> __device__
            T getVariable(const char(&variable_name)[M], unsigned int index) const;
        };
        class iterator {
            Message _message;

         public:
            __device__ iterator(const Filter &parent, const int relative_cell_y, const int relative_cell_z, const int _cell_index_max, const int _cell_index)
                : _message(parent, relative_cell_y, relative_cell_z, _cell_index_max, _cell_index) {
                // Increment to find first message
                ++_message;
            }
            __device__ iterator(const Filter &parent)
                : _message(parent) { }
            __device__ iterator& operator++() { ++_message;  return *this; }
            __device__ iterator operator++(int) {
                iterator temp = *this;
                ++*this;
                return temp;
            }
            __device__ bool operator==(const iterator& rhs) const { return  _message == rhs._message; }
            __device__ bool operator!=(const iterator& rhs) const { return  _message != rhs._message; }
            __device__ Message& operator*() { return _message; }
            __device__ Message* operator->() { return &_message; }
        };
        __device__ Filter(const MetaData *_metadata, float x, float y, float z);
        inline __device__ iterator begin(void) const {
            // Bin before initial bin, as the constructor calls increment operator
            return iterator(*this, -2, 1, 1, 0);
        }
        inline __device__ iterator end(void) const {
            // Empty init, because this object is never used
            // iterator equality doesn't actually check the end object
            return iterator(*this);
        }

     private:
        float loc[3];
        GridPos3D cell;
        const MetaData *metadata;
    };
    class WrapFilter {
     public:
        class Message {
            const WrapFilter&_parent;
            int relative_cell[3] = { -2, 1, -1 };
            int cell_index_max = 0;
            int cell_index = 0;
            __device__ void nextCell() {
                if (relative_cell[2] >= 1) {
                    relative_cell[2] = -1;
                    if (relative_cell[1] >= 1) {
                        relative_cell[1] = -1;
                        ++relative_cell[0];
                    } else {
                        ++relative_cell[1];
                    }
                } else {
                    ++relative_cell[2];
                }
            }

         public:
            __device__ Message(const WrapFilter& parent, const int relative_cell_x, const int relative_cell_y, const int relative_cell_z, const int _cell_index_max, const int _cell_index)
                : _parent(parent)
                , cell_index_max(_cell_index_max)
                , cell_index(_cell_index) {
                relative_cell[0] = relative_cell_x;
                relative_cell[1] = relative_cell_y;
                relative_cell[2] = relative_cell_z;
            }
            __device__ Message(const WrapFilter& parent)
                : _parent(parent) { }
            __device__ bool operator==(const Message &rhs) const {
                return this->relative_cell[0] == rhs.relative_cell[0]
                    && this->relative_cell[1] == rhs.relative_cell[1]
                    && this->relative_cell[2] == rhs.relative_cell[2]
                    && this->cell_index_max == rhs.cell_index_max
                    && this->cell_index == rhs.cell_index;
            }
            __device__ bool operator!=(const Message&) const {
                // The incoming Message& is end(), so we don't care about that
                // We only care that the host object has reached end
                // When the strip number equals 2, it has exceeded the [1, 1, 1] range
                return !(this->relative_cell[0] >= 2);
            }
            __device__ Message& operator++();
            template<typename T, unsigned int N>
            __device__ T getVariable(const char(&variable_name)[N]) const;
            template<typename T, flamegpu::size_type N, unsigned int M> __device__
            T getVariable(const char(&variable_name)[M], unsigned int index) const;
            __device__ float getVirtualX() const {
                return getVirtualX(_parent.loc[0]);
            }
            __device__ float getVirtualY() const {
                return getVirtualY(_parent.loc[1]);
            }
            __device__ float getVirtualZ() const {
                return getVirtualZ(_parent.loc[2]);
            }
            __device__ float getVirtualX(const float x1) const {
                const float x2 = getVariable<float>("x");
                const float x21 = x2 - x1;
                return abs(x21) > _parent.metadata->environmentWidth[0] / 2.0f ? x2 - (x21 / abs(x21) * _parent.metadata->environmentWidth[0]) : x2;
            }
            __device__ float getVirtualY(const float y1) const {
                const float y2 = getVariable<float>("y");
                const float y21 = y2 - y1;
                return abs(y21) > _parent.metadata->environmentWidth[1] / 2.0f ? y2 - (y21 / abs(y21) * _parent.metadata->environmentWidth[1]) : y2;
            }
            __device__ float getVirtualZ(const float z1) const {
                const float z2 = getVariable<float>("z");
                const float z21 = z2 - z1;
                return abs(z21) > _parent.metadata->environmentWidth[2] / 2.0f ? z2 - (z21 / abs(z21) * _parent.metadata->environmentWidth[2]) : z2;
            }
        };
        class iterator {
            Message _message;

         public:
            __device__ iterator(const WrapFilter& parent, const int relative_cell_x, const int relative_cell_y, const int relative_cell_z, const int _cell_index_max, const int _cell_index)
                : _message(parent, relative_cell_x, relative_cell_y, relative_cell_z, _cell_index_max, _cell_index) {
                // Increment to find first message
                ++_message;
            }
            __device__ iterator(const WrapFilter& parent)
                : _message(parent) { }
            __device__ iterator& operator++() { ++_message;  return *this; }
            __device__ iterator operator++(int) {
                iterator temp = *this;
                ++*this;
                return temp;
            }
            __device__ bool operator==(const iterator& rhs) const { return  _message == rhs._message; }
            __device__ bool operator!=(const iterator& rhs) const { return  _message != rhs._message; }
            __device__ Message& operator*() { return _message; }
            __device__ Message* operator->() { return &_message; }
        };
        __device__ WrapFilter(const MetaData *_metadata, float x, float y, float z);
        inline __device__ iterator begin(void) const {
            // Bin before initial bin, as the constructor calls increment operator
            return iterator(*this, -2, 1, 1, 1, 0);
        }
        inline __device__ iterator end(void) const {
            // Empty init, because this object is never used
            // iterator equality doesn't actually check the end object
            return iterator(*this);
        }

     private:
        float loc[3];
        GridPos3D cell;
        const MetaData *metadata;
    };

    __device__ In(const void *_metadata)
        : metadata(reinterpret_cast<const MetaData*>(_metadata))
    { }
    inline __device__ Filter operator() (const float x, const float y, const float z) const {
        return Filter(metadata, x, y, z);
    }
    inline __device__ WrapFilter wrap(const float x, const float y, const float z) const {
#if !defined(FLAMEGPU_SEATBELTS) || FLAMEGPU_SEATBELTS
        if (x > metadata->max[0] ||
            y > metadata->max[1] ||
            z > metadata->max[2] ||
            x < metadata->min[0] ||
            y < metadata->min[1] ||
            z < metadata->min[2]) {
            DTHROW("Location (%f, %f, %f) exceeds environment bounds (%g, %g, %g):(%g, %g, %g),"
                " this is unsupported for the wrapped iterator, MessageSpatial3D::In::wrap().\n", x, y, z,
                metadata->min[0], metadata->min[1], metadata->min[2],
                metadata->max[0], metadata->max[1], metadata->max[2]);
            // Return iterator at min corner of env, this should be safe
            return WrapFilter(metadata, metadata->min[0], metadata->min[1], metadata->min[2]);
        }
        if (fmodf(metadata->max[0] - metadata->min[0], metadata->radius) > 0.00001f ||
            fmodf(metadata->max[1] - metadata->min[1], metadata->radius) > 0.00001f ||
            fmodf(metadata->max[2] - metadata->min[2], metadata->radius) > 0.00001f) {
            DTHROW("Spatial messaging radius (%g) is not a factor of environment dimensions (%g, %g, %g),"
                " this is unsupported for the wrapped iterator, MessageSpatial3D::In::wrap().\n", metadata->radius,
                metadata->max[0] - metadata->min[0],
                metadata->max[1] - metadata->min[1],
                metadata->max[2] - metadata->min[2]);
        }
#endif
        return WrapFilter(metadata, x, y, z);
    }

    __forceinline__ __device__ float radius() const {
        return metadata->radius;
    }

 private:
    const MetaData *metadata;
};

class MessageSpatial3D::Out : public MessageBruteForce::Out {
 public:
    __device__ Out(const void *, unsigned int *scan_flag_messageOutput)
        : MessageBruteForce::Out(nullptr, scan_flag_messageOutput)
    { }
    __device__ void setLocation(float x, float y, float z) const;
};

template<typename T, unsigned int N>
__device__ T MessageSpatial3D::In::Filter::Message::getVariable(const char(&variable_name)[N]) const {
#if !defined(FLAMEGPU_SEATBELTS) || FLAMEGPU_SEATBELTS
    // Ensure that the message is within bounds.
    if (relative_cell[0] >= 2) {
        DTHROW("MessageSpatial3D in invalid bin, unable to get variable '%s'.\n", variable_name);
        return static_cast<T>(0);
    }
#endif
    // get the value from curve using the message index.
    T value = detail::curve::DeviceCurve::getMessageVariable<T>(variable_name, cell_index);
    return value;
}
template<typename T, flamegpu::size_type N, unsigned int M> __device__
T MessageSpatial3D::In::Filter::Message::getVariable(const char(&variable_name)[M], const unsigned int array_index) const {
#if !defined(FLAMEGPU_SEATBELTS) || FLAMEGPU_SEATBELTS
    // Ensure that the message is within bounds.
    if (relative_cell[0] >= 2) {
        DTHROW("MessageSpatial3D in invalid bin, unable to get variable '%s'.\n", variable_name);
        return {};
    }
#endif
    // get the value from curve using the message index.
    T value = detail::curve::DeviceCurve::getMessageArrayVariable<T, N>(variable_name, cell_index, array_index);
    return value;
}
template<typename T, unsigned int N>
__device__ T MessageSpatial3D::In::WrapFilter::Message::getVariable(const char(&variable_name)[N]) const {
#if !defined(FLAMEGPU_SEATBELTS) || FLAMEGPU_SEATBELTS
    // Ensure that the message is within bounds.
    if (relative_cell[0] >= 2) {
        DTHROW("MessageSpatial3D in invalid bin, unable to get variable '%s'.\n", variable_name);
        return static_cast<T>(0);
    }
#endif
    // get the value from curve using the message index.
    T value = detail::curve::DeviceCurve::getMessageVariable<T>(variable_name, cell_index);
    return value;
}
template<typename T, flamegpu::size_type N, unsigned int M> __device__
T MessageSpatial3D::In::WrapFilter::Message::getVariable(const char(&variable_name)[M], const unsigned int array_index) const {
#if !defined(FLAMEGPU_SEATBELTS) || FLAMEGPU_SEATBELTS
    // Ensure that the message is within bounds.
    if (relative_cell[0] >= 2) {
        DTHROW("MessageSpatial3D in invalid bin, unable to get variable '%s'.\n", variable_name);
        return {};
    }
#endif
    // get the value from curve using the message index.
    T value = detail::curve::DeviceCurve::getMessageArrayVariable<T, N>(variable_name, cell_index, array_index);
    return value;
}


__device__ __forceinline__ MessageSpatial3D::GridPos3D getGridPosition3D(const MessageSpatial3D::MetaData *md, float x, float y, float z) {
    // Clamp each grid coord to 0<=x<dim
    int gridPos[3] = {
        static_cast<int>(floorf((x-md->min[0]) / md->radius)),
        static_cast<int>(floorf((y-md->min[1]) / md->radius)),
        static_cast<int>(floorf((z-md->min[2]) / md->radius))
    };
    MessageSpatial3D::GridPos3D rtn = {
        gridPos[0] < 0 ? 0 : (gridPos[0] >= static_cast<int>(md->gridDim[0]) ? static_cast<int>(md->gridDim[0]) - 1 : gridPos[0]),
        gridPos[1] < 0 ? 0 : (gridPos[1] >= static_cast<int>(md->gridDim[1]) ? static_cast<int>(md->gridDim[1]) - 1 : gridPos[1]),
        gridPos[2] < 0 ? 0 : (gridPos[2] >= static_cast<int>(md->gridDim[2]) ? static_cast<int>(md->gridDim[2]) - 1 : gridPos[2])
    };
    return rtn;
}
__device__ __forceinline__ unsigned int getHash3D(const MessageSpatial3D::MetaData *md, const MessageSpatial3D::GridPos3D &xyz) {
    // Bound gridPos to gridDimensions
    unsigned int gridPos[3] = {
        (unsigned int)(xyz.x < 0 ? 0 : (xyz.x >= static_cast<int>(md->gridDim[0]) - 1 ? static_cast<int>(md->gridDim[0]) - 1 : xyz.x)),  // Only x should ever be out of bounds here
        (unsigned int) xyz.y,  // xyz.y < 0 ? 0 : (xyz.y >= md->gridDim[1] - 1 ? md->gridDim[1] - 1 : xyz.y),
        (unsigned int) xyz.z,  // xyz.z < 0 ? 0 : (xyz.z >= md->gridDim[2] - 1 ? md->gridDim[2] - 1 : xyz.z)
    };
    // Compute hash (effectivley an index for to a bin within the partitioning grid in this case)
    return (unsigned int)(
        (gridPos[2] * md->gridDim[0] * md->gridDim[1]) +   // z
        (gridPos[1] * md->gridDim[0]) +                    // y
        gridPos[0]);                                      // x
}

__device__ inline void MessageSpatial3D::Out::setLocation(const float x, const float y, const float z) const {
    unsigned int index = (blockDim.x * blockIdx.x) + threadIdx.x;  // + d_message_count;

    // set the variables using curve
    detail::curve::DeviceCurve::setMessageVariable<float>("x", x, index);
    detail::curve::DeviceCurve::setMessageVariable<float>("y", y, index);
    detail::curve::DeviceCurve::setMessageVariable<float>("z", z, index);

    // Set scan flag incase the message is optional
    this->scan_flag[index] = 1;
}

__device__ inline MessageSpatial3D::In::Filter::Filter(const MetaData* _metadata, const float x, const float y, const float z)
    : metadata(_metadata) {
    loc[0] = x;
    loc[1] = y;
    loc[2] = z;
    cell = getGridPosition3D(_metadata, x, y, z);
}
__device__ inline MessageSpatial3D::In::Filter::Message& MessageSpatial3D::In::Filter::Message::operator++() {
    cell_index++;
    bool move_strip = cell_index >= cell_index_max;
    while (move_strip) {
        nextStrip();
        cell_index = 0;
        cell_index_max = 1;
        if (relative_cell[0] < 2) {
            // Calculate the strips start and end hash
            int absolute_cell[2] = { _parent.cell.y + relative_cell[0], _parent.cell.z + relative_cell[1] };
            // Skip the strip if it is completely out of bounds
            if (absolute_cell[0] >= 0 && absolute_cell[1] >= 0 && absolute_cell[0] < static_cast<int>(_parent.metadata->gridDim[1]) && absolute_cell[1] < static_cast<int>(_parent.metadata->gridDim[2])) {
                unsigned int start_hash = getHash3D(_parent.metadata, { _parent.cell.x - 1, absolute_cell[0], absolute_cell[1] });
                unsigned int end_hash = getHash3D(_parent.metadata, { _parent.cell.x + 1, absolute_cell[0], absolute_cell[1] });
                // Lookup start and end indicies from PBM
                cell_index = _parent.metadata->PBM[start_hash];
                cell_index_max = _parent.metadata->PBM[end_hash + 1];
            } else {
                // Goto next strip
                // Don't update move_strip
                continue;
            }
        }
        move_strip = cell_index >= cell_index_max;
    }
    return *this;
}
__device__ inline MessageSpatial3D::In::WrapFilter::WrapFilter(const MetaData* _metadata, const float x, const float y, const float z)
    : metadata(_metadata) {
    loc[0] = x;
    loc[1] = y;
    loc[2] = z;
    cell = getGridPosition3D(_metadata, x, y, z);
}
__device__ inline MessageSpatial3D::In::WrapFilter::Message& MessageSpatial3D::In::WrapFilter::Message::operator++() {
    cell_index++;
    bool move_strip = cell_index >= cell_index_max;
    while (move_strip) {
        nextCell();
        cell_index = 0;
        cell_index_max = 1;
        if (relative_cell[0] < 2) {
            // Calculate the strips start and end hash
            int absolute_cell_x = (_parent.cell.x + relative_cell[0] + static_cast<int>(_parent.metadata->gridDim[0])) % _parent.metadata->gridDim[0];
            int absolute_cell_y = (_parent.cell.y + relative_cell[1] + static_cast<int>(_parent.metadata->gridDim[1])) % _parent.metadata->gridDim[1];
            int absolute_cell_z = (_parent.cell.z + relative_cell[2] + static_cast<int>(_parent.metadata->gridDim[2])) % _parent.metadata->gridDim[2];
            // Skip the strip if it is completely out of bounds
            unsigned int start_hash = getHash3D(_parent.metadata, { absolute_cell_x, absolute_cell_y, absolute_cell_z });
            unsigned int end_hash = getHash3D(_parent.metadata, { absolute_cell_x, absolute_cell_y, absolute_cell_z });
            // Lookup start and end indicies from PBM
            cell_index = _parent.metadata->PBM[start_hash];
            cell_index_max = _parent.metadata->PBM[end_hash + 1];
        }
        move_strip = cell_index >= cell_index_max;
    }
    return *this;
}

}  // namespace flamegpu

#endif  // INCLUDE_FLAMEGPU_RUNTIME_MESSAGING_MESSAGESPATIAL3D_MESSAGESPATIAL3DDEVICE_CUH_