Skip to content

Inference API

swebench.inference

llamao

distributed_attention
SeqAllToAll

Bases: Function

forward staticmethod
forward(ctx: Any, input: Tensor, scatter_idx: int, gather_idx: int, group: Any) -> Tensor
Source code in swebench/inference/llamao/distributed_attention.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
@staticmethod
def forward(
    ctx: Any, input: Tensor, scatter_idx: int, gather_idx: int, group: Any
) -> Tensor:
    ctx.scatter_idx = scatter_idx
    ctx.gather_idx = gather_idx
    ctx.group = group

    world_size = dist.get_world_size(group)

    input_list = [
        t.contiguous() for t in torch.tensor_split(input, world_size, scatter_idx)
    ]
    output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]

    dist.all_to_all(output_list, input_list, group=group)
    return torch.cat(output_list, dim=gather_idx).contiguous()
backward staticmethod
backward(ctx: Any, *grad_output: Tensor) -> tuple[Tensor, None, None, None]
Source code in swebench/inference/llamao/distributed_attention.py
34
35
36
37
38
39
40
41
@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> tuple[Tensor, None, None, None]:
    return (
        SeqAllToAll.apply(*grad_output, ctx.gather_idx, ctx.scatter_idx, ctx.group),
        None,
        None,
        None,
    )
DistributedAttention
DistributedAttention(local_attention: Module, scatter_idx: int = -2, gather_idx: int = 1)

Bases: Module

Initialization.

Parameters:

Name Type Description Default
local_attention Module

local attention with q,k,v

required
scatter_idx int

scatter_idx for all2all comm

-2
gather_idx int

gather_idx for all2all comm

1
Source code in swebench/inference/llamao/distributed_attention.py
53
54
55
56
57
58
59
60
61
62
def __init__(
    self,
    local_attention: Module,
    scatter_idx: int = -2,
    gather_idx: int = 1,
) -> None:
    super().__init__()
    self.local_attn = local_attention
    self.scatter_idx = scatter_idx  # head axis
    self.gather_idx = gather_idx  # seq axis
local_attn instance-attribute
local_attn = local_attention
scatter_idx instance-attribute
scatter_idx = scatter_idx
gather_idx instance-attribute
gather_idx = gather_idx
forward
forward(query: Tensor, key_values: Tensor, group: Any = None, **kwargs) -> Tensor

forward

Parameters:

Name Type Description Default
query Tensor

query input to the layer

required
key Tensor

key input to the layer

required
value Tensor

value input to the layer

required
args

other args

required

Returns:

Type Description
Tensor
  • output (Tensor): context output
Source code in swebench/inference/llamao/distributed_attention.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def forward(
    self, query: Tensor, key_values: Tensor, group: Any = None, **kwargs
) -> Tensor:
    """forward

    Arguments:
        query (Tensor): query input to the layer
        key (Tensor): key input to the layer
        value (Tensor): value input to the layer
        args: other args

    Returns:
        * output (Tensor): context output
    """
    # in shape : e.g.,  [s/p:h:]
    query_heads = SeqAllToAll.apply(query, self.scatter_idx, self.gather_idx, group)
    key_values_heads = SeqAllToAll.apply(
        key_values, self.scatter_idx, self.gather_idx, group
    )

    # out shape : e.g., [s:h/p:]
    output_heads = self.local_attn(query_heads, key_values_heads, **kwargs)

    # out e.g., [s/p::h]
    return SeqAllToAll.apply(output_heads, self.gather_idx, self.scatter_idx, group)
modeling_flash_llama

PyTorch LLaMA model.

logger module-attribute
logger = get_logger(__name__)
LlamaRMSNorm
LlamaRMSNorm(hidden_size, eps=1e-06)

Bases: Module

LlamaRMSNorm is equivalent to T5LayerNorm

Source code in swebench/inference/llamao/modeling_flash_llama.py
66
67
68
69
70
71
72
73
74
75
76
def __init__(self, hidden_size, eps=1e-6):
    """
    LlamaRMSNorm is equivalent to T5LayerNorm
    """
    super().__init__()
    self.weight = nn.Parameter(torch.ones(hidden_size))
    self.register_buffer(
        "variance_epsilon",
        torch.tensor(eps),
        persistent=False,
    )
weight instance-attribute
weight = Parameter(ones(hidden_size))
forward
forward(hidden_states)
Source code in swebench/inference/llamao/modeling_flash_llama.py
78
79
def forward(self, hidden_states):
    return rmsnorm_func(hidden_states, self.weight, self.variance_epsilon)
FlashRotaryEmbedding
FlashRotaryEmbedding(dim: int, base=10000.0, interleaved=False, scale_base=None, scaling_factor=1.0, pos_idx_in_fp32=True, device=None)

Bases: Module

The rotary position embeddings from RoFormer_ (Su et. al). A crucial insight from the method is that the query and keys are transformed by rotation matrices which depend on the relative positions.

Other implementations are available in the Rotary Transformer repo_ and in GPT-NeoX_, GPT-NeoX was an inspiration

.. _RoFormer: https://arxiv.org/abs/2104.09864 .. _repo: https://github.com/ZhuiyiTechnology/roformer .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox

If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96 Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py

if True, rotate pairs of even and odd dimensions (GPT-J style) instead

of 1st half and 2nd half (GPT-NeoX style).

pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32, otherwise they might be in lower precision. This option was added because previously (before 2023-07-02), when we construct the position indices, we use the dtype of self.inv_freq. In most cases this would be fp32, but if the model is trained in pure bf16 (not mixed precision), then self.inv_freq would be bf16, and the position indices are also in bf16. Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the embeddings for some positions will coincide. To maintain compatibility with models previously trained in pure bf16, we add this option. scaling_factor: RotaryEmbedding extended with linear scaling.

