1 module dparse.rollback_allocator; 2 3 import core.memory : GC; 4 5 //version = debug_rollback_allocator; 6 7 /** 8 * Pointer-bump allocator with rollback functionality. 9 */ 10 struct RollbackAllocator 11 { 12 public: 13 14 // must be multiple of 8 15 enum memoryAlignment = 16u; 16 17 @disable this(this); 18 19 ~this() 20 { 21 while (first !is null) 22 deallocateNode(); 23 } 24 25 /** 26 * Allocates `size` bytes of memory. 27 */ 28 void[] allocate(const size_t size) 29 out (arr) 30 { 31 assert(arr.length == size); 32 } 33 do 34 { 35 import std.algorithm.comparison : min; 36 37 if (first is null) 38 allocateNode(size); 39 40 // Memory align the size 41 immutable size_t s = size & ~(cast(size_t) memoryAlignment - 1); 42 immutable size_t s2 = s == size ? size : s + memoryAlignment; 43 44 size_t fu = first.used; 45 size_t end = fu + s2; 46 //assert(end >= fu + size); 47 //assert(end % 8 == 0); 48 if (end > first.mem.length) 49 { 50 allocateNode(size); 51 fu = first.used; 52 end = fu + s2; 53 } 54 //assert((cast(size_t) first.mem.ptr) % 8 == 0); 55 //assert(((cast(size_t) first.mem.ptr) + end) % 8 == 0); 56 void[] m = first.mem[fu .. fu + size]; 57 // alignment can make our size here bigger than what we actually have, so we clamp down to the used amount 58 first.used = min(end, first.mem.length); 59 return m; 60 } 61 62 /** 63 * Rolls back the allocator to the given checkpoint. 64 */ 65 void rollback(size_t point) 66 { 67 import std.stdio : stderr; 68 69 if (point == 0) 70 { 71 while (first) 72 deallocateNode(); 73 return; 74 } 75 else 76 assert(contains(point), "Attepmted to roll back to a point not in the allocator."); 77 78 // while `first !is null` is always going to pass after the contains(point) check, it may no longer pass after deallocateNode 79 while (first !is null && !first.contains(point)) 80 deallocateNode(); 81 assert(first !is null); 82 83 immutable begin = point - cast(size_t) first.mem.ptr; 84 version (debug_rollback_allocator) 85 (cast(ubyte[]) first.mem)[begin .. $] = 0; 86 first.used = begin; 87 assert(cast(size_t) first.mem.ptr + first.used == point); 88 } 89 90 /** 91 * Get a checkpoint for the allocator. 92 */ 93 size_t setCheckpoint() const nothrow @nogc 94 { 95 assert(first is null || first.used <= first.mem.length); 96 return first is null ? 0 : cast(size_t) first.mem.ptr + first.used; 97 } 98 99 /** 100 * Allocates a T and returns a pointer to it 101 */ 102 auto make(T, Args...)(auto ref Args args) 103 { 104 import std.algorithm.comparison : max; 105 import stdx.allocator : stateSize; 106 import std.conv : emplace; 107 108 void[] mem = allocate(max(stateSize!T, 1)); 109 if (mem.ptr is null) 110 return null; 111 static if (is(T == class)) 112 return emplace!T(mem, args); 113 else 114 return emplace(cast(T*) mem.ptr, args); 115 } 116 117 private: 118 119 // Used for debugging 120 bool contains(size_t point) const 121 { 122 for (const(Node)* n = first; n !is null; n = n.next) 123 if (n.contains(point)) 124 return true; 125 return false; 126 } 127 128 static struct Node 129 { 130 Node* next; 131 size_t used; 132 ubyte[] mem; 133 134 bool contains(size_t p) const pure nothrow @nogc @safe 135 { 136 return p >= cast(size_t) mem.ptr && p <= cast(size_t) mem.ptr + mem.length; 137 } 138 } 139 140 void allocateNode(size_t size) 141 { 142 import core.exception : onOutOfMemoryError; 143 import std.algorithm : max; 144 import std.conv : emplace; 145 import stdx.allocator.mallocator : AlignedMallocator; 146 147 enum ALLOC_SIZE = 1024 * 8; 148 149 ubyte[] m = cast(ubyte[]) AlignedMallocator.instance.alignedAllocate(max(size + Node.sizeof, ALLOC_SIZE), memoryAlignment); 150 if (m is null) 151 onOutOfMemoryError(); 152 GC.addRange(m.ptr, m.length); 153 154 version (debug_rollback_allocator) 155 m[] = 0; 156 Node* n = emplace!Node(cast(Node*) m.ptr, first, 0, m[Node.sizeof .. $]); 157 assert((cast(size_t) n.mem.ptr) % 8 == 0, "The memoriez!"); 158 first = n; 159 } 160 161 void deallocateNode() 162 { 163 assert(first !is null); 164 import stdx.allocator.mallocator : AlignedMallocator; 165 166 Node* next = first.next; 167 ubyte[] mem = (cast(ubyte*) first)[0 .. Node.sizeof + first.mem.length]; 168 version (debug_rollback_allocator) 169 mem[] = 0; 170 GC.removeRange(mem.ptr); 171 AlignedMallocator.instance.deallocate(mem); 172 first = next; 173 } 174 175 Node* first; 176 } 177 178 @("most simple usage, including memory across multiple pointers") 179 unittest 180 { 181 RollbackAllocator rba; 182 size_t[10] checkpoint; 183 foreach (i; 0 .. 10) 184 { 185 checkpoint[i] = rba.setCheckpoint(); 186 rba.allocate(4000); 187 } 188 189 foreach_reverse (i; 0 .. 10) 190 { 191 rba.rollback(checkpoint[i]); 192 } 193 } 194 195 @("many allocates and frees while leaking memory") 196 unittest 197 { 198 RollbackAllocator rba; 199 foreach (i; 0 .. 10) 200 { 201 size_t[3] checkpoint; 202 foreach (n; 0 .. 3) 203 { 204 checkpoint[n] = rba.setCheckpoint(); 205 rba.allocate(4000); 206 } 207 foreach_reverse (n; 1 .. 3) 208 { 209 rba.rollback(checkpoint[n]); 210 } 211 } 212 } 213 214 @("allocating overly big") 215 unittest 216 { 217 import std.stdio : stderr; 218 219 RollbackAllocator rba; 220 size_t[200] checkpoint; 221 size_t cp; 222 foreach (i; 1024 * 8 - 100 .. 1024 * 8 + 100) 223 { 224 try 225 { 226 checkpoint[cp++] = rba.setCheckpoint(); 227 rba.allocate(i); 228 } 229 catch (Error e) 230 { 231 stderr.writeln("Unittest: crashed in allocating ", i, " bytes"); 232 throw e; 233 } 234 } 235 236 foreach_reverse (i, c; checkpoint[0 .. cp]) 237 { 238 try 239 { 240 rba.rollback(c); 241 } 242 catch (Error e) 243 { 244 stderr.writeln("Unittest: crashed in rolling back ", i, " (address ", c, ")"); 245 throw e; 246 } 247 } 248 }