1 module dparse.rollback_allocator; 2 3 //version = debug_rollback_allocator; 4 5 /** 6 * Pointer-bump allocator with rollback functionality. 7 */ 8 struct RollbackAllocator 9 { 10 public: 11 12 // must be multiple of 8 13 enum memoryAlignment = 16u; 14 15 @disable this(this); 16 17 ~this() 18 { 19 while (first !is null) 20 deallocateNode(); 21 } 22 23 /** 24 * Allocates `size` bytes of memory. 25 */ 26 void[] allocate(const size_t size) 27 out (arr) 28 { 29 assert(arr.length == size); 30 } 31 do 32 { 33 import std.algorithm.comparison : min; 34 35 if (first is null) 36 allocateNode(size); 37 38 // Memory align the size 39 immutable size_t s = size & ~(cast(size_t) memoryAlignment - 1); 40 immutable size_t s2 = s == size ? size : s + memoryAlignment; 41 42 size_t fu = first.used; 43 size_t end = fu + s2; 44 //assert(end >= fu + size); 45 //assert(end % 8 == 0); 46 if (end > first.mem.length) 47 { 48 allocateNode(size); 49 fu = first.used; 50 end = fu + s2; 51 } 52 //assert((cast(size_t) first.mem.ptr) % 8 == 0); 53 //assert(((cast(size_t) first.mem.ptr) + end) % 8 == 0); 54 void[] m = first.mem[fu .. fu + size]; 55 // alignment can make our size here bigger than what we actually have, so we clamp down to the used amount 56 first.used = min(end, first.mem.length); 57 return m; 58 } 59 60 /** 61 * Rolls back the allocator to the given checkpoint. 62 */ 63 void rollback(size_t point) 64 { 65 import std.stdio : stderr; 66 67 if (point == 0) 68 { 69 while (first) 70 deallocateNode(); 71 return; 72 } 73 else 74 assert(contains(point), "Attepmted to roll back to a point not in the allocator."); 75 76 // while `first !is null` is always going to pass after the contains(point) check, it may no longer pass after deallocateNode 77 while (first !is null && !first.contains(point)) 78 deallocateNode(); 79 assert(first !is null); 80 81 immutable begin = point - cast(size_t) first.mem.ptr; 82 version (debug_rollback_allocator) 83 (cast(ubyte[]) first.mem)[begin .. $] = 0; 84 first.used = begin; 85 assert(cast(size_t) first.mem.ptr + first.used == point); 86 } 87 88 /** 89 * Get a checkpoint for the allocator. 90 */ 91 size_t setCheckpoint() const nothrow @nogc 92 { 93 assert(first is null || first.used <= first.mem.length); 94 return first is null ? 0 : cast(size_t) first.mem.ptr + first.used; 95 } 96 97 /** 98 * Allocates a T and returns a pointer to it 99 */ 100 auto make(T, Args...)(auto ref Args args) 101 { 102 import std.algorithm.comparison : max; 103 import stdx.allocator : stateSize; 104 import std.conv : emplace; 105 106 void[] mem = allocate(max(stateSize!T, 1)); 107 if (mem.ptr is null) 108 return null; 109 static if (is(T == class)) 110 return emplace!T(mem, args); 111 else 112 return emplace(cast(T*) mem.ptr, args); 113 } 114 115 private: 116 117 // Used for debugging 118 bool contains(size_t point) const 119 { 120 for (const(Node)* n = first; n !is null; n = n.next) 121 if (n.contains(point)) 122 return true; 123 return false; 124 } 125 126 static struct Node 127 { 128 Node* next; 129 size_t used; 130 ubyte[] mem; 131 132 bool contains(size_t p) const pure nothrow @nogc @safe 133 { 134 return p >= cast(size_t) mem.ptr && p <= cast(size_t) mem.ptr + mem.length; 135 } 136 } 137 138 void allocateNode(size_t size) 139 { 140 import core.exception : onOutOfMemoryError; 141 import std.algorithm : max; 142 import std.conv : emplace; 143 import stdx.allocator.mallocator : AlignedMallocator; 144 145 enum ALLOC_SIZE = 1024 * 8; 146 147 ubyte[] m = cast(ubyte[]) AlignedMallocator.instance.alignedAllocate(max(size + Node.sizeof, ALLOC_SIZE), memoryAlignment); 148 if (m is null) 149 onOutOfMemoryError(); 150 151 version (debug_rollback_allocator) 152 m[] = 0; 153 Node* n = emplace!Node(cast(Node*) m.ptr, first, 0, m[Node.sizeof .. $]); 154 assert((cast(size_t) n.mem.ptr) % 8 == 0, "The memoriez!"); 155 first = n; 156 } 157 158 void deallocateNode() 159 { 160 assert(first !is null); 161 import stdx.allocator.mallocator : AlignedMallocator; 162 163 Node* next = first.next; 164 ubyte[] mem = (cast(ubyte*) first)[0 .. Node.sizeof + first.mem.length]; 165 version (debug_rollback_allocator) 166 mem[] = 0; 167 AlignedMallocator.instance.deallocate(mem); 168 first = next; 169 } 170 171 Node* first; 172 } 173 174 @("most simple usage, including memory across multiple pointers") 175 unittest 176 { 177 RollbackAllocator rba; 178 size_t[10] checkpoint; 179 foreach (i; 0 .. 10) 180 { 181 checkpoint[i] = rba.setCheckpoint(); 182 rba.allocate(4000); 183 } 184 185 foreach_reverse (i; 0 .. 10) 186 { 187 rba.rollback(checkpoint[i]); 188 } 189 } 190 191 @("many allocates and frees while leaking memory") 192 unittest 193 { 194 RollbackAllocator rba; 195 foreach (i; 0 .. 10) 196 { 197 size_t[3] checkpoint; 198 foreach (n; 0 .. 3) 199 { 200 checkpoint[n] = rba.setCheckpoint(); 201 rba.allocate(4000); 202 } 203 foreach_reverse (n; 1 .. 3) 204 { 205 rba.rollback(checkpoint[n]); 206 } 207 } 208 } 209 210 @("allocating overly big") 211 unittest 212 { 213 import std.stdio : stderr; 214 215 RollbackAllocator rba; 216 size_t[200] checkpoint; 217 size_t cp; 218 foreach (i; 1024 * 8 - 100 .. 1024 * 8 + 100) 219 { 220 try 221 { 222 checkpoint[cp++] = rba.setCheckpoint(); 223 rba.allocate(i); 224 } 225 catch (Error e) 226 { 227 stderr.writeln("Unittest: crashed in allocating ", i, " bytes"); 228 throw e; 229 } 230 } 231 232 foreach_reverse (i, c; checkpoint[0 .. cp]) 233 { 234 try 235 { 236 rba.rollback(c); 237 } 238 catch (Error e) 239 { 240 stderr.writeln("Unittest: crashed in rolling back ", i, " (address ", c, ")"); 241 throw e; 242 } 243 } 244 }