Source code in swebench/inference/llamao/modeling_flash_llama.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
def __init__(
    self,
    dim: int,
    base=10000.0,
    interleaved=False,
    scale_base=None,
    scaling_factor=1.0,
    pos_idx_in_fp32=True,
    device=None,
):
    """
    interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
        of 1st half and 2nd half (GPT-NeoX style).
    pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32,
        otherwise they might be in lower precision.
        This option was added because previously (before 2023-07-02), when we construct
        the position indices, we use the dtype of self.inv_freq. In most cases this would
        be fp32, but if the model is trained in pure bf16 (not mixed precision), then
        self.inv_freq would be bf16, and the position indices are also in bf16.
        Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the
        embeddings for some positions will coincide.
        To maintain compatibility with models previously trained in pure bf16,
        we add this option.
    scaling_factor: RotaryEmbedding extended with linear scaling.
    """
    super().__init__()
    self.dim = dim
    self.base = float(base)
    self.pos_idx_in_fp32 = pos_idx_in_fp32
    # Generate and save the inverse frequency buffer (non trainable)
    inv_freq = self._compute_inv_freq(device)
    self.register_buffer("inv_freq", inv_freq, persistent=False)
    self.interleaved = interleaved
    self.scale_base = scale_base
    self.scaling_factor = scaling_factor
    scale = (
        (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
        / (1.4 * dim)
        if scale_base is not None
        else None
    )
    self.register_buffer("scale", scale)

    self._seq_len_cached = 0
    self._cos_cached = None
    self._sin_cached = None
    self._cos_k_cached = None
    self._sin_k_cached = None
dim instance-attribute
dim = dim
base instance-attribute
base = float(base)
pos_idx_in_fp32 instance-attribute
pos_idx_in_fp32 = pos_idx_in_fp32
interleaved instance-attribute
interleaved = interleaved
scale_base instance-attribute
scale_base = scale_base
scaling_factor instance-attribute
scaling_factor = scaling_factor
forward
forward(q: Tensor, k: Tensor, seqlen_offset: int = 0, unpadded_lengths: Optional[tuple[Tensor]] = None) -> tuple[Tensor, Tensor]

q: (batch, seqlen, nheads, headdim) k: (batch, seqlen, nheads, headdim) seqlen_offset: can be used in generation where the qkv being passed in is only the last token in the batch.

Source code in swebench/inference/llamao/modeling_flash_llama.py
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
def forward(
    self,
    q: torch.Tensor,
    k: torch.Tensor,
    seqlen_offset: int = 0,
    unpadded_lengths: Optional[tuple[torch.Tensor]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    q: (batch, seqlen, nheads, headdim)
    k: (batch, seqlen, nheads, headdim)
    seqlen_offset: can be used in generation where the qkv being passed in is only the last
    token in the batch.
    """
    if unpadded_lengths is not None:
        cu_seqlens, max_seqlen = unpadded_lengths
    else:
        cu_seqlens, max_seqlen = None, q.shape[1]
    self._update_cos_sin_cache(
        max_seqlen + seqlen_offset, device=q.device, dtype=q.dtype
    )

    if self.scale is None:
        return (
            apply_rotary_emb_func(
                q,
                self._cos_cached[seqlen_offset:],
                self._sin_cached[seqlen_offset:],
                self.interleaved,
                True,  # inplace=True,
                cu_seqlens=cu_seqlens,
                max_seqlen=max_seqlen,
            ),
            apply_rotary_emb_func(
                k,
                self._cos_cached[seqlen_offset:],
                self._sin_cached[seqlen_offset:],
                self.interleaved,
                True,  # inplace=True
                cu_seqlens=cu_seqlens,
                max_seqlen=max_seqlen,
            ),
        )
    else:
        assert False
LlamaMLP
LlamaMLP(config)

Bases: Module

Source code in swebench/inference/llamao/modeling_flash_llama.py
254
255
256
257
258
259
260
261
262
def __init__(self, config):
    super().__init__()
    self.config = config
    self.hidden_size = config.hidden_size
    self.intermediate_size = config.intermediate_size
    self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
    self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
    self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
    self.act_fn = ACT2FN[config.hidden_act]
config instance-attribute
config = config
hidden_size instance-attribute
hidden_size = hidden_size
intermediate_size instance-attribute
intermediate_size = intermediate_size
gate_proj instance-attribute
gate_proj = Linear(hidden_size, intermediate_size, bias=False)
up_proj instance-attribute
up_proj = Linear(hidden_size, intermediate_size, bias=False)
down_proj instance-attribute
down_proj = Linear(intermediate_size, hidden_size, bias=False)
act_fn instance-attribute
act_fn = ACT2FN[hidden_act]
forward
forward(x)
Source code in swebench/inference/llamao/modeling_flash_llama.py
264
265
def forward(self, x):
    return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
LlamaAttention
LlamaAttention(config: LlamaConfig)

Bases: Module

Multi-headed attention from 'Attention Is All You Need' paper

Source code in swebench/inference/llamao/modeling_flash_llama.py
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
def __init__(self, config: LlamaConfig):
    super().__init__()
    self.config = config
    self.hidden_size = config.hidden_size
    self.num_heads = config.num_attention_heads
    self.head_dim = self.hidden_size // self.num_heads
    self.num_key_value_heads = getattr(
        config, "num_key_value_heads", self.num_heads
    )
    self.num_key_value_groups = self.num_heads // self.num_key_value_heads
    self.max_position_embeddings = config.max_position_embeddings

    if (self.head_dim * self.num_heads) != self.hidden_size:
        raise ValueError(
            f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
            f" and `num_heads`: {self.num_heads})."
        )
    self.q_proj = nn.Linear(
        self.hidden_size, self.num_heads * self.head_dim, bias=False
    )
    self.k_proj = nn.Linear(
        self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
    )
    self.v_proj = nn.Linear(
        self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
    )
    self.o_proj = nn.Linear(
        self.num_heads * self.head_dim, self.hidden_size, bias=False
    )

    self.register_buffer(
        "norm_factor",
        torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(
            torch.get_default_dtype()
        ),
        persistent=False,
    )

    if not getattr(self.config, "rope_scaling", None):
        scaling_factor = 1
    else:
        scaling_type = self.config.rope_scaling["type"]
        scaling_factor = self.config.rope_scaling["factor"]
        assert scaling_type == "linear"
    theta = getattr(self.config, "rope_theta", 10000)
    self.rotary_emb = FlashRotaryEmbedding(
        self.head_dim,
        base=theta,
        interleaved=False,
        scaling_factor=scaling_factor,
    )

    self.distributed_attn_func = DistributedAttention(flash_attn_kvpacked_func)
config instance-attribute
config = config
hidden_size instance-attribute
hidden_size = hidden_size
num_heads instance-attribute
num_heads = num_attention_heads
head_dim instance-attribute
head_dim = hidden_size // num_heads
num_key_value_heads instance-attribute
num_key_value_heads = getattr(config, 'num_key_value_heads', num_heads)
num_key_value_groups instance-attribute
num_key_value_groups = num_heads // num_key_value_heads
max_position_embeddings instance-attribute
max_position_embeddings = max_position_embeddings
q_proj instance-attribute
q_proj = Linear(hidden_size, num_heads * head_dim, bias=False)
k_proj instance-attribute
k_proj = Linear(hidden_size, num_key_value_heads * head_dim, bias=False)
v_proj instance-attribute
v_proj = Linear(hidden_size, num_key_value_heads * head_dim, bias=False)
o_proj instance-attribute
o_proj = Linear(num_heads * head_dim, hidden_size, bias=False)
rotary_emb instance-attribute
rotary_emb = FlashRotaryEmbedding(head_dim, base=theta, interleaved=False, scaling_factor=scaling_factor)
distributed_attn_func instance-attribute
distributed_attn_func = DistributedAttention(flash_attn_kvpacked_func)
forward
forward(hidden_states: Tensor, attention_mask: Optional[Tensor] = None, position_ids: Optional[LongTensor] = None, past_key_value: Optional[tuple[Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, unpadded_lengths: Optional[tuple[Tensor]] = None, seq_parallel_group: Optional[Any] = None) -> tuple[Tensor, Optional[Tensor], Optional[tuple[Tensor]]]
Source code in swebench/inference/llamao/modeling_flash_llama.py
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
def forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[tuple[torch.Tensor]] = None,
    output_attentions: bool = False,
    use_cache: bool = False,
    unpadded_lengths: Optional[tuple[torch.Tensor]] = None,
    seq_parallel_group: Optional[Any] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
    h_size = hidden_states.size(-1)

    has_layer_past = past_key_value is not None

    if has_layer_past:
        past_kv = past_key_value[0]
        past_len = past_key_value[1]
    else:
        past_len = 0

    # NOTE: Hack to include position_ids, assuming they are increasing uniformly per block
    if position_ids is not None:
        past_len += position_ids.min()

    q = self.q_proj(hidden_states)
    k = self.k_proj(hidden_states)
    v = self.v_proj(hidden_states)

    q = q.view(*q.shape[:-1], self.num_heads, self.head_dim)
    k = k.view(*k.shape[:-1], self.num_key_value_heads, self.head_dim)
    v = v.view(*v.shape[:-1], self.num_key_value_heads, self.head_dim)

    q, k = self.rotary_emb(q, k, past_len, unpadded_lengths)

    kv = torch.stack([k, v], -3)
    kv = repeat_kv(kv, self.num_key_value_groups)

    # Cache QKV values
    if has_layer_past:
        new_len = past_len + q.size(1)
        if new_len > past_kv.size(1):
            past_kv = torch.cat(
                [
                    past_kv,
                    torch.empty(
                        hidden_states.size(0),
                        256,
                        2,
                        kv.size(3),
                        kv.size(4),
                        dtype=kv.dtype,
                        device=kv.device,
                    ),
                ],
                1,
            )
        past_kv[:, past_len:new_len] = kv
        kv = past_kv[:, :new_len]
    else:
        past_kv = kv

    past_key_value = (past_kv, past_len + q.size(1)) if use_cache else None

    if dist.is_initialized() and dist.get_world_size(seq_parallel_group) > 1:
        # NOTE: we assume that padding tokens are at the end of the sequence and may ignore `attention_mask`
        assert output_attentions is False
        attn_outputs = self.distributed_attn_func(
            q,
            kv,
            dropout_p=0.0,
            softmax_scale=1.0 / self.norm_factor,
            causal=True,
            return_attn_probs=False,
            group=seq_parallel_group,
        )
    else:
        if unpadded_lengths is not None:
            # varlen, ignore padding tokens, efficient for large batch with many paddings
            assert attention_mask is not None
            cu_seqlens, max_seqlen = unpadded_lengths

            attn_outputs = flash_attn_varlen_kvpacked_func(
                q,
                kv,
                cu_seqlens,
                cu_seqlens,
                max_seqlen,
                max_seqlen,
                dropout_p=0.0,
                softmax_scale=1.0 / self.norm_factor,
                causal=True,
                return_attn_probs=output_attentions,
            )
        else:
            attn_outputs = flash_attn_kvpacked_func(
                q,
                kv,
                dropout_p=0.0,
                softmax_scale=1.0 / self.norm_factor,
                causal=True,
                return_attn_probs=output_attentions,
            )

    attn_output = attn_outputs[0] if output_attentions else attn_outputs
    attn_output = attn_output.reshape(*attn_output.shape[:-2], h_size)
    attn_weights = attn_outputs[2] if output_attentions else None

    attn_output = self.o_proj(attn_output)

    if not output_attentions:
        attn_weights = None

    return attn_output, attn_weights, past_key_value
LlamaDecoderLayer
LlamaDecoderLayer(config: LlamaConfig)

Bases: Module

Source code in swebench/inference/llamao/modeling_flash_llama.py
463
464
465
466
467
468
469
470
471
472
def __init__(self, config: LlamaConfig):
    super().__init__()
    self.hidden_size = config.hidden_size
    self.self_attn = LlamaAttention(config=config)
    self.mlp = LlamaMLP(config)
    self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
    self.post_attention_layernorm = LlamaRMSNorm(
        config.hidden_size, eps=config.rms_norm_eps
    )
    self._fsdp_wrap = True
hidden_size instance-attribute
hidden_size = hidden_size
self_attn instance-attribute
self_attn = LlamaAttention(config=config)
mlp instance-attribute
mlp = LlamaMLP(config)
input_layernorm instance-attribute
input_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps)
post_attention_layernorm instance-attribute
post_attention_layernorm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps)
forward
forward(hidden_states: Tensor, attention_mask: Optional[Tensor] = None, position_ids: Optional[LongTensor] = None, past_key_value: Optional[tuple[Tensor]] = None, unpadded_lengths: Optional[tuple[Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, seq_parallel_group: Optional[Any] = None) -> tuple[FloatTensor, Optional[tuple[FloatTensor, FloatTensor]]]

Parameters:

Name Type Description Default
hidden_states `torch.FloatTensor`

input to the layer of shape (batch, seq_len, embed_dim)

required
attention_mask `torch.FloatTensor`, *optional*

attention mask of size (batch, 1, tgt_len, src_len) where padding elements are indicated by very large negative values.

None
output_attentions `bool`, *optional*

Whether or not to return the attentions tensors of all attention layers. See attentions under returned tensors for more detail.

False
use_cache `bool`, *optional*

If set to True, past_key_values key value states are returned and can be used to speed up decoding (see past_key_values).

False
past_key_value `Tuple(torch.FloatTensor)`, *optional*

cached past key and value projection states

None
Source code in swebench/inference/llamao/modeling_flash_llama.py
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
def forward(
    self,
    hidden_states: torch.Tensor,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_value: Optional[tuple[torch.Tensor]] = None,
    unpadded_lengths: Optional[tuple[torch.Tensor]] = None,
    output_attentions: Optional[bool] = False,
    use_cache: Optional[bool] = False,
    seq_parallel_group: Optional[Any] = None,
) -> tuple[
    torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]
]:
    """
    Args:
        hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
        attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
            `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under
            returned tensors for more detail.
        use_cache (`bool`, *optional*):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
            (see `past_key_values`).
        past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
    """

    residual = hidden_states

    hidden_states = self.input_layernorm(hidden_states)

    # Self Attention
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
        hidden_states=hidden_states,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_value=past_key_value,
        output_attentions=output_attentions,
        use_cache=use_cache,
        unpadded_lengths=unpadded_lengths,
        seq_parallel_group=seq_parallel_group,
    )
    hidden_states = residual + hidden_states

    # Fully Connected
    residual = hidden_states
    hidden_states = self.post_attention_layernorm(hidden_states)
    hidden_states = self.mlp(hidden_states)
    hidden_states = residual + hidden_states

    outputs = (hidden_states,)

    if output_attentions:
        outputs += (self_attn_weights,)

    if use_cache:
        outputs += (present_key_value,)

    return outputs
LlamaPreTrainedModel

Bases: PreTrainedModel

config_class class-attribute instance-attribute
config_class = LlamaConfig
base_model_prefix class-attribute instance-attribute
base_model_prefix = 'model'
supports_gradient_checkpointing class-attribute instance-attribute
supports_gradient_checkpointing = True
LlamaModel
LlamaModel(config: LlamaConfig)

Bases: LlamaPreTrainedModel

Transformer decoder consisting of config.num_hidden_layers layers. Each layer is a [LlamaDecoderLayer]

Parameters:

Name Type Description Default
config LlamaConfig

LlamaConfig

required
Source code in swebench/inference/llamao/modeling_flash_llama.py
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
def __init__(self, config: LlamaConfig):
    super().__init__(config)
    self.padding_idx = config.pad_token_id
    self.vocab_size = config.vocab_size

    self.embed_tokens = nn.Embedding(
        config.vocab_size, config.hidden_size, self.padding_idx
    )
    self.layers = nn.ModuleList(
        [LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]
    )
    self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    self.gradient_checkpointing = False
    # Initialize weights and apply final processing
    self.post_init()
padding_idx instance-attribute
padding_idx = pad_token_id
vocab_size instance-attribute
vocab_size = vocab_size
embed_tokens instance-attribute
embed_tokens = Embedding(vocab_size, hidden_size, padding_idx)
layers instance-attribute
layers = ModuleList([LlamaDecoderLayer(config) for _ in range(num_hidden_layers)])
norm instance-attribute
norm = LlamaRMSNorm(hidden_size, eps=rms_norm_eps)
gradient_checkpointing instance-attribute
gradient_checkpointing = False
get_input_embeddings
get_input_embeddings()
Source code in swebench/inference/llamao/modeling_flash_llama.py
583
584
def get_input_embeddings(self):
    return self.embed_tokens
set_input_embeddings
set_input_embeddings(value)
Source code in swebench/inference/llamao/modeling_flash_llama.py
586
587
def set_input_embeddings(self, value):
    self.embed_tokens = value
forward
forward(input_ids: LongTensor = None, attention_mask: Optional[Tensor] = None, position_ids: Optional[LongTensor] = None, past_key_values: Optional[list[FloatTensor]] = None, inputs_embeds: Optional[FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, seq_parallel_group: Optional[Any] = None) -> Union[tuple, BaseModelOutputWithPast]
Source code in swebench/inference/llamao/modeling_flash_llama.py
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
def forward(
    self,
    input_ids: torch.LongTensor = None,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_values: Optional[list[torch.FloatTensor]] = None,
    inputs_embeds: Optional[torch.FloatTensor] = None,
    use_cache: Optional[bool] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
    seq_parallel_group: Optional[Any] = None,
) -> Union[tuple, BaseModelOutputWithPast]:
    output_attentions = (
        output_attentions
        if output_attentions is not None
        else self.config.output_attentions
    )
    output_hidden_states = (
        output_hidden_states
        if output_hidden_states is not None
        else self.config.output_hidden_states
    )
    use_cache = use_cache if use_cache is not None else self.config.use_cache

    return_dict = (
        return_dict if return_dict is not None else self.config.use_return_dict
    )

    # retrieve input_ids and inputs_embeds
    if input_ids is not None and inputs_embeds is not None:
        raise ValueError(
            "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
        )
    elif input_ids is not None:
        batch_size, seq_length = input_ids.shape
    elif inputs_embeds is not None:
        batch_size, seq_length, _ = inputs_embeds.shape
    else:
        raise ValueError(
            "You have to specify either decoder_input_ids or decoder_inputs_embeds"
        )

    # position_ids = None

    if inputs_embeds is None:
        inputs_embeds = self.embed_tokens(input_ids)

    hidden_states = inputs_embeds
    bsz = hidden_states.size(0)

    if self.gradient_checkpointing and self.training:
        if use_cache:
            logger.warning_once(
                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
            )
            use_cache = False

    if (
        ((attention_mask is not None) and (not attention_mask.all().item()))
        and not use_cache
        and not (
            dist.is_initialized() and dist.get_world_size(seq_parallel_group) > 1
        )
    ):
        hidden_states, unpad_indices, cu_seqlens, max_seqlen = unpad_input(
            hidden_states, attention_mask
        )
        unpadded_lengths = (cu_seqlens, max_seqlen)
    else:
        unpadded_lengths = None

    # decoder layers
    all_hidden_states = () if output_hidden_states else None
    all_self_attns = () if output_attentions else None
    next_decoder_cache = () if use_cache else None

    for idx, decoder_layer in enumerate(self.layers):
        if output_hidden_states:
            if unpadded_lengths is not None:
                all_hidden_states += (
                    pad_input(hidden_states, unpad_indices, bsz, max_seqlen),
                )
            else:
                all_hidden_states += (hidden_states,)

        past_key_value = (
            past_key_values[idx] if past_key_values is not None else None
        )

        if self.gradient_checkpointing and self.training:
            layer_outputs = torch.utils.checkpoint.checkpoint(
                decoder_layer,
                hidden_states,
                attention_mask,
                position_ids,
                None,
                unpadded_lengths,
                output_attentions,
                False,
                seq_parallel_group,
            )
        else:
            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                past_key_value=past_key_value,
                unpadded_lengths=unpadded_lengths,
                output_attentions=output_attentions,
                use_cache=use_cache,
                seq_parallel_group=seq_parallel_group,
            )

        hidden_states = layer_outputs[0]

        if use_cache:
            next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

        if output_attentions:
            all_self_attns += (layer_outputs[1],)

    if unpadded_lengths is not None:
        hidden_states = pad_input(hidden_states, unpad_indices, bsz, max_seqlen)
    hidden_states = self.norm(hidden_states)

    # add hidden states from the last decoder layer
    if output_hidden_states:
        all_hidden_states += (hidden_states,)

    next_cache = next_decoder_cache if use_cache else None
    if not return_dict:
        return tuple(
            v
            for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
            if v is not None
        )
    return BaseModelOutputWithPast(
        last_hidden_state=hidden_states,
        past_key_values=next_cache,
        hidden_states=all_hidden_states,
        attentions=all_self_attns,
    )
LlamaForCausalLM
LlamaForCausalLM(config)

Bases: LlamaPreTrainedModel

Source code in swebench/inference/llamao/modeling_flash_llama.py
737
738
739
740
741
742
743
744
def __init__(self, config):
    super().__init__(config)
    self.model = LlamaModel(config)
    self.vocab_size = config.vocab_size
    self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

    # Initialize weights and apply final processing
    self.post_init()
model instance-attribute
model = LlamaModel(config)
vocab_size instance-attribute
vocab_size = vocab_size
lm_head instance-attribute
lm_head = Linear(hidden_size, vocab_size, bias=False)
get_input_embeddings
get_input_embeddings()
Source code in swebench/inference/llamao/modeling_flash_llama.py
746
747
def get_input_embeddings(self):
    return self.model.embed_tokens
set_input_embeddings
set_input_embeddings(value)
Source code in swebench/inference/llamao/modeling_flash_llama.py
749
750
def set_input_embeddings(self, value):
    self.model.embed_tokens = value
get_output_embeddings
get_output_embeddings()
Source code in swebench/inference/llamao/modeling_flash_llama.py
752
753
def get_output_embeddings(self):
    return self.lm_head
set_output_embeddings
set_output_embeddings(new_embeddings)
Source code in swebench/inference/llamao/modeling_flash_llama.py
755
756
def set_output_embeddings(self, new_embeddings):
    self.lm_head = new_embeddings
set_decoder
set_decoder(decoder)
Source code in swebench/inference/llamao/modeling_flash_llama.py
758
759
def set_decoder(self, decoder):
    self.model = decoder
get_decoder
get_decoder()
Source code in swebench/inference/llamao/modeling_flash_llama.py
761
762
def get_decoder(self):
    return self.model
forward
forward(input_ids: LongTensor = None, attention_mask: Optional[Tensor] = None, position_ids: Optional[LongTensor] = None, past_key_values: Optional[list[FloatTensor]] = None, inputs_embeds: Optional[FloatTensor] = None, labels: Optional[LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, unpadded_lengths: Optional[bool] = None, avg_valid_labels_per_chunk: Optional[float] = None, seq_parallel_group: Optional[Any] = None) -> Union[tuple, CausalLMOutputWithPast]

Parameters:

Name Type Description Default
labels `torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*

Labels for computing the masked language modeling loss. Indices should either be in [0, ..., config.vocab_size] or -100 (see input_ids docstring). Tokens with indices set to -100 are ignored (masked), the loss is only computed for the tokens with labels in [0, ..., config.vocab_size].

None

Returns:

Example:

>>> from transformers import AutoTokenizer, LlamaForCausalLM

>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)

>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")

>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
Source code in swebench/inference/llamao/modeling_flash_llama.py
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
def forward(
    self,
    input_ids: torch.LongTensor = None,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_values: Optional[list[torch.FloatTensor]] = None,
    inputs_embeds: Optional[torch.FloatTensor] = None,
    labels: Optional[torch.LongTensor] = None,
    use_cache: Optional[bool] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
    unpadded_lengths: Optional[bool] = None,
    avg_valid_labels_per_chunk: Optional[float] = None,
    seq_parallel_group: Optional[Any] = None,
) -> Union[tuple, CausalLMOutputWithPast]:
    r"""
    Args:
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

    Returns:

    Example:

    ```python
    >>> from transformers import AutoTokenizer, LlamaForCausalLM

    >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
    >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)

    >>> prompt = "Hey, are you conscious? Can you talk to me?"
    >>> inputs = tokenizer(prompt, return_tensors="pt")

    >>> # Generate
    >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
    >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
    ```"""

    output_attentions = (
        output_attentions
        if output_attentions is not None
        else self.config.output_attentions
    )
    output_hidden_states = (
        output_hidden_states
        if output_hidden_states is not None
        else self.config.output_hidden_states
    )
    return_dict = (
        return_dict if return_dict is not None else self.config.use_return_dict
    )

    # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
    outputs = self.model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_values=past_key_values,
        inputs_embeds=inputs_embeds,
        use_cache=use_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
        seq_parallel_group=seq_parallel_group,
    )

    hidden_states = outputs[0]
    loss = None
    if labels is not None:
        # Only compute loss on tokens that contribute to loss
        valid_prediction = labels != -100
        hidden_states_ = hidden_states[valid_prediction]
        logits = self.lm_head(hidden_states_).float()

        # NOTE: We don't shift the labels inside the model here!
        labels_ = labels[valid_prediction]

        if (
            avg_valid_labels_per_chunk is not None
            and avg_valid_labels_per_chunk > 0
        ):
            # Don't take mean since this will give unequal weight to GPUs with unequal amount of padding
            loss = F.cross_entropy(logits, labels_, reduction="mean") * (
                labels_.numel() / avg_valid_labels_per_chunk
            )
            if not valid_prediction.any():
                loss.data = torch.zeros_like(loss)
        else:
            loss = F.cross_entropy(logits, labels_, reduction="mean")
    else:
        logits = self.lm_head(hidden_states).float()

    if not return_dict:
        output = (logits,) + outputs[1:]
        return (loss,) + output if loss is not None else output

    return CausalLMOutputWithPast(
        loss=loss,
        logits=logits,
        past_key_values=outputs.past_key_values,
        hidden_states=outputs.hidden_states,
        attentions=outputs.attentions,
    )
prepare_inputs_for_generation
prepare_inputs_for_generation(input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs)
Source code in swebench/inference/llamao/modeling_flash_llama.py
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
def prepare_inputs_for_generation(
    self,
    input_ids,
    past_key_values=None,
    attention_mask=None,
    inputs_embeds=None,
    **kwargs,
):
    if past_key_values:
        input_ids = input_ids[:, -1:]

    # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
    if inputs_embeds is not None and past_key_values is None:
        model_inputs = {"inputs_embeds": inputs_embeds}
    else:
        model_inputs = {"input_ids": input_ids}

    model_inputs.update(
        {
            "position_ids": kwargs.get("position_ids", None),
            "past_key_values": past_key_values,
            "use_cache": kwargs.get("use_cache"),
            "attention_mask": attention_mask,
            "unpadded_lengths": (
                (attention_mask is not None) and (not attention_mask.all().item())
            ),
            "seq_parallel_group": kwargs.get("seq_parallel_group"),
        }
    )
    return model_inputs
LlamaForSequenceClassification
LlamaForSequenceClassification(config)

Bases: LlamaPreTrainedModel

Source code in swebench/inference/llamao/modeling_flash_llama.py
917
918
919
920
921
922
923
924
def __init__(self, config):
    super().__init__(config)
    self.num_labels = config.num_labels
    self.model = LlamaModel(config)
    self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)

    # Initialize weights and apply final processing
    self.post_init()
num_labels instance-attribute
num_labels = num_labels
model instance-attribute
model = LlamaModel(config)
score instance-attribute
score = Linear(hidden_size, num_labels, bias=False)
get_input_embeddings
get_input_embeddings()
Source code in swebench/inference/llamao/modeling_flash_llama.py
926
927
def get_input_embeddings(self):
    return self.model.embed_tokens
set_input_embeddings
set_input_embeddings(value)
Source code in swebench/inference/llamao/modeling_flash_llama.py
929
930
def set_input_embeddings(self, value):
    self.model.embed_tokens = value
forward
forward(input_ids: LongTensor = None, attention_mask: Optional[Tensor] = None, position_ids: Optional[LongTensor] = None, past_key_values: Optional[list[FloatTensor]] = None, inputs_embeds: Optional[FloatTensor] = None, labels: Optional[LongTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None) -> Union[tuple, SequenceClassifierOutputWithPast]

labels (torch.LongTensor of shape (batch_size,), optional): Labels for computing the sequence classification/regression loss. Indices should be in [0, ..., config.num_labels - 1]. If config.num_labels == 1 a regression loss is computed (Mean-Square loss), If config.num_labels > 1 a classification loss is computed (Cross-Entropy).

Source code in swebench/inference/llamao/modeling_flash_llama.py
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
def forward(
    self,
    input_ids: torch.LongTensor = None,
    attention_mask: Optional[torch.Tensor] = None,
    position_ids: Optional[torch.LongTensor] = None,
    past_key_values: Optional[list[torch.FloatTensor]] = None,
    inputs_embeds: Optional[torch.FloatTensor] = None,
    labels: Optional[torch.LongTensor] = None,
    use_cache: Optional[bool] = None,
    output_attentions: Optional[bool] = None,
    output_hidden_states: Optional[bool] = None,
    return_dict: Optional[bool] = None,
) -> Union[tuple, SequenceClassifierOutputWithPast]:
    r"""
    labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
        Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
        config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
        `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
    """
    return_dict = (
        return_dict if return_dict is not None else self.config.use_return_dict
    )

    transformer_outputs = self.model(
        input_ids,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_values=past_key_values,
        inputs_embeds=inputs_embeds,
        use_cache=use_cache,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
    )
    hidden_states = transformer_outputs[0]
    logits = self.score(hidden_states)

    if input_ids is not None:
        batch_size = input_ids.shape[0]
    else:
        batch_size = inputs_embeds.shape[0]

    if self.config.pad_token_id is None and batch_size != 1:
        raise ValueError(
            "Cannot handle batch sizes > 1 if no padding token is defined."
        )
    if self.config.pad_token_id is None:
        sequence_lengths = -1
    else:
        if input_ids is not None:
            sequence_lengths = (
                torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1
            ).to(logits.device)
        else:
            sequence_lengths = -1

    pooled_logits = logits[
        torch.arange(batch_size, device=logits.device), sequence_lengths
    ]

    loss = None
    if labels is not None:
        labels = labels.to(logits.device)
        if self.config.problem_type is None:
            if self.num_labels == 1:
                self.config.problem_type = "regression"
            elif self.num_labels > 1 and (
                labels.dtype == torch.long or labels.dtype == torch.int
            ):
                self.config.problem_type = "single_label_classification"
            else:
                self.config.problem_type = "multi_label_classification"

        if self.config.problem_type == "regression":
            loss_fct = MSELoss()
            if self.num_labels == 1:
                loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
            else:
                loss = loss_fct(pooled_logits, labels)
        elif self.config.problem_type == "single_label_classification":
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(
                pooled_logits.view(-1, self.num_labels), labels.view(-1)
            )
        elif self.config.problem_type == "multi_label_classification":
            loss_fct = BCEWithLogitsLoss()
            loss = loss_fct(pooled_logits, labels)
    if not return_dict:
        output = (pooled_logits,) + transformer_outputs[1:]
        return ((loss,) + output) if loss is not None else output

    return SequenceClassifierOutputWithPast(
        loss=loss,
        logits=pooled_logits,
        past_key_values=transformer_outputs.past_key_values,
        hidden_states=transformer_outputs.hidden_states,
        attentions=transformer_outputs.attentions,
    )
rmsnorm_func
rmsnorm_func(hidden_states, weight, variance_epsilon)
Source code in swebench/inference/llamao/modeling_flash_llama.py
57
58
59
60
61
62
def rmsnorm_func(hidden_states, weight, variance_epsilon):
    input_dtype = hidden_states.dtype
    hidden_states = hidden_states.to(torch.float32)
    variance = hidden_states.pow(2).mean(-1, keepdim=True)
    hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
    return (weight * hidden_states).to(input_dtype)
repeat_kv
repeat_kv(hidden_states: Tensor, n_rep: int) -> Tensor

This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)

Source code in swebench/inference/llamao/modeling_flash_llama.py
268
269
270
271
272
273
274
275
276
277
278
279
@torch.jit.script
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    if n_rep == 1:
        return hidden_states
    final_shape = list(hidden_states.shape[:-2]) + [-1] + [hidden_states.shape[-1]]
    expand_shape = [-1] * (len(hidden_states.shape) - 1) + [n_rep] + [-1]
    hidden_states = hidden_states.unsqueeze(-1).expand(expand_shape)
    return hidden_states.reshape(final_shape)

