source: CLRX/CLRadeonExtender/trunk/amdasm/AsmRegAlloc.cpp @ 3991

Last change on this file since 3991 was 3991, checked in by matszpk, 12 months ago

CLRadeonExtender: Move SimpleCache? to Containers.h. Move createSSAData stuff into new source file (AsmRegAllocSSAData.cpp). Add new include: AsmRegAlloc?.h.

File size: 55.1 KB
Line 
1/*
2 *  CLRadeonExtender - Unofficial OpenCL Radeon Extensions Library
3 *  Copyright (C) 2014-2018 Mateusz Szpakowski
4 *
5 *  This library is free software; you can redistribute it and/or
6 *  modify it under the terms of the GNU Lesser General Public
7 *  License as published by the Free Software Foundation; either
8 *  version 2.1 of the License, or (at your option) any later version.
9 *
10 *  This library is distributed in the hope that it will be useful,
11 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
12 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13 *  Lesser General Public License for more details.
14 *
15 *  You should have received a copy of the GNU Lesser General Public
16 *  License along with this library; if not, write to the Free Software
17 *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
18 */
19
20#include <CLRX/Config.h>
21#include <iostream>
22#include <stack>
23#include <deque>
24#include <vector>
25#include <utility>
26#include <unordered_set>
27#include <map>
28#include <set>
29#include <unordered_map>
30#include <algorithm>
31#include <CLRX/utils/Utilities.h>
32#include <CLRX/utils/Containers.h>
33#include <CLRX/amdasm/Assembler.h>
34#include "AsmInternals.h"
35#include "AsmRegAlloc.h"
36
37using namespace CLRX;
38
39std::ostream& operator<<(std::ostream& os, const CLRX::BlockIndex& v)
40{
41    if (v.pass==0)
42        return os << v.index;
43    else
44        return os << v.index << "#" << v.pass;
45}
46
47ISAUsageHandler::ISAUsageHandler(const std::vector<cxbyte>& _content) :
48            content(_content), lastOffset(0), readOffset(0), instrStructPos(0),
49            regUsagesPos(0), regUsages2Pos(0), regVarUsagesPos(0),
50            pushedArgs(0), argPos(0), argFlags(0), isNext(false), useRegMode(false)
51{ }
52
53ISAUsageHandler::~ISAUsageHandler()
54{ }
55
56void ISAUsageHandler::rewind()
57{
58    readOffset = instrStructPos = 0;
59    regUsagesPos = regUsages2Pos = regVarUsagesPos = 0;
60    useRegMode = false;
61    pushedArgs = 0;
62    skipBytesInInstrStruct();
63}
64
65void ISAUsageHandler::skipBytesInInstrStruct()
66{
67    // do not add instruction size if usereg (usereg immediately before instr regusages)
68    if ((instrStructPos != 0 || argPos != 0) && !useRegMode)
69        readOffset += defaultInstrSize;
70    argPos = 0;
71    for (;instrStructPos < instrStruct.size() &&
72        instrStruct[instrStructPos] > 0x80; instrStructPos++)
73        readOffset += (instrStruct[instrStructPos] & 0x7f);
74    isNext = (instrStructPos < instrStruct.size());
75}
76
77void ISAUsageHandler::putSpace(size_t offset)
78{
79    if (lastOffset != offset)
80    {
81        flush(); // flush before new instruction
82        // useReg immediately before instruction regusages
83        size_t defaultInstrSize = (!useRegMode ? this->defaultInstrSize : 0);
84        if (lastOffset > offset)
85            throw AsmException("Offset before previous instruction");
86        if (!instrStruct.empty() && offset - lastOffset < defaultInstrSize)
87            throw AsmException("Offset between previous instruction");
88        size_t toSkip = !instrStruct.empty() ? 
89                offset - lastOffset - defaultInstrSize : offset;
90        while (toSkip > 0)
91        {
92            size_t skipped = std::min(toSkip, size_t(0x7f));
93            instrStruct.push_back(skipped | 0x80);
94            toSkip -= skipped;
95        }
96        lastOffset = offset;
97        argFlags = 0;
98        pushedArgs = 0;
99    } 
100}
101
102void ISAUsageHandler::pushUsage(const AsmRegVarUsage& rvu)
103{
104    if (lastOffset == rvu.offset && useRegMode)
105        flush(); // only flush if useRegMode and no change in offset
106    else // otherwise
107        putSpace(rvu.offset);
108    useRegMode = false;
109    if (rvu.regVar != nullptr)
110    {
111        argFlags |= (1U<<pushedArgs);
112        regVarUsages.push_back({ rvu.regVar, rvu.rstart, rvu.rend, rvu.regField,
113            rvu.rwFlags, rvu.align });
114    }
115    else // reg usages
116        regUsages.push_back({ rvu.regField,cxbyte(rvu.rwFlags |
117                    getRwFlags(rvu.regField, rvu.rstart, rvu.rend)) });
118    pushedArgs++;
119}
120
121void ISAUsageHandler::pushUseRegUsage(const AsmRegVarUsage& rvu)
122{
123    if (lastOffset == rvu.offset && !useRegMode)
124        flush(); // only flush if useRegMode and no change in offset
125    else // otherwise
126        putSpace(rvu.offset);
127    useRegMode = true;
128    if (pushedArgs == 0 || pushedArgs == 256)
129    {
130        argFlags = 0;
131        pushedArgs = 0;
132        instrStruct.push_back(0x80); // sign of regvarusage from usereg
133        instrStruct.push_back(0);
134    }
135    if (rvu.regVar != nullptr)
136    {
137        argFlags |= (1U<<(pushedArgs & 7));
138        regVarUsages.push_back({ rvu.regVar, rvu.rstart, rvu.rend, rvu.regField,
139            rvu.rwFlags, rvu.align });
140    }
141    else // reg usages
142        regUsages2.push_back({ rvu.rstart, rvu.rend, rvu.rwFlags });
143    pushedArgs++;
144    if ((pushedArgs & 7) == 0) // just flush per 8 bit
145    {
146        instrStruct.push_back(argFlags);
147        instrStruct[instrStruct.size() - ((pushedArgs+7) >> 3) - 1] = pushedArgs;
148        argFlags = 0;
149    }
150}
151
152void ISAUsageHandler::flush()
153{
154    if (pushedArgs != 0)
155    {
156        if (!useRegMode)
157        {
158            // normal regvarusages
159            instrStruct.push_back(argFlags);
160            if ((argFlags & (1U<<(pushedArgs-1))) != 0)
161                regVarUsages.back().rwFlags |= 0x80;
162            else // reg usages
163                regUsages.back().rwFlags |= 0x80;
164        }
165        else
166        {
167            // use reg regvarusages
168            if ((pushedArgs & 7) != 0) //if only not pushed args remains
169                instrStruct.push_back(argFlags);
170            instrStruct[instrStruct.size() - ((pushedArgs+7) >> 3) - 1] = pushedArgs;
171        }
172    }
173}
174
175AsmRegVarUsage ISAUsageHandler::nextUsage()
176{
177    if (!isNext)
178        throw AsmException("No reg usage in this code");
179    AsmRegVarUsage rvu;
180    // get regvarusage
181    bool lastRegUsage = false;
182    rvu.offset = readOffset;
183    if (!useRegMode && instrStruct[instrStructPos] == 0x80)
184    {
185        // useRegMode (begin fetching useregs)
186        useRegMode = true;
187        argPos = 0;
188        instrStructPos++;
189        // pushedArgs - numer of useregs, 0 - 256 useregs
190        pushedArgs = instrStruct[instrStructPos++];
191        argFlags = instrStruct[instrStructPos];
192    }
193    rvu.useRegMode = useRegMode; // no ArgPos
194   
195    if ((instrStruct[instrStructPos] & (1U << (argPos&7))) != 0)
196    {
197        // regvar usage
198        const AsmRegVarUsageInt& inRVU = regVarUsages[regVarUsagesPos++];
199        rvu.regVar = inRVU.regVar;
200        rvu.rstart = inRVU.rstart;
201        rvu.rend = inRVU.rend;
202        rvu.regField = inRVU.regField;
203        rvu.rwFlags = inRVU.rwFlags & ASMRVU_ACCESS_MASK;
204        rvu.align = inRVU.align;
205        if (!useRegMode)
206            lastRegUsage = ((inRVU.rwFlags&0x80) != 0);
207    }
208    else if (!useRegMode)
209    {
210        // simple reg usage
211        const AsmRegUsageInt& inRU = regUsages[regUsagesPos++];
212        rvu.regVar = nullptr;
213        const std::pair<uint16_t, uint16_t> regPair =
214                    getRegPair(inRU.regField, inRU.rwFlags);
215        rvu.rstart = regPair.first;
216        rvu.rend = regPair.second;
217        rvu.rwFlags = (inRU.rwFlags & ASMRVU_ACCESS_MASK);
218        rvu.regField = inRU.regField;
219        rvu.align = 0;
220        lastRegUsage = ((inRU.rwFlags&0x80) != 0);
221    }
222    else
223    {
224        // use reg (simple reg usage, second structure)
225        const AsmRegUsage2Int& inRU = regUsages2[regUsages2Pos++];
226        rvu.regVar = nullptr;
227        rvu.rstart = inRU.rstart;
228        rvu.rend = inRU.rend;
229        rvu.rwFlags = inRU.rwFlags;
230        rvu.regField = ASMFIELD_NONE;
231        rvu.align = 0;
232    }
233    argPos++;
234    if (useRegMode)
235    {
236        // if inside useregs
237        if (argPos == (pushedArgs&0xff))
238        {
239            instrStructPos++; // end
240            skipBytesInInstrStruct();
241            useRegMode = false;
242        }
243        else if ((argPos & 7) == 0) // fetch new flag
244        {
245            instrStructPos++;
246            argFlags = instrStruct[instrStructPos];
247        }
248    }
249    // after instr
250    if (lastRegUsage)
251    {
252        instrStructPos++;
253        skipBytesInInstrStruct();
254    }
255    return rvu;
256}
257
258AsmRegAllocator::AsmRegAllocator(Assembler& _assembler) : assembler(_assembler)
259{ }
260
261static inline bool codeBlockStartLess(const AsmRegAllocator::CodeBlock& c1,
262                  const AsmRegAllocator::CodeBlock& c2)
263{ return c1.start < c2.start; }
264
265static inline bool codeBlockEndLess(const AsmRegAllocator::CodeBlock& c1,
266                  const AsmRegAllocator::CodeBlock& c2)
267{ return c1.end < c2.end; }
268
269void AsmRegAllocator::createCodeStructure(const std::vector<AsmCodeFlowEntry>& codeFlow,
270             size_t codeSize, const cxbyte* code)
271{
272    ISAAssembler* isaAsm = assembler.isaAssembler;
273    if (codeSize == 0)
274        return;
275    std::vector<size_t> splits;
276    std::vector<size_t> codeStarts;
277    std::vector<size_t> codeEnds;
278    codeStarts.push_back(0);
279    codeEnds.push_back(codeSize);
280    for (const AsmCodeFlowEntry& entry: codeFlow)
281    {
282        size_t instrAfter = 0;
283        if (entry.type == AsmCodeFlowType::JUMP || entry.type == AsmCodeFlowType::CJUMP ||
284            entry.type == AsmCodeFlowType::CALL || entry.type == AsmCodeFlowType::RETURN)
285            instrAfter = entry.offset + isaAsm->getInstructionSize(
286                        codeSize - entry.offset, code + entry.offset);
287       
288        switch(entry.type)
289        {
290            case AsmCodeFlowType::START:
291                codeStarts.push_back(entry.offset);
292                break;
293            case AsmCodeFlowType::END:
294                codeEnds.push_back(entry.offset);
295                break;
296            case AsmCodeFlowType::JUMP:
297                splits.push_back(entry.target);
298                codeEnds.push_back(instrAfter);
299                break;
300            case AsmCodeFlowType::CJUMP:
301                splits.push_back(entry.target);
302                splits.push_back(instrAfter);
303                break;
304            case AsmCodeFlowType::CALL:
305                splits.push_back(entry.target);
306                splits.push_back(instrAfter);
307                break;
308            case AsmCodeFlowType::RETURN:
309                codeEnds.push_back(instrAfter);
310                break;
311            default:
312                break;
313        }
314    }
315    std::sort(splits.begin(), splits.end());
316    splits.resize(std::unique(splits.begin(), splits.end()) - splits.begin());
317    std::sort(codeEnds.begin(), codeEnds.end());
318    codeEnds.resize(std::unique(codeEnds.begin(), codeEnds.end()) - codeEnds.begin());
319    // remove codeStarts between codeStart and codeEnd
320    size_t i = 0;
321    size_t ii = 0;
322    size_t ei = 0; // codeEnd i
323    while (i < codeStarts.size())
324    {
325        size_t end = (ei < codeEnds.size() ? codeEnds[ei] : SIZE_MAX);
326        if (ei < codeEnds.size())
327            ei++;
328        codeStarts[ii++] = codeStarts[i];
329        // skip codeStart to end
330        for (i++ ;i < codeStarts.size() && codeStarts[i] < end; i++);
331    }
332    codeStarts.resize(ii);
333    // add next codeStarts
334    auto splitIt = splits.begin();
335    for (size_t codeEnd: codeEnds)
336    {
337        auto it = std::lower_bound(splitIt, splits.end(), codeEnd);
338        if (it != splits.end())
339        {
340            codeStarts.push_back(*it);
341            splitIt = it;
342        }
343        else // if end
344            break;
345    }
346   
347    std::sort(codeStarts.begin(), codeStarts.end());
348    codeStarts.resize(std::unique(codeStarts.begin(), codeStarts.end()) -
349                codeStarts.begin());
350    // divide to blocks
351    splitIt = splits.begin();
352    for (size_t codeStart: codeStarts)
353    {
354        size_t codeEnd = *std::upper_bound(codeEnds.begin(), codeEnds.end(), codeStart);
355        splitIt = std::lower_bound(splitIt, splits.end(), codeStart);
356       
357        if (splitIt != splits.end() && *splitIt==codeStart)
358            ++splitIt; // skip split in codeStart
359       
360        for (size_t start = codeStart; start < codeEnd; )
361        {
362            size_t end = codeEnd;
363            if (splitIt != splits.end())
364            {
365                end = std::min(end, *splitIt);
366                ++splitIt;
367            }
368            codeBlocks.push_back({ start, end, { }, false, false, false });
369            start = end;
370        }
371    }
372    // force empty block at end if some jumps goes to its
373    if (!codeEnds.empty() && !codeStarts.empty() && !splits.empty() &&
374        codeStarts.back()==codeEnds.back() && codeStarts.back() == splits.back())
375        codeBlocks.push_back({ codeStarts.back(), codeStarts.back(), { },
376                             false, false, false });
377   
378    // construct flow-graph
379    for (const AsmCodeFlowEntry& entry: codeFlow)
380        if (entry.type == AsmCodeFlowType::CALL || entry.type == AsmCodeFlowType::JUMP ||
381            entry.type == AsmCodeFlowType::CJUMP || entry.type == AsmCodeFlowType::RETURN)
382        {
383            std::vector<CodeBlock>::iterator it;
384            size_t instrAfter = entry.offset + isaAsm->getInstructionSize(
385                        codeSize - entry.offset, code + entry.offset);
386           
387            if (entry.type != AsmCodeFlowType::RETURN)
388                it = binaryFind(codeBlocks.begin(), codeBlocks.end(),
389                        CodeBlock{ entry.target }, codeBlockStartLess);
390            else // return
391            {
392                it = binaryFind(codeBlocks.begin(), codeBlocks.end(),
393                        CodeBlock{ 0, instrAfter }, codeBlockEndLess);
394                // if block have return
395                if (it != codeBlocks.end())
396                    it->haveEnd = it->haveReturn = true;
397                continue;
398            }
399           
400            if (it == codeBlocks.end())
401                continue; // error!
402            auto it2 = std::lower_bound(codeBlocks.begin(), codeBlocks.end(),
403                    CodeBlock{ instrAfter }, codeBlockStartLess);
404            auto curIt = it2;
405            --curIt;
406           
407            curIt->nexts.push_back({ size_t(it - codeBlocks.begin()),
408                        entry.type == AsmCodeFlowType::CALL });
409            curIt->haveCalls |= entry.type == AsmCodeFlowType::CALL;
410            if (entry.type == AsmCodeFlowType::CJUMP ||
411                 entry.type == AsmCodeFlowType::CALL)
412            {
413                curIt->haveEnd = false; // revert haveEnd if block have cond jump or call
414                if (it2 != codeBlocks.end() && entry.type == AsmCodeFlowType::CJUMP)
415                    // add next next block (only for cond jump)
416                    curIt->nexts.push_back({ size_t(it2 - codeBlocks.begin()), false });
417            }
418            else if (entry.type == AsmCodeFlowType::JUMP)
419                curIt->haveEnd = true; // set end
420        }
421    // force haveEnd for block with cf_end
422    for (const AsmCodeFlowEntry& entry: codeFlow)
423        if (entry.type == AsmCodeFlowType::END)
424        {
425            auto it = binaryFind(codeBlocks.begin(), codeBlocks.end(),
426                    CodeBlock{ 0, entry.offset }, codeBlockEndLess);
427            if (it != codeBlocks.end())
428                it->haveEnd = true;
429        }
430   
431    if (!codeBlocks.empty()) // always set haveEnd to last block
432        codeBlocks.back().haveEnd = true;
433   
434    // reduce nexts
435    for (CodeBlock& block: codeBlocks)
436    {
437        // first non-call nexts, for correct resolving SSA conflicts
438        std::sort(block.nexts.begin(), block.nexts.end(),
439                  [](const NextBlock& n1, const NextBlock& n2)
440                  { return int(n1.isCall)<int(n2.isCall) ||
441                      (n1.isCall == n2.isCall && n1.block < n2.block); });
442        auto it = std::unique(block.nexts.begin(), block.nexts.end(),
443                  [](const NextBlock& n1, const NextBlock& n2)
444                  { return n1.block == n2.block && n1.isCall == n2.isCall; });
445        block.nexts.resize(it - block.nexts.begin());
446    }
447}
448
449
450void AsmRegAllocator::applySSAReplaces()
451{
452    /* prepare SSA id replaces */
453    struct MinSSAGraphNode
454    {
455        size_t minSSAId;
456        bool visited;
457        std::unordered_set<size_t> nexts;
458        MinSSAGraphNode() : minSSAId(SIZE_MAX), visited(false) { }
459    };
460    struct MinSSAGraphStackEntry
461    {
462        std::unordered_map<size_t, MinSSAGraphNode>::iterator nodeIt;
463        std::unordered_set<size_t>::const_iterator nextIt;
464        size_t minSSAId;
465    };
466   
467    for (auto& entry: ssaReplacesMap)
468    {
469        VectorSet<SSAReplace>& replaces = entry.second;
470        std::sort(replaces.begin(), replaces.end());
471        replaces.resize(std::unique(replaces.begin(), replaces.end()) - replaces.begin());
472        VectorSet<SSAReplace> newReplaces;
473       
474        std::unordered_map<size_t, MinSSAGraphNode> ssaGraphNodes;
475       
476        auto it = replaces.begin();
477        while (it != replaces.end())
478        {
479            auto itEnd = std::upper_bound(it, replaces.end(),
480                            std::make_pair(it->first, size_t(SIZE_MAX)));
481            {
482                MinSSAGraphNode& node = ssaGraphNodes[it->first];
483                node.minSSAId = std::min(node.minSSAId, it->second);
484                for (auto it2 = it; it2 != itEnd; ++it2)
485                    node.nexts.insert(it->second);
486            }
487            it = itEnd;
488        }
489        // propagate min value
490        std::stack<MinSSAGraphStackEntry> minSSAStack;
491        for (auto ssaGraphNodeIt = ssaGraphNodes.begin();
492                 ssaGraphNodeIt!=ssaGraphNodes.end(); )
493        {
494            minSSAStack.push({ ssaGraphNodeIt, ssaGraphNodeIt->second.nexts.begin() });
495            // traverse with minimalize SSA id
496            while (!minSSAStack.empty())
497            {
498                MinSSAGraphStackEntry& entry = minSSAStack.top();
499                MinSSAGraphNode& node = entry.nodeIt->second;
500                bool toPop = false;
501                if (entry.nextIt == node.nexts.begin())
502                {
503                    if (!node.visited)
504                        node.visited = true;
505                    else
506                        toPop = true;
507                }
508                if (!toPop && entry.nextIt != node.nexts.end())
509                {
510                    auto nodeIt = ssaGraphNodes.find(*entry.nextIt);
511                    if (nodeIt != ssaGraphNodes.end())
512                    {
513                        minSSAStack.push({ nodeIt, nodeIt->second.nexts.begin(),
514                                nodeIt->second.minSSAId });
515                    }
516                    ++entry.nextIt;
517                }
518                else
519                {
520                    node.minSSAId = std::min(node.minSSAId, entry.minSSAId);
521                    minSSAStack.pop();
522                    if (!minSSAStack.empty())
523                    {
524                        MinSSAGraphStackEntry& pentry = minSSAStack.top();
525                        pentry.minSSAId = std::min(pentry.minSSAId, node.minSSAId);
526                    }
527                }
528            }
529            // skip visited nodes
530            while (ssaGraphNodeIt != ssaGraphNodes.end())
531                if (!ssaGraphNodeIt->second.visited)
532                    break;
533        }
534       
535        for (const auto& entry: ssaGraphNodes)
536            newReplaces.push_back({ entry.first, entry.second.minSSAId });
537       
538        std::sort(newReplaces.begin(), newReplaces.end());
539        entry.second = newReplaces;
540    }
541   
542    /* apply SSA id replaces */
543    for (CodeBlock& cblock: codeBlocks)
544        for (auto& ssaEntry: cblock.ssaInfoMap)
545        {
546            auto it = ssaReplacesMap.find(ssaEntry.first);
547            if (it == ssaReplacesMap.end())
548                continue;
549            SSAInfo& sinfo = ssaEntry.second;
550            VectorSet<SSAReplace>& replaces = it->second;
551            if (sinfo.readBeforeWrite)
552            {
553                auto rit = binaryMapFind(replaces.begin(), replaces.end(),
554                                 ssaEntry.second.ssaIdBefore);
555                if (rit != replaces.end())
556                    sinfo.ssaIdBefore = rit->second; // replace
557            }
558            if (sinfo.ssaIdFirst != SIZE_MAX)
559            {
560                auto rit = binaryMapFind(replaces.begin(), replaces.end(),
561                                 ssaEntry.second.ssaIdFirst);
562                if (rit != replaces.end())
563                    sinfo.ssaIdFirst = rit->second; // replace
564            }
565            if (sinfo.ssaIdLast != SIZE_MAX)
566            {
567                auto rit = binaryMapFind(replaces.begin(), replaces.end(),
568                                 ssaEntry.second.ssaIdLast);
569                if (rit != replaces.end())
570                    sinfo.ssaIdLast = rit->second; // replace
571            }
572        }
573}
574
575struct Liveness
576{
577    std::map<size_t, size_t> l;
578   
579    Liveness() { }
580   
581    void clear()
582    { l.clear(); }
583   
584    void expand(size_t k)
585    {
586        if (l.empty())
587            l.insert(std::make_pair(k, k+1));
588        else
589        {
590            auto it = l.end();
591            --it;
592            it->second = k+1;
593        }
594    }
595    void newRegion(size_t k)
596    {
597        if (l.empty())
598            l.insert(std::make_pair(k, k));
599        else
600        {
601            auto it = l.end();
602            --it;
603            if (it->first != k && it->second != k)
604                l.insert(std::make_pair(k, k));
605        }
606    }
607   
608    void insert(size_t k, size_t k2)
609    {
610        auto it1 = l.lower_bound(k);
611        if (it1!=l.begin() && (it1==l.end() || it1->first>k))
612            --it1;
613        if (it1->second < k)
614            ++it1;
615        auto it2 = l.lower_bound(k2);
616        if (it1!=it2)
617        {
618            k = std::min(k, it1->first);
619            k2 = std::max(k2, (--it2)->second);
620            l.erase(it1, it2);
621        }
622        l.insert(std::make_pair(k, k2));
623    }
624   
625    bool contain(size_t t) const
626    {
627        auto it = l.lower_bound(t);
628        if (it==l.begin() && it->first>t)
629            return false;
630        if (it==l.end() || it->first>t)
631            --it;
632        return it->first<=t && t<it->second;
633    }
634   
635    bool common(const Liveness& b) const
636    {
637        auto i = l.begin();
638        auto j = b.l.begin();
639        for (; i != l.end() && j != b.l.end();)
640        {
641            if (i->first==i->second)
642            {
643                ++i;
644                continue;
645            }
646            if (j->first==j->second)
647            {
648                ++j;
649                continue;
650            }
651            if (i->first<j->first)
652            {
653                if (i->second > j->first)
654                    return true; // common place
655                ++i;
656            }
657            else
658            {
659                if (i->first < j->second)
660                    return true; // common place
661                ++j;
662            }
663        }
664        return false;
665    }
666};
667
668typedef AsmRegAllocator::VarIndexMap VarIndexMap;
669
670static cxuint getRegType(size_t regTypesNum, const cxuint* regRanges,
671            const AsmSingleVReg& svreg)
672{
673    cxuint regType; // regtype
674    if (svreg.regVar!=nullptr)
675        regType = svreg.regVar->type;
676    else
677        for (regType = 0; regType < regTypesNum; regType++)
678            if (svreg.index >= regRanges[regType<<1] &&
679                svreg.index < regRanges[(regType<<1)+1])
680                break;
681    return regType;
682}
683
684static Liveness& getLiveness(const AsmSingleVReg& svreg, size_t ssaIdIdx,
685        const AsmRegAllocator::SSAInfo& ssaInfo, std::vector<Liveness>* livenesses,
686        const VarIndexMap* vregIndexMaps, size_t regTypesNum, const cxuint* regRanges)
687{
688    size_t ssaId;
689    if (svreg.regVar==nullptr)
690        ssaId = 0;
691    else if (ssaIdIdx==0)
692        ssaId = ssaInfo.ssaIdBefore;
693    else if (ssaIdIdx==1)
694        ssaId = ssaInfo.ssaIdFirst;
695    else if (ssaIdIdx<ssaInfo.ssaIdChange)
696        ssaId = ssaInfo.ssaId + ssaIdIdx-1;
697    else // last
698        ssaId = ssaInfo.ssaIdLast;
699   
700    cxuint regType = getRegType(regTypesNum, regRanges, svreg); // regtype
701    const VarIndexMap& vregIndexMap = vregIndexMaps[regType];
702    const std::vector<size_t>& ssaIdIndices =
703                vregIndexMap.find(svreg)->second;
704    return livenesses[regType][ssaIdIndices[ssaId]];
705}
706
707typedef std::deque<FlowStackEntry3>::const_iterator FlowStackCIter;
708
709struct CLRX_INTERNAL VRegLastPos
710{
711    size_t ssaId; // last SSA id
712    std::vector<FlowStackCIter> blockChain; // subsequent blocks that changes SSAId
713};
714
715/* TODO: add handling calls
716 * handle many start points in this code (for example many kernel's in same code)
717 * replace sets by vector, and sort and remove same values on demand
718 */
719
720typedef std::unordered_map<AsmSingleVReg, VRegLastPos> LastVRegMap;
721
722static void putCrossBlockLivenesses(const std::deque<FlowStackEntry3>& flowStack,
723        const std::vector<CodeBlock>& codeBlocks,
724        const Array<size_t>& codeBlockLiveTimes, const LastVRegMap& lastVRegMap,
725        std::vector<Liveness>* livenesses, const VarIndexMap* vregIndexMaps,
726        size_t regTypesNum, const cxuint* regRanges)
727{
728    const CodeBlock& cblock = codeBlocks[flowStack.back().blockIndex];
729    for (const auto& entry: cblock.ssaInfoMap)
730        if (entry.second.readBeforeWrite)
731        {
732            // find last
733            auto lvrit = lastVRegMap.find(entry.first);
734            if (lvrit == lastVRegMap.end())
735                continue; // not found
736            const VRegLastPos& lastPos = lvrit->second;
737            FlowStackCIter flit = lastPos.blockChain.back();
738            cxuint regType = getRegType(regTypesNum, regRanges, entry.first);
739            const VarIndexMap& vregIndexMap = vregIndexMaps[regType];
740            const std::vector<size_t>& ssaIdIndices =
741                        vregIndexMap.find(entry.first)->second;
742            Liveness& lv = livenesses[regType][ssaIdIndices[entry.second.ssaIdBefore]];
743            FlowStackCIter flitEnd = flowStack.end();
744            --flitEnd; // before last element
745            // insert live time to last seen position
746            const CodeBlock& lastBlk = codeBlocks[flit->blockIndex];
747            size_t toLiveCvt = codeBlockLiveTimes[flit->blockIndex] - lastBlk.start;
748            lv.insert(lastBlk.ssaInfoMap.find(entry.first)->second.lastPos + toLiveCvt,
749                    toLiveCvt + lastBlk.end);
750            for (++flit; flit != flitEnd; ++flit)
751            {
752                const CodeBlock& cblock = codeBlocks[flit->blockIndex];
753                size_t blockLiveTime = codeBlockLiveTimes[flit->blockIndex];
754                lv.insert(blockLiveTime, cblock.end-cblock.start + blockLiveTime);
755            }
756        }
757}
758
759static void putCrossBlockForLoop(const std::deque<FlowStackEntry3>& flowStack,
760        const std::vector<CodeBlock>& codeBlocks,
761        const Array<size_t>& codeBlockLiveTimes,
762        std::vector<Liveness>* livenesses, const VarIndexMap* vregIndexMaps,
763        size_t regTypesNum, const cxuint* regRanges)
764{
765    auto flitStart = flowStack.end();
766    --flitStart;
767    size_t curBlock = flitStart->blockIndex;
768    // find step in way
769    while (flitStart->blockIndex != curBlock) --flitStart;
770    auto flitEnd = flowStack.end();
771    --flitEnd;
772    std::unordered_map<AsmSingleVReg, std::pair<size_t, size_t> > varMap;
773   
774    // collect var to check
775    size_t flowPos = 0;
776    for (auto flit = flitStart; flit != flitEnd; ++flit, flowPos++)
777    {
778        const CodeBlock& cblock = codeBlocks[flit->blockIndex];
779        for (const auto& entry: cblock.ssaInfoMap)
780        {
781            const SSAInfo& sinfo = entry.second;
782            size_t lastSSAId = (sinfo.ssaIdChange != 0) ? sinfo.ssaIdLast :
783                    (sinfo.readBeforeWrite) ? sinfo.ssaIdBefore : 0;
784            varMap[entry.first] = { lastSSAId, flowPos };
785        }
786    }
787    // find connections
788    flowPos = 0;
789    for (auto flit = flitStart; flit != flitEnd; ++flit, flowPos++)
790    {
791        const CodeBlock& cblock = codeBlocks[flit->blockIndex];
792        for (const auto& entry: cblock.ssaInfoMap)
793        {
794            const SSAInfo& sinfo = entry.second;
795            auto varMapIt = varMap.find(entry.first);
796            if (!sinfo.readBeforeWrite || varMapIt == varMap.end() ||
797                flowPos > varMapIt->second.second ||
798                sinfo.ssaIdBefore != varMapIt->second.first)
799                continue;
800            // just connect
801           
802            cxuint regType = getRegType(regTypesNum, regRanges, entry.first);
803            const VarIndexMap& vregIndexMap = vregIndexMaps[regType];
804            const std::vector<size_t>& ssaIdIndices =
805                        vregIndexMap.find(entry.first)->second;
806            Liveness& lv = livenesses[regType][ssaIdIndices[entry.second.ssaIdBefore]];
807           
808            if (flowPos == varMapIt->second.second)
809            {
810                // fill whole loop
811                for (auto flit2 = flitStart; flit != flitEnd; ++flit)
812                {
813                    const CodeBlock& cblock = codeBlocks[flit2->blockIndex];
814                    size_t blockLiveTime = codeBlockLiveTimes[flit2->blockIndex];
815                    lv.insert(blockLiveTime, cblock.end-cblock.start + blockLiveTime);
816                }
817                continue;
818            }
819           
820            size_t flowPos2 = 0;
821            for (auto flit2 = flitStart; flowPos2 < flowPos; ++flit2, flowPos++)
822            {
823                const CodeBlock& cblock = codeBlocks[flit2->blockIndex];
824                size_t blockLiveTime = codeBlockLiveTimes[flit2->blockIndex];
825                lv.insert(blockLiveTime, cblock.end-cblock.start + blockLiveTime);
826            }
827            // insert liveness for last block in loop of last SSAId (prev round)
828            auto flit2 = flitStart + flowPos;
829            const CodeBlock& firstBlk = codeBlocks[flit2->blockIndex];
830            size_t toLiveCvt = codeBlockLiveTimes[flit2->blockIndex] - firstBlk.start;
831            lv.insert(codeBlockLiveTimes[flit2->blockIndex],
832                    firstBlk.ssaInfoMap.find(entry.first)->second.firstPos + toLiveCvt);
833            // insert liveness for first block in loop of last SSAId
834            flit2 = flitStart + (varMapIt->second.second+1);
835            const CodeBlock& lastBlk = codeBlocks[flit2->blockIndex];
836            toLiveCvt = codeBlockLiveTimes[flit2->blockIndex] - lastBlk.start;
837            lv.insert(lastBlk.ssaInfoMap.find(entry.first)->second.lastPos + toLiveCvt,
838                    toLiveCvt + lastBlk.end);
839            // fill up loop end
840            for (++flit2; flit2 != flitEnd; ++flit2)
841            {
842                const CodeBlock& cblock = codeBlocks[flit2->blockIndex];
843                size_t blockLiveTime = codeBlockLiveTimes[flit2->blockIndex];
844                lv.insert(blockLiveTime, cblock.end-cblock.start + blockLiveTime);
845            }
846        }
847    }
848}
849
850struct LiveBlock
851{
852    size_t start;
853    size_t end;
854    size_t vidx;
855   
856    bool operator==(const LiveBlock& b) const
857    { return start==b.start && end==b.end && vidx==b.vidx; }
858   
859    bool operator<(const LiveBlock& b) const
860    { return start<b.start || (start==b.start &&
861            (end<b.end || (end==b.end && vidx<b.vidx))); }
862};
863
864typedef AsmRegAllocator::LinearDep LinearDep;
865typedef AsmRegAllocator::EqualToDep EqualToDep;
866typedef std::unordered_map<size_t, LinearDep> LinearDepMap;
867typedef std::unordered_map<size_t, EqualToDep> EqualToDepMap;
868
869static void addUsageDeps(const cxbyte* ldeps, const cxbyte* edeps, cxuint rvusNum,
870            const AsmRegVarUsage* rvus, LinearDepMap* ldepsOut,
871            EqualToDepMap* edepsOut, const VarIndexMap* vregIndexMaps,
872            std::unordered_map<AsmSingleVReg, size_t> ssaIdIdxMap,
873            size_t regTypesNum, const cxuint* regRanges)
874{
875    // add linear deps
876    cxuint count = ldeps[0];
877    cxuint pos = 1;
878    cxbyte rvuAdded = 0;
879    for (cxuint i = 0; i < count; i++)
880    {
881        cxuint ccount = ldeps[pos++];
882        std::vector<size_t> vidxes;
883        cxuint regType = UINT_MAX;
884        cxbyte align = rvus[ldeps[pos]].align;
885        for (cxuint j = 0; j < ccount; j++)
886        {
887            rvuAdded |= 1U<<ldeps[pos];
888            const AsmRegVarUsage& rvu = rvus[ldeps[pos++]];
889            for (uint16_t k = rvu.rstart; k < rvu.rend; k++)
890            {
891                AsmSingleVReg svreg = {rvu.regVar, k};
892                auto sit = ssaIdIdxMap.find(svreg);
893                if (regType==UINT_MAX)
894                    regType = getRegType(regTypesNum, regRanges, svreg);
895                const VarIndexMap& vregIndexMap = vregIndexMaps[regType];
896                const std::vector<size_t>& ssaIdIndices =
897                            vregIndexMap.find(svreg)->second;
898                // push variable index
899                vidxes.push_back(ssaIdIndices[sit->second]);
900            }
901        }
902        ldepsOut[regType][vidxes[0]].align = align;
903        for (size_t k = 1; k < vidxes.size(); k++)
904        {
905            ldepsOut[regType][vidxes[k-1]].nextVidxes.push_back(vidxes[k]);
906            ldepsOut[regType][vidxes[k]].prevVidxes.push_back(vidxes[k-1]);
907        }
908    }
909    // add single arg linear dependencies
910    for (cxuint i = 0; i < rvusNum; i++)
911        if ((rvuAdded & (1U<<i)) == 0 && rvus[i].rstart+1<rvus[i].rend)
912        {
913            const AsmRegVarUsage& rvu = rvus[i];
914            std::vector<size_t> vidxes;
915            cxuint regType = UINT_MAX;
916            for (uint16_t k = rvu.rstart; k < rvu.rend; k++)
917            {
918                AsmSingleVReg svreg = {rvu.regVar, k};
919                auto sit = ssaIdIdxMap.find(svreg);
920                if (regType==UINT_MAX)
921                    regType = getRegType(regTypesNum, regRanges, svreg);
922                const VarIndexMap& vregIndexMap = vregIndexMaps[regType];
923                const std::vector<size_t>& ssaIdIndices =
924                            vregIndexMap.find(svreg)->second;
925                // push variable index
926                vidxes.push_back(ssaIdIndices[sit->second]);
927            }
928            for (size_t j = 1; j < vidxes.size(); j++)
929            {
930                ldepsOut[regType][vidxes[j-1]].nextVidxes.push_back(vidxes[j]);
931                ldepsOut[regType][vidxes[j]].prevVidxes.push_back(vidxes[j-1]);
932            }
933        }
934       
935    /* equalTo dependencies */
936    count = edeps[0];
937    pos = 1;
938    for (cxuint i = 0; i < count; i++)
939    {
940        cxuint ccount = edeps[pos++];
941        std::vector<size_t> vidxes;
942        cxuint regType = UINT_MAX;
943        for (cxuint j = 0; j < ccount; j++)
944        {
945            const AsmRegVarUsage& rvu = rvus[edeps[pos++]];
946            // only one register should be set for equalTo depencencies
947            // other registers in range will be resolved by linear dependencies
948            AsmSingleVReg svreg = {rvu.regVar, rvu.rstart};
949            auto sit = ssaIdIdxMap.find(svreg);
950            if (regType==UINT_MAX)
951                regType = getRegType(regTypesNum, regRanges, svreg);
952            const VarIndexMap& vregIndexMap = vregIndexMaps[regType];
953            const std::vector<size_t>& ssaIdIndices =
954                        vregIndexMap.find(svreg)->second;
955            // push variable index
956            vidxes.push_back(ssaIdIndices[sit->second]);
957        }
958        for (size_t j = 1; j < vidxes.size(); j++)
959        {
960            edepsOut[regType][vidxes[j-1]].nextVidxes.push_back(vidxes[j]);
961            edepsOut[regType][vidxes[j]].prevVidxes.push_back(vidxes[j-1]);
962        }
963    }
964}
965
966typedef std::unordered_map<size_t, EqualToDep>::const_iterator EqualToDepMapCIter;
967
968struct EqualStackEntry
969{
970    EqualToDepMapCIter etoDepIt;
971    size_t nextIdx; // over nextVidxes size, then prevVidxes[nextIdx-nextVidxes.size()]
972};
973
974void AsmRegAllocator::createInterferenceGraph(ISAUsageHandler& usageHandler)
975{
976    // construct var index maps
977    size_t graphVregsCounts[MAX_REGTYPES_NUM];
978    std::fill(graphVregsCounts, graphVregsCounts+regTypesNum, 0);
979    cxuint regRanges[MAX_REGTYPES_NUM*2];
980    size_t regTypesNum;
981    assembler.isaAssembler->getRegisterRanges(regTypesNum, regRanges);
982   
983    for (const CodeBlock& cblock: codeBlocks)
984        for (const auto& entry: cblock.ssaInfoMap)
985        {
986            const SSAInfo& sinfo = entry.second;
987            cxuint regType = getRegType(regTypesNum, regRanges, entry.first);
988            VarIndexMap& vregIndices = vregIndexMaps[regType];
989            size_t& graphVregsCount = graphVregsCounts[regType];
990            std::vector<size_t>& ssaIdIndices = vregIndices[entry.first];
991            size_t ssaIdCount = 0;
992            if (sinfo.readBeforeWrite)
993                ssaIdCount = sinfo.ssaIdBefore+1;
994            if (sinfo.ssaIdChange!=0)
995            {
996                ssaIdCount = std::max(ssaIdCount, sinfo.ssaIdLast+1);
997                ssaIdCount = std::max(ssaIdCount, sinfo.ssaIdFirst+1);
998            }
999            if (ssaIdIndices.size() < ssaIdCount)
1000                ssaIdIndices.resize(ssaIdCount, SIZE_MAX);
1001           
1002            if (sinfo.readBeforeWrite)
1003                ssaIdIndices[sinfo.ssaIdBefore] = graphVregsCount++;
1004            if (sinfo.ssaIdChange!=0)
1005            {
1006                // fill up ssaIdIndices (with graph Ids)
1007                ssaIdIndices[sinfo.ssaIdFirst] = graphVregsCount++;
1008                for (size_t ssaId = sinfo.ssaId+1;
1009                        ssaId < sinfo.ssaId+sinfo.ssaIdChange-1; ssaId++)
1010                    ssaIdIndices[ssaId] = graphVregsCount++;
1011                ssaIdIndices[sinfo.ssaIdLast] = graphVregsCount++;
1012            }
1013        }
1014   
1015    // construct vreg liveness
1016    std::deque<FlowStackEntry3> flowStack;
1017    std::vector<bool> visited(codeBlocks.size(), false);
1018    // hold last vreg ssaId and position
1019    LastVRegMap lastVRegMap;
1020    // hold start live time position for every code block
1021    Array<size_t> codeBlockLiveTimes(codeBlocks.size());
1022    std::unordered_set<size_t> blockInWay;
1023   
1024    std::vector<Liveness> livenesses[MAX_REGTYPES_NUM];
1025   
1026    for (size_t i = 0; i < regTypesNum; i++)
1027        livenesses[i].resize(graphVregsCounts[i]);
1028   
1029    size_t curLiveTime = 0;
1030   
1031    while (!flowStack.empty())
1032    {
1033        FlowStackEntry3& entry = flowStack.back();
1034        CodeBlock& cblock = codeBlocks[entry.blockIndex];
1035       
1036        if (entry.nextIndex == 0)
1037        {
1038            // process current block
1039            if (!blockInWay.insert(entry.blockIndex).second)
1040            {
1041                // if loop
1042                putCrossBlockForLoop(flowStack, codeBlocks, codeBlockLiveTimes, 
1043                        livenesses, vregIndexMaps, regTypesNum, regRanges);
1044                flowStack.pop_back();
1045                continue;
1046            }
1047           
1048            codeBlockLiveTimes[entry.blockIndex] = curLiveTime;
1049            putCrossBlockLivenesses(flowStack, codeBlocks, codeBlockLiveTimes, 
1050                    lastVRegMap, livenesses, vregIndexMaps, regTypesNum, regRanges);
1051           
1052            for (const auto& sentry: cblock.ssaInfoMap)
1053            {
1054                const SSAInfo& sinfo = sentry.second;
1055                // update
1056                size_t lastSSAId =  (sinfo.ssaIdChange != 0) ? sinfo.ssaIdLast :
1057                        (sinfo.readBeforeWrite) ? sinfo.ssaIdBefore : 0;
1058                FlowStackCIter flit = flowStack.end();
1059                --flit; // to last position
1060                auto res = lastVRegMap.insert({ sentry.first, 
1061                            { lastSSAId, { flit } } });
1062                if (!res.second) // if not first seen, just update
1063                {
1064                    // update last
1065                    res.first->second.ssaId = lastSSAId;
1066                    res.first->second.blockChain.push_back(flit);
1067                }
1068            }
1069           
1070            size_t curBlockLiveEnd = cblock.end - cblock.start + curLiveTime;
1071            if (!visited[entry.blockIndex])
1072            {
1073                visited[entry.blockIndex] = true;
1074                std::unordered_map<AsmSingleVReg, size_t> ssaIdIdxMap;
1075                AsmRegVarUsage instrRVUs[8];
1076                cxuint instrRVUsCount = 0;
1077               
1078                size_t oldOffset = cblock.usagePos.readOffset;
1079                std::vector<AsmSingleVReg> readSVRegs;
1080                std::vector<AsmSingleVReg> writtenSVRegs;
1081               
1082                usageHandler.setReadPos(cblock.usagePos);
1083                // register in liveness
1084                while (true)
1085                {
1086                    AsmRegVarUsage rvu = { 0U, nullptr, 0U, 0U };
1087                    size_t liveTimeNext = curBlockLiveEnd;
1088                    if (usageHandler.hasNext())
1089                    {
1090                        rvu = usageHandler.nextUsage();
1091                        if (rvu.offset >= cblock.end)
1092                            break;
1093                        if (!rvu.useRegMode)
1094                            instrRVUs[instrRVUsCount++] = rvu;
1095                        liveTimeNext = std::min(rvu.offset, cblock.end) -
1096                                cblock.start + curLiveTime;
1097                    }
1098                    size_t liveTime = oldOffset - cblock.start + curLiveTime;
1099                    if (!usageHandler.hasNext() || rvu.offset >= oldOffset)
1100                    {
1101                        // apply to liveness
1102                        for (AsmSingleVReg svreg: readSVRegs)
1103                        {
1104                            Liveness& lv = getLiveness(svreg, ssaIdIdxMap[svreg],
1105                                    cblock.ssaInfoMap.find(svreg)->second,
1106                                    livenesses, vregIndexMaps, regTypesNum, regRanges);
1107                            if (!lv.l.empty() && (--lv.l.end())->first < curLiveTime)
1108                                lv.newRegion(curLiveTime); // begin region from this block
1109                            lv.expand(liveTime);
1110                        }
1111                        for (AsmSingleVReg svreg: writtenSVRegs)
1112                        {
1113                            size_t& ssaIdIdx = ssaIdIdxMap[svreg];
1114                            ssaIdIdx++;
1115                            SSAInfo& sinfo = cblock.ssaInfoMap.find(svreg)->second;
1116                            Liveness& lv = getLiveness(svreg, ssaIdIdx, sinfo,
1117                                    livenesses, vregIndexMaps, regTypesNum, regRanges);
1118                            if (liveTimeNext != curBlockLiveEnd)
1119                                // because live after this instr
1120                                lv.newRegion(liveTimeNext);
1121                            sinfo.lastPos = liveTimeNext - curLiveTime + cblock.start;
1122                        }
1123                        // get linear deps and equal to
1124                        cxbyte lDeps[16];
1125                        cxbyte eDeps[16];
1126                        usageHandler.getUsageDependencies(instrRVUsCount, instrRVUs,
1127                                        lDeps, eDeps);
1128                       
1129                        addUsageDeps(lDeps, eDeps, instrRVUsCount, instrRVUs,
1130                                linearDepMaps, equalToDepMaps, vregIndexMaps, ssaIdIdxMap,
1131                                regTypesNum, regRanges);
1132                       
1133                        readSVRegs.clear();
1134                        writtenSVRegs.clear();
1135                        if (!usageHandler.hasNext())
1136                            break; // end
1137                        oldOffset = rvu.offset;
1138                        instrRVUsCount = 0;
1139                    }
1140                    if (rvu.offset >= cblock.end)
1141                        break;
1142                   
1143                    for (uint16_t rindex = rvu.rstart; rindex < rvu.rend; rindex++)
1144                    {
1145                        // per register/singlvreg
1146                        AsmSingleVReg svreg{ rvu.regVar, rindex };
1147                        if (rvu.rwFlags == ASMRVU_WRITE && rvu.regField == ASMFIELD_NONE)
1148                            writtenSVRegs.push_back(svreg);
1149                        else // read or treat as reading // expand previous region
1150                            readSVRegs.push_back(svreg);
1151                    }
1152                }
1153                curLiveTime += cblock.end-cblock.start;
1154            }
1155            else
1156            {
1157                // back, already visited
1158                flowStack.pop_back();
1159                continue;
1160            }
1161        }
1162        if (entry.nextIndex < cblock.nexts.size())
1163        {
1164            flowStack.push_back({ cblock.nexts[entry.nextIndex].block, 0 });
1165            entry.nextIndex++;
1166        }
1167        else if (entry.nextIndex==0 && cblock.nexts.empty() && !cblock.haveEnd)
1168        {
1169            flowStack.push_back({ entry.blockIndex+1, 0 });
1170            entry.nextIndex++;
1171        }
1172        else // back
1173        {
1174            // revert lastSSAIdMap
1175            blockInWay.erase(entry.blockIndex);
1176            flowStack.pop_back();
1177            if (!flowStack.empty())
1178            {
1179                for (const auto& sentry: cblock.ssaInfoMap)
1180                {
1181                    auto lvrit = lastVRegMap.find(sentry.first);
1182                    if (lvrit != lastVRegMap.end())
1183                    {
1184                        VRegLastPos& lastPos = lvrit->second;
1185                        lastPos.ssaId = sentry.second.ssaIdBefore;
1186                        lastPos.blockChain.pop_back();
1187                        if (lastPos.blockChain.empty()) // just remove from lastVRegs
1188                            lastVRegMap.erase(lvrit);
1189                    }
1190                }
1191            }
1192        }
1193    }
1194   
1195    /// construct liveBlockMaps
1196    std::set<LiveBlock> liveBlockMaps[MAX_REGTYPES_NUM];
1197    for (size_t regType = 0; regType < regTypesNum; regType++)
1198    {
1199        std::set<LiveBlock>& liveBlockMap = liveBlockMaps[regType];
1200        std::vector<Liveness>& liveness = livenesses[regType];
1201        for (size_t li = 0; li < liveness.size(); li++)
1202        {
1203            Liveness& lv = liveness[li];
1204            for (const std::pair<size_t, size_t>& blk: lv.l)
1205                if (blk.first != blk.second)
1206                    liveBlockMap.insert({ blk.first, blk.second, li });
1207            lv.clear();
1208        }
1209        liveness.clear();
1210    }
1211   
1212    // create interference graphs
1213    for (size_t regType = 0; regType < regTypesNum; regType++)
1214    {
1215        InterGraph& interGraph = interGraphs[regType];
1216        interGraph.resize(graphVregsCounts[regType]);
1217        std::set<LiveBlock>& liveBlockMap = liveBlockMaps[regType];
1218       
1219        auto lit = liveBlockMap.begin();
1220        size_t rangeStart = 0;
1221        if (lit != liveBlockMap.end())
1222            rangeStart = lit->start;
1223        while (lit != liveBlockMap.end())
1224        {
1225            const size_t blkStart = lit->start;
1226            const size_t blkEnd = lit->end;
1227            size_t rangeEnd = blkEnd;
1228            auto liStart = liveBlockMap.lower_bound({ rangeStart, 0, 0 });
1229            auto liEnd = liveBlockMap.lower_bound({ rangeEnd, 0, 0 });
1230            // collect from this range, variable indices
1231            std::set<size_t> varIndices;
1232            for (auto lit2 = liStart; lit2 != liEnd; ++lit2)
1233                varIndices.insert(lit2->vidx);
1234            // push to intergraph as full subgGraph
1235            for (auto vit = varIndices.begin(); vit != varIndices.end(); ++vit)
1236                for (auto vit2 = varIndices.begin(); vit2 != varIndices.end(); ++vit2)
1237                    if (vit != vit2)
1238                        interGraph[*vit].insert(*vit2);
1239            // go to next live blocks
1240            rangeStart = rangeEnd;
1241            for (; lit != liveBlockMap.end(); ++lit)
1242                if (lit->start != blkStart && lit->end != blkEnd)
1243                    break;
1244            if (lit == liveBlockMap.end())
1245                break; //
1246            rangeStart = std::max(rangeStart, lit->start);
1247        }
1248    }
1249   
1250    /*
1251     * resolve equalSets
1252     */
1253    for (cxuint regType = 0; regType < regTypesNum; regType++)
1254    {
1255        InterGraph& interGraph = interGraphs[regType];
1256        const size_t nodesNum = interGraph.size();
1257        const std::unordered_map<size_t, EqualToDep>& etoDepMap = equalToDepMaps[regType];
1258        std::vector<bool> visited(nodesNum, false);
1259        std::vector<std::vector<size_t> >& equalSetList = equalSetLists[regType];
1260       
1261        for (size_t v = 0; v < nodesNum;)
1262        {
1263            auto it = etoDepMap.find(v);
1264            if (it == etoDepMap.end())
1265            {
1266                // is not regvar in equalTo dependencies
1267                v++;
1268                continue;
1269            }
1270           
1271            std::stack<EqualStackEntry> etoStack;
1272            etoStack.push(EqualStackEntry{ it, 0 });
1273           
1274            std::unordered_map<size_t, size_t>& equalSetMap =  equalSetMaps[regType];
1275            const size_t equalSetIndex = equalSetList.size();
1276            equalSetList.push_back(std::vector<size_t>());
1277            std::vector<size_t>& equalSet = equalSetList.back();
1278           
1279            // traverse by this
1280            while (!etoStack.empty())
1281            {
1282                EqualStackEntry& entry = etoStack.top();
1283                size_t vidx = entry.etoDepIt->first; // node index, vreg index
1284                const EqualToDep& eToDep = entry.etoDepIt->second;
1285                if (entry.nextIdx == 0)
1286                {
1287                    if (!visited[vidx])
1288                    {
1289                        // push to this equalSet
1290                        equalSetMap.insert({ vidx, equalSetIndex });
1291                        equalSet.push_back(vidx);
1292                    }
1293                    else
1294                    {
1295                        // already visited
1296                        etoStack.pop();
1297                        continue;
1298                    }
1299                }
1300               
1301                if (entry.nextIdx < eToDep.nextVidxes.size())
1302                {
1303                    auto nextIt = etoDepMap.find(eToDep.nextVidxes[entry.nextIdx]);
1304                    etoStack.push(EqualStackEntry{ nextIt, 0 });
1305                    entry.nextIdx++;
1306                }
1307                else if (entry.nextIdx < eToDep.nextVidxes.size()+eToDep.prevVidxes.size())
1308                {
1309                    auto nextIt = etoDepMap.find(eToDep.prevVidxes[
1310                                entry.nextIdx - eToDep.nextVidxes.size()]);
1311                    etoStack.push(EqualStackEntry{ nextIt, 0 });
1312                    entry.nextIdx++;
1313                }
1314                else
1315                    etoStack.pop();
1316            }
1317           
1318            // to first already added node (var)
1319            while (v < nodesNum && !visited[v]) v++;
1320        }
1321    }
1322}
1323
1324typedef AsmRegAllocator::InterGraph InterGraph;
1325
1326struct CLRX_INTERNAL SDOLDOCompare
1327{
1328    const InterGraph& interGraph;
1329    const Array<size_t>& sdoCounts;
1330   
1331    SDOLDOCompare(const InterGraph& _interGraph, const Array<size_t>&_sdoCounts)
1332        : interGraph(_interGraph), sdoCounts(_sdoCounts)
1333    { }
1334   
1335    bool operator()(size_t a, size_t b) const
1336    {
1337        if (sdoCounts[a] > sdoCounts[b])
1338            return true;
1339        return interGraph[a].size() > interGraph[b].size();
1340    }
1341};
1342
1343/* algorithm to allocate regranges:
1344 * from smallest regranges to greatest regranges:
1345 *   choosing free register: from smallest free regranges
1346 *      to greatest regranges:
1347 *         in this same regrange:
1348 *               try to find free regs in regranges
1349 *               try to link free ends of two distinct regranges
1350 */
1351
1352void AsmRegAllocator::colorInterferenceGraph()
1353{
1354    const GPUArchitecture arch = getGPUArchitectureFromDeviceType(
1355                    assembler.deviceType);
1356   
1357    for (size_t regType = 0; regType < regTypesNum; regType++)
1358    {
1359        const size_t maxColorsNum = getGPUMaxRegistersNum(arch, regType);
1360        InterGraph& interGraph = interGraphs[regType];
1361        const VarIndexMap& vregIndexMap = vregIndexMaps[regType];
1362        Array<cxuint>& gcMap = graphColorMaps[regType];
1363        const std::vector<std::vector<size_t> >& equalSetList = equalSetLists[regType];
1364        const std::unordered_map<size_t, size_t>& equalSetMap =  equalSetMaps[regType];
1365       
1366        const size_t nodesNum = interGraph.size();
1367        gcMap.resize(nodesNum);
1368        std::fill(gcMap.begin(), gcMap.end(), cxuint(UINT_MAX));
1369        Array<size_t> sdoCounts(nodesNum);
1370        std::fill(sdoCounts.begin(), sdoCounts.end(), 0);
1371       
1372        SDOLDOCompare compare(interGraph, sdoCounts);
1373        std::set<size_t, SDOLDOCompare> nodeSet(compare);
1374        for (size_t i = 0; i < nodesNum; i++)
1375            nodeSet.insert(i);
1376       
1377        cxuint colorsNum = 0;
1378        // firstly, allocate real registers
1379        for (const auto& entry: vregIndexMap)
1380            if (entry.first.regVar == nullptr)
1381                gcMap[entry.second[0]] = colorsNum++;
1382       
1383        for (size_t colored = 0; colored < nodesNum; colored++)
1384        {
1385            size_t node = *nodeSet.begin();
1386            if (gcMap[node] != UINT_MAX)
1387                continue; // already colored
1388            size_t color = 0;
1389            std::vector<size_t> equalNodes;
1390            equalNodes.push_back(node); // only one node, if equalSet not found
1391            auto equalSetMapIt = equalSetMap.find(node);
1392            if (equalSetMapIt != equalSetMap.end())
1393                // found, get equal set from equalSetList
1394                equalNodes = equalSetList[equalSetMapIt->second];
1395           
1396            for (color = 0; color <= colorsNum; color++)
1397            {
1398                // find first usable color
1399                bool thisSame = false;
1400                for (size_t nb: interGraph[node])
1401                    if (gcMap[nb] == color)
1402                    {
1403                        thisSame = true;
1404                        break;
1405                    }
1406                if (!thisSame)
1407                    break;
1408            }
1409            if (color==colorsNum) // add new color if needed
1410            {
1411                if (colorsNum >= maxColorsNum)
1412                    throw AsmException("Too many register is needed");
1413                colorsNum++;
1414            }
1415           
1416            for (size_t nextNode: equalNodes)
1417                gcMap[nextNode] = color;
1418            // update SDO for node
1419            bool colorExists = false;
1420            for (size_t node: equalNodes)
1421            {
1422                for (size_t nb: interGraph[node])
1423                    if (gcMap[nb] == color)
1424                    {
1425                        colorExists = true;
1426                        break;
1427                    }
1428                if (!colorExists)
1429                    sdoCounts[node]++;
1430            }
1431            // update SDO for neighbors
1432            for (size_t node: equalNodes)
1433                for (size_t nb: interGraph[node])
1434                {
1435                    colorExists = false;
1436                    for (size_t nb2: interGraph[nb])
1437                        if (gcMap[nb2] == color)
1438                        {
1439                            colorExists = true;
1440                            break;
1441                        }
1442                    if (!colorExists)
1443                    {
1444                        if (gcMap[nb] == UINT_MAX)
1445                            nodeSet.erase(nb);  // before update we erase from nodeSet
1446                        sdoCounts[nb]++;
1447                        if (gcMap[nb] == UINT_MAX)
1448                            nodeSet.insert(nb); // after update, insert again
1449                    }
1450                }
1451           
1452            for (size_t nextNode: equalNodes)
1453                gcMap[nextNode] = color;
1454        }
1455    }
1456}
1457
1458void AsmRegAllocator::allocateRegisters(cxuint sectionId)
1459{
1460    // before any operation, clear all
1461    codeBlocks.clear();
1462    for (size_t i = 0; i < MAX_REGTYPES_NUM; i++)
1463    {
1464        vregIndexMaps[i].clear();
1465        interGraphs[i].clear();
1466        linearDepMaps[i].clear();
1467        equalToDepMaps[i].clear();
1468        graphColorMaps[i].clear();
1469        equalSetMaps[i].clear();
1470        equalSetLists[i].clear();
1471    }
1472    ssaReplacesMap.clear();
1473    cxuint maxRegs[MAX_REGTYPES_NUM];
1474    assembler.isaAssembler->getMaxRegistersNum(regTypesNum, maxRegs);
1475   
1476    // set up
1477    const AsmSection& section = assembler.sections[sectionId];
1478    createCodeStructure(section.codeFlow, section.content.size(), section.content.data());
1479    createSSAData(*section.usageHandler);
1480    applySSAReplaces();
1481    createInterferenceGraph(*section.usageHandler);
1482    colorInterferenceGraph();
1483}
Note: See TracBrowser for help on using the repository browser.