source: CLRX/CLRadeonExtender/trunk/amdasm/AsmRegAlloc.cpp @ 3993

Last change on this file since 3993 was 3993, checked in by matszpk, 15 months ago

CLRadeonExtender: AsmRegAlloc?: Conditional compilation of ostream<<BlockIndex? operator.

File size: 55.1 KB
Line 
1/*
2 *  CLRadeonExtender - Unofficial OpenCL Radeon Extensions Library
3 *  Copyright (C) 2014-2018 Mateusz Szpakowski
4 *
5 *  This library is free software; you can redistribute it and/or
6 *  modify it under the terms of the GNU Lesser General Public
7 *  License as published by the Free Software Foundation; either
8 *  version 2.1 of the License, or (at your option) any later version.
9 *
10 *  This library is distributed in the hope that it will be useful,
11 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
12 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
13 *  Lesser General Public License for more details.
14 *
15 *  You should have received a copy of the GNU Lesser General Public
16 *  License along with this library; if not, write to the Free Software
17 *  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
18 */
19
20#include <CLRX/Config.h>
21#include <iostream>
22#include <stack>
23#include <deque>
24#include <vector>
25#include <utility>
26#include <unordered_set>
27#include <map>
28#include <set>
29#include <unordered_map>
30#include <algorithm>
31#include <CLRX/utils/Utilities.h>
32#include <CLRX/utils/Containers.h>
33#include <CLRX/amdasm/Assembler.h>
34#include "AsmInternals.h"
35#include "AsmRegAlloc.h"
36
37using namespace CLRX;
38
39#if ASMREGALLOC_DEBUGDUMP
40std::ostream& operator<<(std::ostream& os, const CLRX::BlockIndex& v)
41{
42    if (v.pass==0)
43        return os << v.index;
44    else
45        return os << v.index << "#" << v.pass;
46}
47#endif
48
49ISAUsageHandler::ISAUsageHandler(const std::vector<cxbyte>& _content) :
50            content(_content), lastOffset(0), readOffset(0), instrStructPos(0),
51            regUsagesPos(0), regUsages2Pos(0), regVarUsagesPos(0),
52            pushedArgs(0), argPos(0), argFlags(0), isNext(false), useRegMode(false)
53{ }
54
55ISAUsageHandler::~ISAUsageHandler()
56{ }
57
58void ISAUsageHandler::rewind()
59{
60    readOffset = instrStructPos = 0;
61    regUsagesPos = regUsages2Pos = regVarUsagesPos = 0;
62    useRegMode = false;
63    pushedArgs = 0;
64    skipBytesInInstrStruct();
65}
66
67void ISAUsageHandler::skipBytesInInstrStruct()
68{
69    // do not add instruction size if usereg (usereg immediately before instr regusages)
70    if ((instrStructPos != 0 || argPos != 0) && !useRegMode)
71        readOffset += defaultInstrSize;
72    argPos = 0;
73    for (;instrStructPos < instrStruct.size() &&
74        instrStruct[instrStructPos] > 0x80; instrStructPos++)
75        readOffset += (instrStruct[instrStructPos] & 0x7f);
76    isNext = (instrStructPos < instrStruct.size());
77}
78
79void ISAUsageHandler::putSpace(size_t offset)
80{
81    if (lastOffset != offset)
82    {
83        flush(); // flush before new instruction
84        // useReg immediately before instruction regusages
85        size_t defaultInstrSize = (!useRegMode ? this->defaultInstrSize : 0);
86        if (lastOffset > offset)
87            throw AsmException("Offset before previous instruction");
88        if (!instrStruct.empty() && offset - lastOffset < defaultInstrSize)
89            throw AsmException("Offset between previous instruction");
90        size_t toSkip = !instrStruct.empty() ? 
91                offset - lastOffset - defaultInstrSize : offset;
92        while (toSkip > 0)
93        {
94            size_t skipped = std::min(toSkip, size_t(0x7f));
95            instrStruct.push_back(skipped | 0x80);
96            toSkip -= skipped;
97        }
98        lastOffset = offset;
99        argFlags = 0;
100        pushedArgs = 0;
101    } 
102}
103
104void ISAUsageHandler::pushUsage(const AsmRegVarUsage& rvu)
105{
106    if (lastOffset == rvu.offset && useRegMode)
107        flush(); // only flush if useRegMode and no change in offset
108    else // otherwise
109        putSpace(rvu.offset);
110    useRegMode = false;
111    if (rvu.regVar != nullptr)
112    {
113        argFlags |= (1U<<pushedArgs);
114        regVarUsages.push_back({ rvu.regVar, rvu.rstart, rvu.rend, rvu.regField,
115            rvu.rwFlags, rvu.align });
116    }
117    else // reg usages
118        regUsages.push_back({ rvu.regField,cxbyte(rvu.rwFlags |
119                    getRwFlags(rvu.regField, rvu.rstart, rvu.rend)) });
120    pushedArgs++;
121}
122
123void ISAUsageHandler::pushUseRegUsage(const AsmRegVarUsage& rvu)
124{
125    if (lastOffset == rvu.offset && !useRegMode)
126        flush(); // only flush if useRegMode and no change in offset
127    else // otherwise
128        putSpace(rvu.offset);
129    useRegMode = true;
130    if (pushedArgs == 0 || pushedArgs == 256)
131    {
132        argFlags = 0;
133        pushedArgs = 0;
134        instrStruct.push_back(0x80); // sign of regvarusage from usereg
135        instrStruct.push_back(0);
136    }
137    if (rvu.regVar != nullptr)
138    {
139        argFlags |= (1U<<(pushedArgs & 7));
140        regVarUsages.push_back({ rvu.regVar, rvu.rstart, rvu.rend, rvu.regField,
141            rvu.rwFlags, rvu.align });
142    }
143    else // reg usages
144        regUsages2.push_back({ rvu.rstart, rvu.rend, rvu.rwFlags });
145    pushedArgs++;
146    if ((pushedArgs & 7) == 0) // just flush per 8 bit
147    {
148        instrStruct.push_back(argFlags);
149        instrStruct[instrStruct.size() - ((pushedArgs+7) >> 3) - 1] = pushedArgs;
150        argFlags = 0;
151    }
152}
153
154void ISAUsageHandler::flush()
155{
156    if (pushedArgs != 0)
157    {
158        if (!useRegMode)
159        {
160            // normal regvarusages
161            instrStruct.push_back(argFlags);
162            if ((argFlags & (1U<<(pushedArgs-1))) != 0)
163                regVarUsages.back().rwFlags |= 0x80;
164            else // reg usages
165                regUsages.back().rwFlags |= 0x80;
166        }
167        else
168        {
169            // use reg regvarusages
170            if ((pushedArgs & 7) != 0) //if only not pushed args remains
171                instrStruct.push_back(argFlags);
172            instrStruct[instrStruct.size() - ((pushedArgs+7) >> 3) - 1] = pushedArgs;
173        }
174    }
175}
176
177AsmRegVarUsage ISAUsageHandler::nextUsage()
178{
179    if (!isNext)
180        throw AsmException("No reg usage in this code");
181    AsmRegVarUsage rvu;
182    // get regvarusage
183    bool lastRegUsage = false;
184    rvu.offset = readOffset;
185    if (!useRegMode && instrStruct[instrStructPos] == 0x80)
186    {
187        // useRegMode (begin fetching useregs)
188        useRegMode = true;
189        argPos = 0;
190        instrStructPos++;
191        // pushedArgs - numer of useregs, 0 - 256 useregs
192        pushedArgs = instrStruct[instrStructPos++];
193        argFlags = instrStruct[instrStructPos];
194    }
195    rvu.useRegMode = useRegMode; // no ArgPos
196   
197    if ((instrStruct[instrStructPos] & (1U << (argPos&7))) != 0)
198    {
199        // regvar usage
200        const AsmRegVarUsageInt& inRVU = regVarUsages[regVarUsagesPos++];
201        rvu.regVar = inRVU.regVar;
202        rvu.rstart = inRVU.rstart;
203        rvu.rend = inRVU.rend;
204        rvu.regField = inRVU.regField;
205        rvu.rwFlags = inRVU.rwFlags & ASMRVU_ACCESS_MASK;
206        rvu.align = inRVU.align;
207        if (!useRegMode)
208            lastRegUsage = ((inRVU.rwFlags&0x80) != 0);
209    }
210    else if (!useRegMode)
211    {
212        // simple reg usage
213        const AsmRegUsageInt& inRU = regUsages[regUsagesPos++];
214        rvu.regVar = nullptr;
215        const std::pair<uint16_t, uint16_t> regPair =
216                    getRegPair(inRU.regField, inRU.rwFlags);
217        rvu.rstart = regPair.first;
218        rvu.rend = regPair.second;
219        rvu.rwFlags = (inRU.rwFlags & ASMRVU_ACCESS_MASK);
220        rvu.regField = inRU.regField;
221        rvu.align = 0;
222        lastRegUsage = ((inRU.rwFlags&0x80) != 0);
223    }
224    else
225    {
226        // use reg (simple reg usage, second structure)
227        const AsmRegUsage2Int& inRU = regUsages2[regUsages2Pos++];
228        rvu.regVar = nullptr;
229        rvu.rstart = inRU.rstart;
230        rvu.rend = inRU.rend;
231        rvu.rwFlags = inRU.rwFlags;
232        rvu.regField = ASMFIELD_NONE;
233        rvu.align = 0;
234    }
235    argPos++;
236    if (useRegMode)
237    {
238        // if inside useregs
239        if (argPos == (pushedArgs&0xff))
240        {
241            instrStructPos++; // end
242            skipBytesInInstrStruct();
243            useRegMode = false;
244        }
245        else if ((argPos & 7) == 0) // fetch new flag
246        {
247            instrStructPos++;
248            argFlags = instrStruct[instrStructPos];
249        }
250    }
251    // after instr
252    if (lastRegUsage)
253    {
254        instrStructPos++;
255        skipBytesInInstrStruct();
256    }
257    return rvu;
258}
259
260AsmRegAllocator::AsmRegAllocator(Assembler& _assembler) : assembler(_assembler)
261{ }
262
263static inline bool codeBlockStartLess(const AsmRegAllocator::CodeBlock& c1,
264                  const AsmRegAllocator::CodeBlock& c2)
265{ return c1.start < c2.start; }
266
267static inline bool codeBlockEndLess(const AsmRegAllocator::CodeBlock& c1,
268                  const AsmRegAllocator::CodeBlock& c2)
269{ return c1.end < c2.end; }
270
271void AsmRegAllocator::createCodeStructure(const std::vector<AsmCodeFlowEntry>& codeFlow,
272             size_t codeSize, const cxbyte* code)
273{
274    ISAAssembler* isaAsm = assembler.isaAssembler;
275    if (codeSize == 0)
276        return;
277    std::vector<size_t> splits;
278    std::vector<size_t> codeStarts;
279    std::vector<size_t> codeEnds;
280    codeStarts.push_back(0);
281    codeEnds.push_back(codeSize);
282    for (const AsmCodeFlowEntry& entry: codeFlow)
283    {
284        size_t instrAfter = 0;
285        if (entry.type == AsmCodeFlowType::JUMP || entry.type == AsmCodeFlowType::CJUMP ||
286            entry.type == AsmCodeFlowType::CALL || entry.type == AsmCodeFlowType::RETURN)
287            instrAfter = entry.offset + isaAsm->getInstructionSize(
288                        codeSize - entry.offset, code + entry.offset);
289       
290        switch(entry.type)
291        {
292            case AsmCodeFlowType::START:
293                codeStarts.push_back(entry.offset);
294                break;
295            case AsmCodeFlowType::END:
296                codeEnds.push_back(entry.offset);
297                break;
298            case AsmCodeFlowType::JUMP:
299                splits.push_back(entry.target);
300                codeEnds.push_back(instrAfter);
301                break;
302            case AsmCodeFlowType::CJUMP:
303                splits.push_back(entry.target);
304                splits.push_back(instrAfter);
305                break;
306            case AsmCodeFlowType::CALL:
307                splits.push_back(entry.target);
308                splits.push_back(instrAfter);
309                break;
310            case AsmCodeFlowType::RETURN:
311                codeEnds.push_back(instrAfter);
312                break;
313            default:
314                break;
315        }
316    }
317    std::sort(splits.begin(), splits.end());
318    splits.resize(std::unique(splits.begin(), splits.end()) - splits.begin());
319    std::sort(codeEnds.begin(), codeEnds.end());
320    codeEnds.resize(std::unique(codeEnds.begin(), codeEnds.end()) - codeEnds.begin());
321    // remove codeStarts between codeStart and codeEnd
322    size_t i = 0;
323    size_t ii = 0;
324    size_t ei = 0; // codeEnd i
325    while (i < codeStarts.size())
326    {
327        size_t end = (ei < codeEnds.size() ? codeEnds[ei] : SIZE_MAX);
328        if (ei < codeEnds.size())
329            ei++;
330        codeStarts[ii++] = codeStarts[i];
331        // skip codeStart to end
332        for (i++ ;i < codeStarts.size() && codeStarts[i] < end; i++);
333    }
334    codeStarts.resize(ii);
335    // add next codeStarts
336    auto splitIt = splits.begin();
337    for (size_t codeEnd: codeEnds)
338    {
339        auto it = std::lower_bound(splitIt, splits.end(), codeEnd);
340        if (it != splits.end())
341        {
342            codeStarts.push_back(*it);
343            splitIt = it;
344        }
345        else // if end
346            break;
347    }
348   
349    std::sort(codeStarts.begin(), codeStarts.end());
350    codeStarts.resize(std::unique(codeStarts.begin(), codeStarts.end()) -
351                codeStarts.begin());
352    // divide to blocks
353    splitIt = splits.begin();
354    for (size_t codeStart: codeStarts)
355    {
356        size_t codeEnd = *std::upper_bound(codeEnds.begin(), codeEnds.end(), codeStart);
357        splitIt = std::lower_bound(splitIt, splits.end(), codeStart);
358       
359        if (splitIt != splits.end() && *splitIt==codeStart)
360            ++splitIt; // skip split in codeStart
361       
362        for (size_t start = codeStart; start < codeEnd; )
363        {
364            size_t end = codeEnd;
365            if (splitIt != splits.end())
366            {
367                end = std::min(end, *splitIt);
368                ++splitIt;
369            }
370            codeBlocks.push_back({ start, end, { }, false, false, false });
371            start = end;
372        }
373    }
374    // force empty block at end if some jumps goes to its
375    if (!codeEnds.empty() && !codeStarts.empty() && !splits.empty() &&
376        codeStarts.back()==codeEnds.back() && codeStarts.back() == splits.back())
377        codeBlocks.push_back({ codeStarts.back(), codeStarts.back(), { },
378                             false, false, false });
379   
380    // construct flow-graph
381    for (const AsmCodeFlowEntry& entry: codeFlow)
382        if (entry.type == AsmCodeFlowType::CALL || entry.type == AsmCodeFlowType::JUMP ||
383            entry.type == AsmCodeFlowType::CJUMP || entry.type == AsmCodeFlowType::RETURN)
384        {
385            std::vector<CodeBlock>::iterator it;
386            size_t instrAfter = entry.offset + isaAsm->getInstructionSize(
387                        codeSize - entry.offset, code + entry.offset);
388           
389            if (entry.type != AsmCodeFlowType::RETURN)
390                it = binaryFind(codeBlocks.begin(), codeBlocks.end(),
391                        CodeBlock{ entry.target }, codeBlockStartLess);
392            else // return
393            {
394                it = binaryFind(codeBlocks.begin(), codeBlocks.end(),
395                        CodeBlock{ 0, instrAfter }, codeBlockEndLess);
396                // if block have return
397                if (it != codeBlocks.end())
398                    it->haveEnd = it->haveReturn = true;
399                continue;
400            }
401           
402            if (it == codeBlocks.end())
403                continue; // error!
404            auto it2 = std::lower_bound(codeBlocks.begin(), codeBlocks.end(),
405                    CodeBlock{ instrAfter }, codeBlockStartLess);
406            auto curIt = it2;
407            --curIt;
408           
409            curIt->nexts.push_back({ size_t(it - codeBlocks.begin()),
410                        entry.type == AsmCodeFlowType::CALL });
411            curIt->haveCalls |= entry.type == AsmCodeFlowType::CALL;
412            if (entry.type == AsmCodeFlowType::CJUMP ||
413                 entry.type == AsmCodeFlowType::CALL)
414            {
415                curIt->haveEnd = false; // revert haveEnd if block have cond jump or call
416                if (it2 != codeBlocks.end() && entry.type == AsmCodeFlowType::CJUMP)
417                    // add next next block (only for cond jump)
418                    curIt->nexts.push_back({ size_t(it2 - codeBlocks.begin()), false });
419            }
420            else if (entry.type == AsmCodeFlowType::JUMP)
421                curIt->haveEnd = true; // set end
422        }
423    // force haveEnd for block with cf_end
424    for (const AsmCodeFlowEntry& entry: codeFlow)
425        if (entry.type == AsmCodeFlowType::END)
426        {
427            auto it = binaryFind(codeBlocks.begin(), codeBlocks.end(),
428                    CodeBlock{ 0, entry.offset }, codeBlockEndLess);
429            if (it != codeBlocks.end())
430                it->haveEnd = true;
431        }
432   
433    if (!codeBlocks.empty()) // always set haveEnd to last block
434        codeBlocks.back().haveEnd = true;
435   
436    // reduce nexts
437    for (CodeBlock& block: codeBlocks)
438    {
439        // first non-call nexts, for correct resolving SSA conflicts
440        std::sort(block.nexts.begin(), block.nexts.end(),
441                  [](const NextBlock& n1, const NextBlock& n2)
442                  { return int(n1.isCall)<int(n2.isCall) ||
443                      (n1.isCall == n2.isCall && n1.block < n2.block); });
444        auto it = std::unique(block.nexts.begin(), block.nexts.end(),
445                  [](const NextBlock& n1, const NextBlock& n2)
446                  { return n1.block == n2.block && n1.isCall == n2.isCall; });
447        block.nexts.resize(it - block.nexts.begin());
448    }
449}
450
451
452void AsmRegAllocator::applySSAReplaces()
453{
454    /* prepare SSA id replaces */
455    struct MinSSAGraphNode
456    {
457        size_t minSSAId;
458        bool visited;
459        std::unordered_set<size_t> nexts;
460        MinSSAGraphNode() : minSSAId(SIZE_MAX), visited(false) { }
461    };
462    struct MinSSAGraphStackEntry
463    {
464        std::unordered_map<size_t, MinSSAGraphNode>::iterator nodeIt;
465        std::unordered_set<size_t>::const_iterator nextIt;
466        size_t minSSAId;
467    };
468   
469    for (auto& entry: ssaReplacesMap)
470    {
471        VectorSet<SSAReplace>& replaces = entry.second;
472        std::sort(replaces.begin(), replaces.end());
473        replaces.resize(std::unique(replaces.begin(), replaces.end()) - replaces.begin());
474        VectorSet<SSAReplace> newReplaces;
475       
476        std::unordered_map<size_t, MinSSAGraphNode> ssaGraphNodes;
477       
478        auto it = replaces.begin();
479        while (it != replaces.end())
480        {
481            auto itEnd = std::upper_bound(it, replaces.end(),
482                            std::make_pair(it->first, size_t(SIZE_MAX)));
483            {
484                MinSSAGraphNode& node = ssaGraphNodes[it->first];
485                node.minSSAId = std::min(node.minSSAId, it->second);
486                for (auto it2 = it; it2 != itEnd; ++it2)
487                    node.nexts.insert(it->second);
488            }
489            it = itEnd;
490        }
491        // propagate min value
492        std::stack<MinSSAGraphStackEntry> minSSAStack;
493        for (auto ssaGraphNodeIt = ssaGraphNodes.begin();
494                 ssaGraphNodeIt!=ssaGraphNodes.end(); )
495        {
496            minSSAStack.push({ ssaGraphNodeIt, ssaGraphNodeIt->second.nexts.begin() });
497            // traverse with minimalize SSA id
498            while (!minSSAStack.empty())
499            {
500                MinSSAGraphStackEntry& entry = minSSAStack.top();
501                MinSSAGraphNode& node = entry.nodeIt->second;
502                bool toPop = false;
503                if (entry.nextIt == node.nexts.begin())
504                {
505                    if (!node.visited)
506                        node.visited = true;
507                    else
508                        toPop = true;
509                }
510                if (!toPop && entry.nextIt != node.nexts.end())
511                {
512                    auto nodeIt = ssaGraphNodes.find(*entry.nextIt);
513                    if (nodeIt != ssaGraphNodes.end())
514                    {
515                        minSSAStack.push({ nodeIt, nodeIt->second.nexts.begin(),
516                                nodeIt->second.minSSAId });
517                    }
518                    ++entry.nextIt;
519                }
520                else
521                {
522                    node.minSSAId = std::min(node.minSSAId, entry.minSSAId);
523                    minSSAStack.pop();
524                    if (!minSSAStack.empty())
525                    {
526                        MinSSAGraphStackEntry& pentry = minSSAStack.top();
527                        pentry.minSSAId = std::min(pentry.minSSAId, node.minSSAId);
528                    }
529                }
530            }
531            // skip visited nodes
532            while (ssaGraphNodeIt != ssaGraphNodes.end())
533                if (!ssaGraphNodeIt->second.visited)
534                    break;
535        }
536       
537        for (const auto& entry: ssaGraphNodes)
538            newReplaces.push_back({ entry.first, entry.second.minSSAId });
539       
540        std::sort(newReplaces.begin(), newReplaces.end());
541        entry.second = newReplaces;
542    }
543   
544    /* apply SSA id replaces */
545    for (CodeBlock& cblock: codeBlocks)
546        for (auto& ssaEntry: cblock.ssaInfoMap)
547        {
548            auto it = ssaReplacesMap.find(ssaEntry.first);
549            if (it == ssaReplacesMap.end())
550                continue;
551            SSAInfo& sinfo = ssaEntry.second;
552            VectorSet<SSAReplace>& replaces = it->second;
553            if (sinfo.readBeforeWrite)
554            {
555                auto rit = binaryMapFind(replaces.begin(), replaces.end(),
556                                 ssaEntry.second.ssaIdBefore);
557                if (rit != replaces.end())
558                    sinfo.ssaIdBefore = rit->second; // replace
559            }
560            if (sinfo.ssaIdFirst != SIZE_MAX)
561            {
562                auto rit = binaryMapFind(replaces.begin(), replaces.end(),
563                                 ssaEntry.second.ssaIdFirst);
564                if (rit != replaces.end())
565                    sinfo.ssaIdFirst = rit->second; // replace
566            }
567            if (sinfo.ssaIdLast != SIZE_MAX)
568            {
569                auto rit = binaryMapFind(replaces.begin(), replaces.end(),
570                                 ssaEntry.second.ssaIdLast);
571                if (rit != replaces.end())
572                    sinfo.ssaIdLast = rit->second; // replace
573            }
574        }
575}
576
577struct Liveness
578{
579    std::map<size_t, size_t> l;
580   
581    Liveness() { }
582   
583    void clear()
584    { l.clear(); }
585   
586    void expand(size_t k)
587    {
588        if (l.empty())
589            l.insert(std::make_pair(k, k+1));
590        else
591        {
592            auto it = l.end();
593            --it;
594            it->second = k+1;
595        }
596    }
597    void newRegion(size_t k)
598    {
599        if (l.empty())
600            l.insert(std::make_pair(k, k));
601        else
602        {
603            auto it = l.end();
604            --it;
605            if (it->first != k && it->second != k)
606                l.insert(std::make_pair(k, k));
607        }
608    }
609   
610    void insert(size_t k, size_t k2)
611    {
612        auto it1 = l.lower_bound(k);
613        if (it1!=l.begin() && (it1==l.end() || it1->first>k))
614            --it1;
615        if (it1->second < k)
616            ++it1;
617        auto it2 = l.lower_bound(k2);
618        if (it1!=it2)
619        {
620            k = std::min(k, it1->first);
621            k2 = std::max(k2, (--it2)->second);
622            l.erase(it1, it2);
623        }
624        l.insert(std::make_pair(k, k2));
625    }
626   
627    bool contain(size_t t) const
628    {
629        auto it = l.lower_bound(t);
630        if (it==l.begin() && it->first>t)
631            return false;
632        if (it==l.end() || it->first>t)
633            --it;
634        return it->first<=t && t<it->second;
635    }
636   
637    bool common(const Liveness& b) const
638    {
639        auto i = l.begin();
640        auto j = b.l.begin();
641        for (; i != l.end() && j != b.l.end();)
642        {
643            if (i->first==i->second)
644            {
645                ++i;
646                continue;
647            }
648            if (j->first==j->second)
649            {
650                ++j;
651                continue;
652            }
653            if (i->first<j->first)
654            {
655                if (i->second > j->first)
656                    return true; // common place
657                ++i;
658            }
659            else
660            {
661                if (i->first < j->second)
662                    return true; // common place
663                ++j;
664            }
665        }
666        return false;
667    }
668};
669
670typedef AsmRegAllocator::VarIndexMap VarIndexMap;
671
672static cxuint getRegType(size_t regTypesNum, const cxuint* regRanges,
673            const AsmSingleVReg& svreg)
674{
675    cxuint regType; // regtype
676    if (svreg.regVar!=nullptr)
677        regType = svreg.regVar->type;
678    else
679        for (regType = 0; regType < regTypesNum; regType++)
680            if (svreg.index >= regRanges[regType<<1] &&
681                svreg.index < regRanges[(regType<<1)+1])
682                break;
683    return regType;
684}
685
686static Liveness& getLiveness(const AsmSingleVReg& svreg, size_t ssaIdIdx,
687        const AsmRegAllocator::SSAInfo& ssaInfo, std::vector<Liveness>* livenesses,
688        const VarIndexMap* vregIndexMaps, size_t regTypesNum, const cxuint* regRanges)
689{
690    size_t ssaId;
691    if (svreg.regVar==nullptr)
692        ssaId = 0;
693    else if (ssaIdIdx==0)
694        ssaId = ssaInfo.ssaIdBefore;
695    else if (ssaIdIdx==1)
696        ssaId = ssaInfo.ssaIdFirst;
697    else if (ssaIdIdx<ssaInfo.ssaIdChange)
698        ssaId = ssaInfo.ssaId + ssaIdIdx-1;
699    else // last
700        ssaId = ssaInfo.ssaIdLast;
701   
702    cxuint regType = getRegType(regTypesNum, regRanges, svreg); // regtype
703    const VarIndexMap& vregIndexMap = vregIndexMaps[regType];
704    const std::vector<size_t>& ssaIdIndices =
705                vregIndexMap.find(svreg)->second;
706    return livenesses[regType][ssaIdIndices[ssaId]];
707}
708
709typedef std::deque<FlowStackEntry3>::const_iterator FlowStackCIter;
710
711struct CLRX_INTERNAL VRegLastPos
712{
713    size_t ssaId; // last SSA id
714    std::vector<FlowStackCIter> blockChain; // subsequent blocks that changes SSAId
715};
716
717/* TODO: add handling calls
718 * handle many start points in this code (for example many kernel's in same code)
719 * replace sets by vector, and sort and remove same values on demand
720 */
721
722typedef std::unordered_map<AsmSingleVReg, VRegLastPos> LastVRegMap;
723
724static void putCrossBlockLivenesses(const std::deque<FlowStackEntry3>& flowStack,
725        const std::vector<CodeBlock>& codeBlocks,
726        const Array<size_t>& codeBlockLiveTimes, const LastVRegMap& lastVRegMap,
727        std::vector<Liveness>* livenesses, const VarIndexMap* vregIndexMaps,
728        size_t regTypesNum, const cxuint* regRanges)
729{
730    const CodeBlock& cblock = codeBlocks[flowStack.back().blockIndex];
731    for (const auto& entry: cblock.ssaInfoMap)
732        if (entry.second.readBeforeWrite)
733        {
734            // find last
735            auto lvrit = lastVRegMap.find(entry.first);
736            if (lvrit == lastVRegMap.end())
737                continue; // not found
738            const VRegLastPos& lastPos = lvrit->second;
739            FlowStackCIter flit = lastPos.blockChain.back();
740            cxuint regType = getRegType(regTypesNum, regRanges, entry.first);
741            const VarIndexMap& vregIndexMap = vregIndexMaps[regType];
742            const std::vector<size_t>& ssaIdIndices =
743                        vregIndexMap.find(entry.first)->second;
744            Liveness& lv = livenesses[regType][ssaIdIndices[entry.second.ssaIdBefore]];
745            FlowStackCIter flitEnd = flowStack.end();
746            --flitEnd; // before last element
747            // insert live time to last seen position
748            const CodeBlock& lastBlk = codeBlocks[flit->blockIndex];
749            size_t toLiveCvt = codeBlockLiveTimes[flit->blockIndex] - lastBlk.start;
750            lv.insert(lastBlk.ssaInfoMap.find(entry.first)->second.lastPos + toLiveCvt,
751                    toLiveCvt + lastBlk.end);
752            for (++flit; flit != flitEnd; ++flit)
753            {
754                const CodeBlock& cblock = codeBlocks[flit->blockIndex];
755                size_t blockLiveTime = codeBlockLiveTimes[flit->blockIndex];
756                lv.insert(blockLiveTime, cblock.end-cblock.start + blockLiveTime);
757            }
758        }
759}
760
761static void putCrossBlockForLoop(const std::deque<FlowStackEntry3>& flowStack,
762        const std::vector<CodeBlock>& codeBlocks,
763        const Array<size_t>& codeBlockLiveTimes,
764        std::vector<Liveness>* livenesses, const VarIndexMap* vregIndexMaps,
765        size_t regTypesNum, const cxuint* regRanges)
766{
767    auto flitStart = flowStack.end();
768    --flitStart;
769    size_t curBlock = flitStart->blockIndex;
770    // find step in way
771    while (flitStart->blockIndex != curBlock) --flitStart;
772    auto flitEnd = flowStack.end();
773    --flitEnd;
774    std::unordered_map<AsmSingleVReg, std::pair<size_t, size_t> > varMap;
775   
776    // collect var to check
777    size_t flowPos = 0;
778    for (auto flit = flitStart; flit != flitEnd; ++flit, flowPos++)
779    {
780        const CodeBlock& cblock = codeBlocks[flit->blockIndex];
781        for (const auto& entry: cblock.ssaInfoMap)
782        {
783            const SSAInfo& sinfo = entry.second;
784            size_t lastSSAId = (sinfo.ssaIdChange != 0) ? sinfo.ssaIdLast :
785                    (sinfo.readBeforeWrite) ? sinfo.ssaIdBefore : 0;
786            varMap[entry.first] = { lastSSAId, flowPos };
787        }
788    }
789    // find connections
790    flowPos = 0;
791    for (auto flit = flitStart; flit != flitEnd; ++flit, flowPos++)
792    {
793        const CodeBlock& cblock = codeBlocks[flit->blockIndex];
794        for (const auto& entry: cblock.ssaInfoMap)
795        {
796            const SSAInfo& sinfo = entry.second;
797            auto varMapIt = varMap.find(entry.first);
798            if (!sinfo.readBeforeWrite || varMapIt == varMap.end() ||
799                flowPos > varMapIt->second.second ||
800                sinfo.ssaIdBefore != varMapIt->second.first)
801                continue;
802            // just connect
803           
804            cxuint regType = getRegType(regTypesNum, regRanges, entry.first);
805            const VarIndexMap& vregIndexMap = vregIndexMaps[regType];
806            const std::vector<size_t>& ssaIdIndices =
807                        vregIndexMap.find(entry.first)->second;
808            Liveness& lv = livenesses[regType][ssaIdIndices[entry.second.ssaIdBefore]];
809           
810            if (flowPos == varMapIt->second.second)
811            {
812                // fill whole loop
813                for (auto flit2 = flitStart; flit != flitEnd; ++flit)
814                {
815                    const CodeBlock& cblock = codeBlocks[flit2->blockIndex];
816                    size_t blockLiveTime = codeBlockLiveTimes[flit2->blockIndex];
817                    lv.insert(blockLiveTime, cblock.end-cblock.start + blockLiveTime);
818                }
819                continue;
820            }
821           
822            size_t flowPos2 = 0;
823            for (auto flit2 = flitStart; flowPos2 < flowPos; ++flit2, flowPos++)
824            {
825                const CodeBlock& cblock = codeBlocks[flit2->blockIndex];
826                size_t blockLiveTime = codeBlockLiveTimes[flit2->blockIndex];
827                lv.insert(blockLiveTime, cblock.end-cblock.start + blockLiveTime);
828            }
829            // insert liveness for last block in loop of last SSAId (prev round)
830            auto flit2 = flitStart + flowPos;
831            const CodeBlock& firstBlk = codeBlocks[flit2->blockIndex];
832            size_t toLiveCvt = codeBlockLiveTimes[flit2->blockIndex] - firstBlk.start;
833            lv.insert(codeBlockLiveTimes[flit2->blockIndex],
834                    firstBlk.ssaInfoMap.find(entry.first)->second.firstPos + toLiveCvt);
835            // insert liveness for first block in loop of last SSAId
836            flit2 = flitStart + (varMapIt->second.second+1);
837            const CodeBlock& lastBlk = codeBlocks[flit2->blockIndex];
838            toLiveCvt = codeBlockLiveTimes[flit2->blockIndex] - lastBlk.start;
839            lv.insert(lastBlk.ssaInfoMap.find(entry.first)->second.lastPos + toLiveCvt,
840                    toLiveCvt + lastBlk.end);
841            // fill up loop end
842            for (++flit2; flit2 != flitEnd; ++flit2)
843            {
844                const CodeBlock& cblock = codeBlocks[flit2->blockIndex];
845                size_t blockLiveTime = codeBlockLiveTimes[flit2->blockIndex];
846                lv.insert(blockLiveTime, cblock.end-cblock.start + blockLiveTime);
847            }
848        }
849    }
850}
851
852struct LiveBlock
853{
854    size_t start;
855    size_t end;
856    size_t vidx;
857   
858    bool operator==(const LiveBlock& b) const
859    { return start==b.start && end==b.end && vidx==b.vidx; }
860   
861    bool operator<(const LiveBlock& b) const
862    { return start<b.start || (start==b.start &&
863            (end<b.end || (end==b.end && vidx<b.vidx))); }
864};
865
866typedef AsmRegAllocator::LinearDep LinearDep;
867typedef AsmRegAllocator::EqualToDep EqualToDep;
868typedef std::unordered_map<size_t, LinearDep> LinearDepMap;
869typedef std::unordered_map<size_t, EqualToDep> EqualToDepMap;
870
871static void addUsageDeps(const cxbyte* ldeps, const cxbyte* edeps, cxuint rvusNum,
872            const AsmRegVarUsage* rvus, LinearDepMap* ldepsOut,
873            EqualToDepMap* edepsOut, const VarIndexMap* vregIndexMaps,
874            std::unordered_map<AsmSingleVReg, size_t> ssaIdIdxMap,
875            size_t regTypesNum, const cxuint* regRanges)
876{
877    // add linear deps
878    cxuint count = ldeps[0];
879    cxuint pos = 1;
880    cxbyte rvuAdded = 0;
881    for (cxuint i = 0; i < count; i++)
882    {
883        cxuint ccount = ldeps[pos++];
884        std::vector<size_t> vidxes;
885        cxuint regType = UINT_MAX;
886        cxbyte align = rvus[ldeps[pos]].align;
887        for (cxuint j = 0; j < ccount; j++)
888        {
889            rvuAdded |= 1U<<ldeps[pos];
890            const AsmRegVarUsage& rvu = rvus[ldeps[pos++]];
891            for (uint16_t k = rvu.rstart; k < rvu.rend; k++)
892            {
893                AsmSingleVReg svreg = {rvu.regVar, k};
894                auto sit = ssaIdIdxMap.find(svreg);
895                if (regType==UINT_MAX)
896                    regType = getRegType(regTypesNum, regRanges, svreg);
897                const VarIndexMap& vregIndexMap = vregIndexMaps[regType];
898                const std::vector<size_t>& ssaIdIndices =
899                            vregIndexMap.find(svreg)->second;
900                // push variable index
901                vidxes.push_back(ssaIdIndices[sit->second]);
902            }
903        }
904        ldepsOut[regType][vidxes[0]].align = align;
905        for (size_t k = 1; k < vidxes.size(); k++)
906        {
907            ldepsOut[regType][vidxes[k-1]].nextVidxes.push_back(vidxes[k]);
908            ldepsOut[regType][vidxes[k]].prevVidxes.push_back(vidxes[k-1]);
909        }
910    }
911    // add single arg linear dependencies
912    for (cxuint i = 0; i < rvusNum; i++)
913        if ((rvuAdded & (1U<<i)) == 0 && rvus[i].rstart+1<rvus[i].rend)
914        {
915            const AsmRegVarUsage& rvu = rvus[i];
916            std::vector<size_t> vidxes;
917            cxuint regType = UINT_MAX;
918            for (uint16_t k = rvu.rstart; k < rvu.rend; k++)
919            {
920                AsmSingleVReg svreg = {rvu.regVar, k};
921                auto sit = ssaIdIdxMap.find(svreg);
922                if (regType==UINT_MAX)
923                    regType = getRegType(regTypesNum, regRanges, svreg);
924                const VarIndexMap& vregIndexMap = vregIndexMaps[regType];
925                const std::vector<size_t>& ssaIdIndices =
926                            vregIndexMap.find(svreg)->second;
927                // push variable index
928                vidxes.push_back(ssaIdIndices[sit->second]);
929            }
930            for (size_t j = 1; j < vidxes.size(); j++)
931            {
932                ldepsOut[regType][vidxes[j-1]].nextVidxes.push_back(vidxes[j]);
933                ldepsOut[regType][vidxes[j]].prevVidxes.push_back(vidxes[j-1]);
934            }
935        }
936       
937    /* equalTo dependencies */
938    count = edeps[0];
939    pos = 1;
940    for (cxuint i = 0; i < count; i++)
941    {
942        cxuint ccount = edeps[pos++];
943        std::vector<size_t> vidxes;
944        cxuint regType = UINT_MAX;
945        for (cxuint j = 0; j < ccount; j++)
946        {
947            const AsmRegVarUsage& rvu = rvus[edeps[pos++]];
948            // only one register should be set for equalTo depencencies
949            // other registers in range will be resolved by linear dependencies
950            AsmSingleVReg svreg = {rvu.regVar, rvu.rstart};
951            auto sit = ssaIdIdxMap.find(svreg);
952            if (regType==UINT_MAX)
953                regType = getRegType(regTypesNum, regRanges, svreg);
954            const VarIndexMap& vregIndexMap = vregIndexMaps[regType];
955            const std::vector<size_t>& ssaIdIndices =
956                        vregIndexMap.find(svreg)->second;
957            // push variable index
958            vidxes.push_back(ssaIdIndices[sit->second]);
959        }
960        for (size_t j = 1; j < vidxes.size(); j++)
961        {
962            edepsOut[regType][vidxes[j-1]].nextVidxes.push_back(vidxes[j]);
963            edepsOut[regType][vidxes[j]].prevVidxes.push_back(vidxes[j-1]);
964        }
965    }
966}
967
968typedef std::unordered_map<size_t, EqualToDep>::const_iterator EqualToDepMapCIter;
969
970struct EqualStackEntry
971{
972    EqualToDepMapCIter etoDepIt;
973    size_t nextIdx; // over nextVidxes size, then prevVidxes[nextIdx-nextVidxes.size()]
974};
975
976void AsmRegAllocator::createInterferenceGraph(ISAUsageHandler& usageHandler)
977{
978    // construct var index maps
979    size_t graphVregsCounts[MAX_REGTYPES_NUM];
980    std::fill(graphVregsCounts, graphVregsCounts+regTypesNum, 0);
981    cxuint regRanges[MAX_REGTYPES_NUM*2];
982    size_t regTypesNum;
983    assembler.isaAssembler->getRegisterRanges(regTypesNum, regRanges);
984   
985    for (const CodeBlock& cblock: codeBlocks)
986        for (const auto& entry: cblock.ssaInfoMap)
987        {
988            const SSAInfo& sinfo = entry.second;
989            cxuint regType = getRegType(regTypesNum, regRanges, entry.first);
990            VarIndexMap& vregIndices = vregIndexMaps[regType];
991            size_t& graphVregsCount = graphVregsCounts[regType];
992            std::vector<size_t>& ssaIdIndices = vregIndices[entry.first];
993            size_t ssaIdCount = 0;
994            if (sinfo.readBeforeWrite)
995                ssaIdCount = sinfo.ssaIdBefore+1;
996            if (sinfo.ssaIdChange!=0)
997            {
998                ssaIdCount = std::max(ssaIdCount, sinfo.ssaIdLast+1);
999                ssaIdCount = std::max(ssaIdCount, sinfo.ssaIdFirst+1);
1000            }
1001            if (ssaIdIndices.size() < ssaIdCount)
1002                ssaIdIndices.resize(ssaIdCount, SIZE_MAX);
1003           
1004            if (sinfo.readBeforeWrite)
1005                ssaIdIndices[sinfo.ssaIdBefore] = graphVregsCount++;
1006            if (sinfo.ssaIdChange!=0)
1007            {
1008                // fill up ssaIdIndices (with graph Ids)
1009                ssaIdIndices[sinfo.ssaIdFirst] = graphVregsCount++;
1010                for (size_t ssaId = sinfo.ssaId+1;
1011                        ssaId < sinfo.ssaId+sinfo.ssaIdChange-1; ssaId++)
1012                    ssaIdIndices[ssaId] = graphVregsCount++;
1013                ssaIdIndices[sinfo.ssaIdLast] = graphVregsCount++;
1014            }
1015        }
1016   
1017    // construct vreg liveness
1018    std::deque<FlowStackEntry3> flowStack;
1019    std::vector<bool> visited(codeBlocks.size(), false);
1020    // hold last vreg ssaId and position
1021    LastVRegMap lastVRegMap;
1022    // hold start live time position for every code block
1023    Array<size_t> codeBlockLiveTimes(codeBlocks.size());
1024    std::unordered_set<size_t> blockInWay;
1025   
1026    std::vector<Liveness> livenesses[MAX_REGTYPES_NUM];
1027   
1028    for (size_t i = 0; i < regTypesNum; i++)
1029        livenesses[i].resize(graphVregsCounts[i]);
1030   
1031    size_t curLiveTime = 0;
1032   
1033    while (!flowStack.empty())
1034    {
1035        FlowStackEntry3& entry = flowStack.back();
1036        CodeBlock& cblock = codeBlocks[entry.blockIndex];
1037       
1038        if (entry.nextIndex == 0)
1039        {
1040            // process current block
1041            if (!blockInWay.insert(entry.blockIndex).second)
1042            {
1043                // if loop
1044                putCrossBlockForLoop(flowStack, codeBlocks, codeBlockLiveTimes, 
1045                        livenesses, vregIndexMaps, regTypesNum, regRanges);
1046                flowStack.pop_back();
1047                continue;
1048            }
1049           
1050            codeBlockLiveTimes[entry.blockIndex] = curLiveTime;
1051            putCrossBlockLivenesses(flowStack, codeBlocks, codeBlockLiveTimes, 
1052                    lastVRegMap, livenesses, vregIndexMaps, regTypesNum, regRanges);
1053           
1054            for (const auto& sentry: cblock.ssaInfoMap)
1055            {
1056                const SSAInfo& sinfo = sentry.second;
1057                // update
1058                size_t lastSSAId =  (sinfo.ssaIdChange != 0) ? sinfo.ssaIdLast :
1059                        (sinfo.readBeforeWrite) ? sinfo.ssaIdBefore : 0;
1060                FlowStackCIter flit = flowStack.end();
1061                --flit; // to last position
1062                auto res = lastVRegMap.insert({ sentry.first, 
1063                            { lastSSAId, { flit } } });
1064                if (!res.second) // if not first seen, just update
1065                {
1066                    // update last
1067                    res.first->second.ssaId = lastSSAId;
1068                    res.first->second.blockChain.push_back(flit);
1069                }
1070            }
1071           
1072            size_t curBlockLiveEnd = cblock.end - cblock.start + curLiveTime;
1073            if (!visited[entry.blockIndex])
1074            {
1075                visited[entry.blockIndex] = true;
1076                std::unordered_map<AsmSingleVReg, size_t> ssaIdIdxMap;
1077                AsmRegVarUsage instrRVUs[8];
1078                cxuint instrRVUsCount = 0;
1079               
1080                size_t oldOffset = cblock.usagePos.readOffset;
1081                std::vector<AsmSingleVReg> readSVRegs;
1082                std::vector<AsmSingleVReg> writtenSVRegs;
1083               
1084                usageHandler.setReadPos(cblock.usagePos);
1085                // register in liveness
1086                while (true)
1087                {
1088                    AsmRegVarUsage rvu = { 0U, nullptr, 0U, 0U };
1089                    size_t liveTimeNext = curBlockLiveEnd;
1090                    if (usageHandler.hasNext())
1091                    {
1092                        rvu = usageHandler.nextUsage();
1093                        if (rvu.offset >= cblock.end)
1094                            break;
1095                        if (!rvu.useRegMode)
1096                            instrRVUs[instrRVUsCount++] = rvu;
1097                        liveTimeNext = std::min(rvu.offset, cblock.end) -
1098                                cblock.start + curLiveTime;
1099                    }
1100                    size_t liveTime = oldOffset - cblock.start + curLiveTime;
1101                    if (!usageHandler.hasNext() || rvu.offset >= oldOffset)
1102                    {
1103                        // apply to liveness
1104                        for (AsmSingleVReg svreg: readSVRegs)
1105                        {
1106                            Liveness& lv = getLiveness(svreg, ssaIdIdxMap[svreg],
1107                                    cblock.ssaInfoMap.find(svreg)->second,
1108                                    livenesses, vregIndexMaps, regTypesNum, regRanges);
1109                            if (!lv.l.empty() && (--lv.l.end())->first < curLiveTime)
1110                                lv.newRegion(curLiveTime); // begin region from this block
1111                            lv.expand(liveTime);
1112                        }
1113                        for (AsmSingleVReg svreg: writtenSVRegs)
1114                        {
1115                            size_t& ssaIdIdx = ssaIdIdxMap[svreg];
1116                            ssaIdIdx++;
1117                            SSAInfo& sinfo = cblock.ssaInfoMap.find(svreg)->second;
1118                            Liveness& lv = getLiveness(svreg, ssaIdIdx, sinfo,
1119                                    livenesses, vregIndexMaps, regTypesNum, regRanges);
1120                            if (liveTimeNext != curBlockLiveEnd)
1121                                // because live after this instr
1122                                lv.newRegion(liveTimeNext);
1123                            sinfo.lastPos = liveTimeNext - curLiveTime + cblock.start;
1124                        }
1125                        // get linear deps and equal to
1126                        cxbyte lDeps[16];
1127                        cxbyte eDeps[16];
1128                        usageHandler.getUsageDependencies(instrRVUsCount, instrRVUs,
1129                                        lDeps, eDeps);
1130                       
1131                        addUsageDeps(lDeps, eDeps, instrRVUsCount, instrRVUs,
1132                                linearDepMaps, equalToDepMaps, vregIndexMaps, ssaIdIdxMap,
1133                                regTypesNum, regRanges);
1134                       
1135                        readSVRegs.clear();
1136                        writtenSVRegs.clear();
1137                        if (!usageHandler.hasNext())
1138                            break; // end
1139                        oldOffset = rvu.offset;
1140                        instrRVUsCount = 0;
1141                    }
1142                    if (rvu.offset >= cblock.end)
1143                        break;
1144                   
1145                    for (uint16_t rindex = rvu.rstart; rindex < rvu.rend; rindex++)
1146                    {
1147                        // per register/singlvreg
1148                        AsmSingleVReg svreg{ rvu.regVar, rindex };
1149                        if (rvu.rwFlags == ASMRVU_WRITE && rvu.regField == ASMFIELD_NONE)
1150                            writtenSVRegs.push_back(svreg);
1151                        else // read or treat as reading // expand previous region
1152                            readSVRegs.push_back(svreg);
1153                    }
1154                }
1155                curLiveTime += cblock.end-cblock.start;
1156            }
1157            else
1158            {
1159                // back, already visited
1160                flowStack.pop_back();
1161                continue;
1162            }
1163        }
1164        if (entry.nextIndex < cblock.nexts.size())
1165        {
1166            flowStack.push_back({ cblock.nexts[entry.nextIndex].block, 0 });
1167            entry.nextIndex++;
1168        }
1169        else if (entry.nextIndex==0 && cblock.nexts.empty() && !cblock.haveEnd)
1170        {
1171            flowStack.push_back({ entry.blockIndex+1, 0 });
1172            entry.nextIndex++;
1173        }
1174        else // back
1175        {
1176            // revert lastSSAIdMap
1177            blockInWay.erase(entry.blockIndex);
1178            flowStack.pop_back();
1179            if (!flowStack.empty())
1180            {
1181                for (const auto& sentry: cblock.ssaInfoMap)
1182                {
1183                    auto lvrit = lastVRegMap.find(sentry.first);
1184                    if (lvrit != lastVRegMap.end())
1185                    {
1186                        VRegLastPos& lastPos = lvrit->second;
1187                        lastPos.ssaId = sentry.second.ssaIdBefore;
1188                        lastPos.blockChain.pop_back();
1189                        if (lastPos.blockChain.empty()) // just remove from lastVRegs
1190                            lastVRegMap.erase(lvrit);
1191                    }
1192                }
1193            }
1194        }
1195    }
1196   
1197    /// construct liveBlockMaps
1198    std::set<LiveBlock> liveBlockMaps[MAX_REGTYPES_NUM];
1199    for (size_t regType = 0; regType < regTypesNum; regType++)
1200    {
1201        std::set<LiveBlock>& liveBlockMap = liveBlockMaps[regType];
1202        std::vector<Liveness>& liveness = livenesses[regType];
1203        for (size_t li = 0; li < liveness.size(); li++)
1204        {
1205            Liveness& lv = liveness[li];
1206            for (const std::pair<size_t, size_t>& blk: lv.l)
1207                if (blk.first != blk.second)
1208                    liveBlockMap.insert({ blk.first, blk.second, li });
1209            lv.clear();
1210        }
1211        liveness.clear();
1212    }
1213   
1214    // create interference graphs
1215    for (size_t regType = 0; regType < regTypesNum; regType++)
1216    {
1217        InterGraph& interGraph = interGraphs[regType];
1218        interGraph.resize(graphVregsCounts[regType]);
1219        std::set<LiveBlock>& liveBlockMap = liveBlockMaps[regType];
1220       
1221        auto lit = liveBlockMap.begin();
1222        size_t rangeStart = 0;
1223        if (lit != liveBlockMap.end())
1224            rangeStart = lit->start;
1225        while (lit != liveBlockMap.end())
1226        {
1227            const size_t blkStart = lit->start;
1228            const size_t blkEnd = lit->end;
1229            size_t rangeEnd = blkEnd;
1230            auto liStart = liveBlockMap.lower_bound({ rangeStart, 0, 0 });
1231            auto liEnd = liveBlockMap.lower_bound({ rangeEnd, 0, 0 });
1232            // collect from this range, variable indices
1233            std::set<size_t> varIndices;
1234            for (auto lit2 = liStart; lit2 != liEnd; ++lit2)
1235                varIndices.insert(lit2->vidx);
1236            // push to intergraph as full subgGraph
1237            for (auto vit = varIndices.begin(); vit != varIndices.end(); ++vit)
1238                for (auto vit2 = varIndices.begin(); vit2 != varIndices.end(); ++vit2)
1239                    if (vit != vit2)
1240                        interGraph[*vit].insert(*vit2);
1241            // go to next live blocks
1242            rangeStart = rangeEnd;
1243            for (; lit != liveBlockMap.end(); ++lit)
1244                if (lit->start != blkStart && lit->end != blkEnd)
1245                    break;
1246            if (lit == liveBlockMap.end())
1247                break; //
1248            rangeStart = std::max(rangeStart, lit->start);
1249        }
1250    }
1251   
1252    /*
1253     * resolve equalSets
1254     */
1255    for (cxuint regType = 0; regType < regTypesNum; regType++)
1256    {
1257        InterGraph& interGraph = interGraphs[regType];
1258        const size_t nodesNum = interGraph.size();
1259        const std::unordered_map<size_t, EqualToDep>& etoDepMap = equalToDepMaps[regType];
1260        std::vector<bool> visited(nodesNum, false);
1261        std::vector<std::vector<size_t> >& equalSetList = equalSetLists[regType];
1262       
1263        for (size_t v = 0; v < nodesNum;)
1264        {
1265            auto it = etoDepMap.find(v);
1266            if (it == etoDepMap.end())
1267            {
1268                // is not regvar in equalTo dependencies
1269                v++;
1270                continue;
1271            }
1272           
1273            std::stack<EqualStackEntry> etoStack;
1274            etoStack.push(EqualStackEntry{ it, 0 });
1275           
1276            std::unordered_map<size_t, size_t>& equalSetMap =  equalSetMaps[regType];
1277            const size_t equalSetIndex = equalSetList.size();
1278            equalSetList.push_back(std::vector<size_t>());
1279            std::vector<size_t>& equalSet = equalSetList.back();
1280           
1281            // traverse by this
1282            while (!etoStack.empty())
1283            {
1284                EqualStackEntry& entry = etoStack.top();
1285                size_t vidx = entry.etoDepIt->first; // node index, vreg index
1286                const EqualToDep& eToDep = entry.etoDepIt->second;
1287                if (entry.nextIdx == 0)
1288                {
1289                    if (!visited[vidx])
1290                    {
1291                        // push to this equalSet
1292                        equalSetMap.insert({ vidx, equalSetIndex });
1293                        equalSet.push_back(vidx);
1294                    }
1295                    else
1296                    {
1297                        // already visited
1298                        etoStack.pop();
1299                        continue;
1300                    }
1301                }
1302               
1303                if (entry.nextIdx < eToDep.nextVidxes.size())
1304                {
1305                    auto nextIt = etoDepMap.find(eToDep.nextVidxes[entry.nextIdx]);
1306                    etoStack.push(EqualStackEntry{ nextIt, 0 });
1307                    entry.nextIdx++;
1308                }
1309                else if (entry.nextIdx < eToDep.nextVidxes.size()+eToDep.prevVidxes.size())
1310                {
1311                    auto nextIt = etoDepMap.find(eToDep.prevVidxes[
1312                                entry.nextIdx - eToDep.nextVidxes.size()]);
1313                    etoStack.push(EqualStackEntry{ nextIt, 0 });
1314                    entry.nextIdx++;
1315                }
1316                else
1317                    etoStack.pop();
1318            }
1319           
1320            // to first already added node (var)
1321            while (v < nodesNum && !visited[v]) v++;
1322        }
1323    }
1324}
1325
1326typedef AsmRegAllocator::InterGraph InterGraph;
1327
1328struct CLRX_INTERNAL SDOLDOCompare
1329{
1330    const InterGraph& interGraph;
1331    const Array<size_t>& sdoCounts;
1332   
1333    SDOLDOCompare(const InterGraph& _interGraph, const Array<size_t>&_sdoCounts)
1334        : interGraph(_interGraph), sdoCounts(_sdoCounts)
1335    { }
1336   
1337    bool operator()(size_t a, size_t b) const
1338    {
1339        if (sdoCounts[a] > sdoCounts[b])
1340            return true;
1341        return interGraph[a].size() > interGraph[b].size();
1342    }
1343};
1344
1345/* algorithm to allocate regranges:
1346 * from smallest regranges to greatest regranges:
1347 *   choosing free register: from smallest free regranges
1348 *      to greatest regranges:
1349 *         in this same regrange:
1350 *               try to find free regs in regranges
1351 *               try to link free ends of two distinct regranges
1352 */
1353
1354void AsmRegAllocator::colorInterferenceGraph()
1355{
1356    const GPUArchitecture arch = getGPUArchitectureFromDeviceType(
1357                    assembler.deviceType);
1358   
1359    for (size_t regType = 0; regType < regTypesNum; regType++)
1360    {
1361        const size_t maxColorsNum = getGPUMaxRegistersNum(arch, regType);
1362        InterGraph& interGraph = interGraphs[regType];
1363        const VarIndexMap& vregIndexMap = vregIndexMaps[regType];
1364        Array<cxuint>& gcMap = graphColorMaps[regType];
1365        const std::vector<std::vector<size_t> >& equalSetList = equalSetLists[regType];
1366        const std::unordered_map<size_t, size_t>& equalSetMap =  equalSetMaps[regType];
1367       
1368        const size_t nodesNum = interGraph.size();
1369        gcMap.resize(nodesNum);
1370        std::fill(gcMap.begin(), gcMap.end(), cxuint(UINT_MAX));
1371        Array<size_t> sdoCounts(nodesNum);
1372        std::fill(sdoCounts.begin(), sdoCounts.end(), 0);
1373       
1374        SDOLDOCompare compare(interGraph, sdoCounts);
1375        std::set<size_t, SDOLDOCompare> nodeSet(compare);
1376        for (size_t i = 0; i < nodesNum; i++)
1377            nodeSet.insert(i);
1378       
1379        cxuint colorsNum = 0;
1380        // firstly, allocate real registers
1381        for (const auto& entry: vregIndexMap)
1382            if (entry.first.regVar == nullptr)
1383                gcMap[entry.second[0]] = colorsNum++;
1384       
1385        for (size_t colored = 0; colored < nodesNum; colored++)
1386        {
1387            size_t node = *nodeSet.begin();
1388            if (gcMap[node] != UINT_MAX)
1389                continue; // already colored
1390            size_t color = 0;
1391            std::vector<size_t> equalNodes;
1392            equalNodes.push_back(node); // only one node, if equalSet not found
1393            auto equalSetMapIt = equalSetMap.find(node);
1394            if (equalSetMapIt != equalSetMap.end())
1395                // found, get equal set from equalSetList
1396                equalNodes = equalSetList[equalSetMapIt->second];
1397           
1398            for (color = 0; color <= colorsNum; color++)
1399            {
1400                // find first usable color
1401                bool thisSame = false;
1402                for (size_t nb: interGraph[node])
1403                    if (gcMap[nb] == color)
1404                    {
1405                        thisSame = true;
1406                        break;
1407                    }
1408                if (!thisSame)
1409                    break;
1410            }
1411            if (color==colorsNum) // add new color if needed
1412            {
1413                if (colorsNum >= maxColorsNum)
1414                    throw AsmException("Too many register is needed");
1415                colorsNum++;
1416            }
1417           
1418            for (size_t nextNode: equalNodes)
1419                gcMap[nextNode] = color;
1420            // update SDO for node
1421            bool colorExists = false;
1422            for (size_t node: equalNodes)
1423            {
1424                for (size_t nb: interGraph[node])
1425                    if (gcMap[nb] == color)
1426                    {
1427                        colorExists = true;
1428                        break;
1429                    }
1430                if (!colorExists)
1431                    sdoCounts[node]++;
1432            }
1433            // update SDO for neighbors
1434            for (size_t node: equalNodes)
1435                for (size_t nb: interGraph[node])
1436                {
1437                    colorExists = false;
1438                    for (size_t nb2: interGraph[nb])
1439                        if (gcMap[nb2] == color)
1440                        {
1441                            colorExists = true;
1442                            break;
1443                        }
1444                    if (!colorExists)
1445                    {
1446                        if (gcMap[nb] == UINT_MAX)
1447                            nodeSet.erase(nb);  // before update we erase from nodeSet
1448                        sdoCounts[nb]++;
1449                        if (gcMap[nb] == UINT_MAX)
1450                            nodeSet.insert(nb); // after update, insert again
1451                    }
1452                }
1453           
1454            for (size_t nextNode: equalNodes)
1455                gcMap[nextNode] = color;
1456        }
1457    }
1458}
1459
1460void AsmRegAllocator::allocateRegisters(cxuint sectionId)
1461{
1462    // before any operation, clear all
1463    codeBlocks.clear();
1464    for (size_t i = 0; i < MAX_REGTYPES_NUM; i++)
1465    {
1466        vregIndexMaps[i].clear();
1467        interGraphs[i].clear();
1468        linearDepMaps[i].clear();
1469        equalToDepMaps[i].clear();
1470        graphColorMaps[i].clear();
1471        equalSetMaps[i].clear();
1472        equalSetLists[i].clear();
1473    }
1474    ssaReplacesMap.clear();
1475    cxuint maxRegs[MAX_REGTYPES_NUM];
1476    assembler.isaAssembler->getMaxRegistersNum(regTypesNum, maxRegs);
1477   
1478    // set up
1479    const AsmSection& section = assembler.sections[sectionId];
1480    createCodeStructure(section.codeFlow, section.content.size(), section.content.data());
1481    createSSAData(*section.usageHandler);
1482    applySSAReplaces();
1483    createInterferenceGraph(*section.usageHandler);
1484    colorInterferenceGraph();
1485}
Note: See TracBrowser for help on using the repository browser.