make_datasets

bm25_retrieval
logger module-attribute
logger = getLogger(__name__)
DOCUMENT_ENCODING_FUNCTIONS module-attribute
DOCUMENT_ENCODING_FUNCTIONS = {'file_name_and_contents': file_name_and_contents, 'file_name_and_documentation': file_name_and_documentation, 'file_name_and_docs_jedi': file_name_and_docs_jedi}
parser module-attribute
parser = ArgumentParser()
args module-attribute
args = parse_args()
ContextManager
ContextManager(repo_path, base_commit, verbose=False)

A context manager for managing a Git repository at a specific commit.

Parameters:

Name Type Description Default
repo_path str

The path to the Git repository.

required
base_commit str

The commit hash to switch to.

required
verbose bool

Whether to print verbose output. Defaults to False.

False

Attributes:

Name Type Description
repo_path str

The path to the Git repository.

base_commit str

The commit hash to switch to.

verbose bool

Whether to print verbose output.

repo Repo

The Git repository object.

Methods:

Name Description
__enter__

Switches to the specified commit and returns the context manager object.

get_readme_files

Returns a list of filenames for all README files in the repository.

__exit__

Does nothing.

Source code in swebench/inference/make_datasets/bm25_retrieval.py
46
47
48
49
50
def __init__(self, repo_path, base_commit, verbose=False):
    self.repo_path = Path(repo_path).resolve().as_posix()
    self.base_commit = base_commit
    self.verbose = verbose
    self.repo = Repo(self.repo_path)
repo_path instance-attribute
repo_path = as_posix()
base_commit instance-attribute
base_commit = base_commit
verbose instance-attribute
verbose = verbose
repo instance-attribute
repo = Repo(repo_path)
__enter__
__enter__()
Source code in swebench/inference/make_datasets/bm25_retrieval.py
52
53
54
55
56
57
58
59
60
61
62
def __enter__(self):
    if self.verbose:
        print(f"Switching to {self.base_commit}")
    try:
        self.repo.git.reset("--hard", self.base_commit)
        self.repo.git.clean("-fdxq")
    except Exception as e:
        logger.error(f"Failed to switch to {self.base_commit}")
        logger.error(e)
        raise e
    return self
get_readme_files
get_readme_files()
Source code in swebench/inference/make_datasets/bm25_retrieval.py
64
65
66
67
68
def get_readme_files(self):
    files = os.listdir(self.repo_path)
    files = list(filter(lambda x: os.path.isfile(x), files))
    files = list(filter(lambda x: x.lower().startswith("readme"), files))
    return files
__exit__
__exit__(exc_type, exc_val, exc_tb)
Source code in swebench/inference/make_datasets/bm25_retrieval.py
70
71
def __exit__(self, exc_type, exc_val, exc_tb):
    pass
file_name_and_contents
file_name_and_contents(filename, relative_path)
Source code in swebench/inference/make_datasets/bm25_retrieval.py
74
75
76
77
78
def file_name_and_contents(filename, relative_path):
    text = relative_path + "\n"
    with open(filename) as f:
        text += f.read()
    return text
file_name_and_documentation
file_name_and_documentation(filename, relative_path)
Source code in swebench/inference/make_datasets/bm25_retrieval.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def file_name_and_documentation(filename, relative_path):
    text = relative_path + "\n"
    try:
        with open(filename) as f:
            node = ast.parse(f.read())
        data = ast.get_docstring(node)
        if data:
            text += f"{data}"
        for child_node in ast.walk(node):
            if isinstance(
                child_node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)
            ):
                data = ast.get_docstring(child_node)
                if data:
                    text += f"\n\n{child_node.name}\n{data}"
    except Exception as e:
        logger.error(e)
        logger.error(f"Failed to parse file {str(filename)}. Using simple filecontent.")
        with open(filename) as f:
            text += f.read()
    return text
file_name_and_docs_jedi
file_name_and_docs_jedi(filename, relative_path)
Source code in swebench/inference/make_datasets/bm25_retrieval.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def file_name_and_docs_jedi(filename, relative_path):
    text = relative_path + "\n"
    with open(filename) as f:
        source_code = f.read()
    try:
        script = jedi.Script(source_code, path=filename)
        module = script.get_context()
        docstring = module.docstring()
        text += f"{module.full_name}\n"
        if docstring:
            text += f"{docstring}\n\n"
        abspath = Path(filename).absolute()
        names = [
            name
            for name in script.get_names(
                all_scopes=True, definitions=True, references=False
            )
            if not name.in_builtin_module()
        ]
        for name in names:
            try:
                origin = name.goto(follow_imports=True)[0]
                if origin.module_name != module.full_name:
                    continue
                if name.parent().full_name != module.full_name:
                    if name.type in {"statement", "param"}:
                        continue
                full_name = name.full_name
                text += f"{full_name}\n"
                docstring = name.docstring()
                if docstring:
                    text += f"{docstring}\n\n"
            except:
                continue
    except Exception as e:
        logger.error(e)
        logger.error(f"Failed to parse file {str(filename)}. Using simple filecontent.")
        text = f"{relative_path}\n{source_code}"
        return text
    return text
clone_repo
clone_repo(repo, root_dir, token)

Clones a GitHub repository to a specified directory.

Parameters:

Name Type Description Default
repo str

The GitHub repository to clone.

required
root_dir str

The root directory to clone the repository to.

required
token str

The GitHub personal access token to use for authentication.

required

Returns:

Name Type Description
Path

The path to the cloned repository directory.

Source code in swebench/inference/make_datasets/bm25_retrieval.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
def clone_repo(repo, root_dir, token):
    """
    Clones a GitHub repository to a specified directory.

    Args:
        repo (str): The GitHub repository to clone.
        root_dir (str): The root directory to clone the repository to.
        token (str): The GitHub personal access token to use for authentication.

    Returns:
        Path: The path to the cloned repository directory.
    """
    repo_dir = Path(root_dir, f"repo__{repo.replace('/', '__')}")

    if not repo_dir.exists():
        repo_url = f"https://{token}@github.com/{repo}.git"
        logger.info(f"Cloning {repo} {os.getpid()}")
        Repo.clone_from(repo_url, repo_dir)
    return repo_dir
build_documents
build_documents(repo_dir, commit, document_encoding_func)

Builds a dictionary of documents from a given repository directory and commit.

Parameters:

Name Type Description Default
repo_dir str

The path to the repository directory.

required
commit str

The commit hash to use.

required
document_encoding_func function

A function that takes a filename and a relative path and returns the encoded document text.

required

Returns:

Name Type Description
dict

A dictionary where the keys are the relative paths of the documents and the values are the encoded document text.

Source code in swebench/inference/make_datasets/bm25_retrieval.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
def build_documents(repo_dir, commit, document_encoding_func):
    """
    Builds a dictionary of documents from a given repository directory and commit.

    Args:
        repo_dir (str): The path to the repository directory.
        commit (str): The commit hash to use.
        document_encoding_func (function): A function that takes a filename and a relative path and returns the encoded document text.

    Returns:
        dict: A dictionary where the keys are the relative paths of the documents and the values are the encoded document text.
    """
    documents = dict()
    with ContextManager(repo_dir, commit):
        filenames = list_files(repo_dir, include_tests=False)
        for relative_path in filenames:
            filename = os.path.join(repo_dir, relative_path)
            text = document_encoding_func(filename, relative_path)
            documents[relative_path] = text
    return documents
make_index
make_index(repo_dir, root_dir, query, commit, document_encoding_func, python, instance_id)

Builds an index for a given set of documents using Pyserini.

Parameters:

Name Type Description Default
repo_dir str

The path to the repository directory.

required
root_dir str

The path to the root directory.

required
query str

The query to use for retrieval.

required
commit str

The commit hash to use for retrieval.

required
document_encoding_func function

The function to use for encoding documents.

required
python str

The path to the Python executable.

required
instance_id int

The ID of the current instance.

required

Returns:

Name Type Description
index_path Path

The path to the built index.

Source code in swebench/inference/make_datasets/bm25_retrieval.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
def make_index(
    repo_dir,
    root_dir,
    query,
    commit,
    document_encoding_func,
    python,
    instance_id,
):
    """
    Builds an index for a given set of documents using Pyserini.

    Args:
        repo_dir (str): The path to the repository directory.
        root_dir (str): The path to the root directory.
        query (str): The query to use for retrieval.
        commit (str): The commit hash to use for retrieval.
        document_encoding_func (function): The function to use for encoding documents.
        python (str): The path to the Python executable.
        instance_id (int): The ID of the current instance.

    Returns:
        index_path (Path): The path to the built index.
    """
    index_path = Path(root_dir, f"index__{str(instance_id)}", "index")
    if index_path.exists():
        return index_path
    thread_prefix = f"(pid {os.getpid()}) "
    documents_path = Path(root_dir, instance_id, "documents.jsonl")
    if not documents_path.parent.exists():
        documents_path.parent.mkdir(parents=True)
    documents = build_documents(repo_dir, commit, document_encoding_func)
    with open(documents_path, "w") as docfile:
        for relative_path, contents in documents.items():
            print(
                json.dumps({"id": relative_path, "contents": contents}),
                file=docfile,
                flush=True,
            )
    cmd = [
        python,
        "-m",
        "pyserini.index",
        "--collection",
        "JsonCollection",
        "--generator",
        "DefaultLuceneDocumentGenerator",
        "--threads",
        "2",
        "--input",
        documents_path.parent.as_posix(),
        "--index",
        index_path.as_posix(),
        "--storePositions",
        "--storeDocvectors",
        "--storeRaw",
    ]
    try:
        proc = subprocess.Popen(
            cmd,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            universal_newlines=True,
        )
        output, error = proc.communicate()
    except KeyboardInterrupt:
        proc.kill()
        raise KeyboardInterrupt
    if proc.returncode == 130:
        logger.warning(thread_prefix + "Process killed by user")
        raise KeyboardInterrupt
    if proc.returncode != 0:
        logger.error(f"return code: {proc.returncode}")
        raise Exception(
            thread_prefix
            + f"Failed to build index for {instance_id} with error {error}"
        )
    return index_path
get_remaining_instances
get_remaining_instances(instances, output_file)

Filters a list of instances to exclude those that have already been processed and saved in a file.

Parameters:

Name Type Description Default
instances List[Dict]

A list of instances, where each instance is a dictionary with an "instance_id" key.

required
output_file Path

The path to the file where the processed instances are saved.

required

Returns:

Type Description

List[Dict]: A list of instances that have not been processed yet.

Source code in swebench/inference/make_datasets/bm25_retrieval.py
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
def get_remaining_instances(instances, output_file):
    """
    Filters a list of instances to exclude those that have already been processed and saved in a file.

    Args:
        instances (List[Dict]): A list of instances, where each instance is a dictionary with an "instance_id" key.
        output_file (Path): The path to the file where the processed instances are saved.

    Returns:
        List[Dict]: A list of instances that have not been processed yet.
    """
    instance_ids = set()
    remaining_instances = list()
    if output_file.exists():
        with FileLock(output_file.as_posix() + ".lock"):
            with open(output_file) as f:
                for line in f:
                    instance = json.loads(line)
                    instance_id = instance["instance_id"]
                    instance_ids.add(instance_id)
            logger.warning(
                f"Found {len(instance_ids)} existing instances in {output_file}. Will skip them."
            )
    else:
        output_file.parent.mkdir(parents=True, exist_ok=True)
        return instances
    for instance in instances:
        instance_id = instance["instance_id"]
        if instance_id not in instance_ids:
            remaining_instances.append(instance)
    return remaining_instances
search
search(instance, index_path)

Searches for relevant documents in the given index for the given instance.

Parameters:

Name Type Description Default
instance dict

The instance to search for.

required
index_path str

The path to the index to search in.

required

Returns:

Name Type Description
dict

A dictionary containing the instance ID and a list of hits, where each hit is a dictionary containing the

document ID and its score.

Source code in swebench/inference/make_datasets/bm25_retrieval.py
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
def search(instance, index_path):
    """
    Searches for relevant documents in the given index for the given instance.

    Args:
        instance (dict): The instance to search for.
        index_path (str): The path to the index to search in.

    Returns:
        dict: A dictionary containing the instance ID and a list of hits, where each hit is a dictionary containing the
        document ID and its score.
    """
    try:
        instance_id = instance["instance_id"]
        searcher = LuceneSearcher(index_path.as_posix())
        cutoff = len(instance["problem_statement"])
        while True:
            try:
                hits = searcher.search(
                    instance["problem_statement"][:cutoff],
                    k=20,
                    remove_dups=True,
                )
            except Exception as e:
                if "maxClauseCount" in str(e):
                    cutoff = int(round(cutoff * 0.8))
                    continue
                else:
                    raise e
            break
        results = {"instance_id": instance_id, "hits": []}
        for hit in hits:
            results["hits"].append({"docid": hit.docid, "score": hit.score})
        return results
    except Exception:
        logger.error(f"Failed to process {instance_id}")
        logger.error(traceback.format_exc())
        return None
search_indexes
search_indexes(remaining_instance, output_file, all_index_paths)

Searches the indexes for the given instances and writes the results to the output file.

Parameters:

Name Type Description Default
remaining_instance list

A list of instances to search for.

required
output_file str

The path to the output file to write the results to.

required
all_index_paths dict

A dictionary mapping instance IDs to the paths of their indexes.

required
Source code in swebench/inference/make_datasets/bm25_retrieval.py
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
def search_indexes(remaining_instance, output_file, all_index_paths):
    """
    Searches the indexes for the given instances and writes the results to the output file.

    Args:
        remaining_instance (list): A list of instances to search for.
        output_file (str): The path to the output file to write the results to.
        all_index_paths (dict): A dictionary mapping instance IDs to the paths of their indexes.
    """
    for instance in tqdm(remaining_instance, desc="Retrieving"):
        instance_id = instance["instance_id"]
        if instance_id not in all_index_paths:
            continue
        index_path = all_index_paths[instance_id]
        results = search(instance, index_path)
        if results is None:
            continue
        with FileLock(output_file.as_posix() + ".lock"):
            with open(output_file, "a") as out_file:
                print(json.dumps(results), file=out_file, flush=True)
get_missing_ids
get_missing_ids(instances, output_file)
Source code in swebench/inference/make_datasets/bm25_retrieval.py
371
372
373
374
375
376
377
378
379
380
381
382
383
def get_missing_ids(instances, output_file):
    with open(output_file) as f:
        written_ids = set()
        for line in f:
            instance = json.loads(line)
            instance_id = instance["instance_id"]
            written_ids.add(instance_id)
    missing_ids = set()
    for instance in instances:
        instance_id = instance["instance_id"]
        if instance_id not in written_ids:
            missing_ids.add(instance_id)
    return missing_ids
get_index_paths_worker
get_index_paths_worker(instance, root_dir_name, document_encoding_func, python, token)
Source code in swebench/inference/make_datasets/bm25_retrieval.py
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
def get_index_paths_worker(
    instance,
    root_dir_name,
    document_encoding_func,
    python,
    token,
):
    index_path = None
    repo = instance["repo"]
    commit = instance["base_commit"]
    instance_id = instance["instance_id"]
    try:
        repo_dir = clone_repo(repo, root_dir_name, token)
        query = instance["problem_statement"]
        index_path = make_index(
            repo_dir=repo_dir,
            root_dir=root_dir_name,
            query=query,
            commit=commit,
            document_encoding_func=document_encoding_func,
            python=python,
            instance_id=instance_id,
        )
    except:
        logger.error(f"Failed to process {repo}/{commit} (instance {instance_id})")
        logger.error(traceback.format_exc())
    return instance_id, index_path
get_index_paths
get_index_paths(remaining_instances: list[dict[str, Any]], root_dir_name: str, document_encoding_func: Any, python: str, token: str, output_file: str) -> dict[str, str]

Retrieves the index paths for the given instances using multiple processes.

Parameters:

Name Type Description Default
remaining_instances list[dict[str, Any]]

A list of instances for which to retrieve the index paths.

required
root_dir_name str

The root directory name.

required
document_encoding_func Any

A function for encoding documents.

required
python str

The path to the Python executable.

required
token str

The token to use for authentication.

required
output_file str

The output file.

required
num_workers

The number of worker processes to use.

required

Returns:

Type Description
dict[str, str]

A dictionary mapping instance IDs to index paths.

Source code in swebench/inference/make_datasets/bm25_retrieval.py
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
def get_index_paths(
    remaining_instances: list[dict[str, Any]],
    root_dir_name: str,
    document_encoding_func: Any,
    python: str,
    token: str,
    output_file: str,
) -> dict[str, str]:
    """
    Retrieves the index paths for the given instances using multiple processes.

    Args:
        remaining_instances: A list of instances for which to retrieve the index paths.
        root_dir_name: The root directory name.
        document_encoding_func: A function for encoding documents.
        python: The path to the Python executable.
        token: The token to use for authentication.
        output_file: The output file.
        num_workers: The number of worker processes to use.

    Returns:
        A dictionary mapping instance IDs to index paths.
    """
    all_index_paths = dict()
    for instance in tqdm(remaining_instances, desc="Indexing"):
        instance_id, index_path = get_index_paths_worker(
            instance=instance,
            root_dir_name=root_dir_name,
            document_encoding_func=document_encoding_func,
            python=python,
            token=token,
        )
        if index_path is None:
            continue
        all_index_paths[instance_id] = index_path
    return all_index_paths
get_root_dir
get_root_dir(dataset_name, output_dir, document_encoding_style)
Source code in swebench/inference/make_datasets/bm25_retrieval.py
453
454
455
456
457
458
def get_root_dir(dataset_name, output_dir, document_encoding_style):
    root_dir = Path(output_dir, dataset_name, document_encoding_style + "_indexes")
    if not root_dir.exists():
        root_dir.mkdir(parents=True, exist_ok=True)
    root_dir_name = root_dir
    return root_dir, root_dir_name
