/ Assets / Resources / Lineart / ComputeShaders / SilhouetteEdge.compute
SilhouetteEdge.compute
  1  #pragma kernel FindSilhouetteEdge
  2  
  3  // Define the input vertex structure
  4  struct VertexData
  5  {
  6      float3 position;
  7      float3 normal;
  8  };
  9  
 10  #include "silhouette_defines.hh"
 11  
 12  // Input Buffers
 13  StructuredBuffer<VertexData> _Vertices;  // All mesh vertices
 14  StructuredBuffer<int> _Indices;         // Triangle indices (3 per triangle)
 15  StructuredBuffer<uint> _AdjIndices;    // Adjacency info (3 indices per triangle, INVALID_UINT if none)
 16  
 17  // Output Buffer
 18  RWStructuredBuffer<StrokeData> _outStrokes;
 19  
 20  // Uniforms (constants)
 21  float3 _WorldSpaceCameraPos; // Camera position, set from C#
 22  uint _NumFaces;              // Total number of faces (triangles)
 23  
 24  // New uniforms: object->world transform and inverse-transpose for normals
 25  float4x4 _ObjectToWorld;
 26  float4x4 _ObjectToWorldIT; // inverse-transpose for normal transform
 27  
 28  // Helper function to find the zero-crossing point on an edge
 29  // This is identical to your provided logic.
 30  bool TryGetZeroPoint(VertexData v1, VertexData v2, out float3 zeroPoint)
 31  {
 32      // We need to find the interpolation factor 't' such that
 33      // lerp(vA.scalar, vB.scalar, t) == 0
 34      //
 35      // sA * (1-t) + sB * t = 0
 36      // sA - sA*t + sB*t = 0
 37      // sA = t * (sA - sB)
 38      // t = sA / (sA - sB)
 39      
 40      float3 dirToCam1 = normalize(_WorldSpaceCameraPos - v1.position);
 41      float3 dirToCam2 = normalize(_WorldSpaceCameraPos - v2.position);
 42  
 43      float dot1 = dot(v1.normal, dirToCam1);
 44      float dot2 = dot(v2.normal, dirToCam2);
 45      
 46      // If dots have the same sign, no silhouette edge crosses this segment
 47      if (dot1 * dot2 > 0.0f)
 48      {
 49          zeroPoint = float3(0, 0, 0);
 50          return false;
 51      }
 52      
 53      // Avoid division by zero, though this case (dot1 == dot2)
 54      // should be filtered out by the (dot1 * dot2 > 0) check.
 55      float denom = dot1 - dot2;
 56      if (abs(denom) < 0.00001f)
 57      {
 58          zeroPoint = float3(0, 0, 0);
 59          return false;
 60      }
 61  
 62      float t = dot1 / denom;
 63  
 64      // Linearly interpolate the world-space positions
 65      zeroPoint = lerp(v1.position, v2.position, t);
 66      return true;
 67  }
 68  
 69  // Adapted version of your TryGetZeroLine function for a compute shader
 70  // It writes its results directly to the _outStrokes buffer at [faceIdx]
 71  int DecodeAdj(uint rawAdj)
 72  {
 73      return (rawAdj == INVALID_UINT) ? ADJ_NONE : (int)rawAdj;
 74  }
 75  
 76  bool TryGetZeroLine(VertexData v0, VertexData v1, VertexData v2, uint faceIdx)
 77  {
 78      int points_found = 0;
 79      float3 zeroPoint;
 80  
 81      float3 tmpPoints[2];
 82      int  tmpAdj[2];
 83      int   tmpEdgeIdx[2];
 84  
 85      // Check edge v0-v1  (edge index 0)
 86      if (TryGetZeroPoint(v0, v1, zeroPoint))
 87      {
 88          tmpPoints[points_found] = zeroPoint;
 89          tmpAdj[points_found] = DecodeAdj(_AdjIndices[faceIdx * 3 + 0]);
 90          tmpEdgeIdx[points_found] = 0;
 91          points_found++;
 92      }
 93      
 94      // Check edge v1-v2  (edge index 1)
 95      if (TryGetZeroPoint(v1, v2, zeroPoint))
 96      {
 97          tmpPoints[points_found] = zeroPoint;
 98          tmpAdj[points_found] = DecodeAdj(_AdjIndices[faceIdx * 3 + 1]);
 99          tmpEdgeIdx[points_found] = 1;
100          points_found++;
101      }
102      
103      // Check edge v2-v0  (edge index 2)
104      if (TryGetZeroPoint(v2, v0, zeroPoint))
105      {
106          tmpPoints[points_found] = zeroPoint;
107          tmpAdj[points_found] = DecodeAdj(_AdjIndices[faceIdx * 3 + 2]);
108          tmpEdgeIdx[points_found] = 2;
109          points_found++;
110      }
111      
112      // If we found exactly two points, a valid silhouette line segment
113      // crosses this triangle.
114      if (points_found == 2)
115      {
116          bool flip = false;
117  
118          if (tmpEdgeIdx[0] == 0 && tmpEdgeIdx[1] == 2)
119          {
120              flip = true;
121          }
122          
123          if (tmpEdgeIdx[0] == 0 && tmpEdgeIdx[1] == 1)
124          {
125              if (dot(v2.normal, normalize(_WorldSpaceCameraPos - v2.position)) > 0.0f)
126              {
127                  flip = !flip;
128              }
129          }
130          else if (tmpEdgeIdx[0] == 1 && tmpEdgeIdx[1] == 2)
131          {
132              if (dot(v0.normal, normalize(_WorldSpaceCameraPos - v0.position)) > 0.0f)
133              {
134                  flip = !flip;
135              }
136          }
137          else if (tmpEdgeIdx[0] == 0 && tmpEdgeIdx[1] == 2)
138          {
139              if (dot(v1.normal, normalize(_WorldSpaceCameraPos - v1.position)) > 0.0f)
140              {
141                  flip = !flip;
142              }
143          }
144          
145          // Ensure winding is consistent with original faces
146          if (flip)
147          {
148              _outStrokes[faceIdx].pos = tmpPoints[1];
149              _outStrokes[faceIdx].adj = tmpAdj[1];
150          }
151          else
152          {
153              _outStrokes[faceIdx].pos = tmpPoints[0];
154              _outStrokes[faceIdx].adj = tmpAdj[0];
155          }
156  
157          return true;
158      }
159      
160      _outStrokes[faceIdx].flags |= STROKE_FLAG_IS_INVALID;
161      return false;
162  }
163  
164  
165  // for each input face (triangle), a line is generated if it contains a silhouette edge
166  // input is triangle list with adjacency info
167  // output is one StrokeData per triangle (valid is set only if a silhoutte line was found)
168  [numthreads(64, 1, 1)]
169  void FindSilhouetteEdge(uint3 id : SV_DispatchThreadID)
170  {
171      uint faceIdx = id.x;
172  
173      if (faceIdx >= _NumFaces)
174      {
175          return;
176      }
177      _outStrokes[faceIdx].flags = 0;
178  
179      // --- 1. Fetch Triangle Indices ---
180      uint i0 = (uint)_Indices[faceIdx * 3 + 0];
181      uint i1 = (uint)_Indices[faceIdx * 3 + 1];
182      uint i2 = (uint)_Indices[faceIdx * 3 + 2];
183  
184      // --- 2. Fetch Vertex Data ---
185      // Note: If vertices are shared, this will fetch the same vertex
186      // multiple times across different threads, which is fine.
187      VertexData v0 = _Vertices[i0];
188      VertexData v1 = _Vertices[i1];
189      VertexData v2 = _Vertices[i2];
190  
191      // --- Transform to world space using provided object matrices ---
192      VertexData w0;
193      VertexData w1;
194      VertexData w2;
195  
196      // transform positions (float4 mul) and normals (inverse-transpose 3x3)
197      w0.position = mul(_ObjectToWorld, float4(v0.position, 1.0)).xyz;
198      w1.position = mul(_ObjectToWorld, float4(v1.position, 1.0)).xyz;
199      w2.position = mul(_ObjectToWorld, float4(v2.position, 1.0)).xyz;
200  
201      float3x3 nit = (float3x3)_ObjectToWorldIT;
202      w0.normal = normalize(mul(nit, v0.normal));
203      w1.normal = normalize(mul(nit, v1.normal));
204      w2.normal = normalize(mul(nit, v2.normal));
205  
206      // --- 3. Run Silhouette Logic on world-space vertices ---
207      // This function will write the result directly to
208      // _outStrokes[faceIdx]
209      TryGetZeroLine(w0, w1, w2, faceIdx);
210  
211      // --- 4. Compute and store face normal (world space) ---
212      float3 fn = normalize(cross(w1.position - w0.position, w2.position - w0.position));
213      _outStrokes[faceIdx].faceNormal = fn;
214  }