main
main(dataset_name_or_path, document_encoding_style, output_dir, shard_id, num_shards, splits, leave_indexes)
Source code in swebench/inference/make_datasets/bm25_retrieval.py
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
def main(
    dataset_name_or_path,
    document_encoding_style,
    output_dir,
    shard_id,
    num_shards,
    splits,
    leave_indexes,
):
    document_encoding_func = DOCUMENT_ENCODING_FUNCTIONS[document_encoding_style]
    token = os.environ.get("GITHUB_TOKEN", "git")
    if Path(dataset_name_or_path).exists():
        dataset = load_from_disk(dataset_name_or_path)
        dataset_name = os.path.basename(dataset_name_or_path)
    else:
        dataset = load_dataset(dataset_name_or_path)
        dataset_name = dataset_name_or_path.replace("/", "__")
    if shard_id is not None:
        for split in splits:
            dataset[split] = dataset[split].shard(num_shards, shard_id)
    instances = list()
    if set(splits) - set(dataset.keys()) != set():
        raise ValueError(f"Unknown splits {set(splits) - set(dataset.keys())}")
    for split in splits:
        instances += list(dataset[split])
    python = subprocess.run("which python", shell=True, capture_output=True)
    python = python.stdout.decode("utf-8").strip()
    output_file = Path(
        output_dir, dataset_name, document_encoding_style + ".retrieval.jsonl"
    )
    remaining_instances = get_remaining_instances(instances, output_file)
    root_dir, root_dir_name = get_root_dir(
        dataset_name, output_dir, document_encoding_style
    )
    try:
        all_index_paths = get_index_paths(
            remaining_instances,
            root_dir_name,
            document_encoding_func,
            python,
            token,
            output_file,
        )
    except KeyboardInterrupt:
        logger.info(f"Cleaning up {root_dir}")
        del_dirs = list(root_dir.glob("repo__*"))
        if leave_indexes:
            index_dirs = list(root_dir.glob("index__*"))
            del_dirs += index_dirs
        for dirname in del_dirs:
            shutil.rmtree(dirname, ignore_errors=True)
    logger.info(f"Finished indexing {len(all_index_paths)} instances")
    search_indexes(remaining_instances, output_file, all_index_paths)
    missing_ids = get_missing_ids(instances, output_file)
    logger.warning(f"Missing indexes for {len(missing_ids)} instances.")
    logger.info(f"Saved retrieval results to {output_file}")
    del_dirs = list(root_dir.glob("repo__*"))
    logger.info(f"Cleaning up {root_dir}")
    if leave_indexes:
        index_dirs = list(root_dir.glob("index__*"))
        del_dirs += index_dirs
    for dirname in del_dirs:
        shutil.rmtree(dirname, ignore_errors=True)
create_instance
logger module-attribute
logger = getLogger(__name__)
PATCH_EXAMPLE module-attribute
PATCH_EXAMPLE = '--- a/file.py\n+++ b/file.py\n@@ -1,27 +1,35 @@\n def euclidean(a, b):\n-    while b:\n-        a, b = b, a % b\n-    return a\n+    if b == 0:\n+        return a\n+    return euclidean(b, a % b)\n \n \n def bresenham(x0, y0, x1, y1):\n     points = []\n     dx = abs(x1 - x0)\n     dy = abs(y1 - y0)\n-    sx = 1 if x0 < x1 else -1\n-    sy = 1 if y0 < y1 else -1\n-    err = dx - dy\n+    x, y = x0, y0\n+    sx = -1 if x0 > x1 else 1\n+    sy = -1 if y0 > y1 else 1\n \n-    while True:\n-        points.append((x0, y0))\n-        if x0 == x1 and y0 == y1:\n-            break\n-        e2 = 2 * err\n-        if e2 > -dy:\n+    if dx > dy:\n+        err = dx / 2.0\n+        while x != x1:\n+            points.append((x, y))\n             err -= dy\n-            x0 += sx\n-        if e2 < dx:\n-            err += dx\n-            y0 += sy\n+            if err < 0:\n+                y += sy\n+                err += dx\n+            x += sx\n+    else:\n+        err = dy / 2.0\n+        while y != y1:\n+            points.append((x, y))\n+            err -= dx\n+            if err < 0:\n+                x += sx\n+                err += dy\n+            y += sy\n \n+    points.append((x, y))\n     return points'
FULL_GENERATION_EXAMPLE module-attribute
FULL_GENERATION_EXAMPLE = '[start of /src/this_file.py]\nimport os\n\ndef euclidean(a, b):\n    if b == 0:\n        return a\n    return euclidean(b, a % b)\n[end of /src/this_file.py]\n[start of /src/another_file.py]\ndef bresenham(x0, y0, x1, y1):\n    points = []\n    dx = abs(x1 - x0)\n    dy = abs(y1 - y0)\n    x, y = x0, y0\n    sx = -1 if x0 > x1 else 1\n    sy = -1 if y0 > y1 else 1\n    if dx > dy:\n        err = dx / 2.0\n        while x != x1:\n            points.append((x, y))\n            err -= dy\n            if err < 0:\n                y += sy\n                err += dx\n            x += sx\n    else:\n        err = dy / 2.0\n        while y != y1:\n            points.append((x\n            err -= dx\n            if err < 0:\n                x += sx\n                err += dy\n            y += sy\n    points.append((x, y))\n    return points\n[end of /src/another_file.py]'
PROMPT_FUNCTIONS module-attribute
PROMPT_FUNCTIONS = {'style-2': prompt_style_2, 'style-3': prompt_style_3, 'full_file_gen': full_file_gen, 'style-2-edits-only': prompt_style_2_edits_only}
add_lines_list
add_lines_list(content)
Source code in swebench/inference/make_datasets/create_instance.py
116
117
118
119
120
def add_lines_list(content):
    content_with_lines = list()
    for ix, line in enumerate(content.split("\n"), start=1):
        content_with_lines.append(f"{ix} {line}")
    return content_with_lines
add_lines
add_lines(content)
Source code in swebench/inference/make_datasets/create_instance.py
123
124
def add_lines(content):
    return "\n".join(add_lines_list(content))
make_code_text
make_code_text(files_dict, add_line_numbers=True)
Source code in swebench/inference/make_datasets/create_instance.py
127
128
129
130
131
132
133
134
135
136
def make_code_text(files_dict, add_line_numbers=True):
    all_text = ""
    for filename, contents in sorted(files_dict.items()):
        all_text += f"[start of {filename}]\n"
        if add_line_numbers:
            all_text += add_lines(contents)
        else:
            all_text += contents
        all_text += f"\n[end of {filename}]\n"
    return all_text.strip("\n")
make_code_text_edits_only
make_code_text_edits_only(files_dict, patch, add_line_numbers=True)
Source code in swebench/inference/make_datasets/create_instance.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
def make_code_text_edits_only(files_dict, patch, add_line_numbers=True):
    files = dict()
    patch = unidiff.PatchSet(patch)
    for patched_file in patch:
        source_file = patched_file.source_file.split("a/", 1)[-1]
        files[source_file] = list()
        for hunk in patched_file:
            start = hunk.source_start - 15
            end = start + hunk.source_length + 15
            files[source_file].append((start, end))
    all_text = ""
    for filename, content in files_dict.items():
        all_text += f"[start of {filename}]\n"
        content_with_lines = add_lines_list(content)
        for start, end in files[filename]:
            if start > 0:
                all_text += "...\n"
            all_text += "\n".join(content_with_lines[start:end])
            all_text += "\n"
            if end < len(content_with_lines):
                all_text += "...\n"
        all_text = all_text.strip("\n")
        all_text += f"\n[end of {filename}]\n"
    return all_text.strip("\n")
prompt_style_2
prompt_style_2(instance)
Source code in swebench/inference/make_datasets/create_instance.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
def prompt_style_2(instance):
    premise = "You will be provided with a partial code base and an issue statement explaining a problem to resolve."
    readmes_text = make_code_text(instance["readmes"])
    code_text = make_code_text(instance["file_contents"])
    instructions = (
        "I need you to solve this issue by generating a single patch file that I can apply "
        + "directly to this repository using git apply. Please respond with a single patch "
        + "file in the following format."
    )
    problem_statement = instance["problem_statement"]
    final_text = [
        premise,
        "<issue>",
        problem_statement,
        "</issue>",
        "<code>",
        readmes_text,
        code_text,
        "</code>",
        instructions,
        "<patch>",
        PATCH_EXAMPLE,
        "</patch>",
    ]
    final_text = "\n".join(final_text)
    return final_text
prompt_style_2_edits_only
prompt_style_2_edits_only(instance)
Source code in swebench/inference/make_datasets/create_instance.py
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
def prompt_style_2_edits_only(instance):
    premise = "You will be provided with a partial code base and an issue statement explaining a problem to resolve."
    readmes_text = make_code_text(instance["readmes"])
    code_text = make_code_text_edits_only(instance["file_contents"], instance["patch"])
    instructions = (
        "I need you to solve this issue by generating a single patch file that I can apply "
        + "directly to this repository using git apply. Please respond with a single patch "
        + "file in the following format."
    )
    problem_statement = instance["problem_statement"]
    final_text = [
        premise,
        "<issue>",
        problem_statement,
        "</issue>",
        "<code>",
        readmes_text,
        code_text,
        "</code>",
        instructions,
        "<patch>",
        PATCH_EXAMPLE,
        "</patch>",
    ]
    final_text = "\n".join(final_text)
    return final_text
prompt_style_3
prompt_style_3(instance)
Source code in swebench/inference/make_datasets/create_instance.py
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
def prompt_style_3(instance):
    premise = "You will be provided with a partial code base and an issue statement explaining a problem to resolve."
    readmes_text = make_code_text(instance["readmes"])
    code_text = make_code_text(instance["file_contents"])
    example_explanation = (
        "Here is an example of a patch file. It consists of changes to the code base. "
        + "It specifies the file names, the line numbers of each change, and the removed and added lines. "
        + "A single patch file can contain changes to multiple files."
    )
    final_instruction = (
        "I need you to solve the provided issue by generating a single patch file that I can apply "
        + "directly to this repository using git apply. Please respond with a single patch "
        + "file in the format shown above."
    )
    problem_statement = instance["problem_statement"]
    final_text = [
        premise,
        "<issue>",
        problem_statement,
        "</issue>",
        "",
        "<code>",
        readmes_text,
        code_text,
        "</code>",
        "",
        example_explanation,
        "<patch>",
        PATCH_EXAMPLE,
        "</patch>",
        "",
        final_instruction,
        "Respond below:",
    ]
    final_text = "\n".join(final_text)
    return final_text
full_file_gen
full_file_gen(instance)
Source code in swebench/inference/make_datasets/create_instance.py
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
def full_file_gen(instance):
    premise = "You will be provided with a partial code base and an issue statement explaining a problem to resolve."
    readmes_text = make_code_text(instance["readmes"], add_line_numbers=False)
    code_text = make_code_text(instance["file_contents"], add_line_numbers=False)
    instructions = (
        "I need you to solve this issue by regenerating the full files in the code base that you would like to change. "
        + "You can change as many files as you like. "
        + "Please respond with a list of files and their revised contents in the following format."
    )
    problem_statement = instance["problem_statement"]
    final_text = [
        premise,
        "<issue>",
        problem_statement,
        "</issue>",
        "<code>",
        readmes_text,
        code_text,
        "</code>",
        instructions,
        "<example>",
        FULL_GENERATION_EXAMPLE,
        "</example>",
    ]
    final_text = "\n".join(final_text)
    return final_text
ingest_files
ingest_files(filenames)
Source code in swebench/inference/make_datasets/create_instance.py
287
288
289
290
291
292
293
def ingest_files(filenames):
    files_dict = dict()
    for filename in filenames:
        with open(filename) as f:
            content = f.read()
        files_dict[filename] = content
    return files_dict
add_retrieval_results
add_retrieval_results(input_instances, retrieval_file, k, file_source)

Adds retrieval results to input_instances in-place

Source code in swebench/inference/make_datasets/create_instance.py
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
def add_retrieval_results(input_instances, retrieval_file, k, file_source):
    """
    Adds retrieval results to input_instances in-place
    """
    retrieval_results_path = Path(retrieval_file)
    assert retrieval_results_path.exists(), (
        f"Retrieval results not found at {retrieval_results_path}"
    )
    retrieval_results = [json.loads(line) for line in open(retrieval_results_path)]
    retrieval_results = {x["instance_id"]: x["hits"] for x in retrieval_results}
    for instance_id, instance in tqdm(
        input_instances.items(),
        total=len(input_instances),
        desc="Adding retrieval results",
    ):
        try:
            instance["hits"] = retrieval_results[instance_id][:k]
        except KeyError:
            logger.warning(f"Instance {instance_id} not found in retrieval results")
            instance["hits"] = list()
get_oracle_filenames
get_oracle_filenames(instance)

Returns the filenames that are changed in the patch

Source code in swebench/inference/make_datasets/create_instance.py
326
327
328
329
330
331
332
333
334
335
336
337
def get_oracle_filenames(instance):
    """
    Returns the filenames that are changed in the patch
    """
    source_files = {
        patch_file.source_file.split("a/", 1)[-1]
        for patch_file in unidiff.PatchSet(instance["patch"])
    }
    gold_docs = set()
    for source_file in source_files:
        gold_docs.add(source_file)
    return gold_docs
add_text_inputs
add_text_inputs(instances, retrieval_file, k, prompt_style, file_source, max_context_len=None, tokenizer_name=None, verbose=False, progress_file=None) -> None

Process instances and save results to progress file.

Args: - instances: dictionary with unprocessed input instances - retrieval_file: if using retrieval method for file_contents, specify retrieval_file - k: if using retrieval, specifies the maximum number of files to include - prompt_style: specify the function to generate instructions and prompt - file_source: where to collect file_contents (e.g. oracle or bm25) - verbose: set ContextManager verbose to True - progress_file: required, path to save processed instances

Source code in swebench/inference/make_datasets/create_instance.py
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
def add_text_inputs(
    instances,
    retrieval_file,
    k,
    prompt_style,
    file_source,
    max_context_len=None,
    tokenizer_name=None,
    verbose=False,
    progress_file=None,
) -> None:
    """Process instances and save results to progress file.

    Args:
    - instances: dictionary with unprocessed input instances
    - retrieval_file: if using retrieval method for file_contents, specify retrieval_file
    - k: if using retrieval, specifies the maximum number of files to include
    - prompt_style: specify the function to generate instructions and prompt
    - file_source: where to collect file_contents (e.g. oracle or bm25)
    - verbose: set ContextManager verbose to True
    - progress_file: required, path to save processed instances
    """
    assert progress_file is not None, "progress_file is required"

    # Create progress file directory if it doesn't exist
    progress_path = Path(progress_file)
    progress_path.parent.mkdir(parents=True, exist_ok=True)

    # Load already processed instances
    processed_ids = set()
    file_exists = os.path.exists(progress_file)

    if file_exists:
        with open(progress_file) as f:
            for line in f:
                instance = json.loads(line)
                processed_ids.add(instance["instance_id"])
        logger.info(f"Found {len(processed_ids)} already processed instances")
        progress_file_handle = open(progress_file, "a")
    else:
        progress_file_handle = open(progress_file, "w")

    try:
        if max_context_len is not None:
            assert tokenizer_name is not None, (
                "Must specify tokenizer_name if using max_context_len"
            )
            tokenizer, tokenizer_func = TOKENIZER_FUNCS[tokenizer_name]

        # Add retrieval results if needed
        if file_source in {"bm25"}:
            instances = deepcopy(instances)
            add_retrieval_results(instances, retrieval_file, k, file_source)

        # Filter out already processed instances
        instances_to_process = {
            k: v for k, v in instances.items() if k not in processed_ids
        }
        logger.info(f"Processing {len(instances_to_process)} instances")

        orig_dir = os.getcwd()
        with TemporaryDirectory(
            dir="/scratch" if os.path.exists("/scratch") else "/tmp"
        ) as root_dir:
            for instance_id, instance in tqdm(
                instances_to_process.items(),
                total=len(instances_to_process),
                desc="Processing instances",
            ):
                try:
                    with AutoContextManager(instance, root_dir, verbose=verbose) as cm:
                        # Process instance
                        processed_instance = deepcopy(instance)

                        # Add readmes
                        readmes = cm.get_readme_files()
                        processed_instance["readmes"] = ingest_files(readmes)

                        # Handle file contents based on configuration
                        if max_context_len is not None:
                            processed_instance["file_contents"] = dict()
                            base_text_inputs = PROMPT_FUNCTIONS[prompt_style](
                                processed_instance
                            )
                            base_text_input_length = len(
                                tokenizer_func(base_text_inputs, tokenizer)
                            )

                        if file_source == "oracle":
                            processed_instance["file_contents"] = ingest_files(
                                get_oracle_filenames(processed_instance)
                            )
                        elif file_source == "bm25":
                            processed_instance["file_contents"] = ingest_files(
                                [x["docid"] for x in processed_instance["hits"]]
                            )
                        elif file_source == "all":
                            processed_instance["file_contents"] = (
                                ingest_directory_contents(cm.repo_path)
                            )
                        elif file_source == "none":
                            processed_instance["file_contents"] = dict()
                        else:
                            raise ValueError(f"Invalid file source {file_source}")

                        # Handle context length limits
                        if max_context_len is not None:
                            cur_input_len = base_text_input_length
                            include_files = []
                            for filename in [
                                x["docid"] for x in processed_instance["hits"]
                            ]:
                                content = make_code_text(
                                    {
                                        filename: processed_instance["file_contents"][
                                            filename
                                        ]
                                    }
                                )
                                if tokenizer_name == "llama":
                                    tokens = tokenizer_func("\n" + content, tokenizer)
                                    idx = tokens.index(13)
                                    tokens = tokens[idx + 1 :]
                                else:
                                    tokens = tokenizer_func(content, tokenizer)
                                if cur_input_len + len(tokens) < max_context_len:
                                    include_files.append(filename)
                                    cur_input_len += len(tokens)
                            processed_instance["file_contents"] = {
                                filename: processed_instance["file_contents"][filename]
                                for filename in include_files
                            }

                        # Generate final text inputs
                        processed_instance["text_inputs"] = PROMPT_FUNCTIONS[
                            prompt_style
                        ](processed_instance)

                        # Save to progress file
                        progress_file_handle.write(
                            json.dumps(processed_instance) + "\n"
                        )
                        progress_file_handle.flush()

                except Exception as e:
                    print(f"Failed on instance {instance_id}", e)
                    traceback.print_exc()
                    # Save failed instance
                    failed_instance = {**instance, "text_inputs": None}
                    progress_file_handle.write(json.dumps(failed_instance) + "\n")
                    progress_file_handle.flush()
                finally:
                    os.chdir(orig_dir)
        os.chdir(orig_dir)
    finally:
        progress_file_handle.close()
create_text_dataset

Create a dataset for text-to-text training from the raw task instance outputs.

logger module-attribute
logger = getLogger(__name__)
parser module-attribute
parser = ArgumentParser(description=__doc__)
load_jsonl_file
load_jsonl_file(filename)
Source code in swebench/inference/make_datasets/create_text_dataset.py
25
26
27
28
29
30
31
32
33
34
35
def load_jsonl_file(filename):
    if type(filename) == str:
        filename = Path(filename)
    if filename.name.endswith(".jsonl") or filename.name.endswith(".jsonl.all"):
        with open(filename) as f:
            return [json.loads(line) for line in f]
    elif filename.name.endswith(".json"):
        with open(filename) as f:
            return json.load(f)
    else:
        raise ValueError(f"Unknown file type {filename}")
instances_generator
instances_generator(files)
Source code in swebench/inference/make_datasets/create_text_dataset.py
38
39
40
41
42
def instances_generator(files):
    all_data = list()
    for file in tqdm(files, desc="Loading instance files"):
        all_data.extend(load_jsonl_file(file))
    return all_data
get_training_and_eval_instances
get_training_and_eval_instances(raw_files, test_dataset)
Source code in swebench/inference/make_datasets/create_text_dataset.py
45
46
47
48
49
50
51
52
53
54
55
def get_training_and_eval_instances(raw_files, test_dataset):
    logger.info("Loading instances")
    raw_instances = list(instances_generator(raw_files))
    final_instances = list(test_dataset["test"])
    eval_repos = {x["repo"] for x in final_instances}
    train_instances = [x for x in raw_instances if x["repo"] not in eval_repos]
    train_instances = list(sorted(train_instances, key=lambda x: x["instance_id"]))
    eval_instances = list(sorted(final_instances, key=lambda x: x["instance_id"]))
    logger.info(f"Found {len(train_instances)} training ids")
    logger.info(f"Found {len(eval_instances)} eval ids")
    return train_instances, eval_instances
extract_fields
extract_fields(instance)
Source code in swebench/inference/make_datasets/create_text_dataset.py
58
59
60
61
62
63
64
65
66
67
68
def extract_fields(instance):
    instance_id = instance["instance_id"]
    if instance["text_inputs"] is None or instance["patch"] is None:
        print(f"No text for {instance_id}")
        return None
    text_inputs = instance["text_inputs"].strip() + "\n\n"
    if text_inputs is None or instance["patch"] is None:
        print(f"No inputs for {instance_id}")
        return None
    patch = "\n".join(["<patch>", instance["patch"], "</patch>"])
    return {**instance, "text": text_inputs, "patch": patch}
validate_arguments
validate_arguments(push_to_hub_user, output_dir, max_context_len, tokenizer_name, file_source, k)

Validate command line arguments and environment setup.

Source code in swebench/inference/make_datasets/create_text_dataset.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def validate_arguments(
    push_to_hub_user, output_dir, max_context_len, tokenizer_name, file_source, k
):
    """Validate command line arguments and environment setup."""
    if push_to_hub_user is not None:
        hub_token = os.environ.get("HUGGING_FACE_HUB_TOKEN", None)
        assert hub_token is not None, (
            "Must provide HUGGING_FACE_HUB_TOKEN to push to the Hub"
        )
        assert output_dir is None, "Cannot provide output_dir if pushing to the Hub"
    if max_context_len is not None:
        assert tokenizer_name is not None
    if push_to_hub_user is None and not Path(output_dir).exists():
        Path(output_dir).mkdir(parents=True)
    if max_context_len is not None:
        assert file_source not in {"all", "oracle"}, (
            "Cannot use max_context_len with oracle or all file sources"
        )
        assert tokenizer_name is not None, (
            "Must provide tokenizer_name if max_context_len is not None"
        )
    if k is not None:
        assert file_source not in {"all", "oracle"}, (
            "Cannot use max_context_len with oracle or all file sources"
        )
    return hub_token if push_to_hub_user is not None else None
construct_output_filename
construct_output_filename(dataset_name, prompt_style, file_source, k, max_context_len, tokenizer_name)

Construct the output filename based on parameters.

Source code in swebench/inference/make_datasets/create_text_dataset.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
def construct_output_filename(
    dataset_name, prompt_style, file_source, k, max_context_len, tokenizer_name
):
    """Construct the output filename based on parameters."""
    if dataset_name.startswith("princeton-nlp"):
        dataset_name = dataset_name.split("/")[-1]
    dataset_name = dataset_name.replace("/", "__")
    output_file = f"{dataset_name}__{prompt_style}__fs-{file_source}"
    if k is not None:
        output_file += f"__k-{k}"
    if max_context_len is not None:
        output_file += f"__mcc-{max_context_len}-{tokenizer_name}"
    return output_file
main
main(dataset_name_or_path, splits, validation_ratio, output_dir, retrieval_file, prompt_style, file_source, k, max_context_len, tokenizer_name, push_to_hub_user)
Source code in swebench/inference/make_datasets/create_text_dataset.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
def main(
    dataset_name_or_path,
    splits,
    validation_ratio,
    output_dir,
    retrieval_file,
    prompt_style,
    file_source,
    k,
    max_context_len,
    tokenizer_name,
    push_to_hub_user,
):
    # Validate arguments and setup
    hub_token = validate_arguments(
        push_to_hub_user, output_dir, max_context_len, tokenizer_name, file_source, k
    )
    output_file = construct_output_filename(
        dataset_name_or_path,
        prompt_style,
        file_source,
        k,
        max_context_len,
        tokenizer_name,
    )
    output_file = Path(output_dir, output_file)
    if push_to_hub_user is None:
        if output_file.exists():
            existing_dataset = load_from_disk(output_file)
            # if requested splits are in existing dataset, abort
            for split in splits:
                if split in existing_dataset:
                    logger.info(
                        f"{output_file.absolute().as_posix()} already exists for split {split}. Aborting"
                    )
                    return
            del existing_dataset  # don't store in memory

    # Load dataset
    dataset = (
        load_from_disk(dataset_name_or_path)
        if Path(dataset_name_or_path).exists()
        else load_dataset(dataset_name_or_path)
    )
    logger.info(f"Found {set(dataset.keys())} splits")
    if set(splits) - set(dataset.keys()) != set():
        raise ValueError(f"Unknown splits {set(splits) - set(dataset.keys())}")

    # Define columns for final dataset
    columns = [
        "instance_id",
        "text",
        "repo",
        "base_commit",
        "problem_statement",
        "hints_text",
        "created_at",
        "patch",
        "test_patch",
        "version",
        "FAIL_TO_PASS",
        "PASS_TO_PASS",
        "environment_setup_commit",
    ]

    # Process each split
    split_data = {}
    progress_files = {}
    for split in splits:
        logger.info(f"Processing {split} split")
        split_instances = {x["instance_id"]: x for x in dataset[split]}
        progress_file = f"{output_file}.{split}.progress.jsonl"
        progress_files[split] = progress_file
        # Process instances and save to progress file
        add_text_inputs(
            split_instances,
            retrieval_file=retrieval_file,
            k=k,
            prompt_style=prompt_style,
            file_source=file_source,
            max_context_len=max_context_len,
            tokenizer_name=tokenizer_name,
            progress_file=progress_file,
        )

    logger.info("Creating final dataset")
    # Create final dataset
    if output_file.exists():
        final_dataset = load_from_disk(output_file)
    else:
        final_dataset = DatasetDict()
    for split in splits:
        split_data = {key: [] for key in columns}
        valid_instance_ids = set(dataset[split]["instance_id"])
        invalid_instances = []

        with open(progress_files[split]) as f:
            for line in f:
                datum = json.loads(line)
                datum = extract_fields(datum)
                if datum["instance_id"] not in valid_instance_ids:
                    invalid_instances.append(datum["instance_id"])
                    continue
                for key in columns:
                    split_data[key].append(datum.get(key, ""))

        if invalid_instances:
            logger.warning(
                f"Found {len(invalid_instances)} instances in progress file that are not in the {split} dataset: {invalid_instances}. These will be removed from the final dataset."
            )

        final_dataset[split] = Dataset.from_dict(split_data)

    # Handle validation split
    if validation_ratio > 0 and "train" in final_dataset:
        train_val = final_dataset["train"].train_test_split(
            test_size=validation_ratio, seed=42
        )
        final_dataset["train"] = train_val["train"]
        final_dataset["validation"] = train_val["test"]

    # Log final dataset sizes
    for split in final_dataset:
        logger.info(f"Found {len(final_dataset[split])} {split} instances")

    # Save dataset
    if push_to_hub_user is not None:
        final_dataset.push_to_hub(
            f"{push_to_hub_user}/{output_file.name}", use_auth_token=hub_token
        )
    else:
        final_dataset.save_to_disk(output_file)

    # Cleanup progress files
    for progress_file in progress_files.values():
        if os.path.exists(progress_file):
            os.remove(progress_file)

    logger.info(f"Finished saving to {output_file}")
eval_retrieval

This script can be used to evaluate the BM25 retrieval results for a dataset created with create_text_dataset.py with the --retrieval_file option and --file_source bm25.

logger module-attribute
logger = getLogger(__name__)
parser module-attribute
parser = ArgumentParser(description=__doc__)
args module-attribute
args = parse_args()
main
main(dataset_name_or_path, split)
Source code in swebench/inference/make_datasets/eval_retrieval.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def main(dataset_name_or_path, split):
    try:
        dataset = load_dataset(dataset_name_or_path, split=split)
    except:
        dataset = load_from_disk(dataset_name_or_path, split=split)
    print(
        f"Evaluating {len(dataset)} instances from {dataset_name_or_path} {split} split"
    )
    instance_files_pattern = re.compile(
        r"\[start of ([\w\.\-\/]+)\]\n(?:.+?)\n\[end of \1\]", re.DOTALL
    )
    patch_files_pattern = re.compile(r"\-\-\- a/(.+)")
    patch_files = {instance["instance_id"]: instance["patch"] for instance in dataset}
    recalls_any = list()
    recalls_all = list()
    recalls = list()
    for datum in dataset:
        instance_id = datum["instance_id"]
        retrieved_files = instance_files_pattern.findall(datum["text"])
        if retrieved_files and "readme" in retrieved_files[0].lower():
            retrieved_files = retrieved_files[
                1:
            ]  # first file is usually the readme, we don't want to count that
        retrieved_files = set(retrieved_files)
        gold_files = set(patch_files_pattern.findall(patch_files[instance_id]))
        if len(gold_files) == 0:
            print(f"WARNING: Instance {datum['instance_id']} has no gold files")
            continue
        if len(retrieved_files) == 0:
            print(f"WARNING: Instance {datum['instance_id']} has no retrieved files")
            continue
        recall = len(retrieved_files.intersection(gold_files)) / len(gold_files)
        recalls.append(recall)
        recalls_any.append(int(recall > 0))
        recalls_all.append(int(recall == 1))
    recalls = np.array(recalls)
    recalls_any = np.array(recalls_any)
    recalls_all = np.array(recalls_all)
    print(f"Avg Recall: {np.mean(recalls) * 100:.2f}")
    print(f"All Recall: {np.mean(recalls_all) * 100:.2f}")
    print(f"Any Recall: {np.mean(recalls_any) * 100:.2f}")
tokenize_dataset

Provided a source (raw) directory and the final (eval) directory, create a training split by removing all instances that are in the final directory from the source directory.

logger module-attribute
logger = getLogger(__name__)
TOKENIZER_FUNCS module-attribute
TOKENIZER_FUNCS = {'cl100k': (get_encoding('cl100k_base'), cl100k), 'llama': (from_pretrained('togethercomputer/LLaMA-2-7B-32K'), llama)}
parser module-attribute
parser = ArgumentParser(description=__doc__)
cl100k
cl100k(text, tokenizer)
Source code in swebench/inference/make_datasets/tokenize_dataset.py
21
22
def cl100k(text, tokenizer):
    return tokenizer.encode(text, disallowed_special=())
llama
llama(text, tokenizer)
Source code in swebench/inference/make_datasets/tokenize_dataset.py
25
26
27
28
def llama(text, tokenizer):
    return tokenizer(text, add_special_tokens=False, return_attention_mask=False)[
        "input_ids"
    ]
extract_fields
extract_fields(instance, tokenizer_name, tokenizer, tokenizer_func, eos_token)
Source code in swebench/inference/make_datasets/tokenize_dataset.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
def extract_fields(instance, tokenizer_name, tokenizer, tokenizer_func, eos_token):
    instance_id = instance["instance_id"]
    if instance["text"] is None or instance["patch"] is None:
        print(f"No text for {instance_id}")
        return {"input_ids": [], "labels": [], "text": "", "patch": ""}
    text_inputs = instance["text"].strip() + "\n"
    if text_inputs is None or instance["patch"] is None:
        print(f"No inputs for {instance_id}")
        return None
    patch = instance["patch"].strip()
    if len(eos_token) > 0:
        patch += f"\n{eos_token}"
    input_ids = tokenizer_func(text_inputs, tokenizer)
    if tokenizer_name in {"llama"}:
        label_ids = tokenizer_func(
            "\n" + patch, tokenizer
        )  # add newline to tokenize patch
        idx = label_ids.index(13)
        assert idx <= 2, (
            "Expected newline token id (13) to be one of the first three tokens"
        )
        label_ids = label_ids[idx + 1 :]  # remove newline tokens
    else:
        label_ids = tokenizer_func(patch, tokenizer)
    inputs = input_ids + label_ids[:-1]
    cond_len = len(input_ids) - 1
    labels = [-100] * cond_len + label_ids
    assert len(inputs) == len(labels)
    return {
        **instance,
        "input_ids": inputs,
        "labels": labels,
        "text": text_inputs,
        "patch": patch,
    }
extract_test_fields
extract_test_fields(instance, tokenizer_name, tokenizer, tokenizer_func, eos_token)
Source code in swebench/inference/make_datasets/tokenize_dataset.py
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def extract_test_fields(instance, tokenizer_name, tokenizer, tokenizer_func, eos_token):
    instance_id = instance["instance_id"]
    if instance["text"] is None or instance["patch"] is None:
        print(f"No text for {instance_id}")
        return None
    text_inputs = instance["text"].strip() + "\n"
    if text_inputs is None or instance["patch"] is None:
        print(f"No inputs for {instance_id}")
        return None
    patch = instance["patch"].strip()
    if len(eos_token) > 0:
        patch += f"\n{eos_token}"
    input_ids = tokenizer_func(text_inputs, tokenizer)
    label_ids = tokenizer_func(patch, tokenizer)
    inputs = input_ids
    labels = label_ids
    return {
        **instance,
        "input_ids": inputs,
        "labels": labels,
        "text": text_inputs,
        "patch": patch,
    }
add_columns_from_dict
add_columns_from_dict(dataset, dict_columns)

dict_columns is a list of dicts with keys that are columns in dataset

Source code in swebench/inference/make_datasets/tokenize_dataset.py
 99
100
101
102
103
104
105
106
def add_columns_from_dict(dataset, dict_columns):
    """dict_columns is a list of dicts with keys that are columns in dataset"""
    for column in dict_columns[0].keys():
        values = [d[column] for d in dict_columns]
        if column in dataset.column_names:
            dataset = dataset.remove_columns(column)
        dataset = dataset.add_column(column, values)
    return dataset
main
main(dataset_name_or_path, output_dir, tokenizer_name, num_proc, push_to_hub_user)
Source code in swebench/inference/make_datasets/tokenize_dataset.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
def main(
    dataset_name_or_path,
    output_dir,
    tokenizer_name,
    num_proc,
    push_to_hub_user,
):
    if push_to_hub_user is not None:
        hub_token = os.environ.get("HUGGING_FACE_HUB_TOKEN", None)
        if hub_token is None:
            raise ValueError("Must provide HUGGING_FACE_HUB_TOKEN to push to the Hub")
    if not Path(output_dir).exists():
        Path(output_dir).mkdir(parents=True)

    if tokenizer_name is not None:
        tokenizer, tokenizer_func = TOKENIZER_FUNCS[tokenizer_name]
        eos_token = getattr(tokenizer, "eos_token", "")
        if num_proc > 0 and tokenizer_name == "cl100k":
            logger.warning(
                "cl100k tokenizer does not support multiprocessing. Ignoring num_proc"
            )
            num_proc = 0

    if Path(dataset_name_or_path).exists():
        dataset = load_from_disk(dataset_name_or_path)
    else:
        dataset = load_dataset(dataset_name_or_path)
    dataset = dataset.filter(
        lambda x: len(x["text"]) <= 5_000_000
    )  # filter out superlong instances
    for split in dataset.keys():
        if split == "test":
            continue
        if num_proc > 0:
            dataset[split] = dataset[split].map(
                lambda instance: extract_fields(
                    instance,
                    tokenizer_name,
                    tokenizer,
                    tokenizer_func,
                    eos_token,
                ),
                num_proc=num_proc,
                batched=False,
                desc=f"Tokenizing {split}",
            )
        elif len(dataset[split]) > 0:
            new_values = list(
                map(
                    lambda x: extract_fields(
                        x, tokenizer_name, tokenizer, tokenizer_func, eos_token
                    ),
                    tqdm(
                        dataset[split],
                        total=len(dataset[split]),
                        desc=f"Tokenizing {split}",
                    ),
                )
            )
            dataset[split] = add_columns_from_dict(dataset[split], new_values)
    for split in ["test"]:
        if split not in dataset:
            logger.warning(f"Split {split} not in dataset. Skipping")
            continue
        if num_proc > 0:
            dataset[split] = dataset[split].map(
                lambda instance: extract_test_fields(
                    instance,
                    tokenizer_name,
                    tokenizer,
                    tokenizer_func,
                    eos_token,
                ),
                num_proc=num_proc,
                batched=False,
                desc=f"Tokenizing {split}",
            )
        elif len(dataset[split]) > 0:
            new_values = list(
                map(
                    lambda x: extract_test_fields(
                        x, tokenizer_name, tokenizer, tokenizer_func, eos_token
                    ),
                    tqdm(
                        dataset[split],
                        total=len(dataset[split]),
                        desc=f"Tokenizing {split}",
                    ),
                )
            )
            dataset[split] = add_columns_from_dict(dataset[split], new_values)
    output_file = Path(dataset_name_or_path).name + f"__tok-{tokenizer_name}"
    if push_to_hub_user is not None:
        output_file = f"{push_to_hub_user}/{output_file}"
        dataset.push_to_hub(output_file, use_auth_token=hub_token)
    else:
        output_file = Path(output_dir) / output_file
        dataset.save_to_disk(output_file)
    logger.warning(f"Saved to {output_file}")
utils
DIFF_PATTERN module-attribute
DIFF_PATTERN = compile('^diff(?:.*)')
PATCH_PATTERN module-attribute
PATCH_PATTERN = compile('(?:diff[\\w\\_\\.\\ \\/\\-]+\\n)?\\-\\-\\-\\s+a\\/(?:.*?)\\n\\+\\+\\+\\s+b\\/(?:.*?)(?=diff\\ |\\-\\-\\-\\ a\\/|\\Z)', DOTALL)
PATCH_FILE_PATTERN module-attribute
PATCH_FILE_PATTERN = compile('\\-\\-\\-\\s+a\\/(?:.+)\\n\\+\\+\\+\\s+b\\/(?:.+)')
PATCH_HUNK_PATTERN module-attribute
PATCH_HUNK_PATTERN = compile('\\@\\@\\s+\\-(\\d+),(\\d+)\\s+\\+(\\d+),(\\d+)\\s+\\@\\@(.+?)(?=diff\\ |\\-\\-\\-\\ a\\/|\\@\\@\\ \\-|\\Z)', DOTALL)
ContextManager
ContextManager(repo_path, base_commit, verbose=False)
Source code in swebench/inference/make_datasets/utils.py
149
150
151
152
153
def __init__(self, repo_path, base_commit, verbose=False):
    self.repo_path = Path(repo_path).resolve().as_posix()
    self.old_dir = os.getcwd()
    self.base_commit = base_commit
    self.verbose = verbose
repo_path instance-attribute
repo_path = as_posix()
old_dir instance-attribute
old_dir = getcwd()
base_commit instance-attribute
base_commit = base_commit
verbose instance-attribute
verbose = verbose
__enter__
__enter__()
Source code in swebench/inference/make_datasets/utils.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
def __enter__(self):
    os.chdir(self.repo_path)
    cmd = f"git reset --hard {self.base_commit} && git clean -fdxq"
    if self.verbose:
        subprocess.run(cmd, shell=True, check=True)
    else:
        subprocess.run(
            cmd,
            shell=True,
            check=True,
            stdout=subprocess.DEVNULL,
            stderr=subprocess.DEVNULL,
        )
    return self
get_environment
get_environment()
Source code in swebench/inference/make_datasets/utils.py
170
171
def get_environment(self):
    raise NotImplementedError()  # TODO: activate conda environment and return the environment file
get_readme_files
get_readme_files()
Source code in swebench/inference/make_datasets/utils.py
173
174
175
176
177
def get_readme_files(self):
    files = os.listdir(self.repo_path)
    files = list(filter(lambda x: os.path.isfile(x), files))
    files = list(filter(lambda x: x.lower().startswith("readme"), files))
    return files
__exit__
__exit__(exc_type, exc_val, exc_tb)
Source code in swebench/inference/make_datasets/utils.py
179
180
def __exit__(self, exc_type, exc_val, exc_tb):
    os.chdir(self.old_dir)
AutoContextManager
AutoContextManager(instance, root_dir=None, verbose=False, token=None)

Bases: ContextManager

Automatically clones the repo if it doesn't exist

Source code in swebench/inference/make_datasets/utils.py
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
def __init__(self, instance, root_dir=None, verbose=False, token=None):
    if token is None:
        token = os.environ.get("GITHUB_TOKEN", "git")
    self.tempdir = None
    if root_dir is None:
        self.tempdir = TemporaryDirectory()
        root_dir = self.tempdir.name
    self.root_dir = root_dir
    repo_dir = os.path.join(self.root_dir, instance["repo"].replace("/", "__"))
    if not os.path.exists(repo_dir):
        repo_url = (
            f"https://{token}@github.com/swe-bench-repos/"
            + instance["repo"].replace("/", "__")
            + ".git"
        )
        if verbose:
            print(f"Cloning {instance['repo']} to {root_dir}")
        Repo.clone_from(repo_url, repo_dir)
    super().__init__(repo_dir, instance["base_commit"], verbose=verbose)
    self.instance = instance
tempdir instance-attribute
tempdir = None
root_dir instance-attribute
root_dir = root_dir
instance instance-attribute
instance = instance
__exit__
__exit__(exc_type, exc_val, exc_tb)
Source code in swebench/inference/make_datasets/utils.py
207
208
209
210
def __exit__(self, exc_type, exc_val, exc_tb):
    if self.tempdir is not None:
        self.tempdir.cleanup()
    return super().__exit__(exc_type, exc_val, exc_tb)
get_first_idx
get_first_idx(charlist)
Source code in swebench/inference/make_datasets/utils.py
24
25
26
27
def get_first_idx(charlist):
    first_min = charlist.index("-") if "-" in charlist else len(charlist)
    first_plus = charlist.index("+") if "+" in charlist else len(charlist)
    return min(first_min, first_plus)
get_last_idx
get_last_idx(charlist)
Source code in swebench/inference/make_datasets/utils.py
30
31
32
33
def get_last_idx(charlist):
    char_idx = get_first_idx(charlist[::-1])
    last_idx = len(charlist) - char_idx
    return last_idx + 1
strip_content
strip_content(hunk)
Source code in swebench/inference/make_datasets/utils.py
36
37
38
39
40
41
42
def strip_content(hunk):
    first_chars = list(map(lambda x: None if not len(x) else x[0], hunk.split("\n")))
    first_idx = get_first_idx(first_chars)
    last_idx = get_last_idx(first_chars)
    new_lines = list(map(lambda x: x.rstrip(), hunk.split("\n")[first_idx:last_idx]))
    new_hunk = "\n" + "\n".join(new_lines) + "\n"
    return new_hunk, first_idx - 1
get_hunk_stats
get_hunk_stats(pre_start, pre_len, post_start, post_len, hunk, total_delta)
Source code in swebench/inference/make_datasets/utils.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def get_hunk_stats(pre_start, pre_len, post_start, post_len, hunk, total_delta):
    stats = {"context": 0, "added": 0, "subtracted": 0}
    hunk = hunk.split("\n", 1)[-1].strip("\n")
    for line in hunk.split("\n"):
        if line.startswith("-"):
            stats["subtracted"] += 1
        elif line.startswith("+"):
            stats["added"] += 1
        else:
            stats["context"] += 1
    context = stats["context"]
    added = stats["added"]
    subtracted = stats["subtracted"]
    pre_len = context + subtracted
    post_start = pre_start + total_delta
    post_len = context + added
    total_delta = total_delta + (post_len - pre_len)
    return pre_start, pre_len, post_start, post_len, total_delta
repair_patch
repair_patch(model_patch)
Source code in swebench/inference/make_datasets/utils.py
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def repair_patch(model_patch):
    if model_patch is None:
        return None
    model_patch = model_patch.lstrip("\n")
    new_patch = ""
    for patch in PATCH_PATTERN.findall(model_patch):
        total_delta = 0
        diff_header = DIFF_PATTERN.findall(patch)
        if diff_header:
            new_patch += diff_header[0] + "\n"
        patch_header = PATCH_FILE_PATTERN.findall(patch)[0]
        if patch_header:
            new_patch += patch_header + "\n"
        for hunk in PATCH_HUNK_PATTERN.findall(patch):
            pre_start, pre_len, post_start, post_len, content = hunk
            pre_start, pre_len, post_start, post_len, total_delta = get_hunk_stats(
                *list(map(lambda x: int(x) if x.isnumeric() else x, hunk)), total_delta
            )
            new_patch += (
                f"@@ -{pre_start},{pre_len} +{post_start},{post_len} @@{content}"
            )
    return new_patch
extract_minimal_patch
extract_minimal_patch(model_patch)
Source code in swebench/inference/make_datasets/utils.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
def extract_minimal_patch(model_patch):
    model_patch = model_patch.lstrip("\n")
    new_patch = ""
    for patch in PATCH_PATTERN.findall(model_patch):
        total_delta = 0
        diff_header = DIFF_PATTERN.findall(patch)
        patch_header = PATCH_FILE_PATTERN.findall(patch)[0]
        if patch_header:
            new_patch += patch_header + "\n"
        for hunk in PATCH_HUNK_PATTERN.findall(patch):
            pre_start, pre_len, post_start, post_len, content = hunk
            pre_start, pre_len, post_start, post_len, content = list(
                map(lambda x: int(x) if x.isnumeric() else x, hunk)
            )
            content, adjust_pre_start = strip_content(content)
            pre_start += adjust_pre_start
            pre_start, pre_len, post_start, post_len, total_delta = get_hunk_stats(
                pre_start, pre_len, post_start, post_len, content, total_delta
            )
            new_patch += (
                f"@@ -{pre_start},{pre_len} +{post_start},{post_len} @@{content}"
            )
    return new_patch
extract_diff
extract_diff(response)

Extracts the diff from a response formatted in different ways

Source code in swebench/inference/make_datasets/utils.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def extract_diff(response):
    """
    Extracts the diff from a response formatted in different ways
    """
    if response is None:
        return None
    diff_matches = []
    other_matches = []
    pattern = re.compile(r"\<([\w-]+)\>(.*?)\<\/\1\>", re.DOTALL)
    for code, match in pattern.findall(response):
        if code in {"diff", "patch"}:
            diff_matches.append(match)
        else:
            other_matches.append(match)
    pattern = re.compile(r"```(\w+)?\n(.*?)```", re.DOTALL)
    for code, match in pattern.findall(response):
        if code in {"diff", "patch"}:
            diff_matches.append(match)
        else:
            other_matches.append(match)
    if diff_matches:
        return diff_matches[0]
    if other_matches:
        return other_matches[0]
    return response.split("</s>")[0]
is_test
is_test(name, test_phrases=None)
Source code in swebench/inference/make_datasets/utils.py
141
142
143
144
145
def is_test(name, test_phrases=None):
    if test_phrases is None:
        test_phrases = ["test", "tests", "testing"]
    words = set(re.split(r" |_|\/|\.", name.lower()))
    return any(word in words for word in test_phrases)
get_imported_modules
get_imported_modules(filename)
Source code in swebench/inference/make_datasets/utils.py
213
214
215
216
217
218
219
220
def get_imported_modules(filename):
    with open(filename) as file:
        tree = ast.parse(file.read(), filename)
    return [
        node
        for node in ast.iter_child_nodes(tree)
        if isinstance(node, (ast.Import, ast.ImportFrom))
    ]
resolve_module_to_file
resolve_module_to_file(module, level, root_dir)
Source code in swebench/inference/make_datasets/utils.py
223
224
225
226
227
228
229
230
231
232
233
234
def resolve_module_to_file(module, level, root_dir):
    components = module.split(".")
    if level > 0:
        components = components[:-level]
    for dirpath, dirnames, filenames in os.walk(root_dir):
        if dirpath.endswith(os.sep.join(components)):
            return [
                os.path.join(dirpath, filename)
                for filename in filenames
                if filename.endswith(".py")
            ]
    return []
ingest_file_directory_contents
ingest_file_directory_contents(target_file, root_dir)
Source code in swebench/inference/make_datasets/utils.py
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
def ingest_file_directory_contents(target_file, root_dir):
    imported_files = []
    files_to_check = [target_file]
    while files_to_check:
        current_file = files_to_check.pop()
        imported_files.append(current_file)
        imports = get_imported_modules(current_file)
        for node in imports:
            if isinstance(node, ast.Import):
                for alias in node.names:
                    files = resolve_module_to_file(alias.name, 0, root_dir)
                    for file in files:
                        if file not in imported_files and file not in files_to_check:
                            files_to_check.append(file)
            elif isinstance(node, ast.ImportFrom):
                files = resolve_module_to_file(node.module, node.level, root_dir)
                for file in files:
                    if file not in imported_files and file not in files_to_check:
                        files_to_check.append(file)
    return imported_files
detect_encoding
detect_encoding(filename)

Detect the encoding of a file

Source code in swebench/inference/make_datasets/utils.py
259
260
261
262
263
264
265
def detect_encoding(filename):
    """
    Detect the encoding of a file
    """
    with open(filename, "rb") as file:
        rawdata = file.read()
    return chardet.detect(rawdata)["encoding"]
list_files
list_files(root_dir, include_tests=False)
Source code in swebench/inference/make_datasets/utils.py
268
269
270
271
272
273
274
def list_files(root_dir, include_tests=False):
    files = []
    for filename in Path(root_dir).rglob("*.py"):
        if not include_tests and is_test(filename.as_posix()):
            continue
        files.append(filename.relative_to(root_dir).as_posix())
    return files
ingest_directory_contents
ingest_directory_contents(root_dir, include_tests=False)
Source code in swebench/inference/make_datasets/utils.py
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
def ingest_directory_contents(root_dir, include_tests=False):
    files_content = {}
    for relative_path in list_files(root_dir, include_tests=include_tests):
        filename = os.path.join(root_dir, relative_path)
        encoding = detect_encoding(filename)
        if encoding is None:
            content = "[BINARY DATA FILE]"
        else:
            try:
                with open(filename, encoding=encoding) as file:
                    content = file.read()
            except (UnicodeDecodeError, LookupError):
                content = "[BINARY DATA FILE]"
        files_content[relative_path] = content
    return files_content
string_to_bool
string_to_bool(v)
Source code in swebench/inference/make_datasets/utils.py
294
295
296
297
298
299
300
301
302
303
304
def string_to_bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif v.lower() in ("no", "false", "f", "n", "0"):
        return False
    else:
        raise ArgumentTypeError(
            f"Truthy value expected: got {v} but expected one of yes/no, true/false, t/f, y/n, 1/0 (case insensitive)."
        )

run_api

This python script is designed to run inference on a dataset using either the OpenAI or Anthropic API, depending on the model specified. It sorts instances by length and continually writes the outputs to a specified file, so that the script can be stopped and restarted without losing progress.

logger module-attribute
logger = getLogger(__name__)
MODEL_LIMITS module-attribute
MODEL_LIMITS = {'claude-instant-1': 100000, 'claude-2': 100000, 'claude-3-opus-20240229': 200000, 'claude-3-sonnet-20240229': 200000, 'claude-3-haiku-20240307': 200000, 'gpt-3.5-turbo-16k-0613': 16385, 'gpt-3.5-turbo-0613': 4097, 'gpt-3.5-turbo-1106': 16385, 'gpt-4-32k-0613': 32768, 'gpt-4-0613': 8192, 'gpt-4-1106-preview': 128000, 'gpt-4-0125-preview': 128000}
MODEL_COST_PER_INPUT module-attribute
MODEL_COST_PER_INPUT = {'claude-instant-1': 1.63e-06, 'claude-2': 1.102e-05, 'claude-3-opus-20240229': 1.5e-05, 'claude-3-sonnet-20240229': 3e-06, 'claude-3-haiku-20240307': 2.5e-07, 'gpt-3.5-turbo-16k-0613': 1.5e-06, 'gpt-3.5-turbo-0613': 1.5e-06, 'gpt-3.5-turbo-1106': 1e-06, 'gpt-35-turbo-0613': 1.5e-06, 'gpt-35-turbo': 1.5e-06, 'gpt-4-0613': 3e-05, 'gpt-4-32k-0613': 6e-05, 'gpt-4-32k': 6e-05, 'gpt-4-1106-preview': 1e-05, 'gpt-4-0125-preview': 1e-05}
MODEL_COST_PER_OUTPUT module-attribute
MODEL_COST_PER_OUTPUT = {'claude-instant-1': 5.51e-06, 'claude-2': 3.268e-05, 'claude-3-opus-20240229': 7.5e-05, 'claude-3-sonnet-20240229': 1.5e-05, 'claude-3-haiku-20240307': 1.25e-06, 'gpt-3.5-turbo-16k-0613': 2e-06, 'gpt-3.5-turbo-16k': 2e-06, 'gpt-3.5-turbo-1106': 2e-06, 'gpt-35-turbo-0613': 2e-06, 'gpt-35-turbo': 2e-06, 'gpt-4-0613': 6e-05, 'gpt-4-32k-0613': 0.00012, 'gpt-4-32k': 0.00012, 'gpt-4-1106-preview': 3e-05, 'gpt-4-0125-preview': 3e-05}
ENGINES module-attribute
ENGINES = {'gpt-3.5-turbo-16k-0613': 'gpt-35-turbo-16k', 'gpt-4-0613': 'gpt-4', 'gpt-4-32k-0613': 'gpt-4-32k'}
parser module-attribute
parser = ArgumentParser(description=__doc__)
args module-attribute
args = parse_args()
calc_cost
calc_cost(model_name, input_tokens, output_tokens)

Calculates the cost of a response from the openai API.

Args: response (openai.ChatCompletion): The response from the API.

Returns: float: The cost of the response.

Source code in swebench/inference/run_api.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
def calc_cost(model_name, input_tokens, output_tokens):
    """
    Calculates the cost of a response from the openai API.

    Args:
    response (openai.ChatCompletion): The response from the API.

    Returns:
    float: The cost of the response.
    """
    cost = (
        MODEL_COST_PER_INPUT[model_name] * input_tokens
        + MODEL_COST_PER_OUTPUT[model_name] * output_tokens
    )
    logger.info(
        f"input_tokens={input_tokens}, output_tokens={output_tokens}, cost={cost:.2f}"
    )
    return cost
call_chat
call_chat(model_name_or_path, inputs, use_azure, temperature, top_p, **model_args)

Calls the openai API to generate completions for the given inputs.

Args: model_name_or_path (str): The name or path of the model to use. inputs (str): The inputs to generate completions for. use_azure (bool): Whether to use the azure API. temperature (float): The temperature to use. top_p (float): The top_p to use. **model_args (dict): A dictionary of model arguments.

Source code in swebench/inference/run_api.py
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
@retry(wait=wait_random_exponential(min=30, max=600), stop=stop_after_attempt(3))
def call_chat(model_name_or_path, inputs, use_azure, temperature, top_p, **model_args):
    """
    Calls the openai API to generate completions for the given inputs.

    Args:
    model_name_or_path (str): The name or path of the model to use.
    inputs (str): The inputs to generate completions for.
    use_azure (bool): Whether to use the azure API.
    temperature (float): The temperature to use.
    top_p (float): The top_p to use.
    **model_args (dict): A dictionary of model arguments.
    """
    system_messages = inputs.split("\n", 1)[0]
    user_message = inputs.split("\n", 1)[1]
    try:
        if use_azure:
            response = openai.chat.completions.create(
                engine=ENGINES[model_name_or_path] if use_azure else None,
                messages=[
                    {"role": "system", "content": system_messages},
                    {"role": "user", "content": user_message},
                ],
                temperature=temperature,
                top_p=top_p,
                **model_args,
            )
        else:
            response = openai.chat.completions.create(
                model=model_name_or_path,
                messages=[
                    {"role": "system", "content": system_messages},
                    {"role": "user", "content": user_message},
                ],
                temperature=temperature,
                top_p=top_p,
                **model_args,
            )
        input_tokens = response.usage.prompt_tokens
        output_tokens = response.usage.completion_tokens
        cost = calc_cost(response.model, input_tokens, output_tokens)
        return response, cost
    except openai.BadRequestError as e:
        if e.code == "context_length_exceeded":
            print("Context length exceeded")
            return None
        raise e
gpt_tokenize
gpt_tokenize(string: str, encoding) -> int

Returns the number of tokens in a text string.

Source code in swebench/inference/run_api.py
162
163
164
165
def gpt_tokenize(string: str, encoding) -> int:
    """Returns the number of tokens in a text string."""
    num_tokens = len(encoding.encode(string))
    return num_tokens
claude_tokenize
claude_tokenize(string: str, api) -> int

Returns the number of tokens in a text string.

Source code in swebench/inference/run_api.py
168
169
170
171
def claude_tokenize(string: str, api) -> int:
    """Returns the number of tokens in a text string."""
    num_tokens = api.count_tokens(string)
    return num_tokens
openai_inference
openai_inference(test_dataset, model_name_or_path, output_file, model_args, existing_ids, max_cost)

Runs inference on a dataset using the openai API.

Args: test_dataset (datasets.Dataset): The dataset to run inference on. model_name_or_path (str): The name or path of the model to use. output_file (str): The path to the output file. model_args (dict): A dictionary of model arguments. existing_ids (set): A set of ids that have already been processed. max_cost (float): The maximum cost to spend on inference.

Source code in swebench/inference/run_api.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
def openai_inference(
    test_dataset,
    model_name_or_path,
    output_file,
    model_args,
    existing_ids,
    max_cost,
):
    """
    Runs inference on a dataset using the openai API.

    Args:
    test_dataset (datasets.Dataset): The dataset to run inference on.
    model_name_or_path (str): The name or path of the model to use.
    output_file (str): The path to the output file.
    model_args (dict): A dictionary of model arguments.
    existing_ids (set): A set of ids that have already been processed.
    max_cost (float): The maximum cost to spend on inference.
    """
    encoding = tiktoken.encoding_for_model(model_name_or_path)
    test_dataset = test_dataset.filter(
        lambda x: gpt_tokenize(x["text"], encoding) <= MODEL_LIMITS[model_name_or_path],
        desc="Filtering",
        load_from_cache_file=False,
    )
    openai_key = os.environ.get("OPENAI_API_KEY", None)
    if openai_key is None:
        raise ValueError(
            "Must provide an api key. Expected in OPENAI_API_KEY environment variable."
        )
    openai.api_key = openai_key
    print(f"Using OpenAI key {'*' * max(0, len(openai_key) - 5) + openai_key[-5:]}")
    use_azure = model_args.pop("use_azure", False)
    if use_azure:
        openai.api_type = "azure"
        openai.api_base = "https://pnlpopenai3.openai.azure.com/"
        openai.api_version = "2023-05-15"
    temperature = model_args.pop("temperature", 0.2)
    top_p = model_args.pop("top_p", 0.95 if temperature > 0 else 1)
    print(f"Using temperature={temperature}, top_p={top_p}")
    basic_args = {
        "model_name_or_path": model_name_or_path,
    }
    total_cost = 0
    print(f"Filtered to {len(test_dataset)} instances")
    with open(output_file, "a+") as f:
        for datum in tqdm(test_dataset, desc=f"Inference for {model_name_or_path}"):
            instance_id = datum["instance_id"]
            if instance_id in existing_ids:
                continue
            output_dict = {"instance_id": instance_id}
            output_dict.update(basic_args)
            output_dict["text"] = f"{datum['text']}\n\n"
            response, cost = call_chat(
                output_dict["model_name_or_path"],
                output_dict["text"],
                use_azure,
                temperature,
                top_p,
            )
            completion = response.choices[0].message.content
            total_cost += cost
            print(f"Total Cost: {total_cost:.2f}")
            output_dict["full_output"] = completion
            output_dict["model_patch"] = extract_diff(completion)
            print(json.dumps(output_dict), file=f, flush=True)
            if max_cost is not None and total_cost >= max_cost:
                print(f"Reached max cost {max_cost}, exiting")
                break
call_anthropic
call_anthropic(inputs, anthropic, model_name_or_path, temperature, top_p, **model_args)

Calls the anthropic API to generate completions for the given inputs.

Args: inputs (str): The inputs to generate completions for. anthropic (Anthropic): The anthropic API object. model_name_or_path (str): The name or path of the model to use. temperature (float): The temperature to use. top_p (float): The top_p to use. model_args (dict): A dictionary of model arguments.

Source code in swebench/inference/run_api.py
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
@retry(wait=wait_random_exponential(min=60, max=600), stop=stop_after_attempt(6))
def call_anthropic(
    inputs, anthropic, model_name_or_path, temperature, top_p, **model_args
):
    """
    Calls the anthropic API to generate completions for the given inputs.

    Args:
    inputs (str): The inputs to generate completions for.
    anthropic (Anthropic): The anthropic API object.
    model_name_or_path (str): The name or path of the model to use.
    temperature (float): The temperature to use.
    top_p (float): The top_p to use.
    model_args (dict): A dictionary of model arguments.
    """
    try:
        completion = anthropic.completions.create(
            model=model_name_or_path,
            max_tokens_to_sample=6000,
            prompt=inputs,
            temperature=temperature,
            top_p=top_p,
            **model_args,
        )
        response = completion.completion
        input_tokens = anthropic.count_tokens(inputs)
        output_tokens = anthropic.count_tokens(response)
        cost = calc_cost(model_name_or_path, input_tokens, output_tokens)
        return completion, cost
    except Exception as e:
        logger.error(e)
        logger.error(f"Inputs: {inputs}")
        traceback.print_exc()
        time.sleep(20)
        return None
call_anthropic_v2
call_anthropic_v2(inputs, anthropic, model_name_or_path, temperature, top_p, **model_args)

Calls the anthropic API to generate completions for the given inputs.

Args: inputs list(str): The inputs to generate completions for. anthropic (Anthropic): The anthropic API object. model_name_or_path (str): The name or path of the model to use. temperature (float): The temperature to use. top_p (float): The top_p to use. model_args (dict): A dictionary of model arguments.

Source code in swebench/inference/run_api.py
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
@retry(wait=wait_random_exponential(min=60, max=600), stop=stop_after_attempt(6))
def call_anthropic_v2(
    inputs, anthropic, model_name_or_path, temperature, top_p, **model_args
):
    """
    Calls the anthropic API to generate completions for the given inputs.

    Args:
    inputs list(str): The inputs to generate completions for.
    anthropic (Anthropic): The anthropic API object.
    model_name_or_path (str): The name or path of the model to use.
    temperature (float): The temperature to use.
    top_p (float): The top_p to use.
    model_args (dict): A dictionary of model arguments.
    """
    system_messages = inputs.split("\n", 1)[0]
    user_message = inputs.split("\n", 1)[1]
    try:
        messages = [
            {"role": "user", "content": user_message},
        ]
        response = anthropic.messages.create(
            messages=messages,
            max_tokens=4096,
            model=model_name_or_path,
            temperature=temperature,
            top_p=top_p,
            system=system_messages,
        )
        input_tokens = response.usage.input_tokens
        output_tokens = response.usage.output_tokens
        cost = calc_cost(response.model, input_tokens, output_tokens)
        return response, cost
    except Exception as e:
        logger.error(e)
        logger.error(f"Inputs: {inputs}")
        traceback.print_exc()
        time.sleep(20)
        return None
anthropic_inference
anthropic_inference(test_dataset, model_name_or_path, output_file, model_args, existing_ids, max_cost)

Runs inference on a dataset using the anthropic API.

Args: test_dataset (datasets.Dataset): The dataset to run inference on. model_name_or_path (str): The name or path of the model to use. output_file (str): The path to the output file. model_args (dict): A dictionary of model arguments. existing_ids (set): A set of ids that have already been processed. max_cost (float): The maximum cost to spend on inference.

Source code in swebench/inference/run_api.py
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
def anthropic_inference(
    test_dataset,
    model_name_or_path,
    output_file,
    model_args,
    existing_ids,
    max_cost,
):
    """
    Runs inference on a dataset using the anthropic API.

    Args:
    test_dataset (datasets.Dataset): The dataset to run inference on.
    model_name_or_path (str): The name or path of the model to use.
    output_file (str): The path to the output file.
    model_args (dict): A dictionary of model arguments.
    existing_ids (set): A set of ids that have already been processed.
    max_cost (float): The maximum cost to spend on inference.
    """
    api_key = os.environ.get("ANTHROPIC_API_KEY", None)
    if api_key is None:
        raise ValueError(
            "Must provide an api key. Expected in ANTHROPIC_API_KEY environment variable."
        )
    print(f"Using Anthropic key {'*' * max(0, len(api_key) - 5) + api_key[-5:]}")
    anthropic = Anthropic(api_key=api_key)
    test_dataset = test_dataset.filter(
        lambda x: claude_tokenize(x["text"], anthropic)
        <= MODEL_LIMITS[model_name_or_path],
        desc="Filtering",
        load_from_cache_file=False,
    )
    temperature = model_args.pop("temperature", 0.2)
    top_p = model_args.pop("top_p", 0.95 if temperature > 0 else 1)
    print(f"Using temperature={temperature}, top_p={top_p}")
    basic_args = {
        "model_name_or_path": model_name_or_path,
    }
    total_cost = 0
    print(f"Filtered to {len(test_dataset)} instances")
    if "claude-3" in model_name_or_path.lower():
        call_api = call_anthropic_v2
    else:
        call_api = call_anthropic
    with open(output_file, "a+") as f:
        for datum in tqdm(test_dataset, desc=f"Inference for {model_name_or_path}"):
            instance_id = datum["instance_id"]
            if instance_id in existing_ids:
                continue
            output_dict = {"instance_id": instance_id}
            output_dict.update(basic_args)
            if "claude-3" in model_name_or_path.lower():
                output_dict["text_inputs"] = f"{datum['text']}\n"
            else:
                output_dict["text_inputs"] = (
                    f"{HUMAN_PROMPT} {datum['text']}\n\n{AI_PROMPT}"
                )
            try:
                completion, cost = call_api(
                    output_dict["text_inputs"],
                    anthropic,
                    model_name_or_path,
                    temperature,
                    top_p,
                    **model_args,
                )
            except Exception as e:
                logger.error(e)
                traceback.print_exc()
                continue
            total_cost += cost
            print(f"Total Cost: {total_cost:.2f}")
            if "claude-3" in model_name_or_path.lower():
                output_dict["full_output"] = completion.content[0].text
            else:
                output_dict["full_output"] = completion.completion
            output_dict["model_patch"] = extract_diff(output_dict["full_output"])
            print(json.dumps(output_dict), file=f, flush=True)
            if max_cost is not None and total_cost >= max_cost:
                print(f"Reached max cost {max_cost}, exiting")
                break
parse_model_args
parse_model_args(model_args)

Parses a string of model arguments and returns a dictionary of keyword arguments.

Parameters:

Name Type Description Default
model_args str

A string of comma-separated key-value pairs representing model arguments.

required

Returns:

Name Type Description
dict

A dictionary of keyword arguments parsed from the input string.

Source code in swebench/inference/run_api.py
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
def parse_model_args(model_args):
    """
    Parses a string of model arguments and returns a dictionary of keyword arguments.

    Args:
        model_args (str): A string of comma-separated key-value pairs representing model arguments.

    Returns:
        dict: A dictionary of keyword arguments parsed from the input string.
    """
    kwargs = dict()
    if model_args is not None:
        for arg in model_args.split(","):
            key, value = arg.split("=")
            # infer value type
            if value in {"True", "False"}:
                kwargs[key] = value == "True"
            elif value.isnumeric():
                kwargs[key] = int(value)
            elif value.replace(".", "", 1).isnumeric():
                kwargs[key] = float(value)
            elif value in {"None"}:
                kwargs[key] = None
            elif value in {"[]"}:
                kwargs[key] = []
            elif value in {"{}"}:
                kwargs[key] = {}
            elif value.startswith("'") and value.endswith("'"):
                kwargs[key] = value[1:-1]
            elif value.startswith('"') and value.endswith('"'):
                kwargs[key] = value[1:-1]
            else:
                kwargs[key] = value
    return kwargs
main
main(dataset_name_or_path, split, model_name_or_path, shard_id, num_shards, output_dir, model_args, max_cost)
Source code in swebench/inference/run_api.py
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
def main(
    dataset_name_or_path,
    split,
    model_name_or_path,
    shard_id,
    num_shards,
    output_dir,
    model_args,
    max_cost,
):
    if shard_id is None and num_shards is not None:
        logger.warning(
            f"Received num_shards={num_shards} but shard_id is None, ignoring"
        )
    if shard_id is not None and num_shards is None:
        logger.warning(f"Received shard_id={shard_id} but num_shards is None, ignoring")
    model_args = parse_model_args(model_args)
    model_nickname = model_name_or_path
    if "checkpoint" in Path(model_name_or_path).name:
        model_nickname = Path(model_name_or_path).parent.name
    else:
        model_nickname = Path(model_name_or_path).name
    output_file = f"{model_nickname}__{dataset_name_or_path.split('/')[-1]}__{split}"
    if shard_id is not None and num_shards is not None:
        output_file += f"__shard-{shard_id}__num_shards-{num_shards}"
    output_file = Path(output_dir, output_file + ".jsonl")
    logger.info(f"Will write to {output_file}")
    existing_ids = set()
    if os.path.exists(output_file):
        with open(output_file) as f:
            for line in f:
                data = json.loads(line)
                instance_id = data["instance_id"]
                existing_ids.add(instance_id)
    logger.info(f"Read {len(existing_ids)} already completed ids from {output_file}")
    if Path(dataset_name_or_path).exists():
        dataset = load_from_disk(dataset_name_or_path)
    else:
        dataset = load_dataset(dataset_name_or_path)
    if split not in dataset:
        raise ValueError(f"Invalid split {split} for dataset {dataset_name_or_path}")
    dataset = dataset[split]
    lens = np.array(list(map(len, dataset["text"])))
    dataset = dataset.select(np.argsort(lens))
    if len(existing_ids) > 0:
        dataset = dataset.filter(
            lambda x: x["instance_id"] not in existing_ids,
            desc="Filtering out existing ids",
            load_from_cache_file=False,
        )
    if shard_id is not None and num_shards is not None:
        dataset = dataset.shard(num_shards, shard_id, contiguous=True)
    inference_args = {
        "test_dataset": dataset,
        "model_name_or_path": model_name_or_path,
        "output_file": output_file,
        "model_args": model_args,
        "existing_ids": existing_ids,
        "max_cost": max_cost,
    }
    if model_name_or_path.startswith("claude"):
        anthropic_inference(**inference_args)
    elif model_name_or_path.startswith("gpt"):
        openai_inference(**inference_args)
    else:
        raise ValueError(f"Invalid model name or path {model_name_or_path}")
    logger.info("Done!")

run_live

This module contains functions for running a live inference session on a GitHub issue. It clones the repository associated with the issue, builds a BM25 retrieval index, and generates a prompt for the user to interact with the model. The output is saved to a specified directory.

logger module-attribute
logger = getLogger(__name__)
parser module-attribute
parser = ArgumentParser(description=__doc__)
args module-attribute
args = parse_args()
get_problem_statement
get_problem_statement(owner, repo, issue_num, ghapi, include_comments=False)
Source code in swebench/inference/run_live.py
47
48
49
50
51
52
53
54
55
56
def get_problem_statement(owner, repo, issue_num, ghapi, include_comments=False):
    issue = ghapi.issues.get(owner, repo, issue_num)
    issue_text = "\n".join([issue.title, issue.body])
    # Solved issues may include comments that give answers away too much
    if include_comments:
        all_comments = list(ghapi.issues.list_comments(owner, repo, issue_num))
        comments = [comment.body for comment in all_comments]
        comment_text = "Comment: " if comments else "" + "\nComment:".join(comments)
        issue_text += "\n" + comment_text
    return issue_text
get_readme_files
get_readme_files(repo_path)
Source code in swebench/inference/run_live.py
59
60
61
62
63
64
65
66
def get_readme_files(repo_path):
    files = list(Path(repo_path).iterdir())
    files = list(filter(lambda x: x.is_file(), files))
    files = list(filter(lambda x: x.name.lower().startswith("readme"), files))
    if files:
        files = sorted(files, key=lambda x: len(x.name))
        files = [files[0]]
    return [Path(file).relative_to(repo_path).as_posix() for file in files]
make_instance
make_instance(owner, repo, query, commit, root_dir, token, document_encoding_func, python, instance_id, tokenizer, tokenizer_func, prompt_style, max_context_len, include_readmes)

Creates an instance for a given query and repository.

Parameters:

Name Type Description Default
owner str

The owner of the repository.

required
repo str

The name of the repository.

required
query str

The query to search for.

required
commit str

The commit hash to use.

required
root_dir str

The root directory to clone the repository to.

required
token str

The GitHub token to use for authentication.

required
document_encoding_func function

The function to use for encoding documents.

required
python str

The path to the Python executable.

required
instance_id int

The ID of the instance.

required
tokenizer str

The name of the tokenizer to use.

required
tokenizer_func function

The function to use for tokenization.

required
prompt_style str

The style of prompt to use.

required
max_context_len int

The maximum length of the context.

required
include_readmes bool

Whether to include README files in the instance.

required

Returns:

Name Type Description
dict

The instance.

Source code in swebench/inference/run_live.py
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def make_instance(
    owner,
    repo,
    query,
    commit,
    root_dir,
    token,
    document_encoding_func,
    python,
    instance_id,
    tokenizer,
    tokenizer_func,
    prompt_style,
    max_context_len,
    include_readmes,
):
    """
    Creates an instance for a given query and repository.

    Args:
        owner (str): The owner of the repository.
        repo (str): The name of the repository.
        query (str): The query to search for.
        commit (str): The commit hash to use.
        root_dir (str): The root directory to clone the repository to.
        token (str): The GitHub token to use for authentication.
        document_encoding_func (function): The function to use for encoding documents.
        python (str): The path to the Python executable.
        instance_id (int): The ID of the instance.
        tokenizer (str): The name of the tokenizer to use.
        tokenizer_func (function): The function to use for tokenization.
        prompt_style (str): The style of prompt to use.
        max_context_len (int): The maximum length of the context.
        include_readmes (bool): Whether to include README files in the instance.

    Returns:
        dict: The instance.
    """
    thread_id = 0
    instance = {"instance_id": instance_id, "problem_statement": query}
    logger.info(f"Cloning repo {owner}/{repo}")
    repo_dir = clone_repo(f"{owner}/{repo}", root_dir, token)
    if commit is None:
        commit = (
            subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=repo_dir)
            .decode("utf-8")
            .strip()
        )
    logger.info(f"Building BM25 retrieval index for {owner}/{repo}@{commit}")
    index_dir = make_index(
        repo_dir=repo_dir,
        root_dir=root_dir,
        query=query,
        commit=commit,
        document_encoding_func=document_encoding_func,
        python=python,
        instance_id=instance_id,
    )
    results = search(instance, index_dir)
    hits = results["hits"]
    logger.info(f"Retrieved {len(hits)} documents")
    with ContextManager(repo_dir, commit) as cm:
        if include_readmes:
            readmes = get_readme_files(cm.repo_path)
        else:
            readmes = list()
        instance["readmes"] = ingest_files(readmes)
        for hit in hits:
            hit["file_contents"] = open(hit["docid"]).read()
        instance["file_contents"] = dict()
        base_text_inputs = PROMPT_FUNCTIONS[prompt_style](instance)
        base_text_input_length = len(tokenizer_func(base_text_inputs, tokenizer))
        instance["file_contents"] = {x["docid"]: x["file_contents"] for x in hits}
        cur_input_len = base_text_input_length
        include_files = list()
        for filename in [x["docid"] for x in hits]:
            content = make_code_text({filename: instance["file_contents"][filename]})
            tokens = tokenizer_func(content, tokenizer)
            if cur_input_len + len(tokens) < max_context_len:
                include_files.append(filename)
                cur_input_len += len(tokens)
        logger.info(
            f"Including {len(include_files)} files in context with {cur_input_len} tokens:\n"
            + "\n\t".join(sorted(include_files))
        )
        instance["file_contents"] = {
            filename: instance["file_contents"][filename] for filename in include_files
        }
        instance["text_inputs"] = PROMPT_FUNCTIONS[prompt_style](instance)
        return instance
parse_issue_url
parse_issue_url(issue_url)
Source code in swebench/inference/run_live.py
161
162
163
164
165
166
167
168
169
170
def parse_issue_url(issue_url):
    issue_pat = re.compile(r"github\.com\/(.+?)\/(.+?)\/issues\/(\d+)")
    match = issue_pat.search(issue_url)
    if not match:
        raise ValueError(
            f"issue_url ({issue_url}) does not seem to be a valid issue url."
            + "\nPlease use url like https://github.com/owner/repo/issues/12345"
        )
    owner, repo, issue_num = match.groups()
    return owner, repo, issue_num
main
main(model_name, prompt_style, issue_url, base_commit, max_context_length, document_encoding_func, output_dir, root_dir, include_readmes)
Source code in swebench/inference/run_live.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
def main(
    model_name,
    prompt_style,
    issue_url,
    base_commit,
    max_context_length,
    document_encoding_func,
    output_dir,
    root_dir,
    include_readmes,
):
    if base_commit is not None and len(issue_url) != len(base_commit):
        raise ValueError(
            "Must provide either no base commits or one base commit per issue url"
        )
    if base_commit is None:
        base_commit = [None] * len(issue_url)
    gh_token = os.environ.get("GITHUB_TOKEN", None)
    if gh_token is not None:
        logger.warning(f"Using GitHub token: {'*' * 8}{gh_token[-4:]}")
    gh = GhApi(token=gh_token)
    tokenizer, tokenizer_func = TOKENIZER_FUNCS["cl100k"]
    document_encoding_func = DOCUMENT_ENCODING_FUNCTIONS[document_encoding_func]
    python = subprocess.check_output(["which", "python"]).decode("utf-8").strip()
    outputs = list()
    for issue, commit in tqdm(zip(issue_url, base_commit), total=len(issue_url)):
        owner, repo, issue_num = parse_issue_url(issue)
        problem_statement = get_problem_statement(owner, repo, int(issue_num), gh)
        instance_id = f"{owner}__{repo}-{issue_num}"
        logger.info(f"Creating instance {instance_id}")
        instance = make_instance(
            owner=owner,
            repo=repo,
            query=problem_statement,
            commit=commit,
            root_dir=root_dir,
            token=gh_token,
            document_encoding_func=document_encoding_func,
            python=python,
            instance_id=instance_id,
            tokenizer=tokenizer,
            tokenizer_func=tokenizer_func,
            prompt_style=prompt_style,
            max_context_len=max_context_length,
            include_readmes=include_readmes,
        )
        logger.info(f"Calling model {model_name}")
        start = time.time()
        if model_name.startswith("gpt"):
            inputs = instance["text_inputs"]
            response, _ = call_chat(
                model_name, inputs, use_azure=False, temperature=0, top_p=1
            )
            completion = response.choices[0].message.content
            logger.info(
                f"Generated {response.usage.completion_tokens} tokens in {(time.time() - start):.2f} seconds"
            )
        else:
            from anthropic import Anthropic

            api_key = os.environ.get("ANTHROPIC_API_KEY", None)
            anthropic = Anthropic(api_key=api_key)
            response = call_anthropic(
                inputs, anthropic, model_name, temperature=0, top_p=1
            )
            completion = response.completion
        model_patch = extract_diff(completion)
        minimal_patch = extract_minimal_patch(model_patch)
        outputs.append(
            {
                "instance_id": instance_id,
                "response": completion,
                "problem_statement": problem_statement,
                "text_inputs": inputs,
                "model_patch": model_patch,
                "minimal_patch": minimal_patch,
            }
        )
    os.makedirs(output_dir, exist_ok=True)
    output_file = Path(
        output_dir,
        f"{model_name}__{prompt_style}__{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.jsonl",
    )
    with open(output_file, "+a") as f:
        for output in outputs:
            print(json.dumps(output), file=f, flush=True)
    logger.info(f"Wrote output to {output_file}")

run_llama

logger module-attribute
logger = getLogger(__name__)
DEVICE_MAPS module-attribute
DEVICE_MAPS = load(open(parent / 'codellama_device_maps.json'))
parser module-attribute
parser = ArgumentParser()
args module-attribute
args = parse_args()
get_output_file
get_output_file(output_dir, model_name_or_path, peft_path, dataset_path, split, temperature, top_p, min_len, max_len, shard_id, num_shards)

Constructs the output file path based on the provided parameters.

Parameters:

Name Type Description Default
output_dir str

The directory where the output file will be saved.

required
model_name_or_path str

The name or path of the model.

required
peft_path str

The path to the PEFT file.

required
dataset_path str

The path to the dataset.

required
split str

The dataset split.

required
temperature float

The temperature value.

required
top_p float

The top-p value.

required
min_len int

The minimum length of the output.

required
max_len int

The maximum length of the output.

required
shard_id int

The shard ID.

required
num_shards int

The total number of shards.

required

Returns:

Name Type Description
str

The constructed output file path.

Source code in swebench/inference/run_llama.py
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
def get_output_file(
    output_dir,
    model_name_or_path,
    peft_path,
    dataset_path,
    split,
    temperature,
    top_p,
    min_len,
    max_len,
    shard_id,
    num_shards,
):
    """
    Constructs the output file path based on the provided parameters.

    Args:
        output_dir (str): The directory where the output file will be saved.
        model_name_or_path (str): The name or path of the model.
        peft_path (str): The path to the PEFT file.
        dataset_path (str): The path to the dataset.
        split (str): The dataset split.
        temperature (float): The temperature value.
        top_p (float): The top-p value.
        min_len (int): The minimum length of the output.
        max_len (int): The maximum length of the output.
        shard_id (int): The shard ID.
        num_shards (int): The total number of shards.

    Returns:
        str: The constructed output file path.
    """
    suffix = ""
    if min_len is not None:
        suffix += f"__min-{min_len}"
    if max_len is not None:
        suffix += f"__max-{max_len}"
    if shard_id is not None and num_shards is not None:
        suffix += f"__shard-{shard_id}-{num_shards}"
    if Path(dataset_path).exists():
        dset_nickname = Path(dataset_path).name + "__" + split
    else:
        dset_nickname = dataset_path.replace("/", "__") + "__" + split
    if peft_path is not None and "checkpoint" in Path(peft_path).name:
        model_nickname = Path(peft_path).parent.name + "__" + Path(peft_path).name
    elif peft_path is not None:
        model_nickname = Path(peft_path).name
    elif Path(model_name_or_path).exists():
        if "checkpoint" in Path(model_name_or_path).name:
            model_nickname = (
                Path(model_name_or_path).parent.name
                + "__"
                + Path(model_name_or_path).name
            )
        else:
            model_nickname = Path(model_name_or_path).name
    else:
        model_nickname = model_name_or_path.replace("/", "__")
    output_file = Path(
        output_dir,
        dset_nickname
        + "__"
        + model_nickname
        + "__temp-"
        + str(temperature)
        + "__top-p-"
        + str(top_p)
        + suffix
        + ".jsonl",
    )
    if not output_file.parent.exists():
        output_file.parent.mkdir(
            parents=True, exist_ok=True
        )  # exists_ok=True for parallel
    return output_file
load_model
load_model(model_name_or_path, peft_path)

Loads a base model and optionally PEFT adapters.

Parameters:

Name Type Description Default
model_name_or_path str

The name or path of the base model.

required
peft_path str or None

The path to the PEFT adapters. If None, no PEFT adapters will be loaded.

required

Returns:

Name Type Description
model

The loaded model.

Raises:

Type Description
ValueError

If there is no device map for the specified model_name_or_path.

Source code in swebench/inference/run_llama.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def load_model(model_name_or_path, peft_path):
    """
    Loads a base model and optionally PEFT adapters.

    Args:
        model_name_or_path (str): The name or path of the base model.
        peft_path (str or None): The path to the PEFT adapters. If None, no PEFT adapters will be loaded.

    Returns:
        model: The loaded model.

    Raises:
        ValueError: If there is no device map for the specified model_name_or_path.
    """
    logger.info(f"Loading base model from {model_name_or_path}")
    max_memory = {
        **{
            k: f"{torch.cuda.get_device_properties(k).total_memory // 1_010_000_000:d}GIB"
            for k in range(torch.cuda.device_count())
        },
        "cpu": "20GIB",
    }
    logger.info(f"Using max memory {max_memory}")
    if "-7b" in model_name_or_path:
        device_map = DEVICE_MAPS["7b"][str(torch.cuda.device_count())]
    elif "-13b" in model_name_or_path:
        device_map = DEVICE_MAPS["13b"][str(torch.cuda.device_count())]
    elif "-34b" in model_name_or_path:
        device_map = DEVICE_MAPS["34b"][str(torch.cuda.device_count())]
    else:
        raise ValueError(f"No device map for {model_name_or_path}")
    logger.info(f"Using device_map {device_map}")
    model = AutoModelForCausalLM.from_pretrained(
        model_name_or_path,
        max_memory=max_memory,
        device_map=device_map,
        torch_dtype=torch.bfloat16,
    ).eval()
    if peft_path is None:
        logger.info("No PEFT adapters to load")
        return model
    logger.info(f"Loading PEFT adapters from {peft_path}")
    model = PeftModel.from_pretrained(
        model,
        peft_path,
        device_map=device_map,
        torch_dtype=torch.bfloat16,
        max_memory=max_memory,
    )
    return model
load_tokenizer
load_tokenizer(model_name_or_path)
Source code in swebench/inference/run_llama.py
157
158
159
160
def load_tokenizer(model_name_or_path):
    logger.info(f"Loading tokenizer {model_name_or_path}")
    tokenizer = LlamaTokenizer.from_pretrained(model_name_or_path)
    return tokenizer
load_data
load_data(dataset_path, split, tokenizer, min_len, max_len, model_name_or_path, peft_path, existing_ids, shard_id, num_shards)

Load and preprocess the dataset for model inference.

Parameters:

Name Type Description Default
dataset_path str

The path to the dataset.

required
split str

The split of the dataset to load.

required
tokenizer

The tokenizer used to tokenize the text.

required
min_len int

The minimum length of input sequences to include in the dataset.

required
max_len int

The maximum length of input sequences to include in the dataset.

required
model_name_or_path str

The name or path of the model.

required
peft_path str

The path to the PEFT file.

required
existing_ids

The list of existing instance IDs to filter out from the dataset.

required
shard_id int

The ID of the shard to load.

required
num_shards int

The total number of shards.

required

Returns:

Name Type Description
dataset

The preprocessed dataset for model inference.

Source code in swebench/inference/run_llama.py
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
def load_data(
    dataset_path,
    split,
    tokenizer,
    min_len,
    max_len,
    model_name_or_path,
    peft_path,
    existing_ids,
    shard_id,
    num_shards,
):
    """
    Load and preprocess the dataset for model inference.

    Args:
        dataset_path (str): The path to the dataset.
        split (str): The split of the dataset to load.
        tokenizer: The tokenizer used to tokenize the text.
        min_len (int): The minimum length of input sequences to include in the dataset.
        max_len (int): The maximum length of input sequences to include in the dataset.
        model_name_or_path (str): The name or path of the model.
        peft_path (str): The path to the PEFT file.
        existing_ids: The list of existing instance IDs to filter out from the dataset.
        shard_id (int): The ID of the shard to load.
        num_shards (int): The total number of shards.

    Returns:
        dataset: The preprocessed dataset for model inference.
    """
    logger.info(f"Loading dataset from {dataset_path}")
    if not Path(dataset_path).exists():
        dataset = load_dataset(dataset_path, split=split)
    elif Path(dataset_path, split).exists():
        dataset = load_from_disk(Path(dataset_path) / split)
    else:
        dataset = load_dataset(dataset_path)[split]
    if peft_path is not None:
        model_nickname = "__".join(peft_path.split("/")[-2:])
    else:
        model_nickname = "__".join(model_name_or_path.split("/")[-2:])
    if "input_ids" not in dataset.column_names:
        dataset = dataset.map(
            lambda x: tokenizer(x["text"], truncation=False),
            batched=False,
            desc="tokenizing",
        )
    if "SWE-Llama" in model_name_or_path and dataset[0]["input_ids"][-2:] != [13, 13]:
        # SWE-Llama needs two exactly two newlines at the end
        dataset = dataset.map(
            lambda x: {"input_ids": x["input_ids"] + [13]}, batched=False
        )
    filter_func = None
    if min_len is not None and max_len is None:
        filter_func = lambda x: x >= min_len
    elif min_len is None and max_len is not None:
        filter_func = lambda x: x < max_len
    elif min_len is not None and max_len is not None:
        filter_func = lambda x: min_len <= x < max_len
    if filter_func is not None:
        dataset = dataset.filter(
            lambda x: filter_func(len(x["input_ids"])), desc="filtering for length"
        )
    lens = torch.tensor(list(map(lambda x: len(x["input_ids"]), dataset)))
    dataset = dataset.select(lens.argsort())
    if shard_id is not None and num_shards is not None:
        dataset = dataset.shard(num_shards, shard_id, contiguous=True)
    dataset = dataset.filter(
        lambda x: x["instance_id"] not in existing_ids,
        desc="filtering for existing ids",
    )
    lens = torch.tensor(list(map(lambda x: len(x["input_ids"]), dataset)))  # recompute
    if shard_id is not None and num_shards is not None:
        logger.info(
            f"filtered dataset - {len(dataset)} examples, min length: {min(lens):_}, max length: {max(lens):_} (shard {shard_id} of {num_shards})"
        )
    else:
        logger.info(
            f"filtered dataset - {len(dataset)} examples, min length: {min(lens):_}, max length: {max(lens):_}"
        )
    return dataset
generate
generate(model, dataset, tokenizer, temperature, top_p, fileobj, model_name_or_path, peft_path)
Source code in swebench/inference/run_llama.py
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
def generate(
    model,
    dataset,
    tokenizer,
    temperature,
    top_p,
    fileobj,
    model_name_or_path,
    peft_path,
):
    class RepeatingTokensCriteria(StoppingCriteria):
        """
        Stopping criteria based on repeating tokens in the generated sequence.

        Attributes:
            min_length (int): The minimum length of the generated sequence.
            min_tokens (int): The minimum number of unique tokens required in the suffix of the generated sequence.
        """

        def __init__(self, min_length=100, min_tokens=10):
            super().__init__()
            self.min_length = min_length
            self.min_tokens = min_tokens

        def __call__(self, input_ids, scores, **kwargs):
            """
            Check if the stopping criteria is met based on repeating tokens.

            Args:
                input_ids (torch.Tensor): The input token IDs of the generated sequence.
                scores (torch.Tensor): The scores of the generated sequence.
                **kwargs: Additional keyword arguments.

            Returns:
                bool: True if the stopping criteria is met, False otherwise.
            """
            if input_ids[0, -1].cpu().item() == tokenizer.eos_token_id:
                return True
            if input_ids.shape[-1] < self.min_length:
                return False
            suffix = input_ids[0, -self.min_length :].cpu().tolist()
            if len(set(suffix)) <= self.min_tokens:
                return True
            return False

    stopping_criteria = StoppingCriteriaList([RepeatingTokensCriteria()])
    fail_count = 0
    with torch.no_grad():
        for ix, instance in enumerate(tqdm(dataset, desc="Generating patches")):
            try:
                input_ids = instance["input_ids"]
                input_ids = torch.tensor(
                    [input_ids], dtype=torch.long, device=model.device
                )
                logger.info(f"Processing {input_ids.shape[-1]} tokens")
                start = datetime.now()
                output = model.generate(
                    input_ids=input_ids,
                    attention_mask=torch.ones_like(input_ids),
                    temperature=1.0 if temperature == 0 else temperature,
                    top_p=top_p,
                    do_sample=False if temperature == 0 else True,
                    max_new_tokens=200,
                    stopping_criteria=stopping_criteria,
                )
                total_len = output.shape[-1]
                output = output[0].cpu()[input_ids.shape[-1] :]
                new_len = len(output)
                logger.info(
                    f"Generated {new_len} tokens ({total_len} total) in {(datetime.now() - start).total_seconds()} "
                    + f"seconds (speed: {new_len / (datetime.now() - start).total_seconds()} tps)"
                )
                output = tokenizer.decode(output, skip_special_tokens=False)
                logger.info(output[:200])
                diff = extract_diff(output)
                model_name_or_path += f"__{peft_path}" if peft_path is not None else ""
                res = {
                    "instance_id": instance["instance_id"],
                    "full_output": output,
                    "model_patch": diff,
                    "model_name_or_path": model_name_or_path,
                }
                print(json.dumps(res), file=fileobj, flush=True)
            except Exception as e:
                logger.exception(e)
                print(f"failed on {ix} with {len(input_ids)} tokens")
                fail_count += 1
                if fail_count >= 3:
                    raise ValueError("too many failures")
get_all_existing_ids
get_all_existing_ids(output_file)
Source code in swebench/inference/run_llama.py
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
def get_all_existing_ids(output_file):
    stub_pattern = re.compile(
        r"((?:[\w\-\.]+)\_\_temp\-((\d+(\.\d+)?)|None)\_\_top\-p\-((\d+(\.\d+)?)|None))(\_\_|\.jsonl)"
    )
    match = stub_pattern.match(output_file.name)
    if not output_file.exists():
        return set()
    if match is None:
        raise ValueError(f"output_file {output_file} doesn't match pattern")
    stub = match[1]
    existing_ids = set()
    output_files = list(Path(output_file.parent).glob(stub + "*"))
    for filename in output_files:
        logger.info(f"Loading existing ids from existing {filename}")
        with open(filename) as f:
            for line in f:
                datum = json.loads(line)
                existing_ids.add(datum["instance_id"])
    logger.info(f"Found {len(existing_ids)} existing ids")
    return existing_ids
main
main(model_name_or_path, peft_path, dataset_path, split, temperature, top_p, output_dir, min_len, max_len, shard_id, num_shards)
Source code in swebench/inference/run_llama.py
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
def main(
    model_name_or_path,
    peft_path,
    dataset_path,
    split,
    temperature,
    top_p,
    output_dir,
    min_len,
    max_len,
    shard_id,
    num_shards,
):
    if shard_id is not None and num_shards is None:
        raise ValueError("num_shards must be specified with shard_id")
    if shard_id is None and num_shards is not None:
        raise ValueError("shard_id must be specified with num_shards")
    peft_config = None
    if peft_path is not None:
        peft_config = PeftConfig.from_pretrained(peft_path)
        if peft_config.base_model_name_or_path != model_name_or_path:
            logger.warning(
                f"model_name_or_path {model_name_or_path} does not match peft_path base_model {peft_config.base_model_name_or_path}"
            )
    output_file = get_output_file(
        output_dir=output_dir,
        model_name_or_path=model_name_or_path,
        peft_path=peft_path,
        dataset_path=dataset_path,
        split=split,
        temperature=temperature,
        top_p=top_p,
        min_len=min_len,
        max_len=max_len,
        shard_id=shard_id,
        num_shards=num_shards,
    )
    logger.warning(f"output_file: {output_file}")
    model = load_model(model_name_or_path, peft_path)
    tokenizer = load_tokenizer(model_name_or_path)
    existing_ids = get_all_existing_ids(output_file)
    dataset = load_data(
        dataset_path=dataset_path,
        split=split,
        tokenizer=tokenizer,
        min_len=min_len,
        max_len=max_len,
        model_name_or_path=model_name_or_path,
        peft_path=peft_path,
        existing_ids=existing_ids,
        shard_id=shard_id,
        num_shards=num_shards,
    )
    with open(output_file, "a") as f:
        generate(
            model=model,
            dataset=dataset,
            tokenizer=tokenizer,
            temperature=temperature,
            top_p=top_p,
            fileobj=f,
            model_name_or_path=model_name_or_path,
            peft_path=peft_path,
        )
    logger.info("Done")